<a href="https://colab.research.google.com/github/raki-rankawat/stm32/blob/main/CIFAR10_KD_Model.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

### Load Trained Teacher Model

In [40]:
import time
import torch
import torch.nn as nn
import torch.nn.functional as F

from torch.utils.data import DataLoader
from torchvision import datasets, transforms

In [41]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [42]:
# Data Loaders
batch_size = 64

train_transform = transforms.Compose([
    transforms.RandomHorizontalFlip(),
    transforms.RandomCrop(32, padding=4),
    transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

test_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

train_data = datasets.CIFAR10(root='./data', train=True, download=True, transform=train_transform)
test_data = datasets.CIFAR10(root='./data', train=False, download=True, transform=test_transform)

train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_data, batch_size=batch_size, shuffle=False)

In [43]:
# Teacher Model
class CIFARConvNet(nn.Module):
    def __init__(self):
        super().__init__()

        self.conv1 = nn.Conv2d(3, 16, 3, padding=1)
        self.bn1 = nn.BatchNorm2d(16)
        self.conv2 = nn.Conv2d(16, 32, 3, padding=1)
        self.bn2 = nn.BatchNorm2d(32)
        self.conv3 = nn.Conv2d(32, 64, 3, padding=1)
        self.bn3 = nn.BatchNorm2d(64)
        self.conv4 = nn.Conv2d(64, 128, 3, padding=1)
        self.bn4 = nn.BatchNorm2d(128)

        self.fc1 = nn.Linear(128 * 2 * 2, 256)
        self.fc2 = nn.Linear(256, 10)

        self.dropout = nn.Dropout(0.25)

    def forward(self, x):
        x = F.relu(self.bn1(self.conv1(x)))
        x = F.max_pool2d(x, 2) # 32 -> 16

        x = F.relu(self.bn2(self.conv2(x)))
        x = F.max_pool2d(x, 2) # 16 -> 8

        x = F.relu(self.bn3(self.conv3(x)))
        x = F.max_pool2d(x, 2) # 8 -> 4

        x = F.relu(self.bn4(self.conv4(x)))
        x = F.max_pool2d(x, 2) # 4 -> 2

        x = x.view(x.size(0), -1) # Flatten

        # x = self.dropout(x)
        x = F.relu(self.fc1(x))
        if self.training:
          x = self.dropout(x)
        x = self.fc2(x)

        return x

In [44]:
# Load weights
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
teacher = CIFARConvNet().to(device)
teacher.load_state_dict(torch.load("/content/drive/My Drive/Colab Notebooks/stm_cifar10_model.pth", map_location=torch.device('cpu')))

<All keys matched successfully>

### Student Model (Smaller)

In [45]:
class StudentNet(nn.Module):
  def __init__(self):
    super().__init__()

    self.conv1 = nn.Conv2d(3, 16, 3, padding=1)
    self.bn1 = nn.BatchNorm2d(16)
    self.conv2 = nn.Conv2d(16, 32, 3, padding=1)
    self.bn2 = nn.BatchNorm2d(32)

    self.fc1 = nn.Linear(32 * 8 * 8, 256)
    self.fc2 = nn.Linear(256, 10)

  def forward(self, x):
    x = F.relu(self.bn1(self.conv1(x)))
    x = F.max_pool2d(x, 2) # 32 -> 16

    x = F.relu(self.bn2(self.conv2(x)))
    x = F.max_pool2d(x, 2) # 16 -> 8

    x = x.view(x.size(0), -1) # Flatten

    x = F.relu(self.fc1(x))
    x = self.fc2(x)

    return x

In [46]:
student = StudentNet().to(device)

In [47]:
# Optimizer
optimizer = torch.optim.Adam(student.parameters(), lr=0.001)

In [48]:
# KD Loss
def kd_loss(student_logits, teacher_logits, labels, T=4, alpha=0.7):
  # Soft target (KD Divergence)
  p_teacher = F.log_softmax(student_logits / T, dim=1)
  q_teacher = F.softmax(teacher_logits / T, dim=1)
  loss_kd = F.kl_div(p_teacher, q_teacher, reduction='batchmean') * (T * T)

  # Hard Label Loss
  loss_ce = F.cross_entropy(student_logits, labels)

  # Combined Loss
  return alpha * loss_ce + (1 - alpha) * loss_kd

In [54]:
def evaluate(model, loader):
    model.eval()
    correct, total = 0, 0
    with torch.no_grad():
        for X, y in loader:
            X, y = X.to(device), y.to(device)

            out = model(X)
            pred = out.argmax(dim=1)
            correct += (pred == y).sum().item()
            total += y.size(0)
            acc = 100 * correct / total

    return acc

In [55]:
acc_teacher = evaluate(teacher, test_loader)
print(f"ðŸŽ“ Teacher Model Accuracy: {acc_teacher:.2f}%")

ðŸŽ“ Teacher Model Accuracy: 78.25%


In [51]:
# Student Model Training
epochs = 10
start_time = time.time()

for epoch in range(epochs):
  student.train()
  train_losses = 0.0

  for X, y in train_loader:
    X, y = X.to(device), y.to(device)

    student_logits = student(X)
    teacher_logits = teacher(X)

    loss = kd_loss(student_logits, teacher_logits, y)

    # Back-propogation
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    train_losses += loss.item()

  acc = evaluate(student, test_loader)
  print(f"Epoch {epoch + 1} / {epochs} | Loss: {train_losses / len(train_loader):.4f} | Student Train Acc: {acc:.2f}%")

print(f"Training time: {(time.time() - start_time)/60:.2f} minutes")

Epoch 1 / 10 | Loss: 2.0224 | Student Train Acc: 59.66%
Epoch 2 / 10 | Loss: 1.4310 | Student Train Acc: 63.87%
Epoch 3 / 10 | Loss: 1.2140 | Student Train Acc: 67.26%
Epoch 4 / 10 | Loss: 1.0991 | Student Train Acc: 68.09%
Epoch 5 / 10 | Loss: 1.0199 | Student Train Acc: 71.36%
Epoch 6 / 10 | Loss: 0.9709 | Student Train Acc: 70.62%
Epoch 7 / 10 | Loss: 0.9149 | Student Train Acc: 72.49%
Epoch 8 / 10 | Loss: 0.8853 | Student Train Acc: 72.02%
Epoch 9 / 10 | Loss: 0.8630 | Student Train Acc: 73.44%
Epoch 10 / 10 | Loss: 0.8358 | Student Train Acc: 74.33%
Training time: 8.15 minutes


In [52]:
acc_student = evaluate(student, test_loader)
print(f"ðŸŽ“ Student Accuracy: {acc_student:.2f}%")

ðŸŽ“ Student Accuracy: 74.33%


In [53]:
torch.save(student.state_dict(), "/content/drive/My Drive/Colab Notebooks/stm_cifar10_kd_model.pth")
print("âœ… Model saved as stm_cifar10_kd_model.pth")

âœ… Model saved as stm_cifar10_kd_model.pth
