In [1]:
%config IPCompleter.greedy=True

import torch
import torch.nn as nn
import torchvision
from torchsummary import summary

from torch.optim import Adam

import numpy as np
import time

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Using device:', device)

# Utils

In [2]:
def weight_reset(m):
    if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear):
        m.reset_parameters()

def calc_accuracy(model):
    correct = 0
    with torch.no_grad():
    for i, data in enumerate(test_dataset):
        imgs, true_labels = data[0].to(device), data[1].to(device)
        preds = model(imgs)
        pred_class = torch.argmax(preds, dim=1)
        incorrect = torch.count_nonzero(pred_class - true_labels)
        correct += len(true_labels) - incorrect

    print(f'test accuracy: {(correct * 100) / len(test):.3f}%')

In [3]:
mnist = torchvision.datasets.MNIST('./datasets', train=True, download=True,
                                   transform=torchvision.transforms.Compose([
                                       torchvision.transforms.ToTensor(),
                                       torchvision.transforms.Normalize((0.1307,), (0.3081,))
                                   ]))

train, test = torch.utils.data.random_split(mnist, [50000, 10000])
batch_size = 256
train_dataset = torch.utils.data.DataLoader(train, batch_size=batch_size, shuffle=True)
test_dataset = torch.utils.data.DataLoader(test, batch_size=batch_size, shuffle=True)

small_mnist_subset = list(range(0, int(len(mnist)*0.03)))
small_train_dataset = torch.utils.data.DataLoader(torch.utils.data.Subset(mnist, small_mnist_subset), batch_size=16, shuffle=True)

# Teacher Model

In [4]:
# Teacher model
teacher_model = nn.Sequential(
    nn.Conv2d(1, 32, kernel_size=(3, 3)),
    nn.BatchNorm2d(32),
    nn.ReLU(),
    nn.MaxPool2d(kernel_size=(2, 2)),
    nn.Conv2d(32, 64, kernel_size=(3, 3)),
    nn.BatchNorm2d(64),
    nn.ReLU(),
    nn.MaxPool2d(kernel_size=(2, 2)),
    nn.Conv2d(64, 128, kernel_size=(2, 2)),
    nn.ReLU(),
    nn.Flatten(),
    nn.Linear(2048, 128),
    nn.ReLU(),
    nn.Dropout(p=0.2),
    nn.Linear(128, 10),
    nn.Softmax(dim=1)
)

teacher_model.to(device)
print(summary(teacher_model, (1, 28, 28)))

Using device: cuda
----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1           [-1, 32, 26, 26]             320
              ReLU-2           [-1, 32, 26, 26]               0
            Conv2d-3           [-1, 64, 24, 24]          18,496
              ReLU-4           [-1, 64, 24, 24]               0
         MaxPool2d-5           [-1, 64, 12, 12]               0
         Dropout2d-6           [-1, 64, 12, 12]               0
           Flatten-7                 [-1, 9216]               0
            Linear-8                  [-1, 128]       1,179,776
              ReLU-9                  [-1, 128]               0
          Dropout-10                  [-1, 128]               0
           Linear-11                   [-1, 10]           1,290
          Softmax-12                   [-1, 10]               0
Total params: 1,199,882
Trainable params: 1,199,882
Non-trainable params: 0
--------

In [5]:
optimizer = torch.optim.Adam(teacher_model.parameters(), lr=0.0001)
teacher_loss_fn = nn.CrossEntropyLoss()

epochs = 5

time_s = lambda: time.time()
for ep in range(epochs):
    start_time = time_s()
    ep_loss = 0.
    for i, data in enumerate(train_dataset):
        imgs, labels = data
        
        optimizer.zero_grad()
        preds = teacher_model(imgs.to(device))
        loss = teacher_loss_fn(preds, labels.to(device))
        loss.backward()
        optimizer.step()
        
        ep_loss += loss.detach().item()
    print(f'epoch: {ep+1}, loss: {ep_loss / len(train_dataset):.4f}, time: {(time_s() - start_time):.2f}s')

calc_accuracy(teacher_model)

torch.save(teacher_model.state_dict(), './models/teacher.pt')

epoch: 1, loss: 1.7583, time: 10.49s
epoch: 2, loss: 1.5801, time: 10.93s
epoch: 3, loss: 1.5544, time: 11.07s
epoch: 4, loss: 1.5379, time: 10.62s
epoch: 5, loss: 1.5270, time: 10.50s
test accuracy: 94.460%


In [6]:
# Load teacher model after training
teacher_model.load_state_dict(torch.load('./models/teacher.pt', map_location=device))

<All keys matched successfully>

# Knowledge Distillation

In [7]:
# Softmax with temperature
# -- Adapted from PyTorch Softmax layer
# -- See: https://pytorch.org/docs/stable/_modules/torch/nn/modules/activation.html#Softmax
class SoftmaxT(nn.Module):
    def __init__(self, temperature, dim = 1) -> None:
        super(SoftmaxT, self).__init__()
        self.temperature = temperature
        self.dim = dim

    def __setstate__(self, state):
        self.__dict__.update(state)
        if not hasattr(self, 'dim'):
            self.dim = None

    def forward(self, input):
        return torch.nn.functional.softmax(input / self.temperature, self.dim, _stacklevel=5)

    def extra_repr(self) -> str:
        return 'dim={dim}'.format(dim=self.dim)

In [8]:
alpha = 0.1
temperature = 5

# Create a new model with the last layer removed, provides access to model logits
teacher_model_w_temperature = torch.nn.Sequential(
    *(list(teacher_model.children())[:-1]),
    SoftmaxT(temperature)
)
teacher_model_w_temperature.to(device)
print(summary(teacher_model_w_temperature, (1, 28, 28)))

# Create the student model
student_model = nn.Sequential(
    nn.Conv2d(1, 16, kernel_size=(3, 3)),
    nn.ReLU(),
    nn.MaxPool2d(kernel_size=(2, 2)),
    nn.Conv2d(16, 32, kernel_size=(3, 3)),
    nn.Flatten(),
    nn.Linear(3872, 10),
    SoftmaxT(temperature)
)
student_model.to(device)
print(summary(student_model, (1, 28, 28)))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1           [-1, 32, 26, 26]             320
              ReLU-2           [-1, 32, 26, 26]               0
            Conv2d-3           [-1, 64, 24, 24]          18,496
              ReLU-4           [-1, 64, 24, 24]               0
         MaxPool2d-5           [-1, 64, 12, 12]               0
         Dropout2d-6           [-1, 64, 12, 12]               0
           Flatten-7                 [-1, 9216]               0
            Linear-8                  [-1, 128]       1,179,776
              ReLU-9                  [-1, 128]               0
          Dropout-10                  [-1, 128]               0
           Linear-11                   [-1, 10]           1,290
         SoftmaxT-12                   [-1, 10]               0
Total params: 1,199,882
Trainable params: 1,199,882
Non-trainable params: 0
---------------------------

In [13]:
# KD from teacher using whole dataset
student_model.apply(weight_reset)
student_optimizer = torch.optim.Adam(student_model.parameters(), lr=0.0001)
student_loss_fn = nn.CrossEntropyLoss()
distillation_loss_fn = torch.nn.KLDivLoss()

kd_epochs = 3
time_s = lambda: time.time()
for ep in range(kd_epochs):
    ep_loss = 0.
    start_time = time_s()
    for i, data in enumerate(train_dataset):
        imgs, true_labels = data[0].to(device), data[1].to(device)

        student_optimizer.zero_grad()

        # Forward pass of the teacher with input
        with torch.no_grad():
            teacher_output = teacher_model_w_temperature(imgs).to(device)

        # Forward pass of the student
        student_output = student_model(imgs).to(device)

        # Calculate loss
        student_loss = student_loss_fn(student_output, true_labels)
        distill_loss = distillation_loss_fn(teacher_output, student_output)
        loss = alpha * student_loss + (1 - alpha) * distill_loss

        loss.backward()
        student_optimizer.step()
        
        ep_loss += loss.detach().item()
    print(f'epoch: {ep+1}, loss: {ep_loss / len(train_dataset):.2e}, time: {(time_s() - start_time):.2f}s')

calc_accuracy(student_model)



epoch: 1, loss: 5.22e-03, time: 9.17s
epoch: 2, loss: 1.24e-03, time: 9.21s
epoch: 3, loss: 1.58e-04, time: 9.32s
test accuracy: 92.000%


In [9]:
# KD from teacher using only ~3% of original dataset
student_model.apply(weight_reset)
student_optimizer = torch.optim.Adam(student_model.parameters(), lr=0.0001)
student_loss_fn = nn.CrossEntropyLoss()
distillation_loss_fn = torch.nn.KLDivLoss()

kd_epochs = 3
time_s = lambda: time.time()
for ep in range(kd_epochs):
    ep_loss = 0.
    start_time = time_s()
    for i, data in enumerate(small_train_dataset):
        imgs, true_labels = data[0].to(device), data[1].to(device)

        student_optimizer.zero_grad()

        # Forward pass of the teacher with input
        with torch.no_grad():
            teacher_output = teacher_model_w_temperature(imgs).to(device)

        # Forward pass of the student
        student_output = student_model(imgs).to(device)

        # Calculate loss
        student_loss = student_loss_fn(student_output, true_labels)
        distill_loss = distillation_loss_fn(teacher_output, student_output)
        loss = alpha * student_loss + (1 - alpha) * distill_loss

        loss.backward()
        student_optimizer.step()
        
        ep_loss += loss.detach().item()
    print(f'epoch: {ep+1}, loss: {ep_loss / len(small_train_dataset):.2e}, time: {(time_s() - start_time):.2f}s')

calc_accuracy(student_model)



epoch: 1, loss: 1.01e-02, time: 0.64s
epoch: 2, loss: 3.70e-03, time: 0.64s
epoch: 3, loss: 2.03e-03, time: 0.65s
test accuracy: 85.300%


# Training student network from scratch

In [10]:
# Create a new model with the last layer removed, provides access to model logits
student_model_wo_temperature = torch.nn.Sequential(
    *(list(student_model.children())[:-1]),
    nn.Softmax(dim=1)
)
student_model_wo_temperature.to(device)
print(summary(student_model_wo_temperature, (1, 28, 28)))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1           [-1, 16, 26, 26]             160
              ReLU-2           [-1, 16, 26, 26]               0
         MaxPool2d-3           [-1, 16, 13, 13]               0
            Conv2d-4           [-1, 32, 11, 11]           4,640
           Flatten-5                 [-1, 3872]               0
            Linear-6                   [-1, 10]          38,730
           Softmax-7                   [-1, 10]               0
Total params: 43,530
Trainable params: 43,530
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.00
Forward/backward pass size (MB): 0.24
Params size (MB): 0.17
Estimated Total Size (MB): 0.41
----------------------------------------------------------------
None


In [11]:
# Using whole dataset
epochs = 3
student_wo_temp_optimizer = torch.optim.Adam(student_model_wo_temperature.parameters(), lr=0.0001)
student_wo_temp_loss_fn = nn.CrossEntropyLoss()

time_s = lambda: time.time()
for ep in range(epochs):
    start_time = time_s()
    ep_loss = 0.
    for i, data in enumerate(train_dataset):
        imgs, labels = data
        
        student_wo_temp_optimizer.zero_grad()
        preds = student_model_wo_temperature(imgs.to(device))
        loss = student_wo_temp_loss_fn(preds, labels.to(device))
        loss.backward()
        student_wo_temp_optimizer.step()
        
        ep_loss += loss.detach().item()
    print(f'epoch: {ep+1}, loss: {ep_loss / len(train_dataset):.4f}, time: {(time_s() - start_time):.2f}s')

calc_accuracy(student_model_wo_temperature)

torch.save(student_model_wo_temperature.state_dict(), './models/student_wo_temp.pt')

epoch: 1, loss: 1.5971, time: 8.09s
epoch: 2, loss: 1.5595, time: 7.98s
epoch: 3, loss: 1.5448, time: 7.93s
test accuracy: 93.000%


In [12]:
# Using ~3% dataset
# Note: Need to reinit the student_wo_temp model first!
student_model_wo_temperature.apply(weight_reset)
epochs = 3
student_wo_temp_optimizer = torch.optim.Adam(student_model_wo_temperature.parameters(), lr=0.0001)
student_wo_temp_loss_fn = nn.CrossEntropyLoss()

time_s = lambda: time.time()
for ep in range(epochs):
    start_time = time_s()
    ep_loss = 0.
    for i, data in enumerate(small_train_dataset):
        imgs, labels = data
        
        student_wo_temp_optimizer.zero_grad()
        preds = student_model_wo_temperature(imgs.to(device))
        loss = student_wo_temp_loss_fn(preds, labels.to(device))
        loss.backward()
        student_wo_temp_optimizer.step()
        
        ep_loss += loss.detach().item() / batch_size

    print(f'epoch: {ep+1}, loss: {ep_loss / len(small_train_dataset):.4f}, time: {(time_s() - start_time):.2f}s')

calc_accuracy(student_model_wo_temperature)

torch.save(student_model_wo_temperature.state_dict(), './models/student_wo_temp.pt')

epoch: 1, loss: 0.0082, time: 0.55s
epoch: 2, loss: 0.0068, time: 0.53s
epoch: 3, loss: 0.0064, time: 0.53s
test accuracy: 83.510%
