In [None]:
!pip install torch torchvision tqdm timm detectors
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
import torchvision
import torchvision.transforms as transforms

from tqdm import tqdm

import timm
import detectors
import datetime


### Retrieve Data and Model Objects

In [None]:
def get_cifar10_data(batch_size):
  """Get CIFAR-10 data loaders with appropriate transforms"""
  transform = transforms.Compose([
    transforms.Resize(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.49139968, 0.48215827 ,0.44653124],
                         std=[0.24703233, 0.24348505, 0.26158768])
  ])

  trainset = torchvision.datasets.CIFAR10(
    root='./data', train=True, download=True, transform=transform
  )
  testset = torchvision.datasets.CIFAR10(
    root='./data', train=False, download=True, transform=transform
  )

  trainloader = DataLoader(trainset, batch_size=batch_size,
                           shuffle=True, num_workers=2)
  testloader = DataLoader(testset, batch_size=batch_size,
                          shuffle=False, num_workers=2)

  return trainset, testset, trainloader, testloader



In [None]:
def init_model(device):
  '''
  initialize vit tiny patch16_224 model from timm, no pretraining
  '''
  model = timm.create_model('vit_tiny_patch16_224',
                            pretrained=False, num_classes=10)
  model.to(device)
  return model

def get_teacher_model(device):
  '''
  get resnet18 teacher model from timm, with pretrained weights
  '''
  resnet18 = timm.create_model("resnet18_cifar10", pretrained=True)
  #don't want to accidentally update params during training
  for p in resnet18.parameters():
    p.requires_grad = False
  resnet18.to(device)
  return resnet18

### Define eval and train functions

In [None]:
#accuracy of classification
def eval_model(model, testloader, device):
  model.eval()
  correct = 0
  total = 0
  with torch.no_grad():
    for inputs, labels in testloader:
      inputs, labels = inputs.to(device), labels.to(device)
      outputs = model(inputs)
      preds = outputs.argmax(dim=-1)
      correct += preds.eq(labels).sum()
      total += inputs.shape[0]
  return correct / float(total)

In [None]:
def train(
  model,
  trainloader,
  testloader,
  device,
  learning_rate,
  weight_decay,
  num_epochs,
  teacher_model=None,
  temperature=0.5,
  alpha=0.5,
  premature_stop=None):
  optimizer = optim.AdamW(params=model.parameters(), lr=learning_rate, weight_decay=weight_decay)
  ce_loss_criterion = nn.CrossEntropyLoss()
  kld_loss_criterion = nn.KLDivLoss(reduction='batchmean')
  scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=num_epochs)

  train_epoch_losses, train_epoch_acc, test_epoch_acc = [], [], []
  epoch_values, iters, train_losses, train_acc = [], [], [], []

  for epoch in range(num_epochs):
    total_loss = 0.0
    correct = 0
    total = 0
    model.train()

    pbar = tqdm(trainloader, desc=f'Epoch {epoch+1}/{num_epochs}')
    for batch_idx, (inputs, labels) in enumerate(pbar):
      #forward and backward pass
      inputs, labels = inputs.to(device), labels.to(device)
      outputs = model(inputs)
      if teacher_model:
        with torch.no_grad():
          soft_labels = F.softmax(teacher_model(inputs) / temperature, dim=1) #get soft labels from teacher model
        soft_preds = F.log_softmax(outputs / temperature, dim=1)
        kld_loss = kld_loss_criterion(soft_preds, soft_labels) * (temperature ** 2)
        ce_loss = ce_loss_criterion(outputs, labels)
        loss = alpha * kld_loss + (1 - alpha) * ce_loss #combine CELoss and KL-Divergence Loss
      else:
        loss = ce_loss_criterion(outputs, labels)
      optimizer.zero_grad()
      loss.backward()
      optimizer.step()

      #update running statistics
      total_loss += loss.item()
      preds = outputs.argmax(dim=-1)
      correct += preds.eq(labels).sum()
      total += labels.size(0)
      if premature_stop:
        if batch_idx >= premature_stop: return total_loss, correct / float(total)

      #update progress bar
      pbar.set_postfix({
        'Loss': f'{loss.item():.4f}',
        'Acc': f'{100.*correct/total:.2f}%'
      })
      epoch_values.append(epoch)
      iters.append(batch_idx)
      train_losses.append(loss.item())
      train_acc.append(correct / float(total))

    # update learning rate
    scheduler.step()

    # Calculate epoch metrics
    epoch_loss = total_loss / len(trainloader)
    epoch_acc = 100. * correct / total
    test_acc = eval_model(model, testloader, device)

    train_epoch_losses.append(epoch_loss)
    train_epoch_acc.append(epoch_acc)
    test_epoch_acc.append(test_acc)

    print(f'Epoch {epoch+1}: Train CELoss: {epoch_loss:.4f}, Train Acc: {epoch_acc:.2f}%, Test Acc: {test_acc:.2f}%')

  results = {
    'train_epoch_losses': train_epoch_losses,
    'train_epoch_acc': train_epoch_acc,
    'test_epoch_acc': test_epoch_acc,
    'epoch_values': epoch_values,
    'iters': iters,
    'train_losses': train_losses,
    'train_acc': train_acc
  }
  return results

In [None]:
def save_stats(results, output_file_prefix, ts):
  #save batch-level stats
  with open(f'/content/drive/My Drive/ML + Robotics/{output_file_prefix}_batchstats_{ts}.csv', 'w') as f:
    for i in range(len(results['train_losses'])):
      row = f'''{results['epoch_values'][i]},{results['iters'][i]},{results['train_losses'][i]},{results['train_acc'][i]}\n'''
      f.write(row)

  #save epoch-level stats
  with open(f'/content/drive/My Drive/ML + Robotics/{output_file_prefix}_epochstats_{ts}.csv', 'w') as f:
    for i in range(len(results['train_epoch_losses'])):
      row = f'''{i},{results['train_epoch_losses'][i]},{results['train_epoch_acc'][i]},{results['test_epoch_acc'][i]}\n'''
      f.write(row)

### Training Code

In [None]:
#define compute device
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
print('running using', device)

#get data
trainset, testset, trainloader, testloader = get_cifar10_data(128)

#get teacher model
resnet_teacher = get_teacher_model(device)

In [None]:
#find best initial learning rate for vit_normal
lr_tests = []
for lr in (torch.rand(20) * 1e-2):
  model = init_model(device)
  loss, acc = train(model, trainloader, testloader,
                    device, learning_rate=lr.item(), weight_decay=1e-2,
                    num_epochs=1, premature_stop=50)
  lr_tests.append((lr.item(), loss, acc))

#see acc-maximizing learning rate
sorted(lr_tests, key=lambda x: x[2].item(), reverse=True)

In [None]:
#find best initial learning rate, temp, and alpha for vit_student
lr_tests = []
for i in range(20):
  model = init_model(device)
  lr = torch.rand(1).item() * 1e-2
  temp, alpha = torch.rand(1).item(), torch.rand(1).item()
  loss, acc = train(model, trainloader, testloader,
                    device, learning_rate=lr, weight_decay=1e-2,
                    num_epochs=1, teacher_model=resnet_teacher,
                    temperature=temp, alpha=alpha, premature_stop=50)
  lr_tests.append((lr, temp, alpha, loss, acc))

#acc maximizing lr, temp, alpha
sorted(lr_tests, key=lambda x: x[-1].item(), reverse=True)

In [None]:
#initialize models
vit_normal = init_model(device)
vit_student = init_model(device)
resnet_teacher = get_teacher_model(device)

vit_param_count = sum([p.numel() for p in vit_normal.parameters()])
resnet_param_count = sum([p.numel() for p in resnet_teacher.parameters()])
print(f'vit_normal and vit_student parameter count: {vit_param_count:,}')
print(f'resnet18 teacher parameter count: {resnet_param_count:,}')

In [None]:
#show initial performance of model before training
print('evaluating accuracy of default parameters on CIFAR10 test set...')
acc = eval_model(vit_normal, testloader, device)
print(f'initial accuracy {acc:,}')

In [None]:
#define hyperparams/training constants across both models
WEIGHT_DECAY = 1e-2
NUM_EPOCHS = 5

In [None]:
#learning_rate = 4e-3
learning_rate = 1e-3

#train vit_normal, store results
print('training vit using normal pre-training approach')
vit_normal_results = train(
  vit_normal,
  trainloader,
  testloader,
  device,
  learning_rate=learning_rate,
  weight_decay=WEIGHT_DECAY,
  num_epochs=NUM_EPOCHS)
ts = datetime.datetime.now().strftime('%Y%m%d_%H%M%S')
save_stats(vit_normal_results, 'vit_normal', ts)
torch.save(vit_normal.state_dict(), f'/content/drive/My Drive/ML + Robotics/vit_normal_{ts}.pt')


In [None]:
student_learning_rate = 1e-4
temperature = 0.212
alpha=0.62

#use train vit_student with teacher model
print('training vit using distillation/student-teacher approach')
vit_student_results = train(
  vit_normal,
  trainloader,
  testloader,
  device,
  learning_rate=student_learning_rate,
  weight_decay=WEIGHT_DECAY,
  num_epochs=NUM_EPOCHS,
  teacher_model=resnet_teacher,
  temperature=temperature,
  alpha=alpha)
ts = datetime.datetime.now().strftime('%Y%m%d_%H%M%S')
save_stats(vit_student_results, 'vit_student', ts)
torch.save(vit_student.state_dict(), f'/content/drive/My Drive/ML + Robotics/vit_normal_{ts}.pt')