In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
import torchvision.datasets as datasets

In [None]:
class ResidualBlock(nn.Module):
    def __init__(self, in_channels, out_channels, stride=1):
        super(ResidualBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(out_channels)

        self.shortcut = nn.Sequential()
        if stride != 1 or in_channels != out_channels:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(out_channels),
            )

    def forward(self, x):
        out = torch.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        out += self.shortcut(x)
        return torch.relu(out)

class CNN(nn.Module):
    def __init__(self):
        super(CNN, self).__init__()
        self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(64)
        self.layer1 = ResidualBlock(64, 128, stride=2)
        self.layer2 = ResidualBlock(128, 256, stride=2)
        self.layer3 = ResidualBlock(256, 512, stride=2)
        self.avg_pool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = nn.Linear(512, 10)

    def forward(self, x):
        x = torch.relu(self.bn1(self.conv1(x)))
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.avg_pool(x)
        x = torch.flatten(x, 1)
        return self.fc(x)

In [None]:
def softmax_with_temperature(logits, T):
    """Apply softmax with temperature scaling."""
    return F.softmax(logits / T, dim=1)

class DistillationLoss(nn.Module):
    """Knowledge Distillation Loss using KL Divergence."""
    def __init__(self, T):
        super(DistillationLoss, self).__init__()
        self.T = T

    def forward(self, student_logits, teacher_probs):
        student_probs = F.log_softmax(student_logits / self.T, dim=1)
        return F.kl_div(student_probs, teacher_probs, reduction="batchmean") * (self.T ** 2)

In [None]:
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])

train_dataset = datasets.CIFAR10(root="./data", train=True, transform=transform, download=True)
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)

test_dataset = datasets.CIFAR10(root="./data", train=False, transform=transform, download=True)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
teacher_model = CNN().to(device)
optimizer = optim.Adam(teacher_model.parameters(), lr=0.001)
criterion = nn.CrossEntropyLoss()

In [None]:
num_epochs = 30
print("Training Teacher Model...")
for epoch in range(num_epochs):
    teacher_model.train()
    for inputs, labels in train_loader:
        inputs, labels = inputs.to(device), labels.to(device)
        optimizer.zero_grad()
        outputs = teacher_model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

torch.save(teacher_model.state_dict(), "/kaggle/working/teacher_model.pth")
print("Teacher Model Trained & Saved.")

In [None]:
T = 5.0  # High temperature for soft labels
teacher_model.eval()

soft_labels_list = []
with torch.no_grad():
    for inputs, _ in train_loader:
        inputs = inputs.to(device)
        logits = teacher_model(inputs)
        soft_labels_list.append(softmax_with_temperature(logits, T))

soft_labels = torch.cat(soft_labels_list)  # Stack all soft labels
print("Soft Labels Generated.")

In [None]:
student_model = CNN().to(device)
optimizer = optim.Adam(student_model.parameters(), lr=0.001)
criterion = DistillationLoss(T)

print("Training Student Model...")
for epoch in range(num_epochs):
    student_model.train()
    for (inputs, _), soft_targets in zip(train_loader, soft_labels):
        inputs, soft_targets = inputs.to(device), soft_targets.to(device)
        optimizer.zero_grad()
        student_logits = student_model(inputs)
        loss = criterion(student_logits, soft_targets)
        loss.backward()
        optimizer.step()

torch.save(student_model.state_dict(), "/kaggle/working/student_model.pth")
print("Student Model Trained & Saved.")