## Training a ResNet18 using a Self-Distillation on CIFAR-10 dataset
 This Jupyter Notebook implements Self-Distillation, a specific form of Knowledge Distillation where the "Student" and the "Teacher" share the same architecture (ResNet18). Instead of a larger model teaching a smaller one, a previously trained version of the same model helps a new instance reach better performance or generalize more effectively on the CIFAR-10 dataset. Model was trained for **2 hours, 13 minutes, and 52 seconds** using **L4** graphics card on Google Colaboratory and achieved accuracy is **94.4%**

#### Model Architecture
- **Internal Classifiers**: The model attaches a bottleneck and a fully connected layer after each of the four main ResNet layers (layer1 through layer4).
- **Bottlenecks**: Each internal layer is compressed into a 256-channel feature map before classification to keep the distillation consistent.
- **Outputs**: The forward pass returns a list of 4 logits (predictions) and 4 feature maps (hidden states).
#### Training Logic
- **Teacher**: The final output (logits[-1]) and final features (feats[-1]) serve as the "Teacher".
- **Student**: Layers 1, 2, and 3 act as "Students".

#### Loss Components:
- **Cross-Entropy (CE)**: Standard classification loss against the real labels.
- **KL Divergence**: Forces the early layers' probability distributions to match the final layer's distribution (soft targets).
- **MSE (L2) Loss**: Aligns the internal feature maps of earlier layers with the final layer's features.

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision.models import resnet18
import torchvision
import torchvision.transforms as transforms
import time

In [2]:
device = "cuda" if torch.cuda.is_available() else "cpu"

transform = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.AutoAugment(transforms.AutoAugmentPolicy.CIFAR10),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465),
                         (0.2023, 0.1994, 0.2010)),
    transforms.RandomErasing(p=0.5, scale=(0.02, 0.33), ratio=(0.3, 3.3), value=0, inplace=False),
])

transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465),
                         (0.2023, 0.1994, 0.2010)),
])

trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=128, shuffle=True, num_workers=2)

testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test)
testloader = torch.utils.data.DataLoader(testset, batch_size=128, shuffle=False, num_workers=2)

100%|██████████| 170M/170M [00:03<00:00, 44.1MB/s]


#### The `Bottleneck` Class
This is a helper module used to standardize the data before it reaching the internal classifiers. It projects feature maps of different sizes (64, 128, 256, 512 channels) into a uniform space of 256 channels. It uses a kernel size of 1 to change the channel depth without altering the spatial height or width. Batch Normalization stabilizes the learning process.

#### The `SelfDistillResNet18` Class
This class builds the main backbone and the four distillation branches. The Stem (custom initial layer) is used instead of the standard ResNet stem to better handle the smaller 32x32 images of CIFAR-10. Layers 1-4 are the standard ResNet residual blocks inherited from the base model. Classifiers (fc1 to fc4) are four linear layers that map the 256-dimensional features to the 10 classes of CIFAR-10.

#### Forward Pass (`forward`)
The forward pass is designed to capture intermediate state information. The input $x$ passes through each layer sequentially (f1 $\rightarrow$ f2 $\rightarrow$ f3 $\rightarrow$ f4). Global Average Pooling `F.adaptive_avg_pool2d(h, 1)` reduces the spatial dimensions to $1 \times 1$ before the data is flattened for the linear layer. And it returns a list of 4 predictions (logits) and 4 feature maps, which are later used in the train_step to calculate the distillation loss.

In [6]:
class Bottleneck(nn.Module):
    def __init__(self, in_ch, out_ch=256):
        super().__init__()
        self.conv = nn.Conv2d(in_ch, out_ch, 1)
        self.bn = nn.BatchNorm2d(out_ch)

    def forward(self, x):
        return self.bn(self.conv(x))

class SelfDistillResNet18(nn.Module):
    def __init__(self, num_classes=10):
        super().__init__()
        base = resnet18(weights=None)

        self.stem = nn.Sequential(
            nn.Conv2d(3, 64, 3, padding=1, bias=False),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True)
        )

        self.layer1 = base.layer1 # 64
        self.layer2 = base.layer2 # 128
        self.layer3 = base.layer3 # 256
        self.layer4 = base.layer4 # 512

        # Bottlenecks
        self.b1 = Bottleneck(64, 256)
        self.b2 = Bottleneck(128, 256)
        self.b3 = Bottleneck(256, 256)
        self.b4 = Bottleneck(512, 256)

        # Classifiers
        self.fc1 = nn.Linear(256, num_classes)
        self.fc2 = nn.Linear(256, num_classes)
        self.fc3 = nn.Linear(256, num_classes)
        self.fc4 = nn.Linear(256, num_classes)

    def forward(self, x):
        x = self.stem(x)

        f1 = self.layer1(x)
        h1 = self.b1(f1)
        p1 = self.fc1(F.adaptive_avg_pool2d(h1,1).flatten(1))

        f2 = self.layer2(f1)
        h2 = self.b2(f2)
        p2 = self.fc2(F.adaptive_avg_pool2d(h2,1).flatten(1))

        f3 = self.layer3(f2)
        h3 = self.b3(f3)
        p3 = self.fc3(F.adaptive_avg_pool2d(h3,1).flatten(1))

        f4 = self.layer4(f3)
        h4 = self.b4(f4)
        p4 = self.fc4(F.adaptive_avg_pool2d(h4,1).flatten(1))

        return [p1, p2, p3, p4], [h1, h2, h3, h4]

In [4]:
CE = nn.CrossEntropyLoss()
KL = nn.KLDivLoss(reduction="batchmean")
MSE = nn.MSELoss()

T = 4.0
alpha = 0.7
beta = 0.3

def train_step(model, images, labels, optimizer):
    logits, feats = model(images)

    teacher_logits = logits[-1] # The deepest output (index -1) acts as the Teacher
    teacher_feat = feats[-1].detach() # detach() is used so that we don't update the teacher based on the student

    total_loss = 0

    for i in range(4):
        student_logits = logits[i]
        student_feat = feats[i]

        loss_ce = CE(student_logits, labels)

        if i == 3: # The final layer only learns from labels
            total_loss += loss_ce
            continue

        # Forces early layers to mimic the teacher's probability distribution
        log_p = F.log_softmax(student_logits / T, dim=1)
        q = F.softmax(teacher_logits / T, dim=1)
        loss_kl = KL(log_p, q) * (T*T)

        # Forces early feature maps to align with the teacher's feature map
        teacher_resized = F.interpolate(teacher_feat, size=student_feat.shape[2:], mode="bilinear")
        loss_l2 = MSE(student_feat, teacher_resized)

        # Total loss is a weighted sum of CE, KL (alpha=0.7), and MSE (beta=0.3)
        total_loss += loss_ce + alpha * loss_kl + beta * loss_l2

    optimizer.zero_grad()
    total_loss.backward()
    optimizer.step()

    return total_loss.item()

In [7]:
lr = 0.1
epochs = 200
model = SelfDistillResNet18().to(device)

optimizer = torch.optim.SGD(model.parameters(), lr=lr, momentum=0.9, weight_decay=5e-4)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs)

start_time = time.time()

for epoch in range(epochs):

    model.train()
    total_epoch_loss = 0
    for images, labels in trainloader:
        images, labels = images.to(device), labels.to(device)
        loss = train_step(model, images, labels, optimizer)
        total_epoch_loss += loss

    avg_loss = total_epoch_loss / len(trainloader)

    scheduler.step()
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for images, labels in testloader:
            images, labels = images.to(device), labels.to(device)
            logits, _ = model(images)
            final_outputs = logits[-1]
            _, predicted = torch.max(final_outputs, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    acc = 100 * correct / total

    current_lr = optimizer.param_groups[0]['lr']
    print(f"Epoch {epoch+1} | Current lr: {current_lr:.4f} | Loss: {avg_loss:.4f} | Accuracy: {acc:.2f}%")

total_time = time.time() - start_time
print(f"\nTotal training time: {total_time:.2f} seconds")

torch.save(model.state_dict(), "saved_models/self_distillation_resnet18.pt")
print("Model saved as saved_models/self_distillation_resnet18.pt")


Epoch 1 | Current lr: 0.1000 | Loss: 8.0706 | Accuracy: 47.22%
Epoch 2 | Current lr: 0.1000 | Loss: 6.4416 | Accuracy: 47.32%
Epoch 3 | Current lr: 0.0999 | Loss: 5.6949 | Accuracy: 61.76%
Epoch 4 | Current lr: 0.0999 | Loss: 5.2692 | Accuracy: 65.15%
Epoch 5 | Current lr: 0.0998 | Loss: 4.9584 | Accuracy: 53.73%
Epoch 6 | Current lr: 0.0998 | Loss: 4.7351 | Accuracy: 45.69%
Epoch 7 | Current lr: 0.0997 | Loss: 4.6063 | Accuracy: 63.52%
Epoch 8 | Current lr: 0.0996 | Loss: 4.4620 | Accuracy: 66.91%
Epoch 9 | Current lr: 0.0995 | Loss: 4.3883 | Accuracy: 65.11%
Epoch 10 | Current lr: 0.0994 | Loss: 4.2821 | Accuracy: 73.42%
Epoch 11 | Current lr: 0.0993 | Loss: 4.2452 | Accuracy: 63.83%
Epoch 12 | Current lr: 0.0991 | Loss: 4.1969 | Accuracy: 70.49%
Epoch 13 | Current lr: 0.0990 | Loss: 4.1165 | Accuracy: 68.85%
Epoch 14 | Current lr: 0.0988 | Loss: 4.1103 | Accuracy: 71.05%
Epoch 15 | Current lr: 0.0986 | Loss: 4.0584 | Accuracy: 71.69%
Epoch 16 | Current lr: 0.0984 | Loss: 4.0301 | Ac