In [1]:
%config IPCompleter.greedy=True

import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
from torchsummary import summary

from torch.optim import Adam

import numpy as np
import time
import torchvision.models as models

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Using device:', device)

Using device: cpu


In [2]:
cifar_class_names = ['airplane', 'automobile', 'bird', 'cat', 'deer',
               'dog', 'frog', 'horse', 'ship', 'truck']

cifar = torchvision.datasets.CIFAR10('./datasets', train=True, download=True,
                                     transform=torchvision.transforms.Compose([
                                         torchvision.transforms.ToTensor(),
                                         torchvision.transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
                                     ]))

train, test = torch.utils.data.random_split(cifar, [40000, 10000])
batch_size = 256
train_dataset = torch.utils.data.DataLoader(train, batch_size=batch_size, shuffle=True)
test_dataset = torch.utils.data.DataLoader(test, batch_size=batch_size, shuffle=True)

Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./datasets/cifar-10-python.tar.gz


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

Extracting ./datasets/cifar-10-python.tar.gz to ./datasets


In [3]:
def weight_reset(m):
    if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear):
        m.reset_parameters()

def calc_accuracy(model):
    correct = 0
    with torch.no_grad():
        for i, data in enumerate(test_dataset):
            imgs, true_labels = data[0].to(device), data[1].to(device)
            preds = model(imgs)
            pred_class = torch.argmax(preds, dim=1)
            incorrect = torch.count_nonzero(pred_class - true_labels)
            correct += len(true_labels) - incorrect

    print(f'test accuracy: {(correct * 100) / len(test):.3f}%')

# ResNet-18 Model

In [65]:
class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self, in_planes, planes, stride=1):
        super(BasicBlock, self).__init__()
        self.conv1 = nn.Conv2d(
            in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3,
                               stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)

        self.shortcut = nn.Sequential()
        if stride != 1 or in_planes != self.expansion*planes:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_planes, self.expansion*planes,
                          kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(self.expansion*planes)
            )

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        out += self.shortcut(x)
        out = F.relu(out)
        return out

class ResNet(nn.Module):
    def __init__(self, block, num_blocks, num_classes=10):
        super(ResNet, self).__init__()
        self.in_planes = 64

        self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(64)
        self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1)
        self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2)
        self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2)
        self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)
        self.linear = nn.Linear(512*block.expansion, num_classes)

    def _make_layer(self, block, planes, num_blocks, stride):
        strides = [stride] + [1]*(num_blocks-1)
        layers = []
        for stride in strides:
            layers.append(block(self.in_planes, planes, stride))
            self.in_planes = planes * block.expansion
        return nn.Sequential(*layers)

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.layer1(out)
        out = self.layer2(out)
        out = self.layer3(out)
        out = self.layer4(out)
        out = F.avg_pool2d(out, 4)
        out = out.view(out.size(0), -1)
        out = self.linear(out)
        return out

teacher_model = ResNet(BasicBlock, [2, 2, 2, 2]).to(device)

In [66]:
teacher_model.load_state_dict(torch.load('./models/cifar_teacher_0.pt'))
print(summary(teacher_model, (3, 32, 32)))
calc_accuracy(teacher_model)

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1           [-1, 64, 32, 32]           1,728
       BatchNorm2d-2           [-1, 64, 32, 32]             128
            Conv2d-3           [-1, 64, 32, 32]          36,864
       BatchNorm2d-4           [-1, 64, 32, 32]             128
            Conv2d-5           [-1, 64, 32, 32]          36,864
       BatchNorm2d-6           [-1, 64, 32, 32]             128
        BasicBlock-7           [-1, 64, 32, 32]               0
            Conv2d-8           [-1, 64, 32, 32]          36,864
       BatchNorm2d-9           [-1, 64, 32, 32]             128
           Conv2d-10           [-1, 64, 32, 32]          36,864
      BatchNorm2d-11           [-1, 64, 32, 32]             128
       BasicBlock-12           [-1, 64, 32, 32]               0
           Conv2d-13          [-1, 128, 16, 16]          73,728
      BatchNorm2d-14          [-1, 128,

In [None]:
optimizer = torch.optim.Adam(teacher_model.parameters(), lr=0.01)
teacher_loss_fn = nn.CrossEntropyLoss()

epochs = 20

time_s = lambda: time.time()
for ep in range(epochs):
    start_time = time_s()
    ep_loss = 0.0
    correct = 0
    for i, (imgs, labels) in enumerate(train_dataset):
        imgs, labels = imgs.to(device), labels.to(device)

        optimizer.zero_grad()
        preds = teacher_model(imgs)
        loss = teacher_loss_fn(preds, labels)
        loss.backward()
        optimizer.step()
        
        pred_class = torch.argmax(preds, dim=1).to(device)
        correct += len(labels) - torch.count_nonzero(pred_class - labels)
        
        ep_loss += loss.detach().item()

    print(f'epoch: {ep+1}, loss: {ep_loss / len(train_dataset):.4f}, train acc: {(correct * 100.0) / len(train):.3f}%, time: {(time_s() - start_time):.2f}s')

calc_accuracy(teacher_model)

torch.save(teacher_model.state_dict(), './models/cifar_teacher.pt')

# Building Teacher and Student models

In [57]:
# Softmax with temperature
# -- Adapted from PyTorch Softmax layer
# -- See: https://pytorch.org/docs/stable/_modules/torch/nn/modules/activation.html#Softmax
class SoftmaxT(nn.Module):
    def __init__(self, temperature, dim = 1) -> None:
        super(SoftmaxT, self).__init__()
        self.temperature = temperature
        self.dim = dim

    def __setstate__(self, state):
        self.__dict__.update(state)
        if not hasattr(self, 'dim'):
            self.dim = None

    def forward(self, input):
        return torch.nn.functional.softmax(input / self.temperature, self.dim, _stacklevel=5)

    def extra_repr(self) -> str:
        return 'dim={dim}'.format(dim=self.dim)

In [84]:
alpha = 0.1
temperature = 5

# Create a new model with softmax temperature
teacher_model_w_temperature = torch.nn.Sequential(
    teacher_model,
    SoftmaxT(temperature)
).to(device)
print(summary(teacher_model_w_temperature, (3, 32, 32)))

# Create the student model
student_model = nn.Sequential(
    nn.Conv2d(3, 16, kernel_size=(3, 3)),
    nn.ReLU(),
    nn.MaxPool2d(kernel_size=(2, 2)),
    nn.Conv2d(16, 32, kernel_size=(3, 3)),
    nn.ReLU(),
    nn.MaxPool2d(kernel_size=(2, 2)),
    nn.Flatten(),
    nn.Linear(1152, 32),
    nn.ReLU(),
    nn.Dropout(p=0.1),
    nn.Linear(32, 10),
    SoftmaxT(temperature)
).to(device)
print(summary(student_model, (3, 32, 32)))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1           [-1, 64, 32, 32]           1,728
       BatchNorm2d-2           [-1, 64, 32, 32]             128
            Conv2d-3           [-1, 64, 32, 32]          36,864
       BatchNorm2d-4           [-1, 64, 32, 32]             128
            Conv2d-5           [-1, 64, 32, 32]          36,864
       BatchNorm2d-6           [-1, 64, 32, 32]             128
        BasicBlock-7           [-1, 64, 32, 32]               0
            Conv2d-8           [-1, 64, 32, 32]          36,864
       BatchNorm2d-9           [-1, 64, 32, 32]             128
           Conv2d-10           [-1, 64, 32, 32]          36,864
      BatchNorm2d-11           [-1, 64, 32, 32]             128
       BasicBlock-12           [-1, 64, 32, 32]               0
           Conv2d-13          [-1, 128, 16, 16]          73,728
      BatchNorm2d-14          [-1, 128,

In [85]:
# KD from teacher using whole dataset
student_model.apply(weight_reset)
student_optimizer = torch.optim.Adam(student_model.parameters(), lr=0.001)
student_loss_fn = nn.CrossEntropyLoss()
distillation_loss_fn = torch.nn.KLDivLoss()

kd_epochs = 20
time_s = lambda: time.time()
for ep in range(kd_epochs):
    ep_loss = 0.0
    correct = 0
    start_time = time_s()
    for i, data in enumerate(train_dataset):
        imgs, true_labels = data[0].to(device), data[1].to(device)

        student_optimizer.zero_grad()

        # Forward pass of the teacher with input
        with torch.no_grad():
            teacher_output = teacher_model_w_temperature(imgs).to(device)

        # Forward pass of the student
        student_output = student_model(imgs).to(device)

        # Calculate loss
        student_loss = student_loss_fn(student_output, true_labels)
        distill_loss = distillation_loss_fn(teacher_output, student_output)
        loss = alpha * student_loss + (1 - alpha) * distill_loss

        loss.backward()
        student_optimizer.step()

        pred_class = torch.argmax(student_output, dim=1).to(device)
        correct += len(true_labels) - torch.count_nonzero(pred_class - true_labels)
        
        ep_loss += loss.detach().item()
    print(f'epoch: {ep+1}, loss: {ep_loss / len(train_dataset):.2e}, train acc: {(correct * 100.0) / len(train):.3f}%, time: {(time_s() - start_time):.2f}s')

calc_accuracy(student_model)

torch.save(student_model.state_dict(), './models/cifar_student.pt')



epoch: 1, loss: 1.21e-02, train acc: 33.000%, time: 16.03s
epoch: 2, loss: 1.02e-02, train acc: 44.335%, time: 15.96s
epoch: 3, loss: 9.12e-03, train acc: 48.727%, time: 15.83s
epoch: 4, loss: 8.24e-03, train acc: 51.757%, time: 15.87s
epoch: 5, loss: 7.42e-03, train acc: 54.970%, time: 15.99s
epoch: 6, loss: 6.86e-03, train acc: 56.885%, time: 16.03s
epoch: 7, loss: 6.33e-03, train acc: 58.537%, time: 16.09s
epoch: 8, loss: 5.96e-03, train acc: 59.702%, time: 15.85s
epoch: 9, loss: 5.67e-03, train acc: 60.822%, time: 15.97s
epoch: 10, loss: 5.36e-03, train acc: 61.745%, time: 16.06s
epoch: 11, loss: 5.16e-03, train acc: 62.618%, time: 15.98s
epoch: 12, loss: 4.95e-03, train acc: 63.170%, time: 16.13s
epoch: 13, loss: 4.71e-03, train acc: 64.010%, time: 16.12s
epoch: 14, loss: 4.53e-03, train acc: 64.955%, time: 16.00s
epoch: 15, loss: 4.35e-03, train acc: 65.185%, time: 16.45s
epoch: 16, loss: 4.25e-03, train acc: 65.675%, time: 16.41s
epoch: 17, loss: 4.06e-03, train acc: 66.228%, ti

# Training Student model from scratch

In [88]:
# Create a new model with the last layer removed, provides access to model logits
student_model_wo_temperature = torch.nn.Sequential(
    *(list(student_model.children())[:-1]),
    nn.Softmax(dim=1)
).to(device)
print(summary(student_model_wo_temperature, (3, 32, 32)))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1           [-1, 16, 30, 30]             448
              ReLU-2           [-1, 16, 30, 30]               0
         MaxPool2d-3           [-1, 16, 15, 15]               0
            Conv2d-4           [-1, 32, 13, 13]           4,640
              ReLU-5           [-1, 32, 13, 13]               0
         MaxPool2d-6             [-1, 32, 6, 6]               0
           Flatten-7                 [-1, 1152]               0
            Linear-8                   [-1, 32]          36,896
              ReLU-9                   [-1, 32]               0
          Dropout-10                   [-1, 32]               0
           Linear-11                   [-1, 10]             330
          Softmax-12                   [-1, 10]               0
Total params: 42,314
Trainable params: 42,314
Non-trainable params: 0
---------------------------------

In [90]:
# Training student from scratch using whole dataset
student_model_wo_temperature.apply(weight_reset)
student_optimizer = torch.optim.Adam(student_model_wo_temperature.parameters(), lr=0.001)
student_loss_fn = nn.CrossEntropyLoss()

kd_epochs = 20
time_s = lambda: time.time()
for ep in range(kd_epochs):
    ep_loss = 0.0
    correct = 0
    start_time = time_s()
    for i, data in enumerate(train_dataset):
        imgs, true_labels = data[0].to(device), data[1].to(device)

        student_optimizer.zero_grad()

        # Forward pass of the student
        student_output = student_model_wo_temperature(imgs).to(device)

        # Calculate loss
        student_loss = student_loss_fn(student_output, true_labels)

        student_loss.backward()
        student_optimizer.step()

        pred_class = torch.argmax(student_output, dim=1).to(device)
        correct += len(true_labels) - torch.count_nonzero(pred_class - true_labels)
        
        ep_loss += student_loss.detach().item()
    print(f'epoch: {ep+1}, loss: {ep_loss / len(train_dataset):.2e}, train acc: {(correct * 100.0) / len(train):.3f}%, time: {(time_s() - start_time):.2f}s')

calc_accuracy(student_model_wo_temperature)

torch.save(student_model_wo_temperature.state_dict(), './models/cifar_student_scratch.pt')

epoch: 1, loss: 2.16e+00, train acc: 29.550%, time: 7.88s
epoch: 2, loss: 2.06e+00, train acc: 40.065%, time: 7.98s
epoch: 3, loss: 2.01e+00, train acc: 44.480%, time: 8.01s
epoch: 4, loss: 1.99e+00, train acc: 47.342%, time: 8.02s
epoch: 5, loss: 1.97e+00, train acc: 49.642%, time: 8.01s
epoch: 6, loss: 1.95e+00, train acc: 51.280%, time: 7.89s
epoch: 7, loss: 1.93e+00, train acc: 53.015%, time: 7.88s
epoch: 8, loss: 1.92e+00, train acc: 54.405%, time: 7.99s
epoch: 9, loss: 1.91e+00, train acc: 55.540%, time: 7.81s
epoch: 10, loss: 1.89e+00, train acc: 56.842%, time: 7.95s
epoch: 11, loss: 1.89e+00, train acc: 57.697%, time: 7.91s
epoch: 12, loss: 1.88e+00, train acc: 58.707%, time: 8.05s
epoch: 13, loss: 1.87e+00, train acc: 59.577%, time: 7.92s
epoch: 14, loss: 1.86e+00, train acc: 60.487%, time: 7.94s
epoch: 15, loss: 1.85e+00, train acc: 60.912%, time: 7.98s
epoch: 16, loss: 1.84e+00, train acc: 61.837%, time: 7.91s
epoch: 17, loss: 1.84e+00, train acc: 62.145%, time: 7.90s
epoch: