-
Notifications
You must be signed in to change notification settings - Fork 872
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
https://arxiv.org/pdf/1302.4389v4.pdf .get() the arg to keep old models alive Includes some comments on accuracy with maxout
- Loading branch information
1 parent
5edd724
commit c708ce7
Showing
5 changed files
with
84 additions
and
7 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,42 @@ | ||
""" | ||
A layer which implements maxout from the "Maxout Networks" paper | ||
https://arxiv.org/pdf/1302.4389v4.pdf | ||
Goodfellow, Warde-Farley, Mirza, Courville, Bengio | ||
or a simpler explanation here: | ||
https://stats.stackexchange.com/questions/129698/what-is-maxout-in-neural-network/298705#298705 | ||
The implementation here: | ||
for k layers of maxout, in -> out channels, we make a single linear | ||
map of size in -> out*k | ||
then we reshape the end to be (..., k, out) | ||
and return the max over the k layers | ||
""" | ||
|
||
|
||
import torch | ||
import torch.nn as nn | ||
|
||
class MaxoutLinear(nn.Module): | ||
def __init__(self, in_channels, out_channels, maxout_k): | ||
super().__init__() | ||
|
||
self.in_channels = in_channels | ||
self.out_channels = out_channels | ||
self.maxout_k = maxout_k | ||
|
||
self.linear = nn.Linear(in_channels, out_channels * maxout_k) | ||
|
||
def forward(self, inputs): | ||
""" | ||
Use the oversized linear as the repeated linear, then take the max | ||
One large linear map makes the implementation simpler and easier for pytorch to make parallel | ||
""" | ||
outputs = self.linear(inputs) | ||
outputs = outputs.view(*outputs.shape[:-1], self.maxout_k, self.out_channels) | ||
outputs = torch.max(outputs, dim=-2)[0] | ||
return outputs | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters