In [None]:
import torch
import torch.nn as nn

In [None]:
class ConvBlock(nn.Module):
    def __init__(self, in_channel, out_channel, k_s, s, p):
        super(ConvBlock, self).__init__()
        self.conv = nn.Conv2d(in_channel, out_channel, k_s, s, p)
        self.relu = nn.ReLU(inplace=True)
    def forward(self, x):
        return self.relu(self.conv(x))

class FireBlock(nn.Module):
    def __init__(self, in_channel, s1_1, e1_1, e3_3):
        super(FireBlock, self).__init__()
        self.squeeze1_1 = ConvBlock(in_channel, s1_1, (1,1), 1, (0,0))
        self.expand1_1 = ConvBlock(s1_1, e1_1, (1,1), 1, (0,0))
        self.expand3_3 = ConvBlock(s1_1, e3_3, (3,3), 1, (1,1)) 


    def forward(self, x):
        x = self.squeeze1_1(x)
        expa_1 = self.expand1_1(x)
        expa_3 = self.expand3_3(x)
        return torch.cat((expa_1, expa_3), dim=1)

test_block = FireBlock(96, 16, 64, 64)
x = torch.randn(8, 96, 55, 55)
print(test_block(x).shape)

In [None]:
class SqueezeNet(nn.Module):
    def __init__(self, img_channel=3):
        super(SqueezeNet, self).__init__()
        self.conv1 = ConvBlock(img_channel, 96, (7,7), 2, (2,2))
        self.maxpool1 = nn.MaxPool2d((3,3), 2)
        self.fire2 = FireBlock(96, 16, 64, 64)
        self.fire3 = FireBlock(128, 16, 64, 64)
        self.fire4 = FireBlock(128, 32, 128, 128)
        self.maxpool4 = nn.MaxPool2d((3,3), 2)        
        self.fire5 = FireBlock(256, 32, 128, 128)
        self.fire6 = FireBlock(256, 48, 192, 192)
        self.fire7 = FireBlock(384, 48, 192, 192)
        self.fire8 = FireBlock(384, 64, 256, 256)
        self.maxpool8 = nn.MaxPool2d((3,3), 2)        
        self.fire9 = FireBlock(512, 64, 256, 256)
        self.conv10 = ConvBlock(512, 1000, (1,1), 1, (0,0))
        self.avgpool10 = nn.AvgPool2d((13,13), 1)
    def forward(self, x):
        x = self.maxpool1(self.conv1(x))
        x = self.maxpool4(self.fire4(self.fire3(self.fire2(x))))
        x = self.maxpool8(self.fire8(self.fire7(self.fire6(self.fire5(x)))))
        x = self.avgpool10(self.conv10(self.fire9(x)))
        return x

model = SqueezeNet()
x = torch.randn(8, 3, 224, 224)
print(model(x).shape)
        