<a href="https://colab.research.google.com/github/vishal-burman/PyTorch-Architectures/blob/master/Knowledge_Distillation.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import copy
import time
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torchvision import datasets, transforms
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [2]:
batch_size = 64
train_dataset = datasets.MNIST(root='data', 
                               train=True, 
                               transform=transforms.ToTensor(),
                               download=True)

test_dataset = datasets.MNIST(root='data', 
                              train=False, 
                              transform=transforms.ToTensor())


train_loader = DataLoader(dataset=train_dataset, 
                          batch_size=batch_size, 
                          shuffle=True)

test_loader = DataLoader(dataset=test_dataset, 
                         batch_size=batch_size, 
                         shuffle=False)
print("Length of train_loader: ", len(train_loader))
print("Length of test_loader: ", len(test_loader))

Length of train_loader:  938
Length of test_loader:  157


In [3]:
# Create the teacher
class Teacher(nn.Module):
  def __init__(self, num_classes=10):
    super(Teacher, self).__init__()
    # input_shape ~ [batch_size, 1, 28, 28]
    # shape ~ [batch_size, 256, 14, 14]
    self.conv_1 = nn.Conv2d(in_channels=1, out_channels=256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
    # shape ~ [batch_size, 256, 14, 14]
    self.lr_1 = nn.LeakyReLU(inplace=True)
    # shape ~ [batch_size, 256, 15, 15]
    self.pool_1 = nn.MaxPool2d(kernel_size=(2, 2), stride=(1, 1), padding=(1, 1))
    # shape ~ [batch_size, 512, 8, 8]
    self.conv_2 = nn.Conv2d(in_channels=256, out_channels=512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
    # shape ~ [batch_size, 10]
    self.lin_1 = nn.Linear(in_features=512 * 8 * 8, out_features=10)

  def forward(self, x):
    x = self.conv_1(x)
    x = F.leaky_relu(x)
    x = self.pool_1(x)
    x = self.conv_2(x)
    x = x.view(x.size(0), -1)
    x = self.lin_1(x)
    return x

In [4]:
# Create Student
class Student(nn.Module):
  def __init__(self, num_classes=10):
    super(Student, self).__init__()
    self.conv_1 = nn.Conv2d(in_channels=1, out_channels=16, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
    
    self.pool_1 = nn.MaxPool2d(kernel_size=(2, 2), stride=(1, 1), padding=(1, 1))

    self.conv_2 = nn.Conv2d(in_channels=16, out_channels=32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
    
    self.lin_1 = nn.Linear(in_features=32 * 8 * 8, out_features=10)
  
  def forward(self, x):
    x = self.conv_1(x)
    x = F.leaky_relu(x)
    x = self.pool_1(x)
    x = self.conv_2(x)
    x = x.view(x.size(0), -1)
    x = self.lin_1(x)
    return x

In [5]:
teacher = Teacher(num_classes=10)
student = Student(num_classes=10)
student = student.to(device)
teacher = teacher.to(device)
total_params_t = sum(p.numel() for p in teacher.parameters())
total_params_s = sum(p.numel() for p in student.parameters())
print("Total Parameters in Teacher: ", total_params_t)
print("Total Parameters in Student: ", total_params_s)

Total Parameters in Teacher:  1510410
Total Parameters in Student:  25290


In [6]:
# Creating a clone of student
student_clone = Student(num_classes=10)

In [7]:
# Train Teacher
optimizer_t = torch.optim.Adam(teacher.parameters(), lr=0.0001)
loss_t = nn.CrossEntropyLoss()

def compute_loss(model, data_loader, loss_generic, device):
  tot = 0.
  model.eval()
  for features, targets in data_loader:
    features = features.to(device)
    targets = targets.to(device)
    logits = model(features)
    loss = loss_generic(logits, targets)
    tot += loss.item()
  return tot/len(data_loader)

def compute_accuracy(model, data_loader, device):
    model.eval()
    correct_pred, num_examples = 0, 0
    with torch.no_grad():
        for features, targets in data_loader:
            features = features.to(device)
            targets = targets.to(device)
            logits = model(features)
            probas = F.softmax(logits, dim=1)
            _, predicted_labels = torch.max(probas, 1)
            num_examples += targets.size(0)
            correct_pred += (predicted_labels == targets).sum()
        return correct_pred.float()/num_examples * 100


EPOCHS = 4
start_time = time.time()
for epoch in range(EPOCHS):
  teacher.train()
  for batch_idx, (features, targets) in enumerate(train_loader):
    features = features.to(device)
    targets = targets.to(device)

    optimizer_t.zero_grad()
    logits = teacher(features)
    loss = loss_t(logits, targets)
    
    # LOGGING
    if batch_idx % 200 == 0:
      print("Batch: %03d/%03d" % (batch_idx, len(train_loader)))

    loss.backward()
    optimizer_t.step()
  teacher.eval()
  with torch.set_grad_enabled(False):
    train_average_loss = compute_loss(teacher, train_loader, loss_t, device)
    test_average_loss = compute_loss(teacher, test_loader, loss_t, device)
    test_accuracy = compute_accuracy(teacher, test_loader, device)
    print("Epoch: %03d/%03d | Teacher Train Loss: %.3f | Teacher Test Loss: %.3f | Teacher Test Accuracy: %.2f" % (epoch+1, EPOCHS, train_average_loss, test_average_loss, test_accuracy))
  epoch_elapsed_time = time.time() - start_time
  print("Epoch Elapsed Time: ", epoch_elapsed_time)
total_training_time = time.time() - start_time
print("Total Training Time: ", total_training_time)

Batch: 000/938
Batch: 200/938
Batch: 400/938
Batch: 600/938
Batch: 800/938
Epoch: 001/004 | Teacher Train Loss: 0.095 | Teacher Test Loss: 0.087 | Teacher Test Accuracy: 97.40
Epoch Elapsed Time:  17.808791637420654
Batch: 000/938
Batch: 200/938
Batch: 400/938
Batch: 600/938
Batch: 800/938
Epoch: 002/004 | Teacher Train Loss: 0.062 | Teacher Test Loss: 0.059 | Teacher Test Accuracy: 98.20
Epoch Elapsed Time:  35.56360054016113
Batch: 000/938
Batch: 200/938
Batch: 400/938
Batch: 600/938
Batch: 800/938
Epoch: 003/004 | Teacher Train Loss: 0.062 | Teacher Test Loss: 0.062 | Teacher Test Accuracy: 97.92
Epoch Elapsed Time:  53.29674530029297
Batch: 000/938
Batch: 200/938
Batch: 400/938
Batch: 600/938
Batch: 800/938
Epoch: 004/004 | Teacher Train Loss: 0.048 | Teacher Test Loss: 0.053 | Teacher Test Accuracy: 98.22
Epoch Elapsed Time:  71.29234600067139
Total Training Time:  71.29257702827454


In [8]:
# Create distil training
optimizer_distil = torch.optim.Adam(student.parameters(), lr=0.001)
loss_s = nn.CrossEntropyLoss()
loss_distil = nn.KLDivLoss(reduction='batchmean')
alpha = 0.1
temperature = 10
EPOCHS = 3

# Freezing the layers of Teacher
teacher.eval()
for parameter in teacher.parameters():
    parameter.requires_grad = False

start_time = time.time()
for epoch in range(EPOCHS):
  student.train()
  for batch_idx, (features, targets) in enumerate(train_loader):
    features = features.to(device)
    targets = targets.to(device)

    teacher_logits = teacher(features)
    
    optimizer_distil.zero_grad()
    student_logits = student(features)

    student_loss = loss_s(student_logits, targets)
    distillation_loss = loss_distil(F.log_softmax(input=(student_logits/temperature), dim=1), F.log_softmax(input=(teacher_logits/temperature), dim=1))
    loss = alpha * student_loss + (1 - alpha) * distillation_loss

    # LOGGING
    if batch_idx % 200 == 0:
      print("Batch: %03d/%03d" % (batch_idx, len(train_loader)))

    loss.backward()
    optimizer_distil.step()

  student.eval()
  with torch.set_grad_enabled(False):
    train_average_loss = compute_loss(student, train_loader, loss_s, device)
    test_average_loss = compute_loss(student, test_loader, loss_s, device)
    test_accuracy = compute_accuracy(student, test_loader, device)
    print("Epoch: %03d/%03d | Student Train Loss: %.3f | Student Test Loss: %.3f | Student Test Accuracy: %.2f" % (epoch+1, EPOCHS, train_average_loss, test_average_loss, test_accuracy))
  epoch_elapsed_time = time.time() - start_time
  print("Epoch Elapsed Time: ", epoch_elapsed_time)
total_training_time = time.time() - start_time
print("Total Training Time: ", total_training_time)

Batch: 000/938
Batch: 200/938
Batch: 400/938
Batch: 600/938
Batch: 800/938
Epoch: 001/003 | Student Train Loss: 0.120 | Student Test Loss: 0.108 | Student Test Accuracy: 96.83
Epoch Elapsed Time:  10.960257053375244
Batch: 000/938
Batch: 200/938
Batch: 400/938
Batch: 600/938
Batch: 800/938
Epoch: 002/003 | Student Train Loss: 0.078 | Student Test Loss: 0.072 | Student Test Accuracy: 97.69
Epoch Elapsed Time:  22.495341300964355
Batch: 000/938
Batch: 200/938
Batch: 400/938
Batch: 600/938
Batch: 800/938
Epoch: 003/003 | Student Train Loss: 0.067 | Student Test Loss: 0.063 | Student Test Accuracy: 98.00
Epoch Elapsed Time:  33.67755627632141
Total Training Time:  33.67793083190918


In [9]:
# Train Student clone from scratch for comparison
student_clone = student_clone.to(device)
optimizer_clone = torch.optim.Adam(student_clone.parameters(), lr=0.001)
loss_clone = nn.CrossEntropyLoss()

EPOCHS = 3
start_time = time.time()
for epoch in range(EPOCHS):
  student_clone.train()
  for batch_idx, (features, targets) in enumerate(train_loader):
    features = features.to(device)
    targets = targets.to(device)

    optimizer_clone.zero_grad()
    logits = student_clone(features)
    loss = loss_clone(logits, targets)
    
    # LOGGING
    if batch_idx % 200 == 0:
      print("Batch: %03d/%03d" % (batch_idx, len(train_loader)))

    loss.backward()
    optimizer_clone.step()
  student_clone.eval()
  with torch.set_grad_enabled(False):
    train_average_loss = compute_loss(student_clone, train_loader, loss_clone, device)
    test_average_loss = compute_loss(student_clone, test_loader, loss_clone, device)
    test_accuracy = compute_accuracy(student_clone, test_loader, device)
    print("Epoch: %03d/%03d | Clone_Student Train Loss: %.3f | Clone_Student Test Loss: %.3f | Clone_Student Test Accuracy: %.2f" % (epoch+1, EPOCHS, train_average_loss, test_average_loss, test_accuracy))
  epoch_elapsed_time = time.time() - start_time
  print("Epoch Elapsed Time: ", epoch_elapsed_time)
total_training_time = time.time() - start_time
print("Total Training Time: ", total_training_time)

Batch: 000/938
Batch: 200/938
Batch: 400/938
Batch: 600/938
Batch: 800/938
Epoch: 001/003 | Clone_Student Train Loss: 0.118 | Clone_Student Test Loss: 0.110 | Clone_Student Test Accuracy: 96.74
Epoch Elapsed Time:  10.838661193847656
Batch: 000/938
Batch: 200/938
Batch: 400/938
Batch: 600/938
Batch: 800/938
Epoch: 002/003 | Clone_Student Train Loss: 0.079 | Clone_Student Test Loss: 0.072 | Clone_Student Test Accuracy: 97.69
Epoch Elapsed Time:  21.47327733039856
Batch: 000/938
Batch: 200/938
Batch: 400/938
Batch: 600/938
Batch: 800/938
Epoch: 003/003 | Clone_Student Train Loss: 0.069 | Clone_Student Test Loss: 0.068 | Clone_Student Test Accuracy: 97.83
Epoch Elapsed Time:  31.9901065826416
Total Training Time:  31.99079704284668
