In [41]:
import torch
import torchvision
import torchvision.transforms as transforms
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
import numpy as np
from tqdm.notebook import tqdm

In [42]:
BATCH_SIZE = 256
LR = 1e-3
EPOCHS = 3
DEVICE = torch.device('mps')

In [43]:
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Lambda(lambda x: x if (x.ndim == 3) and (x.shape[0] == 3) else x.repeat(3, 1, 1)),
    transforms.Resize((32, 32), interpolation=transforms.InterpolationMode.NEAREST_EXACT),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
])
dataset_train = torchvision.datasets.MNIST(root='./data', train=True, download=True, transform=transform)
dataset_test = torchvision.datasets.MNIST(root='./data', train=False, download=True, transform=transform)


train_loader = torch.utils.data.DataLoader(dataset_train, batch_size=BATCH_SIZE, shuffle=True)
test_loader = torch.utils.data.DataLoader(dataset_test, batch_size=BATCH_SIZE, shuffle=True)

In [44]:
class CNN(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = torch.flatten(x, 1) # flatten all dimensions except batch
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

In [45]:
class MLP(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.flat = nn.Flatten()
        self.fc1 = nn.Linear(3 * 32 * 32, 500)
        self.fc2 = nn.Linear(500, 100)
        self.fc3 = nn.Linear(100, 10)
    
    def forward(self, x):
        x = self.flat(x)
        x = self.fc1(x)
        x = F.relu(x)
        x = self.fc2(x)
        x = F.relu(x)
        x = self.fc3(x)
        return x

In [46]:
model = CNN().to(DEVICE)
optimiser = torch.optim.Adam(model.parameters(), lr=LR)
loss = nn.CrossEntropyLoss()
prog_bar = tqdm(total=EPOCHS * len(train_loader))

for epoch in range(EPOCHS):
    for step, (x, y_hat) in enumerate(train_loader):
        x, y_hat = x.to(DEVICE), y_hat.to(DEVICE)
        y = model(x)
        l = loss(y, y_hat)
        optimiser.zero_grad()
        l.backward()
        optimiser.step()
        prog_bar.update(1)
        if step % 100 == 0:
            prog_bar.set_description(f'epoch: {epoch} loss: {l.detach().cpu().item()}')
    

  0%|          | 0/705 [00:00<?, ?it/s]

In [47]:
accuracies = []
for x, y_hat in test_loader:
    x, y_hat = x.to(DEVICE), y_hat.to(DEVICE)
    with torch.no_grad():
        y = model(x)
    pred = torch.argmax(y, dim=-1)
    accuracy = (y_hat == pred).to(torch.float).mean().cpu().item()
    accuracies.append(accuracy)
print(f'Accuracy: {sum(accuracies)/len(accuracies)}')

Accuracy: 0.982421875


## Train student model

In [55]:
class SmallMLP(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.flat = nn.Flatten()
        self.fc1 = nn.Linear(3 * 32 * 32, 10)
        self.fc2 = nn.Linear(10, 10)
        self.fc3 = nn.Linear(10, 10)
    
    def forward(self, x):
        x = self.flat(x)
        x = self.fc1(x)
        x = F.relu(x)
        x = self.fc2(x)
        x = F.relu(x)
        x = self.fc3(x)
        return x
    
student = SmallMLP().to(DEVICE)

In [56]:
TEMPERATURE = 10
ALPHA = 0.5

In [57]:
optimiser = torch.optim.Adam(student.parameters(), lr=LR)
loss_ce = nn.CrossEntropyLoss()
loss_kld = nn.KLDivLoss()
prog_bar = tqdm(total=EPOCHS * len(train_loader))

for epoch in range(EPOCHS):
    for step, (x, y_hat) in enumerate(train_loader):
        x, y_hat = x.to(DEVICE), y_hat.to(DEVICE)

        with torch.no_grad():
            teacher_logits = model(x)
            teacher_dist = F.softmax(teacher_logits / TEMPERATURE)

        student_logits = student(x)
        student_dist = F.softmax(student_logits / TEMPERATURE)

        l = ALPHA*loss_ce(student_logits, y_hat) + (1-ALPHA)*loss_kld(student_dist, teacher_dist)

        optimiser.zero_grad()
        l.backward()
        optimiser.step()
        prog_bar.update(1)
        if step % 100 == 0:
            prog_bar.set_description(f'epoch: {epoch} loss: {l.detach().cpu().item()}')
    

  0%|          | 0/705 [00:00<?, ?it/s]

  teacher_dist = F.softmax(teacher_logits / TEMPERATURE)
  student_dist = F.softmax(student_logits / TEMPERATURE)


In [58]:
accuracies = []
for x, y_hat in test_loader:
    x, y_hat = x.to(DEVICE), y_hat.to(DEVICE)
    with torch.no_grad():
        y = student(x)
    pred = torch.argmax(y, dim=-1)
    accuracy = (y_hat == pred).to(torch.float).mean().cpu().item()
    accuracies.append(accuracy)
print(f'Accuracy: {sum(accuracies)/len(accuracies)}')

Accuracy: 0.86103515625


In [48]:
def n_params(model):
    n_params = 0
    for param in model.parameters():
        n_params += param.nelement()
    return n_params

In [54]:
print(n_params(model))
print(n_params(student))

62006
30950
