## Training a ResNet18 Student using Knowledge Distillation from Teacher ResNet50 on CIFAR-10 dataset

Knowledge Distillation is a technique where a smaller "Student" model (ResNet18) learns to mimic the behavior of a pre-trained, larger "Teacher" model (ResNet50). Instead of the Student learning only from hard labels e.g. "this is a cat", it learns from the Teacher’s full probability distribution e.g. "this is 90% a cat, but it has 8% features of a dog and 2% features of a truck".

Model was trained for **2 hours, 13 minutes, and 6 seconds** using **L4** graphics card on Google Colaboratory and achieved accuracy is **94.2%**

#### Core Components of the Process
- **Soft Targets & Temperature ($T$)**: In standard training, the final layer uses a Softmax function to produce probabilities. By introducing Temperature ($T > 1$), we "soften" the output. This reveals the relative similarities between classes that are otherwise hidden in a standard 0 or 1 classification.
- **The Teacher Model (ResNet50)**: This model is pre-trained and its weights are frozen `(teacher.eval())`. It acts as a guide, providing high-quality "soft targets" for the Student to mimic.
- **The Student Model (ResNet18)**: This is the model being trained. It is significantly smaller and faster, making it ideal for deployment on mobile devices or edge hardware.

#### The Two-Part Loss Function
The Student model doesn't just listen to the Teacher; it also looks at the actual labels. The total loss is a weighted sum:
- **Distillation Loss**: KL Divergence between the Student's and Teacher's soft outputs. Forces the Student to learn the internal logic and "mistakes" of the Teacher.
- **Student Loss**: Standard Cross-Entropy between Student predictions and ground-truth labels. Ensures the Student still learns the correct final answers from the dataset.

**Alpha** ($\alpha = 0.5$): This balances the two losses. A value of 0.5 means the Student cares equally about the Teacher's advice and the actual ground truth.

**Temperature** ($T = 3.0$): This "stretches" the probability distribution. A higher $T$ makes the distribution flatter, helping the Student see which "wrong" classes the Teacher thinks are somewhat similar to the correct one.

In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
import torch.nn.functional as F
import time

In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
batch_size = 128
epochs = 200
lr = 0.1
temperature = 3.0 # Softens the teacher's probability distribution
alpha = 0.5 # Balances the loss between teacher guidance and true labels

transform_train = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465),
                         (0.2023, 0.1994, 0.2010)),
])

transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465),
                         (0.2023, 0.1994, 0.2010)),
])

trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=2, pin_memory=True)

testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test)
testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size, shuffle=False, num_workers=2, pin_memory=True)

100%|██████████| 170M/170M [00:03<00:00, 44.5MB/s]


In [7]:
teacher = torchvision.models.resnet50(num_classes=10)

# Load pre-trained weights
teacher.load_state_dict(torch.load('saved_models/teacher_resnet50.pt', map_location=device))

# Modify architecture for CIFAR-10: 3x3 kernel and no maxpool to preserve spatial info
teacher.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
teacher.maxpool = nn.Identity()

teacher.to(device)
teacher.eval() # Teacher is in evaluation mode all the time

student = torchvision.models.resnet18(num_classes=10)
student.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
student.maxpool = nn.Identity()
student.to(device)

optimizer = optim.SGD(student.parameters(), lr=lr, momentum=0.9, weight_decay=5e-4)
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs)

def distillation_loss(student_logits, teacher_logits, labels, T, alpha):
    student_ce_loss = F.cross_entropy(student_logits, labels)

    soft_teacher = F.softmax(teacher_logits / T, dim=1)
    soft_student = F.log_softmax(student_logits / T, dim=1)
    distill_loss = F.kl_div(soft_student, soft_teacher, reduction='batchmean') * (T * T)

    return alpha * distill_loss + (1 - alpha) * student_ce_loss

In [8]:
start_time = time.time()

for epoch in range(epochs):
    student.train()
    train_loss = 0
    correct = 0
    total = 0

    for batch_idx, (inputs, targets) in enumerate(trainloader):
        inputs, targets = inputs.to(device), targets.to(device)

        optimizer.zero_grad()

        student_outputs = student(inputs) # Forward pass through the student

        with torch.no_grad():
            teacher_outputs = teacher(inputs) # Forward pass through the teacher

        # Calculate the distillation loss
        loss = distillation_loss(student_outputs, teacher_outputs, targets, temperature, alpha)
        loss.backward()
        optimizer.step()

        train_loss += loss.item() # training statistics
        _, predicted = student_outputs.max(1)
        total += targets.size(0)
        correct += predicted.eq(targets).sum().item()

    scheduler.step() # Update the learning rate

    student.eval()
    test_correct = 0
    test_total = 0
    with torch.no_grad():
        for inputs, targets in testloader:
            inputs, targets = inputs.to(device), targets.to(device)
            outputs = student(inputs)
            _, predicted = outputs.max(1)
            test_total += targets.size(0)
            test_correct += predicted.eq(targets).sum().item()

    acc = 100. * test_correct / test_total
    print(f'Epoch {epoch+1} | Current lr: {scheduler.get_last_lr()[0]:.4f} | Loss: {train_loss/(batch_idx+1):.4f} | Test Acc: {acc:.2f}%')

torch.save(student.state_dict(), "saved_models/student_resnet18.pt")

total_time = time.time() - start_time
print(f"\nTotal training time: {total_time:.2f} seconds")

Epoch 1 | Current lr: 0.1000 | Loss: 1.4895 | Test Acc: 23.32%
Epoch 2 | Current lr: 0.1000 | Loss: 1.2001 | Test Acc: 39.96%
Epoch 3 | Current lr: 0.0999 | Loss: 1.0930 | Test Acc: 50.43%
Epoch 4 | Current lr: 0.0999 | Loss: 1.0191 | Test Acc: 58.48%
Epoch 5 | Current lr: 0.0998 | Loss: 0.9681 | Test Acc: 63.07%
Epoch 6 | Current lr: 0.0998 | Loss: 0.9293 | Test Acc: 65.32%
Epoch 7 | Current lr: 0.0997 | Loss: 0.8929 | Test Acc: 67.18%
Epoch 8 | Current lr: 0.0996 | Loss: 0.8620 | Test Acc: 71.76%
Epoch 9 | Current lr: 0.0995 | Loss: 0.8441 | Test Acc: 71.43%
Epoch 10 | Current lr: 0.0994 | Loss: 0.8330 | Test Acc: 71.55%
Epoch 11 | Current lr: 0.0993 | Loss: 0.8227 | Test Acc: 69.86%
Epoch 12 | Current lr: 0.0991 | Loss: 0.8162 | Test Acc: 76.87%
Epoch 13 | Current lr: 0.0990 | Loss: 0.8082 | Test Acc: 66.65%
Epoch 14 | Current lr: 0.0988 | Loss: 0.8029 | Test Acc: 72.38%
Epoch 15 | Current lr: 0.0986 | Loss: 0.7950 | Test Acc: 73.06%
Epoch 16 | Current lr: 0.0984 | Loss: 0.7904 | Te