In [11]:
import torch
import torch.nn as nn

In [55]:
class ClassicFire(nn.Module):
    def __init__(self, in_channels, s1x1, e1x1, e3x3):
        super(ClassicFire, self).__init__()
        self.squeeze_conv = nn.Conv2d(in_channels, s1x1, kernel_size=1)
        self.relu = nn.ReLU()
        self.expand_layer1x1 = nn.Conv2d(s1x1, e1x1, kernel_size=1)
        self.expand_layer3x3 = nn.Conv2d(s1x1, e3x3, kernel_size=3, padding=1)
    
    def forward(self, x):
        x = self.squeeze_conv(x) # [1, 16, 16, 16]
        x = self.relu(x)
        x1 = self.expand_layer1x1(x) # [1, 64, 16, 16]
        x1 = self.relu(x1)
        x2 = self.expand_layer3x3(x) # [1, 64, 16, 16]
        x2 = self.relu(x2)
        x = torch.cat([x1, x2], dim=1) 
        return x

In [59]:
"""
fire = ClassicFire(1, 16, 64, 64)
if __name__ == "__main__":
    input = torch.randn(1, 1, 16, 16) # batch, channels, img_width, img_height
    out = fire(input)
    print(out.shape)
    print(out)
"""

'\nfire = ClassicFire(1, 16, 64, 64)\nif __name__ == "__main__":\n    input = torch.randn(1, 1, 16, 16) # batch, channels, img_width, img_height\n    out = fire(input)\n    print(out.shape)\n    print(out)\n'

In [121]:
class SqueezeNet(nn.Module):
    def __init__(self, in_channels=1, out_channels=1000):
        super(SqueezeNet, self).__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels=96, kernel_size=(7,7), stride=(2,2))
        self.maxpool1 = nn.MaxPool2d(kernel_size=(3,3), stride=(2,2))
        self.relu1 = nn.ReLU()
        
        self.fire2 = ClassicFire(96, 16, 64, 64)
        self.fire3 = ClassicFire(128, 16, 64, 64)
        self.fire4 = ClassicFire(128, 32, 128, 128)
        self.maxpool4 = nn.MaxPool2d(kernel_size=(3,3), stride=(2,2))
        
        self.fire5 = ClassicFire(256, 32, 128, 128)
        self.fire6 = ClassicFire(256, 48, 192, 192)
        self.fire7 = ClassicFire(384, 48, 192, 192)
        self.fire8 = ClassicFire(384, 64, 256, 256)
        self.maxpool8 = nn.MaxPool2d(kernel_size=(3,3), stride=(2,2))
        
        self.fire9 = ClassicFire(512, 64, 256, 256)
        self.conv10 = nn.Conv2d(512, out_channels, kernel_size=1, stride=1)
        self.avgpool10 = nn.AvgPool2d(kernel_size=13, stride=1, padding=6)
        self.softmax = nn.Softmax(dim=1)
        
    def forward(self, x):
        x = self.conv1(x) # [1, 96, 109, 109]
        x = self.maxpool1(x) # [1, 96, 54, 54]
        x = self.relu1(x)
        
        x = self.fire2(x) # [1, 128, 54 54]
        x = self.fire3(x) # [1, 128, 54, 54]
        x = self.fire4(x) # [1, 256, 54, 54]
        x = self.maxpool4(x) # [1, 256, 26, 26]
        
        x = self.fire5(x) # [1, 256, 26, 26])
        x = self.fire6(x) # [1, 384, 26, 26]
        x = self.fire7(x) # [1, 384, 26, 26]
        x = self.fire8(x) # [1, 512, 26, 26]
        x = self.maxpool8(x) # [1, 512, 12, 12]
        
        x = self.fire9(x) # [1, 512, 12, 12]
        x = self.conv10(x) # [1, 1000, 12, 12]
        x = self.avgpool10(x) # 1, 1000, 12, 12]
        x = self.softmax(x) 
        return x

In [122]:
squeeze = SqueezeNet()
if __name__ == "__main__":
    #input = torch.randn(1, 1, 224, 224) # batch, channels, img_width, img_height
    #out = squeeze(input)
    print(out.shape)
    print(out)

torch.Size([1, 1000, 12, 12])
tensor([[[[0.0010, 0.0010, 0.0010,  ..., 0.0010, 0.0010, 0.0010],
          [0.0010, 0.0010, 0.0010,  ..., 0.0010, 0.0010, 0.0010],
          [0.0010, 0.0010, 0.0010,  ..., 0.0010, 0.0010, 0.0010],
          ...,
          [0.0010, 0.0010, 0.0010,  ..., 0.0010, 0.0010, 0.0010],
          [0.0010, 0.0010, 0.0010,  ..., 0.0010, 0.0010, 0.0010],
          [0.0010, 0.0010, 0.0010,  ..., 0.0010, 0.0010, 0.0010]],

         [[0.0010, 0.0010, 0.0010,  ..., 0.0010, 0.0010, 0.0010],
          [0.0010, 0.0010, 0.0010,  ..., 0.0010, 0.0010, 0.0010],
          [0.0010, 0.0010, 0.0010,  ..., 0.0010, 0.0010, 0.0010],
          ...,
          [0.0010, 0.0010, 0.0010,  ..., 0.0010, 0.0010, 0.0010],
          [0.0010, 0.0010, 0.0010,  ..., 0.0010, 0.0010, 0.0010],
          [0.0010, 0.0010, 0.0010,  ..., 0.0010, 0.0010, 0.0010]],

         [[0.0010, 0.0010, 0.0010,  ..., 0.0010, 0.0010, 0.0010],
          [0.0010, 0.0010, 0.0010,  ..., 0.0010, 0.0010, 0.0010],
          [0