<a href="https://colab.research.google.com/github/vishal-burman/PyTorch-Architectures/blob/master/Knowledge_Distillation.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torchvision import datasets, transforms
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [56]:
batch_size = 64
train_dataset = datasets.MNIST(root='data', 
                               train=True, 
                               transform=transforms.ToTensor(),
                               download=True)

test_dataset = datasets.MNIST(root='data', 
                              train=False, 
                              transform=transforms.ToTensor())


train_loader = DataLoader(dataset=train_dataset, 
                          batch_size=batch_size, 
                          shuffle=True)

test_loader = DataLoader(dataset=test_dataset, 
                         batch_size=batch_size, 
                         shuffle=False)

In [55]:
# Create the teacher
class Teacher(nn.Module):
  def __init__(self, num_classes=10):
    super(Teacher, self).__init__()
    # input_shape ~ [batch_size, 1, 28, 28]
    # shape ~ [batch_size, 256, 14, 14]
    self.conv_1 = nn.Conv2d(in_channels=1, out_channels=256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
    # shape ~ [batch_size, 256, 14, 14]
    self.lr_1 = nn.LeakyReLU(inplace=True)
    # shape ~ [batch_size, 256, 15, 15]
    self.pool_1 = nn.MaxPool2d(kernel_size=(2, 2), stride=(1, 1), padding=(1, 1))
    # shape ~ [batch_size, 512, 8, 8]
    self.conv_2 = nn.Conv2d(in_channels=256, out_channels=512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
    # shape ~ [batch_size, 10]
    self.lin_1 = nn.Linear(in_features=512 * 8 * 8, out_features=10)

  def forward(self, x):
    x = self.conv_1(x)
    x = F.leaky_relu(x)
    x = self.pool_1(x)
    x = self.conv_2(x)
    x = self.x.view(x.size(0), -1)
    x = self.lin_1(x)
    return x

In [None]:
# Create the student
student = nn.Sequential(
    # input_shape ~ [batch_size, 1, 28, 28]
    nn.Conv2d(in_channels=1, out_channels=16, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1)),
    nn.LeakyReLU(inplace=True),
    nn.MaxPool2d(kernel_size=(2, 2), stride=(1, 1), padding=(1, 1)),
    nn.Conv2d(in_channels=16, out_channels=32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1)),
    nn.Flatten(),
    nn.Linear(in_features=32 * 8 * 8, out_features=10)
)

In [57]:
teacher = Teacher(num_classes=10)
student = student.to(device)
teacher = teacher.to(device)
total_params_t = sum(p.numel() for p in teacher.parameters())
total_params_s = sum(p.numel() for p in student.parameters())
print("Total Parameters in Teacher: ", total_params_t)
print("Total Parameters in Student: ", total_params_s)

Total Parameters in Teacher:  1510410
Total Parameters in Student:  25290
