In [1]:
import copy
import torch
import numpy as np
from torch import nn
from tqdm import tqdm
from torch.utils.data import DataLoader
from torchvision import transforms, datasets, models

In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
n_epochs = 1
img_size = 224
batch_size = 64
num_classes = 10
lr = 3e-4
alpha = 0.1
temp = 3
T = transforms.Compose(
    [
     transforms.Resize((img_size, img_size)),
     transforms.ToTensor()
    ]
)
print(device)

cuda


In [3]:
train_data = datasets.CIFAR10("data/", train=True, download=True, transform=T)
val_data = datasets.CIFAR10("data/", train=False, download=True, transform=T)
train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True, num_workers=2, pin_memory=True)
val_loader = DataLoader(val_data, batch_size=batch_size, shuffle=False, num_workers=2, pin_memory=True)
x, y = next(iter(train_loader))
print(len(train_data), x.shape, y.shape)

Files already downloaded and verified
Files already downloaded and verified
50000 torch.Size([64, 3, 224, 224]) torch.Size([64])


In [4]:
teacher = models.resnet50(num_classes=num_classes).to(device) # big model
student = models.resnet18(num_classes=num_classes).to(device) # small model
init_student = copy.deepcopy(student)

In [5]:
ce_loss_fn = nn.CrossEntropyLoss()
kld_loss_fn = nn.KLDivLoss(reduction='batchmean')
def get_accuracy(preds, y):
    preds = preds.argmax(dim=1, keepdim=True)
    correct = preds.squeeze(1).eq(y)
    acc = correct.sum() / torch.FloatTensor([y.shape[0]]).to(device)
    return acc

In [6]:
def pretrain_loop(net, loader, is_train, optimizer=None):
    net.train(is_train)
    losses = []
    accs = []
    pbar = tqdm(loader, total=len(loader))
    for x, y in pbar:
        x = x.to(device)
        y = y.to(device)
        with torch.set_grad_enabled(is_train):
            preds = net(x)
            loss = ce_loss_fn(preds, y)
            acc = get_accuracy(preds, y)
            losses.append(loss.item())
            accs.append(acc.item())

        if is_train:
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
        pbar.set_description(f'epoch={epoch}, train={int(is_train)}')
        pbar.set_postfix(loss=f'{np.mean(losses):.4f}', acc=f'{np.mean(accs):.4f}')

def distill_loop(teacher, student, loader, is_train, optimizer=None):
    teacher.eval()
    student.train(is_train)
    losses = []
    accs = []
    pbar = tqdm(loader, total=len(loader))
    for x, y in pbar:
        x = x.to(device)
        y = y.to(device)
        with torch.no_grad():
            teacher_preds = teacher(x)
            
        with torch.set_grad_enabled(is_train):
            student_preds = student(x)
            student_loss = ce_loss_fn(student_preds, y)
            acc = get_accuracy(student_preds, y)
            student_preds = (student_preds / temp).softmax(-1)
            teacher_preds = (teacher_preds / temp).softmax(-1)
            distillation_loss = kld_loss_fn(student_preds, teacher_preds)
            loss = alpha * student_loss + (1 - alpha) * distillation_loss
            losses.append(loss.item())
            accs.append(acc.item())

        if is_train:
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
        pbar.set_description(f'epoch={epoch}, train={int(is_train)}')
        pbar.set_postfix(loss=f'{np.mean(losses):.4f}', acc=f'{np.mean(accs):.4f}')

In [7]:
# training the teacher model
teacher_optimizer = torch.optim.Adam(teacher.parameters(), lr=lr)
for epoch in range(n_epochs):
    pretrain_loop(teacher, train_loader, True, teacher_optimizer)
    pretrain_loop(teacher, val_loader, False)

epoch=0, train=1: 100%|██████████| 782/782 [08:32<00:00,  1.53it/s, acc=0.4597, loss=1.4821]
epoch=0, train=0: 100%|██████████| 157/157 [00:31<00:00,  4.92it/s, acc=0.5373, loss=1.3888]


In [8]:
# training a clone of the student model, just to compare this with the distilled one
init_student_optimizer = torch.optim.Adam(init_student.parameters(), lr=lr)
for epoch in range(n_epochs):
    pretrain_loop(init_student, train_loader, True, init_student_optimizer)
    pretrain_loop(init_student, val_loader, False)

epoch=0, train=1: 100%|██████████| 782/782 [02:33<00:00,  5.11it/s, acc=0.5580, loss=1.2202]
epoch=0, train=0: 100%|██████████| 157/157 [00:16<00:00,  9.34it/s, acc=0.6309, loss=1.0518]


In [9]:
# freezing the teacher model's paramters for distillation process
for params in teacher.parameters():
    params.requires_grad = False

# distilling the teacher model
distill_optimizer = torch.optim.Adam(student.parameters(), lr=lr)
for epoch in range(n_epochs):
    distill_loop(teacher, student, train_loader, True, distill_optimizer)
    distill_loop(teacher, student, val_loader, False)

epoch=0, train=1: 100%|██████████| 782/782 [05:10<00:00,  2.52it/s, acc=0.5448, loss=-1.7678]
epoch=0, train=0: 100%|██████████| 157/157 [00:41<00:00,  3.74it/s, acc=0.6291, loss=-1.8027]


In [None]:
# now the distilled student model is just as good as the pretrained teacher model