This notebook is for MobileNet development.

## initial model

In [1]:
# https://github.com/tonylins/pytorch-mobilenet-v2/blob/master/MobileNetV2.py
import torch.nn as nn
import math


def conv_bn(inp, oup, stride):
    return nn.Sequential(
        nn.Conv2d(inp, oup, 3, stride, 1, bias=False),
        nn.BatchNorm2d(oup),
        nn.ReLU6(inplace=True)
    )

def conv_1x1_bn(inp, oup):
    return nn.Sequential(
        nn.Conv2d(inp, oup, 1, 1, 0, bias=False),
        nn.BatchNorm2d(oup),
        nn.ReLU6(inplace=True)
    )

def make_divisible(x, divisible_by=8):
    import numpy as np
    return int(np.ceil(x * 1. / divisible_by) * divisible_by)


class InvertedResidual(nn.Module):
    def __init__(self, inp, oup, stride, expand_ratio):
        super(InvertedResidual, self).__init__()
        self.stride = stride
        assert stride in [1, 2]

        hidden_dim = int(inp * expand_ratio)
        self.use_res_connect = self.stride == 1 and inp == oup

        if expand_ratio == 1:
            self.conv = nn.Sequential(
                # dw
                nn.Conv2d(hidden_dim, hidden_dim, 3, stride, 1, groups=hidden_dim, bias=False),
                nn.BatchNorm2d(hidden_dim),
                nn.ReLU6(inplace=True),
                # pw-linear
                nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False),
                nn.BatchNorm2d(oup),
            )
        else:
            self.conv = nn.Sequential(
                # pw
                nn.Conv2d(inp, hidden_dim, 1, 1, 0, bias=False),
                nn.BatchNorm2d(hidden_dim),
                nn.ReLU6(inplace=True),
                # dw
                nn.Conv2d(hidden_dim, hidden_dim, 3, stride, 1, groups=hidden_dim, bias=False),
                nn.BatchNorm2d(hidden_dim),
                nn.ReLU6(inplace=True),
                # pw-linear
                nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False),
                nn.BatchNorm2d(oup),
            )

    def forward(self, x):
        if self.use_res_connect:
            return x + self.conv(x)
        else:
            return self.conv(x)


class MobileNetV2(nn.Module):
    def __init__(self, n_class=1000, input_size=224, width_mult=1.):
        super(MobileNetV2, self).__init__()
        block = InvertedResidual
        input_channel = 32
        last_channel = 1280
        interverted_residual_setting = [
            # t, c, n, s
            [1, 16, 1, 1],
            [6, 24, 2, 2],
            [6, 32, 3, 2],
            [6, 64, 4, 2],
            [6, 96, 3, 1],
            [6, 160, 3, 2],
            [6, 320, 1, 1],
        ]

        # building first layer
        assert input_size % 32 == 0
        # input_channel = make_divisible(input_channel * width_mult)  # first channel is always 32!
        self.last_channel = make_divisible(last_channel * width_mult) if width_mult > 1.0 else last_channel
        self.features = [conv_bn(3, input_channel, 2)]
        # building inverted residual blocks
        for t, c, n, s in interverted_residual_setting:
            output_channel = make_divisible(c * width_mult) if t > 1 else c
            for i in range(n):
                if i == 0:
                    self.features.append(block(input_channel, output_channel, s, expand_ratio=t))
                else:
                    self.features.append(block(input_channel, output_channel, 1, expand_ratio=t))
                input_channel = output_channel
        # building last several layers
        self.features.append(conv_1x1_bn(input_channel, self.last_channel))
        # make it nn.Sequential
        self.features = nn.Sequential(*self.features)

        # building classifier
        self.classifier = nn.Linear(self.last_channel, n_class)

        self._initialize_weights()

    def forward(self, x):
        x = self.features(x)
        x = x.mean(3).mean(2)
        x = self.classifier(x)
        return x

    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
                m.weight.data.normal_(0, math.sqrt(2. / n))
                if m.bias is not None:
                    m.bias.data.zero_()
            elif isinstance(m, nn.BatchNorm2d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()
            elif isinstance(m, nn.Linear):
                n = m.weight.size(1)
                m.weight.data.normal_(0, 0.01)
                m.bias.data.zero_()


model = MobileNetV2(width_mult=1, n_class=10, input_size=32)

In [2]:
import torchvision

transform = torchvision.transforms.Compose([
    torchvision.transforms.RandomHorizontalFlip(),
    torchvision.transforms.RandomPerspective(),
    torchvision.transforms.ToTensor()
])
dataset = torchvision.datasets.CIFAR10("cifar10/", train=True, transform=transform, download=True)
# dataset = torchvision.datasets.CIFAR10("/mnt/cifar10/", train=True, transform=transform, download=True)

Files already downloaded and verified


In [3]:
from torch.utils.data import DataLoader
dataloader = DataLoader(dataset, batch_size=32, shuffle=True)

In [4]:
X_ex, y_ex = next(iter(dataloader))

In [5]:
model(X_ex)

tensor([[ 2.4882e-02,  3.0258e-01, -5.8743e-01, -4.5475e-01, -2.8394e-02,
         -5.1983e-02, -5.7043e-01,  2.4971e-01, -6.0351e-02,  8.4419e-03],
        [ 6.0768e-02,  9.8650e-03,  2.4985e-01,  9.3584e-02, -1.2789e-01,
         -2.3547e-01,  1.2566e-01, -1.2906e-01,  1.2965e-01, -1.2443e-01],
        [-1.5583e-01,  2.7106e-01, -5.9658e-02,  5.3169e-02,  1.9452e-01,
         -2.9189e-01, -4.0493e-01, -8.0180e-02,  2.1773e-01,  4.1286e-02],
        [ 1.5030e-01, -1.9092e-01, -1.4071e-01, -3.7609e-01,  9.0936e-02,
         -9.5844e-02, -3.9342e-01, -1.0372e-01, -2.7239e-01, -2.7863e-01],
        [ 2.7413e-01,  2.7644e-01,  5.1805e-02,  1.5311e-01,  5.3775e-02,
         -7.5813e-02, -2.9641e-01, -1.7854e-01,  1.3940e-01, -1.9341e-01],
        [-6.1110e-02,  1.5751e-01, -5.9958e-02, -1.2210e-01,  2.0742e-01,
          1.5072e-02, -2.0363e-01,  2.7187e-01,  1.4726e-01,  3.5018e-01],
        [-2.0053e-01,  1.7566e-01, -4.2770e-01,  1.1256e-01, -1.7031e-01,
          6.0348e-03, -4.8407e-0

In [8]:
from torch import optim
import numpy as np

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters())

def train():
    NUM_EPOCHS = 10
    for epoch in range(1, NUM_EPOCHS + 1):
        losses = []

        for i, (X_batch, y_cls) in enumerate(dataloader):
            optimizer.zero_grad()

            y = y_cls
            X_batch = X_batch
            # y = y_cls.cuda()
            # X_batch = X_batch.cuda()

            y_pred = model(X_batch)
            loss = criterion(y_pred, y)
            loss.backward()
            optimizer.step()

            curr_loss = loss.item()
            if i % 200 == 0:
                print(
                    f'Finished epoch {epoch}/{NUM_EPOCHS}, batch {i}. Loss: {curr_loss:.3f}.'
                )

            losses.append(curr_loss)
            break

        print(
            f'Finished epoch {epoch}. '
            f'avg loss: {np.mean(losses)}; median loss: {np.min(losses)}'
        )
        break

train()

Finished epoch 1/10, batch 0. Loss: 2.296.
Finished epoch 1. avg loss: 2.296233892440796; median loss: 2.296233892440796


## quantized version

In [37]:
nn.Sequential?

[0;31mInit signature:[0m [0mnn[0m[0;34m.[0m[0mSequential[0m[0;34m([0m[0;34m*[0m[0margs[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0;31mDocstring:[0m     
A sequential container.
Modules will be added to it in the order they are passed in the constructor.
Alternatively, an ordered dict of modules can also be passed in.

To make it easier to understand, here is a small example::

    # Example of using Sequential
    model = nn.Sequential(
              nn.Conv2d(1,20,5),
              nn.ReLU(),
              nn.Conv2d(20,64,5),
              nn.ReLU()
            )

    # Example of using Sequential with OrderedDict
    model = nn.Sequential(OrderedDict([
              ('conv1', nn.Conv2d(1,20,5)),
              ('relu1', nn.ReLU()),
              ('conv2', nn.Conv2d(20,64,5)),
              ('relu2', nn.ReLU())
            ]))
[0;31mInit docstring:[0m Initializes internal Module state, shared by both nn.Module and ScriptModule.
[0;31mFile:[0m           ~/opt/minicon

In [46]:
# Forked from https://github.com/tonylins/pytorch-mobilenet-v2/blob/master/MobileNetV2.py
import torch.nn as nn
import torch.quantization
import math
from collections import OrderedDict

# NOTE(aleksey): assigning layers names makes them easier to reference in the fuse_module code later
# on. fuse_modules takes a list of lists of layers as input, without named layers we'd have to use
# something like ['features.0.1', 'features.0.2']. Needless to say, that's not exactly readable.

def conv_bn(inp, oup, stride):
    return nn.Sequential(OrderedDict([
        ('q', torch.quantization.QuantStub()),
        ('conv2d', nn.Conv2d(inp, oup, 3, stride, 1, bias=False)),
        ('batchnorm2d', nn.BatchNorm2d(oup)),
        ('relu6', nn.ReLU6(inplace=True)),
        ('dq', torch.quantization.DeQuantStub())
    ]))

def conv_1x1_bn(inp, oup):
    return nn.Sequential(OrderedDict([
        ('q', torch.quantization.QuantStub()),
        ('conv2d', nn.Conv2d(inp, oup, 1, 1, 0, bias=False)),
        ('batchnorm2d', nn.BatchNorm2d(oup)),
        ('relu6', nn.ReLU6(inplace=True)),
        ('dq', torch.quantization.DeQuantStub())
    ]))

def make_divisible(x, divisible_by=8):
    import numpy as np
    return int(np.ceil(x * 1. / divisible_by) * divisible_by)


class InvertedResidual(nn.Module):
    def __init__(self, inp, oup, stride, expand_ratio):
        super(InvertedResidual, self).__init__()
        self.stride = stride
        assert stride in [1, 2]

        hidden_dim = int(inp * expand_ratio)
        self.use_res_connect = self.stride == 1 and inp == oup

        if expand_ratio == 1:
            self.conv = nn.Sequential(OrderedDict([
                ('q', torch.quantization.QuantStub()),
                # dw
                ('conv2d_1', nn.Conv2d(hidden_dim, hidden_dim, 3, stride, 1, groups=hidden_dim, bias=False)),
                ('bnorm_2', nn.BatchNorm2d(hidden_dim)),
                ('relu6_3', nn.ReLU6(inplace=True)),
                # pw-linear
                ('conv2d_4', nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False)),
                ('bnorm_5', nn.BatchNorm2d(oup)),
                ('dq', torch.quantization.DeQuantStub())
            ]))
        else:
            self.conv = nn.Sequential(OrderedDict([
                ('q', torch.quantization.QuantStub()),
                # pw
                ('conv2d_1', nn.Conv2d(inp, hidden_dim, 1, 1, 0, bias=False)),
                ('bnorm_2', nn.BatchNorm2d(hidden_dim)),
                ('relu6_3', nn.ReLU6(inplace=True)),
                # dw
                ('conv2d_4', nn.Conv2d(hidden_dim, hidden_dim, 3, stride, 1, groups=hidden_dim, bias=False)),
                ('bnorm_5', nn.BatchNorm2d(hidden_dim)),
                ('relu6_6', nn.ReLU6(inplace=True)),
                # pw-linear
                ('conv2d_7', nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False)),
                ('bnorm_8', nn.BatchNorm2d(oup)),
                ('dq', torch.quantization.DeQuantStub())
            ]))

    def forward(self, x):
        if self.use_res_connect:
            return x + self.conv(x)
        else:
            return self.conv(x)


class MobileNetV2(nn.Module):
    def __init__(self, n_class=1000, input_size=224, width_mult=1.):
        super(MobileNetV2, self).__init__()
        block = InvertedResidual
        input_channel = 32
        last_channel = 1280
        interverted_residual_setting = [
            # t, c, n, s
            [1, 16, 1, 1],
            [6, 24, 2, 2],
            [6, 32, 3, 2],
            [6, 64, 4, 2],
            [6, 96, 3, 1],
            [6, 160, 3, 2],
            [6, 320, 1, 1],
        ]

        # building first layer
        assert input_size % 32 == 0
        # input_channel = make_divisible(input_channel * width_mult)  # first channel is always 32!
        self.last_channel = make_divisible(last_channel * width_mult) if width_mult > 1.0 else last_channel
        self.features = [conv_bn(3, input_channel, 2)]
        # building inverted residual blocks
        for t, c, n, s in interverted_residual_setting:
            output_channel = make_divisible(c * width_mult) if t > 1 else c
            for i in range(n):
                if i == 0:
                    self.features.append(block(input_channel, output_channel, s, expand_ratio=t))
                else:
                    self.features.append(block(input_channel, output_channel, 1, expand_ratio=t))
                input_channel = output_channel
        # building last several layers
        self.features.append(conv_1x1_bn(input_channel, self.last_channel))
        # make it nn.Sequential
        submodule_names = [
            'in_conv',
            *[f'inv_conv_{i}' for i in range(1, 18)],
            'out_conv'
        ]
        self.features = nn.Sequential(OrderedDict(list(zip(submodule_names, self.features))))

        # building classifier
        self.classifier = nn.Linear(self.last_channel, n_class)

        self._initialize_weights()

    def forward(self, x):
        x = self.features(x)
        x = x.mean(3).mean(2)
        x = self.classifier(x)
        return x

    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
                m.weight.data.normal_(0, math.sqrt(2. / n))
                if m.bias is not None:
                    m.bias.data.zero_()
            elif isinstance(m, nn.BatchNorm2d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()
            elif isinstance(m, nn.Linear):
                n = m.weight.size(1)
                m.weight.data.normal_(0, 0.01)
                m.bias.data.zero_()


model = MobileNetV2(width_mult=1, n_class=10, input_size=32)

In [47]:
model

MobileNetV2(
  (features): Sequential(
    (in_conv): Sequential(
      (q): QuantStub()
      (conv2d): Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
      (batchnorm2d): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu6): ReLU6(inplace=True)
      (dq): DeQuantStub()
    )
    (inv_conv_1): InvertedResidual(
      (conv): Sequential(
        (q): QuantStub()
        (conv2d_1): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=32, bias=False)
        (bnorm_2): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu6_3): ReLU6(inplace=True)
        (conv2d_4): Conv2d(32, 16, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bnorm_5): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (dq): DeQuantStub()
      )
    )
    (inv_conv_2): InvertedResidual(
      (conv): Sequential(
        (q): QuantStub()
      

In [54]:
from torch import optim
import numpy as np
import torch.quantization

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters())

def eval_fn(model):
    model.eval()
    model.qconfig = torch.quantization.get_default_qconfig('fbgemm')
    
    # TODO: module fusion would go here
    model = torch.quantization.prepare(model)
    model = torch.quantization.fuse_modules(
        model,
        [
            # NOTE(aleksey): 'features' is the attr containing the non-head layers.
            ['features.in_conv.conv2d', 'features.in_conv.batchnorm2d'],
            ['features.inv_conv_1.conv.conv2d_1', 'features.inv_conv_1.conv.bnorm_2'],
            ['features.inv_conv_1.conv.conv2d_4', 'features.inv_conv_1.conv.bnorm_5'],
            *[
                *[[f'features.inv_conv_{i}.conv.conv2d_1',
                   f'features.inv_conv_{i}.conv.bnorm_2'] for i in range(2, 18)],
                *[[f'features.inv_conv_{i}.conv.conv2d_4',
                   f'features.inv_conv_{i}.conv.bnorm_5'] for i in range(2, 18)],
                *[[f'features.inv_conv_{i}.conv.conv2d_7',
                   f'features.inv_conv_{i}.conv.bnorm_8'] for i in range(2, 18)]
            ]
        ]
    )
    
    for i, (X_batch, y_cls) in enumerate(dataloader):
        optimizer.zero_grad()

        y = y_cls
        X_batch = X_batch
        
        y_pred = model(X_batch)
        break

    model = torch.quantization.convert(model)

eval_fn(model)

## finished scripts

In [56]:
!mkdir ../models/

In [4]:
%%writefile ../models/model_1.py
"""
Initial MobileNet model. Trained on CIFAR10.
"""

# Forked from https://github.com/tonylins/pytorch-mobilenet-v2/blob/master/MobileNetV2.py

####################
# MODEL DEFINITION #
####################

import torch.nn as nn
import math
import torch
import os


def conv_bn(inp, oup, stride):
    return nn.Sequential(
        nn.Conv2d(inp, oup, 3, stride, 1, bias=False),
        nn.BatchNorm2d(oup),
        nn.ReLU6(inplace=True)
    )

def conv_1x1_bn(inp, oup):
    return nn.Sequential(
        nn.Conv2d(inp, oup, 1, 1, 0, bias=False),
        nn.BatchNorm2d(oup),
        nn.ReLU6(inplace=True)
    )

def make_divisible(x, divisible_by=8):
    import numpy as np
    return int(np.ceil(x * 1. / divisible_by) * divisible_by)


class InvertedResidual(nn.Module):
    def __init__(self, inp, oup, stride, expand_ratio):
        super(InvertedResidual, self).__init__()
        self.stride = stride
        assert stride in [1, 2]

        hidden_dim = int(inp * expand_ratio)
        self.use_res_connect = self.stride == 1 and inp == oup

        if expand_ratio == 1:
            self.conv = nn.Sequential(
                # dw
                nn.Conv2d(hidden_dim, hidden_dim, 3, stride, 1, groups=hidden_dim, bias=False),
                nn.BatchNorm2d(hidden_dim),
                nn.ReLU6(inplace=True),
                # pw-linear
                nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False),
                nn.BatchNorm2d(oup),
            )
        else:
            self.conv = nn.Sequential(
                # pw
                nn.Conv2d(inp, hidden_dim, 1, 1, 0, bias=False),
                nn.BatchNorm2d(hidden_dim),
                nn.ReLU6(inplace=True),
                # dw
                nn.Conv2d(hidden_dim, hidden_dim, 3, stride, 1, groups=hidden_dim, bias=False),
                nn.BatchNorm2d(hidden_dim),
                nn.ReLU6(inplace=True),
                # pw-linear
                nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False),
                nn.BatchNorm2d(oup),
            )

    def forward(self, x):
        if self.use_res_connect:
            return x + self.conv(x)
        else:
            return self.conv(x)


class MobileNetV2(nn.Module):
    def __init__(self, n_class=1000, input_size=224, width_mult=1.):
        super(MobileNetV2, self).__init__()
        block = InvertedResidual
        input_channel = 32
        last_channel = 1280
        interverted_residual_setting = [
            # t, c, n, s
            [1, 16, 1, 1],
            [6, 24, 2, 2],
            [6, 32, 3, 2],
            [6, 64, 4, 2],
            [6, 96, 3, 1],
            [6, 160, 3, 2],
            [6, 320, 1, 1],
        ]

        # building first layer
        assert input_size % 32 == 0
        # input_channel = make_divisible(input_channel * width_mult)  # first channel is always 32!
        self.last_channel = make_divisible(last_channel * width_mult) if width_mult > 1.0 else last_channel
        self.features = [conv_bn(3, input_channel, 2)]
        # building inverted residual blocks
        for t, c, n, s in interverted_residual_setting:
            output_channel = make_divisible(c * width_mult) if t > 1 else c
            for i in range(n):
                if i == 0:
                    self.features.append(block(input_channel, output_channel, s, expand_ratio=t))
                else:
                    self.features.append(block(input_channel, output_channel, 1, expand_ratio=t))
                input_channel = output_channel
        # building last several layers
        self.features.append(conv_1x1_bn(input_channel, self.last_channel))
        # make it nn.Sequential
        self.features = nn.Sequential(*self.features)

        # building classifier
        self.classifier = nn.Linear(self.last_channel, n_class)

        self._initialize_weights()

    def forward(self, x):
        x = self.features(x)
        x = x.mean(3).mean(2)
        x = self.classifier(x)
        return x

    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
                m.weight.data.normal_(0, math.sqrt(2. / n))
                if m.bias is not None:
                    m.bias.data.zero_()
            elif isinstance(m, nn.BatchNorm2d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()
            elif isinstance(m, nn.Linear):
                n = m.weight.size(1)
                m.weight.data.normal_(0, 0.01)
                m.bias.data.zero_()


model = MobileNetV2(width_mult=1, n_class=10, input_size=32)


############
# TRAINING #
############

import torchvision
from torch.utils.data import DataLoader
transform = torchvision.transforms.Compose([
    torchvision.transforms.RandomHorizontalFlip(),
    torchvision.transforms.RandomPerspective(),
    torchvision.transforms.ToTensor()
])
dataset = torchvision.datasets.CIFAR10("/mnt/cifar10/", train=True, transform=transform, download=True)
dataloader = DataLoader(dataset, batch_size=32, shuffle=True)

from torch import optim
import numpy as np

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters())

if torch.cuda.device_count() >= 1:
    device = torch.device('cuda')
    is_cpu = False
else:
    device = torch.device('cpu')
    is_cpu = True

def train(model):
    model.to(device)
    
    NUM_EPOCHS = 10
    for epoch in range(1, NUM_EPOCHS + 1):
        losses = []

        for i, (X_batch, y_cls) in enumerate(dataloader):
            optimizer.zero_grad()

            if is_cpu:
                y = y_cls
                X_batch = X_batch
            else:
                y = y_cls.cuda()
                X_batch = X_batch.cuda()

            y_pred = model(X_batch)
            loss = criterion(y_pred, y)
            loss.backward()
            optimizer.step()

            curr_loss = loss.item()
            if i % 200 == 0:
                print(
                    f'Finished epoch {epoch}/{NUM_EPOCHS}, batch {i}. Loss: {curr_loss:.3f}.'
                )

            losses.append(curr_loss)

        print(
            f'Finished epoch {epoch}. '
            f'avg loss: {np.mean(losses)}; median loss: {np.min(losses)}'
        )
        
        if not os.path.exists("/spell/checkpoints/"):
            os.mkdir("/spell/checkpoints")
        if epoch % 5 == 0:
            torch.save(model.state_dict(), "/spell/checkpoints/model_{epoch}.pth")

train(model)

Overwriting ../models/model_1.py


```python
spell run --github-url https://github.com/ResidentMario/mobilenet-cifar10.git \
  --machine-type t4 --github-ref dev \
  python models/model_1.py
```

In [60]:
!mkdir ../servers/

In [62]:
%%writefile ../servers/eval_quantized.py
"""
MobileNet model with quantization-aware training (QAT) enabled. Trained on CIFAR10.

This file batches training and evaluation into the same script. Since QAT requires careful
management of the training loop, it's easiest do both in the same run.
"""
# Forked from https://github.com/tonylins/pytorch-mobilenet-v2/blob/master/MobileNetV2.py

# TODO: can we train on CUDA, then quantize on CPU? This bears investigating.

####################
# MODEL DEFINITION #
####################

import torch.nn as nn
import torch.quantization
import math
from collections import OrderedDict
import time

# NOTE(aleksey): assigning layers names makes them easier to reference in the fuse_module code later
# on. fuse_modules takes a list of lists of layers as input, without named layers we'd have to use
# something like ['features.0.1', 'features.0.2']. Needless to say, that's not exactly readable.

def conv_bn(inp, oup, stride):
    return nn.Sequential(OrderedDict([
        ('q', torch.quantization.QuantStub()),
        ('conv2d', nn.Conv2d(inp, oup, 3, stride, 1, bias=False)),
        ('batchnorm2d', nn.BatchNorm2d(oup)),
        ('relu6', nn.ReLU6(inplace=True)),
        ('dq', torch.quantization.DeQuantStub())
    ]))

def conv_1x1_bn(inp, oup):
    return nn.Sequential(OrderedDict([
        ('q', torch.quantization.QuantStub()),
        ('conv2d', nn.Conv2d(inp, oup, 1, 1, 0, bias=False)),
        ('batchnorm2d', nn.BatchNorm2d(oup)),
        ('relu6', nn.ReLU6(inplace=True)),
        ('dq', torch.quantization.DeQuantStub())
    ]))

def make_divisible(x, divisible_by=8):
    import numpy as np
    return int(np.ceil(x * 1. / divisible_by) * divisible_by)


class InvertedResidual(nn.Module):
    def __init__(self, inp, oup, stride, expand_ratio):
        super(InvertedResidual, self).__init__()
        self.stride = stride
        assert stride in [1, 2]

        hidden_dim = int(inp * expand_ratio)
        self.use_res_connect = self.stride == 1 and inp == oup

        if expand_ratio == 1:
            self.conv = nn.Sequential(OrderedDict([
                ('q', torch.quantization.QuantStub()),
                # dw
                ('conv2d_1', nn.Conv2d(hidden_dim, hidden_dim, 3, stride, 1, groups=hidden_dim, bias=False)),
                ('bnorm_2', nn.BatchNorm2d(hidden_dim)),
                ('relu6_3', nn.ReLU6(inplace=True)),
                # pw-linear
                ('conv2d_4', nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False)),
                ('bnorm_5', nn.BatchNorm2d(oup)),
                ('dq', torch.quantization.DeQuantStub())
            ]))
        else:
            self.conv = nn.Sequential(OrderedDict([
                ('q', torch.quantization.QuantStub()),
                # pw
                ('conv2d_1', nn.Conv2d(inp, hidden_dim, 1, 1, 0, bias=False)),
                ('bnorm_2', nn.BatchNorm2d(hidden_dim)),
                ('relu6_3', nn.ReLU6(inplace=True)),
                # dw
                ('conv2d_4', nn.Conv2d(hidden_dim, hidden_dim, 3, stride, 1, groups=hidden_dim, bias=False)),
                ('bnorm_5', nn.BatchNorm2d(hidden_dim)),
                ('relu6_6', nn.ReLU6(inplace=True)),
                # pw-linear
                ('conv2d_7', nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False)),
                ('bnorm_8', nn.BatchNorm2d(oup)),
                ('dq', torch.quantization.DeQuantStub())
            ]))

    def forward(self, x):
        if self.use_res_connect:
            return x + self.conv(x)
        else:
            return self.conv(x)


class MobileNetV2(nn.Module):
    def __init__(self, n_class=1000, input_size=224, width_mult=1.):
        super(MobileNetV2, self).__init__()
        block = InvertedResidual
        input_channel = 32
        last_channel = 1280
        interverted_residual_setting = [
            # t, c, n, s
            [1, 16, 1, 1],
            [6, 24, 2, 2],
            [6, 32, 3, 2],
            [6, 64, 4, 2],
            [6, 96, 3, 1],
            [6, 160, 3, 2],
            [6, 320, 1, 1],
        ]

        # building first layer
        assert input_size % 32 == 0
        # input_channel = make_divisible(input_channel * width_mult)  # first channel is always 32!
        self.last_channel = make_divisible(last_channel * width_mult) if width_mult > 1.0 else last_channel
        self.features = [conv_bn(3, input_channel, 2)]
        # building inverted residual blocks
        for t, c, n, s in interverted_residual_setting:
            output_channel = make_divisible(c * width_mult) if t > 1 else c
            for i in range(n):
                if i == 0:
                    self.features.append(block(input_channel, output_channel, s, expand_ratio=t))
                else:
                    self.features.append(block(input_channel, output_channel, 1, expand_ratio=t))
                input_channel = output_channel
        # building last several layers
        self.features.append(conv_1x1_bn(input_channel, self.last_channel))
        # make it nn.Sequential
        submodule_names = [
            'in_conv',
            *[f'inv_conv_{i}' for i in range(1, 18)],
            'out_conv'
        ]
        self.features = nn.Sequential(OrderedDict(list(zip(submodule_names, self.features))))

        # building classifier
        self.classifier = nn.Linear(self.last_channel, n_class)

        self._initialize_weights()

    def forward(self, x):
        x = self.features(x)
        x = x.mean(3).mean(2)
        x = self.classifier(x)
        return x

    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
                m.weight.data.normal_(0, math.sqrt(2. / n))
                if m.bias is not None:
                    m.bias.data.zero_()
            elif isinstance(m, nn.BatchNorm2d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()
            elif isinstance(m, nn.Linear):
                n = m.weight.size(1)
                m.weight.data.normal_(0, 0.01)
                m.bias.data.zero_()


def get_model():
    model = MobileNetV2(width_mult=1, n_class=10, input_size=32)


############
# TRAINING #
############

from torch import optim
import numpy as np

import torchvision
from torch.utils.data import DataLoader
transform = torchvision.transforms.Compose([
    torchvision.transforms.RandomHorizontalFlip(),
    torchvision.transforms.RandomPerspective(),
    torchvision.transforms.ToTensor()
])
dataset = torchvision.datasets.CIFAR10("/mnt/cifar10/", train=True, transform=transform, download=True)
dataloader = DataLoader(dataset, batch_size=32, shuffle=True)

def prepare_model(model):
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters())

    model.qconfig = torch.quantization.get_default_qat_qconfig('fbgemm')
    model = torch.quantization.fuse_modules(
        model,
        [
            # NOTE(aleksey): 'features' is the attr containing the non-head layers.
            ['features.in_conv.conv2d', 'features.in_conv.batchnorm2d'],
            ['features.inv_conv_1.conv.conv2d_1', 'features.inv_conv_1.conv.bnorm_2'],
            ['features.inv_conv_1.conv.conv2d_4', 'features.inv_conv_1.conv.bnorm_5'],
            *[
                *[[f'features.inv_conv_{i}.conv.conv2d_1',
                   f'features.inv_conv_{i}.conv.bnorm_2'] for i in range(2, 18)],
                *[[f'features.inv_conv_{i}.conv.conv2d_4',
                   f'features.inv_conv_{i}.conv.bnorm_5'] for i in range(2, 18)],
                *[[f'features.inv_conv_{i}.conv.conv2d_7',
                   f'features.inv_conv_{i}.conv.bnorm_8'] for i in range(2, 18)]
            ]
        ]
    )
    model = torch.quantization.prepare_qat(model)
    return model


def train(model):
    print(f"Training the model...")
    start_time = time.time()
    NUM_EPOCHS = 10
    for epoch in range(1, NUM_EPOCHS + 1):
        losses = []

        for i, (X_batch, y_cls) in enumerate(dataloader):
            optimizer.zero_grad()

            y = y_cls
            X_batch = X_batch
            # y = y_cls.cuda()
            # X_batch = X_batch.cuda()

            y_pred = model(X_batch)
            loss = criterion(y_pred, y)
            loss.backward()
            optimizer.step()

            curr_loss = loss.item()
            if i % 200 == 0:
                print(
                    f'Finished epoch {epoch}/{NUM_EPOCHS}, batch {i}. Loss: {curr_loss:.3f}.'
                )

            losses.append(curr_loss)

        print(
            f'Finished epoch {epoch}. '
            f'avg loss: {np.mean(losses)}; median loss: {np.min(losses)}'
        )
    print(f"Training done in {str(time.time() - start_time)} seconds.")


def eval_fn(model):
    model.eval()
    
    print(f"Quantizing the model (post-training)...")
    start_time = time.time()
    model = torch.quantization.convert(model)
    print(f"Quantization done in {str(time.time() - start_time)} seconds.")

    print(f"Evaluating the model...")
    start_time = time.time()
    for i, (X_batch, y_cls) in enumerate(dataloader):
        y = y_cls
        y_pred = model(X_batch)
    print(f"Evaluation done in {str(time.time() - start_time)} seconds.")

if __name__ == "__main__":
    model = get_model()
    prepare_model(model)
    train(model)
    eval_fn(model)

Overwriting ../servers/eval_quantized.py
