In [4]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision.models import resnet18
import torchvision
import torchvision.transforms as transforms
import time
import warnings
warnings.filterwarnings("ignore", category=DeprecationWarning)

In [6]:
device = "cuda" if torch.cuda.is_available() else "cpu"

transform = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)

trainloader = torch.utils.data.DataLoader(trainset, batch_size=128, shuffle=True, num_workers=2)

In [7]:
class Bottleneck(nn.Module):
    def __init__(self, in_ch, out_ch=256):
        super().__init__()
        self.conv = nn.Conv2d(in_ch, out_ch, 1)
        self.bn = nn.BatchNorm2d(out_ch)

    def forward(self, x):
        return self.bn(self.conv(x))

class SelfDistillResNet18(nn.Module):
    def __init__(self, num_classes=10):
        super().__init__()
        base = resnet18(pretrained=False)

        self.stem = nn.Sequential(
            nn.Conv2d(3, 64, 3, padding=1, bias=False),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True)
        )

        self.layer1 = base.layer1 # 64
        self.layer2 = base.layer2 # 128
        self.layer3 = base.layer3 # 256
        self.layer4 = base.layer4 # 512

        # Bottlenecks
        self.b1 = Bottleneck(64, 256)
        self.b2 = Bottleneck(128, 256)
        self.b3 = Bottleneck(256, 256)
        self.b4 = Bottleneck(512, 256)

        # Classifiers
        self.fc1 = nn.Linear(256, num_classes)
        self.fc2 = nn.Linear(256, num_classes)
        self.fc3 = nn.Linear(256, num_classes)
        self.fc4 = nn.Linear(256, num_classes)

    def forward(self, x):
        x = self.stem(x)

        f1 = self.layer1(x)
        h1 = self.b1(f1)
        p1 = self.fc1(F.adaptive_avg_pool2d(h1,1).flatten(1))

        f2 = self.layer2(f1)
        h2 = self.b2(f2)
        p2 = self.fc2(F.adaptive_avg_pool2d(h2,1).flatten(1))

        f3 = self.layer3(f2)
        h3 = self.b3(f3)
        p3 = self.fc3(F.adaptive_avg_pool2d(h3,1).flatten(1))

        f4 = self.layer4(f3)
        h4 = self.b4(f4)
        p4 = self.fc4(F.adaptive_avg_pool2d(h4,1).flatten(1))

        return [p1, p2, p3, p4], [h1, h2, h3, h4]

In [8]:
CE = nn.CrossEntropyLoss()
KL = nn.KLDivLoss(reduction="batchmean")
MSE = nn.MSELoss()

T = 4.0
alpha = 0.7
beta = 0.3

def train_step(model, images, labels, optimizer):
    logits, feats = model(images)

    teacher_logits = logits[-1]
    teacher_feat = feats[-1].detach()

    total_loss = 0

    for i in range(4):
        student_logits = logits[i]
        student_feat = feats[i]

        loss_ce = CE(student_logits, labels)

        if i == 3:
            total_loss += loss_ce
            continue

        log_p = F.log_softmax(student_logits / T, dim=1)
        q = F.softmax(teacher_logits / T, dim=1)
        loss_kl = KL(log_p, q) * (T*T)

        teacher_resized = F.interpolate(teacher_feat, size=student_feat.shape[2:], mode="bilinear")
        loss_l2 = MSE(student_feat, teacher_resized)

        total_loss += loss_ce + alpha * loss_kl + beta * loss_l2

    optimizer.zero_grad()
    total_loss.backward()
    optimizer.step()

    return total_loss.item()

In [9]:
lr = 0.05
epochs = 100
model = SelfDistillResNet18().to(device)
optimizer = torch.optim.SGD(model.parameters(), lr=lr, momentum=0.9, weight_decay=5e-4)
start_time = time.time()

for epoch in range(epochs):

    if epoch == 40 or epoch == 80:
        lr /= 2
        for param_group in optimizer.param_groups:
            param_group['lr'] = lr
        print("Learning rate changed to ", lr)

    model.train()
    for images, labels in trainloader:
        images, labels = images.to(device), labels.to(device)
        loss = train_step(model, images, labels, optimizer)

    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for images, labels in trainloader:
            images, labels = images.to(device), labels.to(device)
            logits, _ = model(images)
            final_outputs = logits[-1]
            _, predicted = torch.max(final_outputs, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    acc = 100 * correct / total

    print(f"Epoch {epoch} | Loss: {loss:.4f} | Accuracy: {acc:.2f}%")

total_time = time.time() - start_time
print(f"\nTotal training time: {total_time:.2f} seconds")

torch.save(model.state_dict(), "pytorch/saved_models/cifar_10.pt")
print("Model saved as pytorch/saved_models/cifar_10.pt")




KeyboardInterrupt: 