In [None]:
%config IPCompleter.greedy=True

import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
from torchsummary import summary

from torch.optim import Adam

import numpy as np
import time
import torchvision.models as models

from tqdm import tqdm

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Using device:', device)

Using device: cuda


In [None]:
cifar_class_names = ['airplane', 'automobile', 'bird', 'cat', 'deer',
               'dog', 'frog', 'horse', 'ship', 'truck']

train_transforms = [
    torchvision.transforms.Pad(4, padding_mode='reflect'),
    torchvision.transforms.RandomCrop(32),
    torchvision.transforms.RandomHorizontalFlip(),
    torchvision.transforms.ToTensor(),
    torchvision.transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616))
]

test_transforms = [
    torchvision.transforms.ToTensor(),
    torchvision.transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616))
]

cifar_train = torchvision.datasets.CIFAR10('./datasets', train=True, download=True,
  transform=torchvision.transforms.Compose(train_transforms))

cifar_test = torchvision.datasets.CIFAR10('./datasets', train=False, download=True,
  transform=torchvision.transforms.Compose(test_transforms))

batch_size = 256
train_dataset = torch.utils.data.DataLoader(cifar_train, batch_size=batch_size, shuffle=True, num_workers=2)
test_dataset = torch.utils.data.DataLoader(cifar_test, batch_size=batch_size, shuffle=True, num_workers=2)

Files already downloaded and verified
Files already downloaded and verified


In [None]:
def weight_reset(m):
    if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear):
        m.reset_parameters()

def calc_accuracy(model):
    correct = 0
    with torch.no_grad():
        for i, data in enumerate(test_dataset):
            imgs, true_labels = data[0].to(device), data[1].to(device)
            preds = model(imgs)
            pred_class = torch.argmax(preds, dim=1)
            incorrect = torch.count_nonzero(pred_class - true_labels)
            correct += len(true_labels) - incorrect

    print(f'test accuracy: {(correct * 100) / len(test):.3f}%')

# ResNet-18 Model

In [None]:
class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self, in_planes, planes, stride=1):
        super(BasicBlock, self).__init__()
        self.conv1 = nn.Conv2d(
            in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3,
                               stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)

        self.shortcut = nn.Sequential()
        if stride != 1 or in_planes != self.expansion*planes:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_planes, self.expansion*planes,
                          kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(self.expansion*planes)
            )

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        out += self.shortcut(x)
        out = F.relu(out)
        return out

class ResNet(nn.Module):
    def __init__(self, block, num_blocks, num_classes=10):
        super(ResNet, self).__init__()
        self.in_planes = 64

        self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(64)
        self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1)
        self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2)
        self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2)
        self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)
        self.linear = nn.Linear(512*block.expansion, num_classes)

    def _make_layer(self, block, planes, num_blocks, stride):
        strides = [stride] + [1]*(num_blocks-1)
        layers = []
        for stride in strides:
            layers.append(block(self.in_planes, planes, stride))
            self.in_planes = planes * block.expansion
        return nn.Sequential(*layers)

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.layer1(out)
        out = self.layer2(out)
        out = self.layer3(out)
        out = self.layer4(out)
        out = F.avg_pool2d(out, 4)
        out = out.view(out.size(0), -1)
        out = self.linear(out)
        return out

teacher_model = ResNet(BasicBlock, [2, 2, 2, 2]).to(device)

In [None]:
teacher_model.load_state_dict(torch.load('./models/cifar_teacher_0.pt'))
#print(summary(teacher_model, (3, 32, 32)))
calc_accuracy(teacher_model)

test accuracy: 83.210%


In [None]:
optimizer = torch.optim.Adam(teacher_model.parameters(), lr=0.01)
teacher_loss_fn = nn.CrossEntropyLoss()

epochs = 20

time_s = lambda: time.time()
for ep in range(epochs):
    start_time = time_s()
    ep_loss = 0.0
    correct = 0
    for i, (imgs, labels) in enumerate(train_dataset):
        optimizer.zero_grad()
        preds = teacher_model(imgs)
        loss = teacher_loss_fn(preds, labels)
        loss.backward()
        optimizer.step()
        
        pred_class = torch.argmax(preds, dim=1).to(device)
        correct += len(labels) - torch.count_nonzero(pred_class - labels)
        
        ep_loss += loss.detach().item()

    print(f'epoch: {ep+1}, loss: {ep_loss / len(train_dataset):.4f}, train acc: {(correct * 100.0) / len(cifar_train):.3f}%, time: {(time_s() - start_time):.2f}s')

calc_accuracy(teacher_model)

torch.save(teacher_model.state_dict(), './models/cifar_teacher.pt')

# Training loop

In [None]:
def kd_train(dataset, teacher, student, lr=0.001, student_optimizer = None, lr_scheduler = None):
    student.apply(weight_reset)
    student_optimizer = torch.optim.Adam(student.parameters(), lr=0.001)
    student_lr_sch = torch.optim.lr_scheduler.MultiStepLR(student_optimizer, [10, 25, 40], gamma=0.5)
    student_loss_fn = nn.CrossEntropyLoss()
    distillation_loss_fn = torch.nn.KLDivLoss(reduction='mean')

    kd_epochs = 30
    time_s = lambda: time.time()
    for ep in range(kd_epochs):
        ep_loss = 0.0
        correct = 0
        start_time = time_s()
        for i, (imgs, labels) in enumerate(train_dataset):
            imgs, labels = imgs.to(device), labels.to(device)
            student_optimizer.zero_grad()

            # Forward pass of the teacher with input
            with torch.no_grad():
                teacher_output = teacher(imgs).to(device)

            # Forward pass of the student
            student_output = student(imgs).to(device)

            # Calculate loss
            student_loss = student_loss_fn(student_output, labels)
            distill_loss = distillation_loss_fn(teacher_output, student_output)
            loss = (alpha * student_loss + (1 - alpha) * distill_loss) * temperature * temperature

            loss.backward()
            student_optimizer.step()

            pred_class = torch.argmax(student_output, dim=1).to(device)
            correct += len(labels) - torch.count_nonzero(pred_class - labels)

            ep_loss += loss.detach().item()
        student_lr_sch.step()
        print(f'epoch: {ep+1}, loss: {ep_loss / len(train_dataset):.2e}, train acc: {(correct * 100.0) / len(cifar_train):.3f}%, time: {(time_s() - start_time):.2f}s')

        if (ep + 1) % 10 == 0:
            calc_accuracy(student_model)
    calc_accuracy(student_model)

    torch.save(student_model.state_dict(), './models/cifar_student.pt')

# Building Teacher and Student models

In [None]:
# Softmax with temperature
# -- Adapted from PyTorch Softmax layer
# -- See: https://pytorch.org/docs/stable/_modules/torch/nn/modules/activation.html#Softmax
class SoftmaxT(nn.Module):
    def __init__(self, temperature, dim = 1) -> None:
        super(SoftmaxT, self).__init__()
        self.temperature = temperature
        self.dim = dim

    def __setstate__(self, state):
        self.__dict__.update(state)
        if not hasattr(self, 'dim'):
            self.dim = None

    def forward(self, input):
        return torch.nn.functional.softmax(input / self.temperature, self.dim)

    def extra_repr(self) -> str:
        return 'dim={dim}'.format(dim=self.dim)

In [None]:
alpha = 0.1
temperature = 4

# Create a new model with softmax temperature
teacher_model_w_temperature = torch.nn.Sequential(
    teacher_model,
    SoftmaxT(temperature)
).to(device)
#print(summary(teacher_model_w_temperature, (3, 32, 32)))
print(calc_accuracy(teacher_model_w_temperature))

# Create the student model
student_model = nn.Sequential(
    nn.Conv2d(3, 32, kernel_size=(3, 3)),
    nn.BatchNorm2d(32),
    nn.ReLU(),
    nn.MaxPool2d(kernel_size=(2, 2)),
    nn.Conv2d(32, 64, kernel_size=(3, 3)),
    nn.BatchNorm2d(64),
    nn.ReLU(),
    nn.MaxPool2d(kernel_size=(2, 2)),
    nn.Flatten(),
    nn.Linear(2304, 128),
    nn.ReLU(),
    nn.Dropout(p=0.2),
    nn.Linear(128, 64),
    nn.ReLU(),
    nn.Dropout(p=0.2),
    nn.Linear(64, 10),
    SoftmaxT(temperature)
).to(device)
print(summary(student_model, (3, 32, 32)))

test accuracy: 83.220%
None
----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1           [-1, 32, 30, 30]             896
       BatchNorm2d-2           [-1, 32, 30, 30]              64
              ReLU-3           [-1, 32, 30, 30]               0
         MaxPool2d-4           [-1, 32, 15, 15]               0
            Conv2d-5           [-1, 64, 13, 13]          18,496
       BatchNorm2d-6           [-1, 64, 13, 13]             128
              ReLU-7           [-1, 64, 13, 13]               0
         MaxPool2d-8             [-1, 64, 6, 6]               0
           Flatten-9                 [-1, 2304]               0
           Linear-10                  [-1, 128]         295,040
             ReLU-11                  [-1, 128]               0
          Dropout-12                  [-1, 128]               0
           Linear-13                   [-1, 64]           8,256
           

In [None]:
# KD from teacher using whole dataset
student_model.apply(weight_reset)
student_optimizer = torch.optim.Adam(student_model.parameters(), lr=0.001)
student_lr_sch = torch.optim.lr_scheduler.MultiStepLR(student_optimizer, [10, 25, 40], gamma=0.5)
student_loss_fn = nn.CrossEntropyLoss()
distillation_loss_fn = torch.nn.KLDivLoss(reduction='mean')

kd_epochs = 30
time_s = lambda: time.time()
for ep in range(kd_epochs):
    ep_loss = 0.0
    correct = 0
    start_time = time_s()
    for i, (imgs, labels) in enumerate(train_dataset):
        imgs, labels = imgs.to(device), labels.to(device)
        student_optimizer.zero_grad()

        # Forward pass of the teacher with input
        with torch.no_grad():
            teacher_output = teacher_model_w_temperature(imgs).to(device)

        # Forward pass of the student
        student_output = student_model(imgs).to(device)

        # Calculate loss
        student_loss = student_loss_fn(student_output, labels)
        distill_loss = distillation_loss_fn(teacher_output, student_output)
        loss = (alpha * student_loss + (1 - alpha) * distill_loss) * temperature * temperature

        loss.backward()
        student_optimizer.step()

        pred_class = torch.argmax(student_output, dim=1).to(device)
        correct += len(labels) - torch.count_nonzero(pred_class - labels)

        ep_loss += loss.detach().item()
    student_lr_sch.step()
    print(f'epoch: {ep+1}, loss: {ep_loss / len(train_dataset):.2e}, train acc: {(correct * 100.0) / len(cifar_train):.3f}%, time: {(time_s() - start_time):.2f}s')

    if (ep + 1) % 10 == 0:
        calc_accuracy(student_model)

calc_accuracy(student_model)

torch.save(student_model.state_dict(), './models/cifar_student.pt')



epoch: 1, loss: 1.58e-01, train acc: 53.153%, time: 26.22s
epoch: 2, loss: 1.12e-01, train acc: 67.860%, time: 26.54s
epoch: 3, loss: 9.48e-02, train acc: 73.130%, time: 27.23s
epoch: 4, loss: 8.47e-02, train acc: 75.842%, time: 26.85s
epoch: 5, loss: 7.74e-02, train acc: 77.908%, time: 26.94s
epoch: 6, loss: 7.17e-02, train acc: 79.855%, time: 27.37s
epoch: 7, loss: 6.71e-02, train acc: 80.795%, time: 26.89s
epoch: 8, loss: 6.30e-02, train acc: 81.985%, time: 26.73s
epoch: 9, loss: 5.95e-02, train acc: 82.932%, time: 26.93s
epoch: 10, loss: 5.59e-02, train acc: 83.812%, time: 26.72s
test accuracy: 71.220%
epoch: 11, loss: 4.69e-02, train acc: 85.825%, time: 26.97s
epoch: 12, loss: 4.46e-02, train acc: 86.725%, time: 27.11s
epoch: 13, loss: 4.36e-02, train acc: 86.950%, time: 27.01s
epoch: 14, loss: 4.25e-02, train acc: 87.147%, time: 27.09s
epoch: 15, loss: 4.09e-02, train acc: 87.592%, time: 27.04s
epoch: 16, loss: 3.86e-02, train acc: 88.170%, time: 27.03s
epoch: 17, loss: 3.85e-02,

# Training Student model from scratch

In [None]:
# Create a new model with the last layer removed, provides access to model logits
student_model_wo_temperature = torch.nn.Sequential(
    *(list(student_model.children())[:-1]),
    nn.Softmax(dim=1)
).to(device)
print(summary(student_model_wo_temperature, (3, 32, 32)))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1           [-1, 32, 30, 30]             896
       BatchNorm2d-2           [-1, 32, 30, 30]              64
              ReLU-3           [-1, 32, 30, 30]               0
         MaxPool2d-4           [-1, 32, 15, 15]               0
            Conv2d-5           [-1, 64, 13, 13]          18,496
       BatchNorm2d-6           [-1, 64, 13, 13]             128
              ReLU-7           [-1, 64, 13, 13]               0
         MaxPool2d-8             [-1, 64, 6, 6]               0
           Flatten-9                 [-1, 2304]               0
           Linear-10                  [-1, 128]         295,040
             ReLU-11                  [-1, 128]               0
          Dropout-12                  [-1, 128]               0
           Linear-13                   [-1, 64]           8,256
             ReLU-14                   

In [None]:
# Training student from scratch using whole dataset
student_model_wo_temperature.apply(weight_reset)
student_optimizer = torch.optim.Adam(student_model_wo_temperature.parameters(), lr=0.001)
student_loss_fn = nn.CrossEntropyLoss()
student_lr_sch = torch.optim.lr_scheduler.MultiStepLR(student_optimizer, [10, 25, 40], gamma=0.5)

kd_epochs = 30
time_s = lambda: time.time()
for ep in range(kd_epochs):
    ep_loss = 0.0
    correct = 0
    start_time = time_s()
    for i, data in enumerate(train_dataset):
        imgs, true_labels = data[0].to(device), data[1].to(device)

        student_optimizer.zero_grad()

        # Forward pass of the student
        student_output = student_model_wo_temperature(imgs).to(device)

        # Calculate loss
        student_loss = student_loss_fn(student_output, true_labels)

        student_loss.backward()
        student_optimizer.step()

        pred_class = torch.argmax(student_output, dim=1).to(device)
        correct += len(true_labels) - torch.count_nonzero(pred_class - true_labels)
        
        ep_loss += student_loss.detach().item()

    student_lr_sch.step()
    print(f'epoch: {ep+1}, loss: {ep_loss / len(train_dataset):.2e}, train acc: {(correct * 100.0) / len(cifar_train):.3f}%, time: {(time_s() - start_time):.2f}s')

    if (ep + 1) % 10 == 0:
        calc_accuracy(student_model_wo_temperature)

calc_accuracy(student_model_wo_temperature)

torch.save(student_model_wo_temperature.state_dict(), './models/cifar_student_scratch.pt')

epoch: 1, loss: 2.09e+00, train acc: 46.627%, time: 23.87s
epoch: 2, loss: 1.97e+00, train acc: 61.412%, time: 23.57s
epoch: 3, loss: 1.93e+00, train acc: 66.478%, time: 23.73s
epoch: 4, loss: 1.90e+00, train acc: 70.132%, time: 23.86s
epoch: 5, loss: 1.89e+00, train acc: 71.497%, time: 23.58s
epoch: 6, loss: 1.87e+00, train acc: 73.070%, time: 23.59s
epoch: 7, loss: 1.86e+00, train acc: 74.332%, time: 23.82s
epoch: 8, loss: 1.85e+00, train acc: 75.885%, time: 23.79s
epoch: 9, loss: 1.85e+00, train acc: 76.777%, time: 23.73s
epoch: 10, loss: 1.83e+00, train acc: 78.232%, time: 23.87s
test accuracy: 65.430%
epoch: 11, loss: 1.81e+00, train acc: 80.832%, time: 23.79s
epoch: 12, loss: 1.81e+00, train acc: 81.892%, time: 23.79s
epoch: 13, loss: 1.80e+00, train acc: 82.397%, time: 23.76s
epoch: 14, loss: 1.80e+00, train acc: 82.780%, time: 23.78s
epoch: 15, loss: 1.79e+00, train acc: 83.277%, time: 23.65s
epoch: 16, loss: 1.79e+00, train acc: 83.517%, time: 23.57s
epoch: 17, loss: 1.79e+00,

# KD with ~3% of dataset

In [None]:
small_cifar_train_subset = list(range(0, int(len(cifar_train)*0.03)))
small_cifar_train_dataset = torch.utils.data.DataLoader(torch.utils.data.Subset(cifar_train, small_cifar_train_subset), batch_size=16, shuffle=True)

In [None]:
# KD from teacher using ~3% of dataset
student_model.apply(weight_reset)
student_optimizer = torch.optim.Adam(student_model.parameters(), lr=0.003)
student_lr_sch = torch.optim.lr_scheduler.MultiStepLR(student_optimizer, [10, 25, 40], gamma=0.3)
student_loss_fn = nn.CrossEntropyLoss()
distillation_loss_fn = torch.nn.KLDivLoss(reduction='mean')

kd_epochs = 30
time_s = lambda: time.time()
for ep in range(kd_epochs):
    ep_loss = 0.0
    correct = 0
    start_time = time_s()
    for i, (imgs, labels) in enumerate(small_cifar_train_dataset):
        imgs, labels = imgs.to(device), labels.to(device)
        student_optimizer.zero_grad()

        # Forward pass of the teacher with input
        with torch.no_grad():
            teacher_output = teacher_model_w_temperature(imgs).to(device)

        # Forward pass of the student
        student_output = student_model(imgs).to(device)

        # Calculate loss
        student_loss = student_loss_fn(student_output, labels)
        distill_loss = distillation_loss_fn(teacher_output, student_output)
        loss = (alpha * student_loss + (1 - alpha) * distill_loss) * temperature * temperature

        loss.backward()
        student_optimizer.step()

        pred_class = torch.argmax(student_output, dim=1).to(device)
        correct += len(labels) - torch.count_nonzero(pred_class - labels)

        ep_loss += loss.detach().item()
    student_lr_sch.step()
    print(f'epoch: {ep+1}, loss: {ep_loss / len(train_dataset):.2e}, train acc: {(correct * 100.0) / len(small_cifar_train_subset):.3f}%, time: {(time_s() - start_time):.2f}s')

    if (ep + 1) % 10 == 0:
        calc_accuracy(student_model)

calc_accuracy(student_model)

torch.save(student_model.state_dict(), './models/small_cifar_student.pt')



epoch: 1, loss: 1.00e-01, train acc: 22.667%, time: 1.75s
epoch: 2, loss: 9.63e-02, train acc: 27.333%, time: 1.67s
epoch: 3, loss: 9.14e-02, train acc: 34.000%, time: 1.65s
epoch: 4, loss: 8.72e-02, train acc: 37.533%, time: 1.68s
epoch: 5, loss: 8.42e-02, train acc: 39.733%, time: 1.65s
epoch: 6, loss: 8.42e-02, train acc: 40.267%, time: 1.62s
epoch: 7, loss: 8.20e-02, train acc: 41.733%, time: 1.68s
epoch: 8, loss: 7.88e-02, train acc: 43.267%, time: 1.66s
epoch: 9, loss: 7.68e-02, train acc: 45.267%, time: 1.62s
epoch: 10, loss: 7.56e-02, train acc: 45.600%, time: 1.68s
test accuracy: 45.550%
epoch: 11, loss: 6.94e-02, train acc: 47.067%, time: 1.72s
epoch: 12, loss: 6.40e-02, train acc: 51.333%, time: 1.65s
epoch: 13, loss: 6.15e-02, train acc: 52.333%, time: 1.66s
epoch: 14, loss: 6.21e-02, train acc: 52.133%, time: 1.63s
epoch: 15, loss: 5.90e-02, train acc: 54.933%, time: 1.65s
epoch: 16, loss: 5.86e-02, train acc: 52.533%, time: 1.64s
epoch: 17, loss: 5.72e-02, train acc: 54.7

In [None]:
# Training student from scratch using small dataset
student_model_wo_temperature.apply(weight_reset)
student_optimizer = torch.optim.Adam(student_model.parameters(), lr=0.003)
student_lr_sch = torch.optim.lr_scheduler.MultiStepLR(student_optimizer, [10, 25, 40], gamma=0.3)
student_loss_fn = nn.CrossEntropyLoss()

kd_epochs = 30
time_s = lambda: time.time()
for ep in range(kd_epochs):
    ep_loss = 0.0
    correct = 0
    start_time = time_s()
    for i, data in enumerate(small_cifar_train_dataset):
        imgs, true_labels = data[0].to(device), data[1].to(device)

        student_optimizer.zero_grad()

        # Forward pass of the student
        student_output = student_model_wo_temperature(imgs).to(device)

        # Calculate loss
        student_loss = student_loss_fn(student_output, true_labels)

        student_loss.backward()
        student_optimizer.step()

        pred_class = torch.argmax(student_output, dim=1).to(device)
        correct += len(true_labels) - torch.count_nonzero(pred_class - true_labels)
        
        ep_loss += student_loss.detach().item()

    student_lr_sch.step()
    print(f'epoch: {ep+1}, loss: {ep_loss / len(train_dataset):.2e}, train acc: {(correct * 100.0) / len(small_cifar_train_subset):.3f}%, time: {(time_s() - start_time):.2f}s')

    if (ep + 1) % 10 == 0:
        calc_accuracy(student_model_wo_temperature)

calc_accuracy(student_model_wo_temperature)

torch.save(student_model_wo_temperature.state_dict(), './models/small_cifar_student_scratch.pt')

epoch: 1, loss: 1.08e+00, train acc: 19.067%, time: 1.30s
epoch: 2, loss: 1.06e+00, train acc: 23.000%, time: 1.28s
epoch: 3, loss: 1.05e+00, train acc: 25.067%, time: 1.28s
epoch: 4, loss: 1.05e+00, train acc: 26.267%, time: 1.27s
epoch: 5, loss: 1.03e+00, train acc: 31.467%, time: 1.28s
epoch: 6, loss: 1.03e+00, train acc: 30.067%, time: 1.34s
epoch: 7, loss: 1.04e+00, train acc: 29.000%, time: 1.27s
epoch: 8, loss: 1.04e+00, train acc: 29.467%, time: 1.31s
epoch: 9, loss: 1.03e+00, train acc: 30.067%, time: 1.30s
epoch: 10, loss: 1.03e+00, train acc: 31.400%, time: 1.30s
test accuracy: 33.560%
epoch: 11, loss: 1.01e+00, train acc: 35.400%, time: 1.32s
epoch: 12, loss: 1.01e+00, train acc: 35.733%, time: 1.34s
epoch: 13, loss: 1.00e+00, train acc: 36.000%, time: 1.30s
epoch: 14, loss: 1.00e+00, train acc: 36.533%, time: 1.29s
epoch: 15, loss: 9.93e-01, train acc: 38.533%, time: 1.28s
epoch: 16, loss: 9.92e-01, train acc: 39.000%, time: 1.30s
epoch: 17, loss: 9.83e-01, train acc: 40.8

# Linearly increase α in KD from teacher

In [None]:
# KD from teacher using whole dataset
alpha = 0.1
student_model.apply(weight_reset)
student_optimizer = torch.optim.Adam(student_model.parameters(), lr=0.001)
student_lr_sch = torch.optim.lr_scheduler.MultiStepLR(student_optimizer, [10, 25, 40], gamma=0.5)
student_loss_fn = nn.CrossEntropyLoss()
distillation_loss_fn = torch.nn.KLDivLoss(reduction='mean')

kd_epochs = 30
time_s = lambda: time.time()
for ep in range(kd_epochs):
    ep_loss = 0.0
    correct = 0
    start_time = time_s()
    alpha = max(0.1, 0.5 * (ep / kd_epochs))
    for i, (imgs, labels) in enumerate(train_dataset):
        imgs, labels = imgs.to(device), labels.to(device)
        student_optimizer.zero_grad()

        # Forward pass of the teacher with input
        with torch.no_grad():
            teacher_output = teacher_model_w_temperature(imgs).to(device)

        # Forward pass of the student
        student_output = student_model(imgs).to(device)

        # Calculate loss
        student_loss = student_loss_fn(student_output, labels)
        distill_loss = distillation_loss_fn(teacher_output, student_output)
        loss = (alpha * student_loss + (1 - alpha) * distill_loss) * temperature * temperature

        loss.backward()
        student_optimizer.step()

        pred_class = torch.argmax(student_output, dim=1).to(device)
        correct += len(labels) - torch.count_nonzero(pred_class - labels)

        ep_loss += loss.detach().item()

    student_lr_sch.step()
    print(f'epoch: {ep+1}, alpha: {alpha:.3e}, loss: {ep_loss / len(train_dataset):.2e}, train acc: {(correct * 100.0) / len(cifar_train):.3f}%, time: {(time_s() - start_time):.2f}s')

    if (ep + 1) % 10 == 0:
        calc_accuracy(student_model)

calc_accuracy(student_model)

torch.save(student_model.state_dict(), './models/cifar_student_linear_alpha.pt')

# Looks like a schedule for increasing alpha doesn't have that much of an effect
# By choosing alpha carefully you can still get equivalent results to schedule

epoch: 1, alpha: 2.000e-01, loss: 4.12e+00, train acc: 41.376%, time: 26.10s
epoch: 2, alpha: 2.000e-01, loss: 3.97e+00, train acc: 54.336%, time: 26.15s
epoch: 3, alpha: 2.000e-01, loss: 3.91e+00, train acc: 58.612%, time: 26.07s
epoch: 4, alpha: 2.000e-01, loss: 3.88e+00, train acc: 61.132%, time: 26.00s
epoch: 5, alpha: 2.000e-01, loss: 3.85e+00, train acc: 62.912%, time: 26.09s
epoch: 6, alpha: 2.000e-01, loss: 3.83e+00, train acc: 64.582%, time: 26.03s
epoch: 7, alpha: 2.000e-01, loss: 3.81e+00, train acc: 65.692%, time: 26.23s
epoch: 8, alpha: 2.000e-01, loss: 3.80e+00, train acc: 66.420%, time: 26.03s
epoch: 9, alpha: 2.000e-01, loss: 3.79e+00, train acc: 67.308%, time: 25.82s
epoch: 10, alpha: 2.000e-01, loss: 3.78e+00, train acc: 67.666%, time: 25.88s
test accuracy: 71.630%
epoch: 11, alpha: 2.000e-01, loss: 3.75e+00, train acc: 69.904%, time: 26.09s
epoch: 12, alpha: 2.000e-01, loss: 3.75e+00, train acc: 69.962%, time: 26.28s
epoch: 13, alpha: 2.000e-01, loss: 3.74e+00, train