# Knowledge Distillation on simple Convolutional Neural Networks (CNNs)

## Library Imports

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.optim.lr_scheduler import StepLR

## Dataset Loading
Dataset: CIFAR-10

In [2]:
# Pre-process the dataset by converting it to Tensor as well as normalize using mean and std
my_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# Downloading the CIFAR-10 dataset
train_dataset = datasets.CIFAR10(root='./data', train=True, download=True, transform=my_transform)
test_dataset = datasets.CIFAR10(root='./data', train=False, download=True, transform=my_transform)

# DataLoader ready to be used
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=128, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=128, shuffle=False)

Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./data/cifar-10-python.tar.gz


100%|██████████| 170M/170M [00:03<00:00, 43.3MB/s]


Extracting ./data/cifar-10-python.tar.gz to ./data
Files already downloaded and verified


## Model Definitions

#### Teacher: A simple CNN with 4 Conv2D layers, 2 MaxPool2D, 1 Flatten, and 2 Linear layers

#### Student: A simple CNN with 2 Conv2D layers, 2 MaxPool2D, 1 Flatten, and 2 Linear layers

In [12]:
class TeacherModel(nn.Module):
  def __init__(self, num_classes=10):
    super(TeacherModel, self).__init__()
    self.features = nn.Sequential(
        nn.Conv2d(3, 128, kernel_size=3, padding=1),
        nn.ReLU(),
        nn.Conv2d(128, 64, kernel_size=3, padding=1),
        nn.ReLU(),
        nn.MaxPool2d(kernel_size=2, stride=2),
        nn.Conv2d(64, 64, kernel_size=3, padding=1),
        nn.ReLU(),
        nn.Conv2d(64, 32, kernel_size=3, padding=1),
        nn.ReLU(),
        nn.MaxPool2d(kernel_size=2, stride=2),
    )
    self.classifier = nn.Sequential(
        nn.Flatten(),
        nn.Linear(32 * 8 * 8, 512),
        nn.ReLU(),
        nn.Linear(512, num_classes)
    )

  def forward(self, x):
    x = self.features(x)
    x = self.classifier(x)
    return x


class StudentModel(nn.Module):
  def __init__(self, num_classes=10):
    super(StudentModel, self).__init__()
    self.features = nn.Sequential(
        nn.Conv2d(3, 16, kernel_size=3, padding=1),
        nn.ReLU(),
        nn.MaxPool2d(kernel_size=2, stride=2),
        nn.Conv2d(16, 16, kernel_size=3, padding=1),
        nn.ReLU(),
        nn.MaxPool2d(kernel_size=2, stride=2),
    )
    self.classifier = nn.Sequential(
        nn.Flatten(),
        nn.Linear(16 * 8 * 8, 256),
        nn.ReLU(),
        nn.Linear(256, num_classes)
    )

  def forward(self, x):
    x = self.features(x)
    x = self.classifier(x)
    return x

## Individual Training
First, we train each teacher and student model to calculate their accuracy without using knowledge distillation

In [13]:
def train_individual(model, train_loader, epochs, lr, device):
  optimizer = optim.Adam(model.parameters(), lr=lr)
  criterion = nn.CrossEntropyLoss()

  model.train() # train mode

  for epoch in range(epochs):
    running_loss = 0.0

    for inputs, labels in train_loader:
      inputs, labels = inputs.to(device), labels.to(device) # transfer to GPU

      optimizer.zero_grad() # reset gradient
      outputs = model(inputs) # generate response (probability distribution)
      loss = criterion(outputs, labels) # calculate loss
      loss.backward()
      optimizer.step()
      running_loss += loss.item()

    print(f"Epoch {epoch+1}/{epochs}, Loss: {running_loss / len(train_loader)}")

def test(model, test_loader, device):
  model.to(device) # transfer to GPU
  model.eval() # eval mode

  correct = 0
  total = 0

  with torch.no_grad(): # no gradient update
    for inputs, labels in test_loader:
      inputs, labels = inputs.to(device), labels.to(device)

      outputs = model(inputs) # generate response (probability distribution)
      _, predicted = torch.max(outputs.data, 1) # find the class with max probability
      total += labels.size(0)
      correct += (predicted == labels).sum().item()

  accuracy = 100 * correct / total
  print(f"Test Accuracy: {accuracy:.2f}%")
  return accuracy

In [14]:
# Model Creation
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

teacher_model = TeacherModel().to(device)
student_model = StudentModel().to(device)

total_params_deep = "{:,}".format(sum(p.numel() for p in teacher_model.parameters()))
print(f"Teacher Model parameters: {total_params_deep}")
total_params_light = "{:,}".format(sum(p.numel() for p in student_model.parameters()))
print(f"Student Model parameters: {total_params_light}")

Teacher Model parameters: 1,186,986
Student Model parameters: 267,738


In [15]:
# Teacher Model Training and Testing
torch.manual_seed(42)
epoch = 10
lr = 0.001
train_individual(teacher_model, train_loader, epoch, lr, device)
test(teacher_model, test_loader, device)

Epoch 1/10, Loss: 1.3853567789887529
Epoch 2/10, Loss: 0.9018769300807162
Epoch 3/10, Loss: 0.6984720657701078
Epoch 4/10, Loss: 0.5410691216168806
Epoch 5/10, Loss: 0.4124810284818225
Epoch 6/10, Loss: 0.28769126050460064
Epoch 7/10, Loss: 0.1918012785827717
Epoch 8/10, Loss: 0.1417765292574835
Epoch 9/10, Loss: 0.10655551437107498
Epoch 10/10, Loss: 0.09177767301021177
Test Accuracy: 74.43%


74.43

In [16]:
# Student Model Training and Testing
torch.manual_seed(42)
epoch = 10
lr = 0.001
train_individual(student_model, train_loader, epoch, lr, device)
test(student_model, test_loader, device)

Epoch 1/10, Loss: 1.4606203031356988
Epoch 2/10, Loss: 1.1367454313866012
Epoch 3/10, Loss: 0.9992197162050116
Epoch 4/10, Loss: 0.9032106946801286
Epoch 5/10, Loss: 0.8217958803372005
Epoch 6/10, Loss: 0.7537446095967841
Epoch 7/10, Loss: 0.6937782188967976
Epoch 8/10, Loss: 0.6249525134673204
Epoch 9/10, Loss: 0.5771157086810188
Epoch 10/10, Loss: 0.522296984253637
Test Accuracy: 69.51%


69.51

## Results (no Knowledge Distillation)

### Teacher Model

*   Number of parameters: 1,186,986
*   Time to train: ~3 minutes
*   Train loss (last epoch): 0.09178
*   Test accuracy: **74.43%**

### Student Model

*   Number of parameters: 267,738
*   Time to train: ~2 minutes
*   Train loss (last epoch): 0.5223
*   Test accuracy: **69.51%**

## Train using Knowledge Distillation

Now, our goal is to use the predictions from the Teacher model (teacher logits) to improve the performance of our Student model

To do so, we use an additional loss function, called **KL-divergence**



In [17]:
def train_with_kd(teacher, student, train_loader, epochs, lr, device, alpha=0.5, T=10):
  ce_loss = nn.CrossEntropyLoss() # traditional cross-entropy loss
  optimizer = optim.Adam(student.parameters(), lr=lr)

  teacher.eval() # eval mode (no update)
  student.train() # train mode

  for epoch in range(epochs):
    running_loss = 0.0
    for inputs, labels in train_loader:
      inputs, labels = inputs.to(device), labels.to(device)

      optimizer.zero_grad()
      with torch.no_grad(): # no need to calculate gradient for teacher
        teacher_logits = teacher(inputs)

      student_logits = student(inputs)
      ce_loss_value = ce_loss(student_logits, labels)

      soft_targets = F.softmax(teacher_logits / T, dim=-1)
      soft_prob = F.log_softmax(student_logits / T, dim=-1)

      kd_loss_value = F.kl_div(soft_prob, soft_targets.detach(), reduction='batchmean') * (T * T) # KL-divergence loss
      loss = ce_loss_value + kd_loss_value # combined loss (CE + KL)
      loss.backward()
      optimizer.step()
      running_loss += loss.item()

    print(f"Epoch {epoch+1}/{epochs}, Loss: {running_loss / len(train_loader)}")

In [18]:
# Student Model Training with Knowledge Distillation and Testing

torch.manual_seed(42)
epoch = 10
lr = 0.001
train_with_kd(teacher_model, student_model, train_loader, 10, 0.001, device)
test(student_model, test_loader, device)

Epoch 1/10, Loss: 12.14585326943556
Epoch 2/10, Loss: 10.29619171552341
Epoch 3/10, Loss: 9.561673729011165
Epoch 4/10, Loss: 8.949190325139428
Epoch 5/10, Loss: 8.506584990664821
Epoch 6/10, Loss: 8.1307039590138
Epoch 7/10, Loss: 7.781981824304137
Epoch 8/10, Loss: 7.4728119635520995
Epoch 9/10, Loss: 7.217940778073753
Epoch 10/10, Loss: 6.9759596583178585
Test Accuracy: 71.38%


71.38

## Results

### Teacher Model

*   Number of parameters: 1,186,986
*   Time to train: ~3 minutes
*   Test accuracy: **74.43%**

### Student Model (Without Knowledge Distillation)

*   Number of parameters: 267,738
*   Time to train: ~2 minutes
*   Test accuracy: **69.51%**

### Student Model (With Knowledge Distillation)

*   Number of parameters: 267,738
*   Time to train: ~2 minutes
*   Test accuracy: **71.38%**