This file trains the baseline teacher and student models without label smoothing.

In [None]:
from google.colab import drive
drive.mount('/content/gdrive')

In [None]:
import torch
import torchvision
import torchvision.transforms as transforms
import torch.nn as nn
import numpy as np
device = torch.device('cuda')

In [None]:
batch_size = 4 # Gradients are updated every 60 samples, but we run 4 samples at a time to save memory.
iter_batch = 15
epochs = 80
num_classes = 37
kd = 1
base_lr = 0.00002
weight_decay = 0.001
dropout_p = 0.1
name = 'classicnosmooth'

In [None]:
weights = torchvision.models.ResNet50_Weights.IMAGENET1K_V1
T = weights.transforms()
data = torchvision.datasets.OxfordIIITPet('OxfordIIITPet', transform=T, download=True)
train, val = torch.utils.data.random_split(data, [3000, 680], generator=torch.Generator().manual_seed(42))
train_loader = torch.utils.data.DataLoader(train, batch_size=batch_size, shuffle=True)
val_loader = torch.utils.data.DataLoader(val, batch_size=batch_size, shuffle=True)

In [None]:
def dropout(p):
    return nn.Sequential(nn.ReLU(), nn.Dropout(p=p))
def set_dropout(model, p):
    for name, child in model.named_children():
        if isinstance(child, nn.ReLU):
            model.relu = dropout(p)
        set_dropout(child, p)
def make_model(model_name='', p=dropout_p):
    model = torchvision.models.resnet50(weights='IMAGENET1K_V1')
    model.fc = torch.nn.Linear(2048, num_classes, bias=False)
    set_dropout(model, p)
    model = model.to(device)
    if model_name != '':
        model.load_state_dict(torch.load(f'/content/gdrive/My Drive/{model_name}'))
    model.eval();
    return model

In [None]:
# Train the teacher

model = make_model()
optimizer = torch.optim.AdamW(model.parameters(), lr=base_lr, weight_decay=weight_decay)
loss_fn = torch.nn.CrossEntropyLoss()
num_update = 1
curve = []
optimizer.zero_grad()

for epoch in range(epochs):
    train_likelihood = 0
    train_intra = 0
    val_likelihood = 0

    model.train()
    for j, (batch, label) in enumerate(train_loader):
        batch = batch.to(device)
        label = label.to(device)
        output = model(batch)
        loss = loss_fn(output, label)
        (loss/iter_batch).backward()
        train_likelihood += float(loss)/iter_batch
        if (j+1) % iter_batch == 0:
            for g in optimizer.param_groups:
                if num_update < 50:
                    g['lr'] = base_lr * num_update
                else:
                    g['lr'] = 0.001 / np.sqrt(num_update-49)
            optimizer.step()
            optimizer.zero_grad()
            num_update += 1

    model.eval()
    for batch, label in val_loader:
        batch = batch.to(device)
        label = label.to(device)
        output = model(batch)
        loss = loss_fn(output, label)
        val_likelihood += float(loss)

    stat = [train_likelihood/50, val_likelihood/170]
    curve.append(stat)

    np.savetxt(f'/content/gdrive/My Drive/{name}_curve.txt', np.array(curve))
    torch.save(model.state_dict(), f'/content/gdrive/My Drive/{name}_{epoch}.mdl')

In [None]:
# Compute the teacher predictions for self-distillation

def teacher_loss_fn(outputs, teacher_outputs):
    p = torch.softmax(teacher_outputs, dim=-1)
    q = torch.softmax(outputs, dim=-1)
    return - torch.mean(torch.sum(p * torch.log(q), dim=-1))

model = make_model(f'{name}_{epochs-1}.mdl')
next_data = []

for img, label in train:
    output = model(img.to(device)[None,...]).to('cpu').detach()[0]
    next_data.append((img, output, label))
next_train_dataloader = torch.utils.data.DataLoader(next_data, batch_size=batch_size, shuffle=True)

In [None]:
# Train the student

model = make_model()
optimizer = torch.optim.AdamW(model.parameters(), lr=base_lr, weight_decay=weight_decay)
loss_fn = torch.nn.CrossEntropyLoss()
num_update = 1
curve = []
optimizer.zero_grad()

for epoch in range(epochs):
    train_likelihood = 0
    train_teacher = 0
    val_likelihood = 0

    model.train()
    for j, (img, teacher, label) in enumerate(next_train_dataloader):
        img = img.to(device)
        teacher = teacher.to(device)
        label = label.to(device)
        output = model(img)
        intrinsic_loss = loss_fn(output, label)
        teacher_loss = teacher_loss_fn(output, teacher)
        train_likelihood += float(intrinsic_loss)/iter_batch
        train_teacher += float(teacher_loss)
        if epoch >= 75:
            loss = intrinsic_loss
        else:
            loss = intrinsic_loss + kd * teacher_loss
        loss.backward()

        if (j+1) % iter_batch == 0:
            for g in optimizer.param_groups:
                if num_update < 50:
                    g['lr'] = base_lr * num_update
                else:
                    g['lr'] = 0.001 / np.sqrt(num_update-49)
            optimizer.step()
            optimizer.zero_grad()
            num_update += 1

    model.eval()
    for batch, label in val_loader:
        batch = batch.to(device)
        label = label.to(device)
        output = model(batch)
        loss = loss_fn(output, label)
        val_likelihood += float(loss)

    stat = [train_likelihood/50, train_teacher/50, val_likelihood/170]
    curve.append(stat)

    np.savetxt(f'/content/gdrive/My Drive/{name}r1_curve.txt', np.array(curve))
    torch.save(model.state_dict(), f'/content/gdrive/My Drive/{name}r1_{epoch}.mdl')