# "SqueezeNet: AlexNet-level accuracy with 50x fewer parameters and< 1MB model size" paper implementation - https://arxiv.org/pdf/1602.07360.pdf

In [6]:
import torch
import torch.nn as nn
import math

class FireBlock(nn.Module):
    def __init__(self, inplanes, squeeze_planes, expand_planes):
        super(FireBlock, self).__init__()
        self.squeeze = nn.Sequential(
            nn.Conv2d(inplanes, squeeze_planes, kernel_size=1),
            nn.BatchNorm2d(squeeze_planes),
            nn.ReLU(inplace=True)
        )
        self.expand1x1 = nn.Sequential(
            nn.Conv2d(squeeze_planes, expand_planes, kernel_size=1),
            nn.BatchNorm2d(expand_planes),
            nn.ReLU(inplace=True)
        )
        self.expand3x3 = nn.Sequential(
            nn.Conv2d(squeeze_planes, expand_planes, kernel_size=3, padding=1),
            nn.BatchNorm2d(expand_planes),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        x = self.squeeze(x)
        return torch.cat([self.expand1x1(x), self.expand3x3(x)], 1)

class SqueezeNet(nn.Module):
    def __init__(self):
        super(SqueezeNet, self).__init__()
        self.features = nn.Sequential(
            nn.Conv2d(3, 96, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(96),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2),
            FireBlock(96, 16, 64),
            FireBlock(128, 16, 64),
            nn.MaxPool2d(kernel_size=2, stride=2),
            FireBlock(128, 32, 128),
            FireBlock(256, 32, 128),
            nn.MaxPool2d(kernel_size=2, stride=2),
            FireBlock(256, 48, 192),
            FireBlock(384, 48, 192),
            FireBlock(384, 64, 256),
            FireBlock(512, 64, 256)
        )
        self.classifier = nn.Sequential(
            nn.Dropout(p=0.5),
            nn.Conv2d(512, 10, kernel_size=1),
            nn.AdaptiveAvgPool2d(1)
        )

    def forward(self, x):
        x = self.features(x)
        x = self.classifier(x)
        return x.view(x.size(0), -1)


In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.autograd import Variable
import argparse
import numpy as np
import torchvision.datasets as datasets
import torchvision.transforms as transforms
import os
import torch.nn.functional as F
import matplotlib.pyplot as plt
from IPython import embed

args = {
    'batch_size': 64,
    'epoch': 55,
    'learning_rate': 0.001,
    'momentum': 0.9,
    'no_cuda': False,
    'log_schedule': 10,
    'seed': 1,
    'model_name': None,
    'want_to_test': False,
    'epoch_55': False,
    'num_classes': 10
}

args['cuda'] = not args['no_cuda'] and torch.cuda.is_available()

torch.manual_seed(args['seed'])
if args['cuda']:
    torch.cuda.manual_seed(args['seed'])

kwargs = {'num_workers': 1, 'pin_memory': True} if args['cuda'] else {}
train_loader = torch.utils.data.DataLoader(
    datasets.CIFAR10('./data', train=True, download=True,
                     transform=transforms.Compose([
                         transforms.RandomHorizontalFlip(),
                         transforms.ToTensor(),
                         transforms.Normalize((0.491399689874, 0.482158419622, 0.446530924224), (
                             0.247032237587, 0.243485133253, 0.261587846975))
                     ])),
    batch_size=args['batch_size'], shuffle=True, **kwargs)
test_loader = torch.utils.data.DataLoader(
    datasets.CIFAR10('./data', train=False, transform=transforms.Compose([
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize((0.491399689874, 0.482158419622, 0.446530924224),
                             (0.247032237587, 0.243485133253, 0.261587846975))
    ])),
    batch_size=args['batch_size'], shuffle=True, **kwargs)

net = SqueezeNet()
if args['model_name'] is not None:
    pretrained_weights = torch.load(args['model_name'])
    net.load_state_dict(pretrained_weights)

if args['cuda']:
    net.cuda()


def params_for_epoch(epoch):
    p = dict()
    regimes = [[1, 18, 5e-3, 5e-4],
               [19, 29, 1e-3, 5e-4],
               [30, 43, 5e-4, 5e-4],
               [44, 52, 1e-4, 0],
               [53, 1e8, 1e-5, 0]]
    for i, row in enumerate(regimes):
        if epoch >= row[0] and epoch <= row[1]:
            p['learning_rate'] = row[2]
            p['weight_decay'] = row[3]
    return p


avg_loss = list()
best_accuracy = 0.0
fig1, ax1 = plt.subplots()

optimizer = optim.SGD(net.parameters(), lr=args['learning_rate'],
                      momentum=0.9, weight_decay=5e-4)


def adjust_lr_wd(params):
    for param_group in optimizer.state_dict()['param_groups']:
        param_group['lr'] = params['learning_rate']
        param_group['weight_decay'] = params['weight_decay']


def train(epoch):
    if args['epoch_55']:
        params = params_for_epoch(epoch)
        adjust_lr_wd(params)

    global avg_loss
    correct = 0
    net.train()
    for b_idx, (data, targets) in enumerate(train_loader):
        if args['cuda']:
            data, targets = data.cuda(), targets.cuda()
        data, targets = Variable(data), Variable(targets)

        optimizer.zero_grad()
        scores = net.forward(data)
        scores = scores.view(args['batch_size'], args['num_classes'])
        loss = F.nll_loss(scores, targets)

        pred = scores.data.max(1)[1]
        correct += pred.eq(targets.data).cpu().sum()

        avg_loss.append(loss.item())
        loss.backward()
        optimizer.step()

        if b_idx % args['log_schedule'] == 0:
            print('Train epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, (b_idx+1) * len(data), len(train_loader.dataset),
                100. * (b_idx+1)*len(data) / len(train_loader.dataset), loss.item()))

            ax1.plot(avg_loss)
            fig1.savefig("loss.jpg")

    train_accuracy = correct / float(len(train_loader.dataset))
    print("Train accuracy ({:.2f}%)".format(100*train_accuracy))
    return (train_accuracy*100.0)


def val():
    global best_accuracy
    correct = 0
    net.eval()
    for idx, (data, target) in enumerate(test_loader):
        if idx == 73:
            break

        if args['cuda']:
            data, target = data.cuda(), target.cuda()
        data, target = Variable(data), Variable(target)

        score = net.forward(data)
        pred = score.data.max(1)[1]
        correct += pred.eq(target.data).cpu().sum()

    print("Predicted {} out of {}".format(correct, 73*64))
    val_accuracy = correct / (73.0*64.0) * 100
    print("Accuracy: {:.2f}".format(val_accuracy))

    if val_accuracy > best_accuracy:
        best_accuracy = val_accuracy
        torch.save(net.state_dict(), 'bsqueezenet_onfulldata.pth')
    return val_accuracy


def test():
    weights = torch.load('bsqueezenet_onfulldata.pth')
    net.load_state_dict(weights)
    net.eval()

    test_correct = 0
    total_examples = 0
    accuracy = 0.0
    for idx, (data, target) in enumerate(test_loader):
        if idx < 73:
            continue
        total_examples += len(target)
        data, target = Variable(data), Variable(target)
        if args['cuda']:
            data, target = data.cuda(), target.cuda()

        scores = net(data)
        pred = scores.data.max(1)[1]
        test_correct += pred.eq(target.data).cpu().sum()
    print("Predicted {} out of {} correctly".format(
        test_correct, total_examples))
    return 100.0 * test_correct / (float(total_examples))


if not args['want_to_test']:
    fig2, ax2 = plt.subplots()
    train_acc, val_acc = list(), list()
    for i in range(1, args['epoch']+1):
        train_acc.append(train(i))
        val_acc.append(val())
        ax2.plot(train_acc, 'g')
        ax2.plot(val_acc, 'b')
        fig2.savefig('train_val_accuracy.jpg')
else:
    test_acc = test()
    print("Test accuracy: {:.2f}%".format(test_acc))
