<a href="https://colab.research.google.com/github/vishal-burman/PyTorch-Architectures/blob/master/Knowledge_Distillation.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
! nvidia-smi

In [2]:
import time
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torchvision import datasets, transforms
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [9]:
batch_size = 64
train_dataset = datasets.MNIST(root='data', 
                               train=True, 
                               transform=transforms.ToTensor(),
                               download=True)

test_dataset = datasets.MNIST(root='data', 
                              train=False, 
                              transform=transforms.ToTensor())


train_loader = DataLoader(dataset=train_dataset, 
                          batch_size=batch_size, 
                          shuffle=True)

test_loader = DataLoader(dataset=test_dataset, 
                         batch_size=batch_size, 
                         shuffle=False)
print("Length of train_loader: ", len(train_loader))
print("Length of test_loader: ", len(test_loader))

Length of train_loader:  938
Length of test_loader:  157


In [7]:
# Create the teacher
class Teacher(nn.Module):
  def __init__(self, num_classes=10):
    super(Teacher, self).__init__()
    # input_shape ~ [batch_size, 1, 28, 28]
    # shape ~ [batch_size, 256, 14, 14]
    self.conv_1 = nn.Conv2d(in_channels=1, out_channels=256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
    # shape ~ [batch_size, 256, 14, 14]
    self.lr_1 = nn.LeakyReLU(inplace=True)
    # shape ~ [batch_size, 256, 15, 15]
    self.pool_1 = nn.MaxPool2d(kernel_size=(2, 2), stride=(1, 1), padding=(1, 1))
    # shape ~ [batch_size, 512, 8, 8]
    self.conv_2 = nn.Conv2d(in_channels=256, out_channels=512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
    # shape ~ [batch_size, 10]
    self.lin_1 = nn.Linear(in_features=512 * 8 * 8, out_features=10)

  def forward(self, x):
    x = self.conv_1(x)
    x = F.leaky_relu(x)
    x = self.pool_1(x)
    x = self.conv_2(x)
    x = x.view(x.size(0), -1)
    x = self.lin_1(x)
    return x

In [4]:
class Student(nn.Module):
  def __init__(self, num_classes=10):
    super(Student, self).__init__()
    self.conv_1 = nn.Conv2d(in_channels=1, out_channels=16, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
    
    self.pool_1 = nn.MaxPool2d(kernel_size=(2, 2), stride=(1, 1), padding=(1, 1))

    self.conv_2 = nn.Conv2d(in_channels=16, out_channels=32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
    
    self.lin_1 = nn.Linear(in_features=32 * 8 * 8, out_features=10)
  
  def forward(self, x):
    x = self.conv_1(x)
    x = F.leaky_relu(x)
    x = self.pool_1(x)
    x = self.conv_2(x)
    x = x.view(x.size(0), -1)
    x = self.lin_1(x)
    return x

In [8]:
teacher = Teacher(num_classes=10)
student = Student(num_classes=10)
student = student.to(device)
teacher = teacher.to(device)
total_params_t = sum(p.numel() for p in teacher.parameters())
total_params_s = sum(p.numel() for p in student.parameters())
print("Total Parameters in Teacher: ", total_params_t)
print("Total Parameters in Student: ", total_params_s)

Total Parameters in Teacher:  1510410
Total Parameters in Student:  25290


In [None]:
# Train Teacher
optimizer_t = torch.optim.Adam(teacher.parameters(), lr=0.0001)
loss_t = nn.CrossEntropyLoss()

def compute_loss(model, data_loader, device):
  tot = 0.
  model.eval()
  for features, targets in data_loader:
    features = features.to(device)
    targets = targets.to(device)
    logits = model(features)
    loss = loss_t(logits, targets)
    tot += loss.item()
  return tot/len(data_loader)


EPOCHS = 5
start_time = time.time()
for epoch in range(EPOCHS):
  teacher.train()
  for batch_idx, (features, targets) in enumerate(train_loader):
    features = features.to(device)
    targets = targets.to(device)

    optimizer_t.zero_grad()
    logits = teacher(features)
    loss = loss_t(logits, targets)
    
    # LOGGING
    if batch_idx % 200 == 0:
      print("Batch: %03d/%03d" % (batch_idx, len(train_loader)))

    loss.backward()
    optimizer_t.step()
  teacher.eval()
  with torch.set_grad_enabled(False):
    train_average_loss = compute_loss(teacher, train_loader, device)
    test_average_loss = compute_loss(teacher, test_loader, device)
    print("Epoch: %03d/%03d | Train Loss: %.3f | Test Loss: %.3f" % (epoch+1, EPOCHS, train_average_loss, test_average_loss))
  epoch_elapsed_time = time.time() - start_time
  print("Epoch Elapsed Time: ", epoch_elapsed_time)
total_training_time = time.time() - start_time
print("Total Training Time: ", total_training_time)

Batch: 000/938
Batch: 200/938
Batch: 400/938
Batch: 600/938
Batch: 800/938
Epoch: 001/005 | Train Loss: 0.089 | Test Loss: 0.081
Epoch Elapsed Time:  18.70784282684326
Batch: 000/938
Batch: 200/938
Batch: 400/938
Batch: 600/938
Batch: 800/938
Epoch: 002/005 | Train Loss: 0.061 | Test Loss: 0.057
Epoch Elapsed Time:  37.91933751106262
Batch: 000/938
Batch: 200/938
Batch: 400/938
Batch: 600/938
Batch: 800/938
Epoch: 003/005 | Train Loss: 0.054 | Test Loss: 0.052
Epoch Elapsed Time:  57.7604284286499
Batch: 000/938
Batch: 200/938
Batch: 400/938
Batch: 600/938
Batch: 800/938
Epoch: 004/005 | Train Loss: 0.054 | Test Loss: 0.060
Epoch Elapsed Time:  77.10402655601501
Batch: 000/938
Batch: 200/938
Batch: 400/938
Batch: 600/938
Batch: 800/938
Epoch: 005/005 | Train Loss: 0.042 | Test Loss: 0.051
Epoch Elapsed Time:  95.94168615341187
Total Training Time:  95.94187355041504
