In [None]:
!pip install torchinfo

In [None]:
import torch
from torch import nn
import torch.nn.functional as F
import torchvision
from torchvision import transforms
from torch.utils.data import DataLoader
from torchinfo import summary
from tqdm import tqdm

In [None]:
train_dataset = torchvision.datasets.MNIST(
    root='dataset/',
    train=True, 
    transform=transforms.ToTensor(),
    download=True
)

In [None]:
test_dataset = torchvision.datasets.MNIST(
    root='dataset/',
    train=False,
    transform=transforms.ToTensor(),
    download=True
)

In [None]:
train_dataloader = DataLoader(train_dataset, batch_size=32, shuffle=True)

In [None]:
test_dataloader = DataLoader(test_dataset, batch_size=32, shuffle=False)

In [None]:
class TeacherModel(nn.Module):
  def __init__(self, in_channels=1, num_classes=10):
    super(TeacherModel, self).__init__()
    self.conv1 = nn.Conv2d(
        in_channels=in_channels,
        out_channels=64,
        kernel_size=(3, 3),
        stride=(1, 1),
        padding=(1, 1)
    )
    self.pool = nn.MaxPool2d(kernel_size=(2, 2), stride=(2, 2))
    self.conv2 = nn.Conv2d(
        in_channels=64, 
        out_channels=256,
        kernel_size=(3, 3),
        stride=(1, 1),
        padding=(1, 1)
    )
    self.fc1 = nn.Linear(in_features=256 * 7 * 7, out_features=num_classes)

  def forward(self, x):
    x = F.relu(self.conv1(x))
    x = self.pool(x)
    x = F.relu(self.conv2(x))
    x = self.pool(x)
    x = x.reshape(x.shape[0], -1)
    x = self.fc1(x)
    return x

In [None]:
teacher_model = TeacherModel()

In [None]:
summary(teacher_model, (32, 1, 28, 28))

In [None]:
class StudentModel(nn.Module):
  def __init__(self, in_channels=1, num_classes=10):
    super(StudentModel, self).__init__()
    self.conv1 = nn.Conv2d(
        in_channels=in_channels,
        out_channels=8,
        kernel_size=(3, 3),
        stride=(1, 1),
        padding=(1, 1)
    )
    self.pool = nn.MaxPool2d(kernel_size=(2, 2), stride=(2, 2))
    self.conv2 = nn.Conv2d(
        in_channels=8, 
        out_channels=16,
        kernel_size=(3, 3),
        stride=(1, 1),
        padding=(1, 1)
    )
    self.fc1 = nn.Linear(in_features=16 * 7 * 7, out_features=num_classes)

  def forward(self, x):
    x = F.relu(self.conv1(x))
    x = self.pool(x)
    x = F.relu(self.conv2(x))
    x = self.pool(x)
    x = x.reshape(x.shape[0], -1)
    x = self.fc1(x)
    return x

In [None]:
student_model = StudentModel()

In [None]:
summary(student_model, (32, 1, 28, 28))

In [None]:
def check_accuracy(loader, model, device):
  num_correct = 0
  num_samples = 0
  model.eval()

  with torch.no_grad():
    for x, y in loader:
      x = x.to(device)
      y = y.to(device)

      scores = model(x)
      _, predictions = scores.max(1)
      num_correct += (predictions == y).sum()
      num_samples += predictions.size(0)

  model.train()
  return (num_correct / num_samples).item()


In [None]:
def train_teacher(epochs):
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
  teacher_model = TeacherModel().to(device)
  criterion = nn.CrossEntropyLoss()
  optimizer = torch.optim.Adam(teacher_model.parameters(), lr=1e-4)

  for epoch in range(epochs):
    teacher_model.train()
    losses = []
    pbar = tqdm(train_dataloader, total=len(train_dataloader), position=0, leave=True, desc='Epoch {}'.format(epoch))
    for data, targets in pbar:
      data = data.to(device)
      targets = targets.to(device)
      # forward
      scores = teacher_model(data)
      loss = criterion(scores, targets)
      losses.append(loss.item())
      # backward
      optimizer.zero_grad()
      loss.backward()
      optimizer.step()
    avg_loss = sum(losses) / len(losses)
    acc = check_accuracy(test_dataloader, teacher_model, device)
    print('Loss {:.2f}\tAccuracy {:.2f}'.format(avg_loss, acc))
  return teacher_model

In [None]:
teacher_model = train_teacher(epochs=3)

In [None]:
def train_step(
    teacher,
    student,
    optimizer,
    student_loss_fn,
    divergence_loss_fn,
    temp,
    alpha,
    epoch,
    device
):
  losses = []
  pbar = tqdm(train_dataloader, total=len(train_dataloader), position=0, leave=True, desc='Epoch {}'.format(epoch))
  for data, targets, in pbar:
    data = data.to(device)
    targets = targets.to(device)

    # forward
    with torch.no_grad():
      teacher_preds = teacher(data)
    
    student_preds = student(data)
    student_loss = student_loss_fn(student_preds, targets)

    distillation_loss = divergence_loss_fn(
        F.softmax(student_preds / temp, dim=1),
        F.softmax(teacher_preds / temp, dim=1)
    )
    loss = alpha * student_loss + (1 - alpha) * distillation_loss
    losses.append(loss.item())

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
  
  avg_loss = sum(losses) / len(losses)
  return avg_loss

In [None]:
def main(epochs, teacher, student, temp=7, alpha=0.3):
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
  teacher = teacher.to(device)
  student = student.to(device)
  student_loss_fn = nn.CrossEntropyLoss()
  divergence_loss_fn = nn.KLDivLoss(reduction="batchmean")
  optimizer = torch.optim.Adam(student.parameters(), lr=1e-3)

  teacher.eval()
  student.train()
  for epoch in range(epochs):
      loss = train_step(
          teacher,
          student,
          optimizer,
          student_loss_fn,
          divergence_loss_fn,
          temp,
          alpha,
          epoch,
          device
      )
      acc = check_accuracy(test_dataloader, student, device)
      print(f"Loss:{loss:.2f}\tAccuracy:{acc:.2f}")

In [None]:
student_model = StudentModel()

In [None]:
main(epochs=3, teacher=teacher_model, student=student_model, temp=0.3, alpha=0.1)