In [25]:
import torch
import torch.nn.functional as F 
import torchvision.datasets as datasets 
import torchvision.transforms as transforms 
from torch import optim
from torch import nn
from torch.utils.data import DataLoader 
from tqdm import tqdm 
import numpy as np
import matplotlib.pyplot as plt

In [26]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

in_channels = 1
num_classes = 10
learning_rate = 3e-4
batch_size = 64
num_epochs = 15

In [27]:
train_dataset = datasets.MNIST(root="dataset/", train=True, transform=transforms.ToTensor(), download=True)
test_dataset = datasets.MNIST(root="dataset/", train=False, transform=transforms.ToTensor(), download=True)
train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=True)

In [28]:
class Teacher(nn.Module):
    def __init__(self, num_classes):
        super(Teacher, self).__init__()
        self.conv1 = nn.Conv2d(1, 256, 3, 1, 1)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(256, 512, 3, 1, 1)
        self.fc = nn.Linear(512*7*7, num_classes)

    def forward(self, x):
        x = self.conv1(x)
        x = F.relu(x)
        x = self.pool(x)

        x = self.conv2(x)
        x = F.relu(x)
        x = self.pool(x)

        x = x.view(-1, 512*7*7)
        x = self.fc(x)
        return x

    
class Student(nn.Module):
    def __init__(self, num_classes):
        super(Student, self).__init__()
        self.conv1 = nn.Conv2d(1, 8, 3, 1, 1)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(8, 8, 3, 1, 1)
        self.fc = nn.Linear(8*7*7, num_classes)

    def forward(self, x):
        x = self.conv1(x)
        x = F.relu(x)
        x = self.pool(x)
        
        x = self.conv2(x)
        x = F.relu(x)
        x = self.pool(x)
        
        x = x.view(-1, 8*7*7)
        x = self.fc(x)
        return x

In [29]:
def train_without_kd(model, optimizer, num_epochs):
    for epoch in range(num_epochs):
        loop = tqdm(enumerate(train_loader), total=len(train_loader), leave=False)
        for batch_idx, (data, targets) in loop:
            data = data.to(device=device)
            targets = targets.to(device=device)
            logits = model(data)
            loss = nn.CrossEntropyLoss()(logits, targets)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            loop.set_description(f"Epoch [{epoch+1}/{num_epochs}]")

In [30]:
def train_with_kd(model, optimizer, num_epochs, T=1):
    for epoch in range(num_epochs):
        loop = tqdm(enumerate(train_loader), total=len(train_loader), leave=False)
        for batch_idx, (data, targets) in loop:
            data = data.to(device=device)
            targets = targets.to(device=device)
            logits = model(data)
            loss = 0.85*nn.CrossEntropyLoss()(logits, targets)+0.15*(T**2)*nn.KLDivLoss(reduction="batchmean")(torch.softmax(logits/T, 1), get_probs(data, teacher))
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            loop.set_description(f"Epoch [{epoch+1}/{num_epochs}]")

In [31]:
def get_accuracy(loader, model):
    num_correct = 0
    num_samples = 0
    model.eval()
    with torch.no_grad():
        for x, y in loader:
            x = x.to(device=device)
            y = y.to(device=device)
            logits = model(x)
            _, predictions = logits.max(1)
            num_correct += (predictions == y).sum()
            num_samples += predictions.size(0)
    model.train()
    return num_correct/num_samples

In [32]:
def get_probs(batch_features, model, T=1):
    model.eval()
    with torch.no_grad():
        x = batch_features.to(device=device)
        probs = torch.softmax(model(x)/T, 1)
    model.train()
    return probs

In [33]:
teacher = Teacher(num_classes).to(device)
torch.manual_seed(42)
student1 = Student(num_classes).to(device)
torch.manual_seed(42)
student2 = Student(num_classes).to(device)

teacher_optimizer = optim.Adam(teacher.parameters(), lr=learning_rate)
student1_optimizer = optim.Adam(student1.parameters(), lr=learning_rate)
student2_optimizer = optim.Adam(student2.parameters(), lr=learning_rate)

In [34]:
print("Norm of 1st layer of student 1:", torch.norm(student1.conv1.weight).item())
print("Norm of 1st layer of student 2:", torch.norm(student1.conv1.weight).item())

Norm of 1st layer of student 1: 1.6051220893859863
Norm of 1st layer of student 2: 1.6051220893859863


In [35]:
print("Teacher's Parameters: ", sum(p.numel() for p in teacher.parameters()))
print("Student's Parameters: ", sum(p.numel() for p in student1.parameters()), sum(p.numel() for p in student2.parameters()))

Teacher's Parameters:  1433610
Student's Parameters:  4594 4594


In [36]:
train_without_kd(teacher, teacher_optimizer, num_epochs)
train_without_kd(student1, student1_optimizer, num_epochs)
train_with_kd(student2, student2_optimizer, num_epochs)

Epoch [13/15]:  28%|██▊       | 262/938 [00:02<00:06, 101.36it/s]IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_msg_rate_limit`.

Current values:
NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
NotebookApp.rate_limit_window=3.0 (secs)



In [37]:
print("__________________________________________________Teacher Performance______________________________________")
print(f"Accuracy on training set: {get_accuracy(train_loader, teacher)*100:.2f}")
print(f"Accuracy on test set: {get_accuracy(test_loader, teacher)*100:.2f}")
print("__________________________________________________Student Training Without Knowledge Distillation Performance______________________________________")
print(f"Accuracy on training set: {get_accuracy(train_loader, student1)*100:.2f}")
print(f"Accuracy on test set: {get_accuracy(test_loader, student1)*100:.2f}")
print("__________________________________________________Student Training With Knowledge Distillation Performance______________________________________")
print(f"Accuracy on training set: {get_accuracy(train_loader, student2)*100:.2f}")
print(f"Accuracy on test set: {get_accuracy(test_loader, student2)*100:.2f}")


__________________________________________________Teacher Performance______________________________________
Accuracy on training set: 99.80
Accuracy on test set: 99.09
__________________________________________________Student Training Without Knowledge Distillation Performance______________________________________
Accuracy on training set: 98.22
Accuracy on test set: 98.28
__________________________________________________Student Training With Knowledge Distillation Performance______________________________________
Accuracy on training set: 98.40
Accuracy on test set: 98.44
