In [None]:
from google.colab import drive
drive.mount('/content/gdrive')

Mounted at /content/gdrive


In [None]:
import torch
import torchvision
import torchvision.transforms as transforms
import torch.nn as nn
import numpy as np
device = torch.device('cuda') #cuda

In [None]:
K = 3
p = 5
q = 10
alpha = 1.5
N = 4000
num_classes = 37
batch_size = 4
iter_batch = 15
mask_type = 'full'

In [None]:
weights = torchvision.models.ResNet50_Weights.IMAGENET1K_V1
T = weights.transforms()
data = torchvision.datasets.OxfordIIITPet('OxfordIIITPet', transform=T, download=True)
train, val = torch.utils.data.random_split(data, [3000, 680], generator=torch.Generator().manual_seed(42))
train_loader = torch.utils.data.DataLoader(train, batch_size=batch_size, shuffle=True)
val_loader = torch.utils.data.DataLoader(val, batch_size=batch_size, shuffle=True)
test = torchvision.datasets.OxfordIIITPet('OxfordIIITPet', transform=T, download=True, split='test')
test_loader = torch.utils.data.DataLoader(test, batch_size=batch_size, shuffle=True)
model = torchvision.models.resnet50(weights='IMAGENET1K_V1')
model.fc = torch.nn.Linear(2048, num_classes, bias=False)

In [None]:
def mixup(x, y, alpha=0.2):
    h = torch.distributions.beta.Beta(alpha, 1).rsample((batch_size,))
    index = torch.randperm(batch_size)
    mixed_x = h[:,None,None,None] * x + (1 - h[:,None,None,None]) * x[index, ...]
    mixed_y = h*torch.nn.functional.one_hot(y, num_classes=num_classes) + (1-h) * torch.nn.functional.one_hot(y[index], num_classes=num_classes)
    return mixed_x, mixed_y

In [None]:
def dropout(p):
    return nn.Sequential(nn.ReLU(), nn.Dropout(p=p))
def set_dropout(model, p):
    for name, child in model.named_children():
        if isinstance(child, nn.ReLU):
            model.relu = dropout(p)
        set_dropout(child, p)
set_dropout(model, 0.1)
model = model.to(device)

In [None]:
def X_loss(distributions, target_mask):
    dict_size = distributions.shape[-1]

    m = torch.sum(distributions, dim=0) / distributions.shape[0]
    m = m.float().view(-1, dict_size)[target_mask]

    kl_all = 0
    for l in distributions:
        l = l.float().view(-1, dict_size)[target_mask]
        d = (l-m) * (torch.log(l) - torch.log(m))
        kl_all += d.sum()
    return kl_all / distributions.shape[0]

def _get_alpha(alpha, num_update, max_update, p, q):
    if num_update >= max_update / p or alpha <= 1:
        return alpha
    else:
        alpha = torch.tensor([alpha])
        gamma = torch.log(1/alpha) / torch.log(torch.tensor([p/q])) # log_(p/q)(1/alpha)
        new_alpha = ( p**gamma * alpha * num_update ** gamma) / (max_update ** gamma)
        return new_alpha.item()

def build_loss_function(N, p, q, alpha, label_smoothing=0.1, mask_type='full'):
    cross_entropy = torch.nn.CrossEntropyLoss(label_smoothing=label_smoothing)
    def loss_function(outputs, true, num_update):
        K = outputs.shape[0]
        num_classes = outputs.shape[-1]
        if mask_type == 'one_hot':
            mask = true > 0
        elif mask_type == 'full':
            mask = torch.ones((batch_size, num_classes)).type(torch.bool)
        elif mask_type == 'none':
            mask = torch.zeros(num_classes).type(torch.bool)

        probabilities = torch.softmax(outputs, -1)
        intra_loss = X_loss(probabilities, mask)
        expanded_true = true.repeat(K)
        expanded_outputs = outputs.reshape(-1, num_classes)
        likelihood_loss = cross_entropy(expanded_outputs, expanded_true)
        adaptive_alpha = _get_alpha(alpha, num_update, N, p, q)
        return likelihood_loss + adaptive_alpha * intra_loss, likelihood_loss, intra_loss
    return loss_function

In [None]:
base_lr = 0.00002
optimizer = torch.optim.AdamW(model.parameters(), lr=base_lr, weight_decay=0.001)
loss_fn = build_loss_function(N, p, q, alpha, mask_type=mask_type)
val_loss_fn = torch.nn.CrossEntropyLoss()
num_update = 1
curve = []
optimizer.zero_grad()

for epoch in range(80):
    train_likelihood = 0
    train_intra = 0
    val_likelihood = 0

    model.train()
    for j, (batch, label) in enumerate(train_loader):
        outputs = torch.empty(K, batch_size, num_classes)
        batch = batch.to(device)
        label = label.to(device)
        outputs = outputs.to(device)
        for i in range(K):
            outputs[i] = model(batch)
        loss, likelihood_loss, intra_loss = loss_fn(outputs, label, num_update)
        (loss/iter_batch).backward()
        train_likelihood += float(likelihood_loss)/iter_batch
        train_intra += float(intra_loss)/iter_batch
        if (j+1) % iter_batch == 0:
            print(loss)
            for g in optimizer.param_groups:
                if num_update < 50:
                    g['lr'] = 0.00002 * num_update
                else:
                    g['lr'] = 0.001 / np.sqrt(num_update-49)
            optimizer.step()
            optimizer.zero_grad()
            num_update += 1

    model.eval()
    for batch, label in val_loader:
        batch = batch.to(device)
        label = label.to(device)
        output = model(batch)
        loss = val_loss_fn(output, label)
        val_likelihood += float(loss)

    stat = [train_likelihood/50, train_intra/50, val_likelihood/170]
    curve.append(stat)

    np.savetxt(f'/content/gdrive/My Drive/1.5{mask_type}_curve.txt', np.array(curve))
    torch.save(model.state_dict(), f'/content/gdrive/My Drive/1.5{mask_type}_{epoch}.mdl')

In [None]:
base_lr = 0.00002
optimizer = torch.optim.AdamW(model.parameters(), lr=base_lr, weight_decay=0.001)
loss_fn = torch.nn.CrossEntropyLoss()
num_update = 1
curve = []
optimizer.zero_grad()

for epoch in range(80):
    train_likelihood = 0
    val_likelihood = 0

    model.train()
    for j, (batch, label) in enumerate(train_loader):
        outputs = torch.empty(K, batch_size, num_classes)
        batch = batch.to(device)
        label = label.to(device)
        output = model(batch)
        loss = loss_fn(output, label)
        (loss/iter_batch).backward()
        train_likelihood += float(loss)/iter_batch
        if (j+1) % iter_batch == 0:
            print(loss)
            for g in optimizer.param_groups:
                if num_update < 50:
                    g['lr'] = 0.00002 * num_update
                else:
                    g['lr'] = 0.001 / np.sqrt(num_update-49)
            optimizer.step()
            optimizer.zero_grad()
            num_update += 1

    model.eval()
    for batch, label in val_loader:
        batch = batch.to(device)
        label = label.to(device)
        output = model(batch)
        loss = val_loss_fn(output, label)
        val_likelihood += float(loss)

    stat = [train_likelihood/50, val_likelihood/85]
    curve.append(stat)

    np.savetxt(f'/content/gdrive/My Drive/classic_curve.txt', np.array(curve))
    torch.save(model.state_dict(), f'/content/gdrive/My Drive/classic_{epoch}.mdl')