In [1]:
## Imports
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torchvision.models import mobilenet_v2, squeezenet1_0
from torch.utils.data import DataLoader
from torchvision.datasets import Flowers102
from tqdm import tqdm
import warnings

warnings.simplefilter("ignore", UserWarning)

In [2]:

## Define another Teacher Model (SqueezeNet)
class TeacherModel(nn.Module):
    def __init__(self, num_classes=102):
        super(TeacherModel, self).__init__()
        self.model = squeezenet1_0(pretrained=True)
        self.model.classifier[1] = nn.Conv2d(512, num_classes, kernel_size=(1,1))
        self.model.num_classes = num_classes

    def forward(self, x):
        return self.model(x)

## Define Student Model (Lightweight CNN)
class StudentModel(nn.Module):
    def __init__(self, num_classes=102):
        super(StudentModel, self).__init__()
        self.conv1 = nn.Conv2d(3, 32, kernel_size=3, stride=1, padding=1)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1)
        self.fc1 = nn.Linear(64 * 56 * 56, 256)
        self.fc2 = nn.Linear(256, num_classes)
        self.relu = nn.ReLU()
        self.pool = nn.MaxPool2d(2, 2)

    def forward(self, x):
        x = self.pool(self.relu(self.conv1(x)))
        x = self.pool(self.relu(self.conv2(x)))
        x = x.view(x.size(0), -1)
        x = self.relu(self.fc1(x))
        x = self.fc2(x)
        return x

In [3]:
## Device Configuration
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

## Hyperparameters
batch_size = 32
learning_rate = 0.001
num_epochs = 10
temperature = 5.0  # Temperature for soft targets
alpha = 0.5  # Balance between cross-entropy and KD loss

## Data Transforms
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])

## Load Flowers-102 Dataset
train_dataset = Flowers102(root="./data", split="train", transform=transform, download=True)
val_dataset = Flowers102(root="./data", split="val", transform=transform, download=True)
test_dataset = Flowers102(root="./data", split="test", transform=transform, download=True)

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

In [4]:
## Instantiate Models
teacher = TeacherModel().to(device)
student = StudentModel().to(device)

## Define Loss Functions
criterion_ce = nn.CrossEntropyLoss()
criterion_kd = nn.KLDivLoss(reduction="batchmean")

## Optimizers
optimizer_teacher = optim.Adam(teacher.parameters(), lr=learning_rate)
optimizer_student = optim.Adam(student.parameters(), lr=learning_rate)

## Train Teacher Model
def train_teacher(model, train_loader, criterion, optimizer, epochs):
    model.train()
    for epoch in tqdm(range(epochs)):
        running_loss = 0.0
        for images, labels in train_loader:
            images, labels = images.to(device), labels.to(device)

            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            running_loss += loss.item()

        print(f"Epoch [{epoch+1}/{epochs}], Loss: {running_loss / len(train_loader):.4f}")

    print("Teacher training complete.")

## Train Student Model with Knowledge Distillation
def train_student(teacher, student, train_loader, criterion_ce, criterion_kd, optimizer, epochs, alpha, temperature):
    teacher.eval()  # Freeze teacher model
    student.train()

    for epoch in tqdm(range(epochs)):
        running_loss = 0.0
        for images, labels in train_loader:
            images, labels = images.to(device), labels.to(device)

            optimizer.zero_grad()

            with torch.no_grad():
                teacher_outputs = teacher(images)

            student_outputs = student(images)

            ## Compute Hard Target Loss (Cross-Entropy)
            loss_ce = criterion_ce(student_outputs, labels)

            ## Compute Soft Target Loss (KL Divergence with Temperature Scaling)
            soft_teacher_outputs = torch.nn.functional.log_softmax(teacher_outputs / temperature, dim=1)
            soft_student_outputs = torch.nn.functional.softmax(student_outputs / temperature, dim=1)
            loss_kd = criterion_kd(soft_teacher_outputs, soft_student_outputs) * (temperature ** 2)

            ## Total Loss (Weighted Sum)
            loss = alpha * loss_ce + (1 - alpha) * loss_kd
            loss.backward()
            optimizer.step()

            running_loss += loss.item()

        print(f"Epoch [{epoch+1}/{epochs}], Loss: {running_loss / len(train_loader):.4f}")

    print("Student training complete.")

## Evaluation Function
def evaluate_model(model, test_loader):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for images, labels in test_loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            _, predicted = torch.max(outputs, 1)
            correct += (predicted == labels).sum().item()
            total += labels.size(0)

    accuracy = 100 * correct / total
    print(f"Accuracy: {accuracy:.2f}%")
    return accuracy

Downloading: "https://download.pytorch.org/models/squeezenet1_0-b66bff10.pth" to /root/.cache/torch/hub/checkpoints/squeezenet1_0-b66bff10.pth
100%|██████████| 4.78M/4.78M [00:00<00:00, 27.0MB/s]


In [5]:
## Training Process
print("Training Teacher Model...")
train_teacher(teacher, train_loader, criterion_ce, optimizer_teacher, num_epochs)

Training Teacher Model...


 10%|█         | 1/10 [03:11<28:46, 191.80s/it]

Epoch [1/10], Loss: 4.6603


 20%|██        | 2/10 [06:15<24:56, 187.08s/it]

Epoch [2/10], Loss: 4.6279


 30%|███       | 3/10 [09:16<21:30, 184.32s/it]

Epoch [3/10], Loss: 4.6253


 40%|████      | 4/10 [12:24<18:35, 185.85s/it]

Epoch [4/10], Loss: 4.6253


 50%|█████     | 5/10 [15:27<15:23, 184.72s/it]

Epoch [5/10], Loss: 4.6252


 60%|██████    | 6/10 [18:31<12:17, 184.43s/it]

Epoch [6/10], Loss: 4.6253


 70%|███████   | 7/10 [21:43<09:20, 186.95s/it]

Epoch [7/10], Loss: 4.6252


 80%|████████  | 8/10 [24:47<06:12, 186.04s/it]

Epoch [8/10], Loss: 4.6253


 90%|█████████ | 9/10 [27:50<03:05, 185.01s/it]

Epoch [9/10], Loss: 4.6253


100%|██████████| 10/10 [30:53<00:00, 185.38s/it]

Epoch [10/10], Loss: 4.6252
Teacher training complete.





In [6]:
# Evaluate Models
print("Evaluating Teacher Model...")
evaluate_model(teacher, test_loader)

Evaluating Teacher Model...
Accuracy: 0.37%


0.37404456009107173

In [7]:
print("Training Student Model with Knowledge Distillation...")
train_student(teacher, student, train_loader, criterion_ce, criterion_kd, optimizer_student, num_epochs, alpha, temperature)

Training Student Model with Knowledge Distillation...


 10%|█         | 1/10 [03:26<31:01, 206.85s/it]

Epoch [1/10], Loss: 3.1119


 20%|██        | 2/10 [06:52<27:27, 205.98s/it]

Epoch [2/10], Loss: 2.2959


 30%|███       | 3/10 [10:17<24:00, 205.76s/it]

Epoch [3/10], Loss: 2.2130


 40%|████      | 4/10 [13:41<20:28, 204.80s/it]

Epoch [4/10], Loss: 2.1133


 50%|█████     | 5/10 [17:07<17:06, 205.26s/it]

Epoch [5/10], Loss: 1.9612


 60%|██████    | 6/10 [20:33<13:41, 205.49s/it]

Epoch [6/10], Loss: 1.6952


 70%|███████   | 7/10 [23:58<10:16, 205.48s/it]

Epoch [7/10], Loss: 1.3524


 80%|████████  | 8/10 [27:23<06:50, 205.30s/it]

Epoch [8/10], Loss: 1.0044


 90%|█████████ | 9/10 [30:48<03:25, 205.18s/it]

Epoch [9/10], Loss: 0.7707


100%|██████████| 10/10 [34:12<00:00, 205.28s/it]

Epoch [10/10], Loss: 0.6359
Student training complete.





In [8]:
print("Evaluating Student Model...")
evaluate_model(student, test_loader)

Evaluating Student Model...
Accuracy: 11.27%


11.270125223613595