<a href="https://colab.research.google.com/github/rrankawat/pytorch-cnn/blob/main/CIFAR_10_Knowledge_Distillation.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## Knowledge distillation

Knowledge distillation is a machine learning technique where a large, complex "teacher" model transfers its knowledge to a smaller, simpler "student" model. This process compresses the model, allowing a more efficient student model to perform nearly as well as the large teacher model, making it suitable for deployment on devices with limited resources, such as mobile phones.

* KD is a model compression / performance enhancement technique.

* Idea: Train a smaller or same-size “student” model to mimic a larger or pretrained “teacher” model.

In [20]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import time

In [2]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [3]:
def totalTime(start_time):
  current_time = time.time()
  total_time = (current_time - start_time) / 60
  return round(total_time, 2)

###### Teacher Model (Pretrained)

In [4]:
class CIFARConvNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(3, 16, 3, padding=1)   # -> 16x32x32
        self.bn1 = nn.BatchNorm2d(16)
        self.conv2 = nn.Conv2d(16, 32, 3, padding=1)  # -> 64x32x32
        self.bn2 = nn.BatchNorm2d(32)
        self.conv3 = nn.Conv2d(32, 64, 3, padding=1) # -> 64x32x32
        self.bn3 = nn.BatchNorm2d(64)
        self.conv4 = nn.Conv2d(64, 128, 3, padding=1) # -> 128x32x32
        self.bn4 = nn.BatchNorm2d(128)

        self.fc1 = nn.Linear(128*2*2, 256)
        self.fc2 = nn.Linear(256, 10)

        self.dropout = nn.Dropout(0.25)

    def forward(self, x):
        # Block 1
        x = F.relu(self.bn1(self.conv1(x)))
        x = F.max_pool2d(x, 2, 2)  # 32 -> 16

        # Block 2
        x = F.relu(self.bn2(self.conv2(x)))
        x = F.max_pool2d(x, 2, 2)  # 16 -> 8

        # Block 3
        x = F.relu(self.bn3(self.conv3(x)))
        x = F.max_pool2d(x, 2, 2)  # 8 -> 4

        # Block 4
        x = F.relu(self.bn4(self.conv4(x)))
        x = F.max_pool2d(x, 2, 2)  # 4 -> 2

        # Flatten
        x = x.reshape(-1, 128*2*2)

        # Fully connected
        x = F.relu(self.fc1(x))
        x = self.dropout(x)
        x = self.fc2(x)
        return x

In [5]:
teacher = CIFARConvNet()
teacher.load_state_dict(torch.load("/content/drive/My Drive/Colab Notebooks/model_cifar10.pth", map_location="cpu"))
teacher.eval()  # Freeze teacher

for param in teacher.parameters():
    param.requires_grad = False
print("✅ Teacher model loaded.")

✅ Teacher model loaded.


###### Student Model (Smaller)

In [6]:
class StudentNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(3, 16, 3, padding=1)
        self.bn1 = nn.BatchNorm2d(16)
        self.conv2 = nn.Conv2d(16, 32, 3, padding=1)
        self.bn2 = nn.BatchNorm2d(32)
        self.fc1 = nn.Linear(32*8*8, 128)  # smaller FC
        self.fc2 = nn.Linear(128, 10)

    def forward(self, x):
        x = F.relu(self.bn1(self.conv1(x)))
        x = F.max_pool2d(x, 2, 2)  # 32 -> 16

        x = F.relu(self.bn2(self.conv2(x)))
        x = F.max_pool2d(x, 2, 2)  # 16 -> 8

        x = x.reshape(-1, 32*8*8)

        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x

In [7]:
student = StudentNet()

###### CIFAR10 Data Loading

In [8]:
train_transform = transforms.Compose([
    transforms.RandomHorizontalFlip(),
    transforms.RandomCrop(32, padding=4),
    transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

test_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

In [9]:
train_data = datasets.CIFAR10(root='./data', train=True, download=True, transform=train_transform)
test_data = datasets.CIFAR10(root='./data', train=False, download=True, transform=test_transform)

train_loader = DataLoader(train_data, batch_size=128, shuffle=True)
test_loader = DataLoader(test_data, batch_size=128, shuffle=False)

100%|██████████| 170M/170M [00:03<00:00, 44.5MB/s]


###### Distillation Loss Function

In [14]:
def distillation_loss(student_logits, teacher_logits, labels, T=4, alpha=0.7):
    # Soft targets (KL Divergence)
    p_teacher = F.log_softmax(student_logits / T, dim=1)
    q_teacher = F.softmax(teacher_logits / T, dim=1)
    loss_kd = F.kl_div(p_teacher, q_teacher, reduction='batchmean') * (T*T)

    # Hard label loss
    loss_ce = F.cross_entropy(student_logits, labels)

    # Combined loss
    return alpha*loss_ce + (1-alpha)*loss_kd

###### Optimizer

In [11]:
optimizer = torch.optim.Adam(student.parameters(), lr=0.001)

###### Training Loop with KD

In [15]:
epochs = 20
start_time = time.time()

for epoch in range(epochs):
    student.train()
    train_losses = 0.0
    for images, labels in train_loader:
        student_logits = student(images)
        teacher_logits = teacher(images)
        loss = distillation_loss(student_logits, teacher_logits, labels)

        # Back-propagation & Update parameters
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        train_losses += loss.item()

    print(f"Epoch {epoch+1}/{epochs} | Loss: {train_losses/len(train_loader):.4f}")

print(f'Time taken: {totalTime(start_time)} minutes!')

Epoch 1/20 | Loss: 1.8922
Epoch 2/20 | Loss: 1.6399
Epoch 3/20 | Loss: 1.5200
Epoch 4/20 | Loss: 1.4339
Epoch 5/20 | Loss: 1.3555
Epoch 6/20 | Loss: 1.2890
Epoch 7/20 | Loss: 1.2432
Epoch 8/20 | Loss: 1.1998
Epoch 9/20 | Loss: 1.1732
Epoch 10/20 | Loss: 1.1391
Epoch 11/20 | Loss: 1.1180
Epoch 12/20 | Loss: 1.0833
Epoch 13/20 | Loss: 1.0685
Epoch 14/20 | Loss: 1.0534
Epoch 15/20 | Loss: 1.0370
Epoch 16/20 | Loss: 1.0218
Epoch 17/20 | Loss: 1.0070
Epoch 18/20 | Loss: 0.9924
Epoch 19/20 | Loss: 0.9795
Epoch 20/20 | Loss: 0.9675
Time taken: 26.29 minutes!


###### Evaluate Student

In [17]:
def evaluate(model, loader):
    model.eval()
    correct, total = 0, 0
    with torch.no_grad():
        for images, labels in loader:
            outputs = model(images)
            _, preds = torch.max(outputs, 1)
            correct += (preds == labels).sum().item()
            total += labels.size(0)
    return 100 * correct / total

In [18]:
acc_student = evaluate(student, test_loader)
print(f"🎓 Student Accuracy: {acc_student:.2f}%")

🎓 Student Accuracy: 75.05%


In [19]:
torch.save(student.state_dict(), "/content/drive/My Drive/Colab Notebooks/model_cifar10_kd.pth")
print("✅ Model saved as model_cifar10_kd.pth")

✅ Model saved as model_cifar10_kd.pth
