This file trains the intra-distillation teacher and student models without 
label smoothing.

In [None]:
from google.colab import drive
drive.mount('/content/gdrive')

In [None]:
import torch
import torchvision
import torchvision.transforms as transforms
import torch.nn as nn
import numpy as np
device = torch.device('cuda')

In [None]:
K = 3
p = 5
q = 10
alpha = 5
N = 4000
num_classes = 37
batch_size = 4 # Gradients are updated every 60 samples, but we run 4 samples at a time to save memory.
iter_batch = 15
epochs = 80
kd = 1
base_lr = 0.00002
weight_decay = 0.001
dropout_p = 0.1
name = 'intranosmooth'

In [None]:
weights = torchvision.models.ResNet50_Weights.IMAGENET1K_V1
data = torchvision.datasets.OxfordIIITPet('OxfordIIITPet', transform=weights.transforms(), download=True)
train, val = torch.utils.data.random_split(data, [3000, 680], generator=torch.Generator().manual_seed(42))
train_loader = torch.utils.data.DataLoader(train, batch_size=batch_size, shuffle=True)
val_loader = torch.utils.data.DataLoader(val, batch_size=batch_size, shuffle=True)

In [None]:
def dropout(p):
    return nn.Sequential(nn.ReLU(), nn.Dropout(p=p))
def set_dropout(model, p):
    for name, child in model.named_children():
        if isinstance(child, nn.ReLU):
            model.relu = dropout(p)
        set_dropout(child, p)
def make_model(model_name='', p=dropout_p):
    model = torchvision.models.resnet50(weights='IMAGENET1K_V1')
    model.fc = torch.nn.Linear(2048, num_classes, bias=False)
    set_dropout(model, p)
    model = model.to(device)
    if model_name != '':
        model.load_state_dict(torch.load(f'/content/gdrive/My Drive/{model_name}'))
    model.eval();
    return model
model = make_model()

In [None]:
def X_loss(distributions, target_mask):
    dict_size = distributions.shape[-1]

    m = torch.sum(distributions, dim=0) / distributions.shape[0]
    m = m.float().view(-1, dict_size)[target_mask]

    kl_all = 0
    for l in distributions:
        l = l.float().view(-1, dict_size)[target_mask]
        d = (l-m) * (torch.log(l) - torch.log(m))
        kl_all += d.sum()
    return kl_all / distributions.shape[0]

def _get_alpha(alpha, num_update, max_update, p, q):
    if num_update >= max_update / p or alpha <= 1:
        return alpha
    else:
        alpha = torch.tensor([alpha])
        gamma = torch.log(1/alpha) / torch.log(torch.tensor([p/q])) # log_(p/q)(1/alpha)
        new_alpha = ( p**gamma * alpha * num_update ** gamma) / (max_update ** gamma)
        return new_alpha.item()

def IntraDistillationLoss(N, p, q, alpha):
    cross_entropy = torch.nn.CrossEntropyLoss(label_smoothing=0.1)
    def loss_function(outputs, true, num_update):
        K = outputs.shape[0]
        num_classes = outputs.shape[-1]
        mask = torch.ones((batch_size, num_classes)).type(torch.bool)
        probabilities = torch.softmax(outputs, -1)
        intra_loss = X_loss(probabilities, mask)
        expanded_true = true.repeat(K)
        expanded_outputs = outputs.reshape(-1, num_classes)
        likelihood_loss = cross_entropy(expanded_outputs, expanded_true)
        adaptive_alpha = _get_alpha(alpha, num_update, N, p, q)
        return likelihood_loss + adaptive_alpha * intra_loss, likelihood_loss, intra_loss
    return loss_function

In [None]:
# Train the teacher

model = make_model()
optimizer = torch.optim.AdamW(model.parameters(), lr=base_lr, weight_decay=weight_decay)
intra_distillation_loss_fn = IntraDistillationLoss(N, p, q, alpha)
val_loss_fn = torch.nn.CrossEntropyLoss()
num_update = 1
curve = []
optimizer.zero_grad()

for epoch in range(epochs):
    train_likelihood = 0
    train_intra = 0
    val_likelihood = 0

    model.train()
    for j, (batch, label) in enumerate(train_loader):
        outputs = torch.empty(K, batch_size, num_classes)
        batch = batch.to(device)
        label = label.to(device)
        outputs = outputs.to(device)
        for i in range(K):
            outputs[i] = model(batch)
        loss, likelihood_loss, intra_loss = intra_distillation_loss_fn(outputs, label, num_update)
        (loss/iter_batch).backward()
        train_likelihood += float(likelihood_loss)/iter_batch
        train_intra += float(intra_loss)/iter_batch
        if (j+1) % iter_batch == 0:
            for g in optimizer.param_groups:
                if num_update < 50:
                    g['lr'] = base_lr * num_update
                else:
                    g['lr'] = 0.001 / np.sqrt(num_update-49)
            optimizer.step()
            optimizer.zero_grad()
            num_update += 1

    model.eval()
    for batch, label in val_loader:
        batch = batch.to(device)
        label = label.to(device)
        output = model(batch)
        loss = val_loss_fn(output, label)
        val_likelihood += float(loss)

    stat = [train_likelihood/50, train_intra/50, val_likelihood/170]
    curve.append(stat)

    np.savetxt(f'/content/gdrive/My Drive/{name}_curve.txt', np.array(curve))
    torch.save(model.state_dict(), f'/content/gdrive/My Drive/{name}_{epoch}.mdl')

In [None]:
# Compute the teacher predictions for self-distillation

def teacher_loss_fn(outputs, teacher_outputs):
    p = torch.softmax(teacher_outputs, dim=-1)
    q = torch.softmax(outputs, dim=-1)
    return - torch.mean(torch.sum(p * torch.log(q), dim=-1))

model = make_model(f'{name}_{epochs}.mdl')
next_data = []

for img, label in train:
    output = model(img.to(device)[None,...]).to('cpu').detach()[0]
    next_data.append((img, output, label))
next_train_dataloader = torch.utils.data.DataLoader(next_data, batch_size=batch_size, shuffle=True)

In [None]:
# Train the student

model = make_model()
optimizer = torch.optim.AdamW(model.parameters(), lr=base_lr, weight_decay=weight_decay)
intra_distillation_loss_fn = IntraDistillationLoss(N, p, q, alpha)
val_loss_fn = torch.nn.CrossEntropyLoss()
num_update = 1
curve = []
optimizer.zero_grad()

for epoch in range(epochs):
    train_likelihood = 0
    train_intra = 0
    train_teacher = 0
    val_likelihood = 0

    model.train()
    for j, (img, teacher, label) in enumerate(next_train_dataloader):
        outputs = torch.empty(K, batch_size, num_classes)
        img = img.to(device)
        teacher = teacher.to(device)
        label = label.to(device)
        outputs = outputs.to(device)
        for i in range(K):
            outputs[i] = model(img)
        intrinsic_loss, likelihood_loss, intra_loss = intra_distillation_loss_fn(outputs, label, num_update)
        teacher_loss = teacher_loss_fn(outputs, teacher)
        if epoch >= 75:
            loss = intrinsic_loss
        else:
            loss = intrinsic_loss + kd * teacher_loss
        (loss/iter_batch).backward()
        train_likelihood += float(likelihood_loss)/iter_batch
        train_intra += float(intra_loss)/iter_batch
        train_teacher += float(teacher_loss)/iter_batch
        if (j+1) % iter_batch == 0:
            for g in optimizer.param_groups:
                if num_update < 50:
                    g['lr'] = base_lr * num_update
                else:
                    g['lr'] = 0.001 / np.sqrt(num_update-49)
            optimizer.step()
            optimizer.zero_grad()
            num_update += 1

    model.eval()
    for batch, label in val_loader:
        batch = batch.to(device)
        label = label.to(device)
        output = model(batch)
        loss = val_loss_fn(output, label)
        val_likelihood += float(loss)

    stat = [train_likelihood/50, train_intra/50, train_teacher/50, val_likelihood/170]
    curve.append(stat)

    np.savetxt(f'/content/gdrive/My Drive/{name}r1_curve.txt', np.array(curve))
    torch.save(model.state_dict(), f'/content/gdrive/My Drive/{name}r1_{epoch}.mdl')