In [3]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from tqdm import tqdm
import time
import copy

# Set device
if(torch.cuda.is_available()):      
    print(torch.cuda.get_device_name(0))
else:
    print("GPU Not Avaliable")

NVIDIA GeForce RTX 3050 Laptop GPU


In [6]:
TEACHER_EPOCHS = 15       # Epochs for teacher model
STUDENT_EPOCHS = 20       # Epochs for student models
BATCH_SIZE = 128
LEARNING_RATE = 0.001
TEMPERATURE = 4.0
ALPHA = 0.7

In [7]:
# Teacher Model (ResNet-18 pre-trained on ImageNet)
class TeacherModel(nn.Module):
    def __init__(self):
        super(TeacherModel, self).__init__()
        self.model = torchvision.models.resnet18(weights='IMAGENET1K_V1')
        num_features = self.model.fc.in_features
        self.model.fc = nn.Linear(num_features, 10)  # For CIFAR-10 classes

    def forward(self, x):
        return self.model(x)


# Student Model (small CNN)
class StudentModel(nn.Module):
    def __init__(self):
        super(StudentModel, self).__init__()
        self.conv1 = nn.Conv2d(3, 16, 3, padding=1)
        self.pool1 = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(16, 32, 3, padding=1)
        self.pool2 = nn.MaxPool2d(2, 2)
        self.fc1 = nn.Linear(32 * 8 * 8, 256)
        self.fc2 = nn.Linear(256, 10)

    def forward(self, x):
        x = self.pool1(F.relu(self.conv1(x)))
        x = self.pool2(F.relu(self.conv2(x)))
        x = torch.flatten(x, 1)
        x = F.relu(self.fc1(x))
        return self.fc2(x)

In [8]:
transform_train = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465),
                         (0.2023, 0.1994, 0.2010)),
])

transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465),
                         (0.2023, 0.1994, 0.2010)),
])

trainset = torchvision.datasets.CIFAR10(
    root='./data', train=True, download=True, transform=transform_train)
trainloader = DataLoader(trainset, batch_size=BATCH_SIZE,
                         shuffle=True, num_workers=2)

testset = torchvision.datasets.CIFAR10(
    root='./data', train=False, download=True, transform=transform_test)
testloader = DataLoader(testset, batch_size=BATCH_SIZE,
                        shuffle=False, num_workers=2)

print("Data loaded successfully.")

100%|██████████| 170M/170M [08:55<00:00, 318kB/s]  


Data loaded successfully.


In [None]:
def train_model(model, trainloader, optimizer, criterion, epoch):
    model.train()
    running_loss = 0.0
    progress_bar = tqdm(trainloader, desc=f"Epoch {epoch+1}")
    for inputs, labels in progress_bar:
        inputs, labels = inputs.to(device), labels.to(device)
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
        progress_bar.set_postfix(loss=running_loss / len(progress_bar))


def evaluate_model(model, testloader):
    model.eval()
    correct, total = 0, 0
    with torch.no_grad():
        for inputs, labels in testloader:
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)
            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    return 100 * correct / total


def train_distilled_student(teacher, student, trainloader, optimizer, epoch, T, alpha):
    teacher.eval()
    student.train()
    running_loss = 0.0
    progress_bar = tqdm(trainloader, desc=f"Epoch {epoch+1} (Distill)")
    for inputs, labels in progress_bar:
        inputs, labels = inputs.to(device), labels.to(device)
        optimizer.zero_grad()

        with torch.no_grad():
            teacher_logits = teacher(inputs)

        student_logits = student(inputs)

        soft_loss = nn.KLDivLoss(reduction='batchmean')(
            F.log_softmax(student_logits / T, dim=1),
            F.softmax(teacher_logits / T, dim=1)
        ) * (T * T)

        hard_loss = F.cross_entropy(student_logits, labels)
        total_loss = (alpha * soft_loss) + ((1 - alpha) * hard_loss)

        total_loss.backward()
        optimizer.step()
        running_loss += total_loss.item()
        progress_bar.set_postfix(loss=running_loss / len(progress_bar))

In [None]:
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

def measure_inference_speed(model, testloader):
    model.eval()
    total_time = 0.0
    num_images = 0
    with torch.no_grad():
        for inputs, labels in testloader:
            inputs = inputs.to(device)
            start = time.time()
            _ = model(inputs)
            if device.type == 'cuda':
                torch.cuda.synchronize()
            total_time += (time.time() - start)
            num_images += labels.size(0)
    return total_time / num_images

In [None]:
teacher_model = TeacherModel().to(device)
criterion = nn.CrossEntropyLoss()
optimizer_teacher = optim.Adam(teacher_model.parameters(), lr=LEARNING_RATE)

best_teacher_acc = 0.0
best_teacher_state = None

for epoch in range(TEACHER_EPOCHS):
    train_model(teacher_model, trainloader, optimizer_teacher, criterion, epoch)
    acc = evaluate_model(teacher_model, testloader)
    print(f"Teacher Accuracy after Epoch {epoch+1}: {acc:.2f}%")
    if acc > best_teacher_acc:
        best_teacher_acc = acc
        best_teacher_state = copy.deepcopy(teacher_model.state_dict())

torch.save(best_teacher_state, "teacher_best.pth")
teacher_model.load_state_dict(best_teacher_state)
print(f"✅ Final Teacher Accuracy: {best_teacher_acc:.2f}%")

In [None]:
student_conventional = StudentModel().to(device)
criterion = nn.CrossEntropyLoss()
optimizer_student_conv = optim.Adam(student_conventional.parameters(), lr=LEARNING_RATE)

best_student_conv_acc = 0.0

for epoch in range(STUDENT_EPOCHS):
    train_model(student_conventional, trainloader, optimizer_student_conv, criterion, epoch)
    acc = evaluate_model(student_conventional, testloader)
    print(f"Conventional Student Accuracy after Epoch {epoch+1}: {acc:.2f}%")
    if acc > best_student_conv_acc:
        best_student_conv_acc = acc

torch.save(student_conventional.state_dict(), "student_conventional_best.pth")
print(f"✅ Final Conventional Student Accuracy: {best_student_conv_acc:.2f}%")

In [None]:
student_distilled = StudentModel().to(device)
optimizer_student_distill = optim.Adam(student_distilled.parameters(), lr=LEARNING_RATE)

best_student_distill_acc = 0.0

for epoch in range(STUDENT_EPOCHS):
    train_distilled_student(
        teacher=teacher_model,
        student=student_distilled,
        trainloader=trainloader,
        optimizer=optimizer_student_distill,
        epoch=epoch,
        T=TEMPERATURE,
        alpha=ALPHA
    )
    acc = evaluate_model(student_distilled, testloader)
    print(f"Distilled Student Accuracy after Epoch {epoch+1}: {acc:.2f}%")
    if acc > best_student_distill_acc:
        best_student_distill_acc = acc

torch.save(student_distilled.state_dict(), "student_distilled_best.pth")
print(f"✅ Final Distilled Student Accuracy: {best_student_distill_acc:.2f}%")

In [None]:
teacher_params = count_parameters(teacher_model)
student_params = count_parameters(student_conventional)
teacher_speed = measure_inference_speed(teacher_model, testloader)
student_conv_speed = measure_inference_speed(student_conventional, testloader)
student_distill_speed = measure_inference_speed(student_distilled, testloader)

print("\n" + "="*80)
print(f"{'Model':<25} | {'Accuracy (%)':<15} | {'Parameters':<15} | {'Avg. Inference (s/img)':<25}")
print("-"*80)
print(f"{'1. Teacher Model':<25} | {best_teacher_acc:<15.2f} | {teacher_params:<15,} | {teacher_speed:<25.8f}")
print(f"{'2. Student (Conventional)':<25} | {best_student_conv_acc:<15.2f} | {student_params:<15,} | {student_conv_speed:<25.8f}")
print(f"{'3. Student (Distilled)':<25} | {best_student_distill_acc:<15.2f} | {student_params:<15,} | {student_distill_speed:<25.8f}")
print("="*80)

# ✅ Success Criterion Check
if best_student_distill_acc > best_student_conv_acc:
    print(f"\n✅ Success: Distilled Student ({best_student_distill_acc:.2f}%) > Conventional ({best_student_conv_acc:.2f}%)")
else:
    print(f"\n❌ Distilled Student ({best_student_distill_acc:.2f}%) did not outperform Conventional ({best_student_conv_acc:.2f}%)")
    print("Try tuning Alpha, Temperature, or learning rate.")