In [1]:
!pip install timm detectors

Collecting detectors
  Downloading detectors-0.1.11-py3-none-any.whl.metadata (9.3 kB)
Collecting optuna (from detectors)
  Downloading optuna-4.4.0-py3-none-any.whl.metadata (17 kB)
Collecting wilds (from detectors)
  Downloading wilds-2.0.0-py3-none-any.whl.metadata (22 kB)
Collecting faiss-cpu (from detectors)
  Downloading faiss_cpu-1.11.0-cp311-cp311-manylinux_2_28_x86_64.whl.metadata (4.8 kB)
Collecting nvidia-cuda-nvrtc-cu12==12.4.127 (from torch->timm)
  Downloading nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-runtime-cu12==12.4.127 (from torch->timm)
  Downloading nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-cupti-cu12==12.4.127 (from torch->timm)
  Downloading nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cudnn-cu12==9.1.0.70 (from torch->timm)
  Downloading nvidia_cudnn_cu12-9.1.0.70-py3-none-man

In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset

import torchvision
import torchvision.transforms as transforms
#define hyperparams/training constants
LEARNING_RATE = 1e-3
WEIGHT_DECAY = 1e-2
NUM_EPOCHS = 10
from tqdm import tqdm
import timm
import detectors
import datetime


### Retrieve Data and Model Objects

In [3]:
def get_cifar10_data(batch_size):
  """Get CIFAR-10 data loaders with appropriate transforms"""
  transform = transforms.Compose([
    transforms.Resize(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.49139968, 0.48215827 ,0.44653124],
                         std=[0.24703233, 0.24348505, 0.26158768])
  ])

  trainset = torchvision.datasets.CIFAR10(
    root='./data', train=True, download=True, transform=transform
  )
  testset = torchvision.datasets.CIFAR10(
    root='./data', train=False, download=True, transform=transform
  )

  trainloader = DataLoader(trainset, batch_size=batch_size,
                           shuffle=True, num_workers=2)
  testloader = DataLoader(testset, batch_size=batch_size,
                          shuffle=False, num_workers=2)

  return trainset, testset, trainloader, testloader



In [4]:
def init_model(device):
  '''
  initialize vit tiny patch16_224 model from timm, no pretraining
  '''
  model = timm.create_model('vit_tiny_patch16_224',
                            pretrained=False, num_classes=10)
  model.to(device)
  return model

def get_teacher_model(device):
  '''
  get resnet18 teacher model from timm, with pretrained weights
  '''
  resnet18 = timm.create_model("resnet18_cifar10", pretrained=True)
  #don't want to accidentally update params during training
  for p in resnet18.parameters():
    p.requires_grad = False
  resnet18.to(device)
  return resnet18

### Define eval and train functions

In [5]:
#accuracy of classification
def eval_model(model, testloader, device):
  model.eval()
  correct = 0
  total = 0
  with torch.no_grad():
    for inputs, labels in testloader:
      inputs, labels = inputs.to(device), labels.to(device)
      outputs = model(inputs)
      preds = outputs.argmax(dim=-1)
      correct += preds.eq(labels).sum()
      total += inputs.shape[0]
  return correct / float(total)

In [6]:
def train(model, trainloader, testloader, device, teacher_model=None, temperature=0.5, alpha=0.5):
  optimizer = optim.AdamW(params=model.parameters(), lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY)
  ce_loss_criterion = nn.CrossEntropyLoss()
  kld_loss_criterion = nn.KLDivLoss(reduction='batchmean')
  scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=NUM_EPOCHS)

  train_epoch_losses, train_epoch_acc, test_epoch_acc = [], [], []
  epoch_values, iters, train_losses, train_acc = [], [], [], []

  for epoch in range(NUM_EPOCHS):
    total_loss = 0.0
    correct = 0
    total = 0
    model.train()

    pbar = tqdm(trainloader, desc=f'Epoch {epoch+1}/{NUM_EPOCHS}')
    for batch_idx, (inputs, labels) in enumerate(pbar):
      #forward and backward pass
      inputs, labels = inputs.to(device), labels.to(device)
      outputs = model(inputs)
      if teacher_model:
        with torch.no_grad():
          soft_labels = F.softmax(teacher_model(inputs) / temperature, dim=1) #get soft labels from teacher model
        soft_preds = F.log_softmax(outputs / temperature, dim=1)
        kld_loss = kld_loss_criterion(soft_preds, soft_labels) * (temperature ** 2)
        ce_loss = ce_loss_criterion(outputs, labels)
        loss = alpha * kld_loss + (1 - alpha) * ce_loss #combine CELoss and KL-Divergence Loss
      else:
        loss = ce_loss_criterion(outputs, labels)
      optimizer.zero_grad()
      loss.backward()
      optimizer.step()

      #update running statistics
      total_loss += loss.item()
      preds = outputs.argmax(dim=-1)
      correct += preds.eq(labels).sum()
      total += labels.size(0)

      #update progress bar
      pbar.set_postfix({
        'Loss': f'{loss.item():.4f}',
        'Acc': f'{100.*correct/total:.2f}%'
      })
      epoch_values.append(epoch)
      iters.append(batch_idx)
      train_losses.append(loss.item())
      train_acc.append(correct / float(total))

    # update learning rate
    scheduler.step()

    # Calculate epoch metrics
    epoch_loss = total_loss / len(trainloader)
    epoch_acc = 100. * correct / total
    test_acc = eval_model(model, testloader, device)

    train_epoch_losses.append(epoch_loss)
    train_epoch_acc.append(epoch_acc)
    test_epoch_acc.append(test_acc)

    print(f'Epoch {epoch+1}: Train CELoss: {epoch_loss:.4f}, Train Acc: {epoch_acc:.2f}%, Test Acc: {test_acc:.2f}%')

  results = {
    'train_epoch_losses': train_epoch_losses,
    'train_epoch_acc': train_epoch_acc,
    'test_epoch_acc': test_epoch_acc,
    'epoch_values': epoch_values,
    'iters': iters,
    'train_losses': train_losses,
    'train_acc': train_acc
  }
  return results

In [7]:
def save_stats(results, output_file_prefix, ts):
  #save batch-level stats
  with open(f'{output_file_prefix}_batchstats_{ts}.csv', 'w') as f:
    for i in range(len(results['train_losses'])):
      row = f'''{results['epoch_values'][i]},{results['iters'][i]},{results['train_losses'][i]},{results['train_acc'][i]}\n'''
      f.write(row)

  #save epoch-level stats
  with open(f'{output_file_prefix}_epochstats_{ts}.csv', 'w') as f:
    for i in range(len(results['train_epoch_losses'])):
      row = f'''{i},{results['train_epoch_losses'][i]},{results['train_epoch_acc'][i]},{results['test_epoch_acc'][i]}\n'''
      f.write(row)

### Training Code

In [8]:
#define compute device
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
print('running using', device)

#get data
trainset, testset, trainloader, testloader = get_cifar10_data(128)

#initialize models
vit_normal = init_model(device)
vit_student = init_model(device)
resnet_teacher = get_teacher_model(device)

vit_param_count = sum([p.numel() for p in vit_normal.parameters()])
resnet_param_count = sum([p.numel() for p in resnet_teacher.parameters()])
print(f'vit_normal and vit_student parameter count: {vit_param_count:,}')
print(f'resnet18 teacher parameter count: {resnet_param_count:,}')
print('\n')

running using cpu


100%|██████████| 170M/170M [00:01<00:00, 91.2MB/s]
Downloading: "https://huggingface.co/edadaltocg/resnet18_cifar10/resolve/main/pytorch_model.bin" to /root/.cache/torch/hub/checkpoints/resnet18_cifar10.pth
100%|██████████| 42.7M/42.7M [00:00<00:00, 111MB/s]

vit_normal and vit_student parameter count: 5,526,346
resnet18 teacher parameter count: 11,173,962







In [9]:
#show initial performance of model before training
print('evaluating accuracy of default parameters on CIFAR10 test set...')
acc = eval_model(vit_normal, testloader, device)
print(f'initial accuracy {acc:,}')

evaluating accuracy of default parameters on CIFAR10 test set...
initial accuracy 0.1234000027179718


In [10]:
#define hyperparams/training constants
LEARNING_RATE = 1e-3
WEIGHT_DECAY = 1e-2
NUM_EPOCHS = 10
#define hyperparams/training constants
LEARNING_RATE = 1e-3
WEIGHT_DECAY = 1e-2
NUM_EPOCHS = 10

In [None]:
#train vit_normal, store results
print('training vit using normal pre-training approach')
vit_normal_results = train(vit_normal, trainloader, testloader, device)
ts = datetime.datetime.now().strftime('%Y%m%d_%H%M%S')
save_stats(vit_normal_results, 'vit_normal', ts)
torch.save(vit_normal.state_dict(), f'vit_normal_{ts}.pt')


training vit using normal pre-training approach


Epoch 1/10:  92%|█████████▏| 358/391 [1:29:57<08:01, 14.58s/it, Loss=1.6489, Acc=31.22%]

In [None]:
#use train vit_student with teacher model
print('training vit using distillation/student-teacher approach')
vit_student_results = train(vit_normal, trainloader, testloader, device, teacher_model=resnet_teacher)
ts = datetime.datetime.now().strftime('%Y%m%d_%H%M%S')
save_stats(vit_student_results, 'vit_student', ts)
torch.save(vit_student.state_dict(), f'vit_student_{ts}.pt')

### Comparing Performance metrics