# "SqueezeNet: AlexNet-level accuracy with 50x fewer parameters and< 1MB model size" paper implementation - https://arxiv.org/pdf/1602.07360.pdf

In [None]:
import torch
import torch.nn as nn
import math

class FireBlock(nn.Module):
    def __init__(self, inplanes, squeeze_planes, expand_planes):
        super(FireBlock, self).__init__()
        self.layers = nn.Sequential(
            nn.Conv2d(inplanes, squeeze_planes, kernel_size=1, stride=1),
            nn.BatchNorm2d(squeeze_planes),
            nn.ReLU(inplace=True),
            nn.Conv2d(squeeze_planes, expand_planes, kernel_size=1, stride=1),
            nn.BatchNorm2d(expand_planes),
            nn.Conv2d(squeeze_planes, expand_planes, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(expand_planes),
            nn.ReLU(inplace=True)
        )

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                n = m.kernel_size[0] * m.kernel_size[1] * m.in_channels
                m.weight.data.normal_(0, math.sqrt(2.0 / n))

    def forward(self, x):
        return self.layers(x)


class SqueezeNet(nn.Module):
    def __init__(self):
        super(SqueezeNet, self).__init__()
        self.features = nn.Sequential(
            nn.Conv2d(3, 96, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(96),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2),
            FireBlock(96, 16, 64),
            FireBlock(128, 16, 64),
            FireBlock(128, 32, 128),
            nn.MaxPool2d(kernel_size=2, stride=2),
            FireBlock(256, 32, 128),
            FireBlock(256, 48, 192),
            FireBlock(384, 48, 192),
            FireBlock(384, 64, 256),
            nn.MaxPool2d(kernel_size=2, stride=2),
            FireBlock(512, 64, 256)
        )
        self.classifier = nn.Sequential(
            nn.Conv2d(512, 10, kernel_size=1, stride=1),
            nn.AvgPool2d(kernel_size=4, stride=4),
            nn.LogSoftmax(dim=1)
        )

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                n = m.kernel_size[0] * m.kernel_size[1] * m.in_channels
                m.weight.data.normal_(0, math.sqrt(2.0 / n))
            elif isinstance(m, nn.BatchNorm2d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()

    def forward(self, x):
        x = self.features(x)
        x = self.classifier(x)
        return x