In [1]:
# Mount Google Drive
from google.colab import drive
drive.mount('/content/drive')

# Write a sample file to Google Drive
with open('/content/drive/My Drive/ML + Robotics/sample.txt', 'w') as f:
  f.write('Hello, World!')

Mounted at /content/drive


In [2]:
!pip install torch torchvision tqdm timm detectors
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
import torchvision
import torchvision.transforms as transforms

from tqdm import tqdm

import timm
import detectors
import datetime


Collecting detectors
  Downloading detectors-0.1.11-py3-none-any.whl.metadata (9.3 kB)
Collecting nvidia-cuda-nvrtc-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-runtime-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-cupti-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cudnn-cu12==9.1.0.70 (from torch)
  Downloading nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cublas-cu12==12.4.5.8 (from torch)
  Downloading nvidia_cublas_cu12-12.4.5.8-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cufft-cu12==11.2.1.3 (from torch)
  Downloading nvidia_cufft_cu12-11.2.1.3-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collect

### Retrieve Data and Model Objects

In [3]:
def get_cifar10_data(batch_size):
  """Get CIFAR-10 data loaders with appropriate transforms"""
  transform = transforms.Compose([
    transforms.Resize(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.49139968, 0.48215827 ,0.44653124],
                         std=[0.24703233, 0.24348505, 0.26158768])
  ])

  trainset = torchvision.datasets.CIFAR10(
    root='./data', train=True, download=True, transform=transform
  )
  testset = torchvision.datasets.CIFAR10(
    root='./data', train=False, download=True, transform=transform
  )

  trainloader = DataLoader(trainset, batch_size=batch_size,
                           shuffle=True, num_workers=2)
  testloader = DataLoader(testset, batch_size=batch_size,
                          shuffle=False, num_workers=2)

  return trainset, testset, trainloader, testloader



In [4]:
def init_model(device):
  '''
  initialize vit tiny patch16_224 model from timm, no pretraining
  '''
  model = timm.create_model('vit_tiny_patch16_224',
                            pretrained=False, num_classes=10)
  model.to(device)
  return model

def get_teacher_model(device):
  '''
  get resnet18 teacher model from timm, with pretrained weights
  '''
  resnet18 = timm.create_model("resnet18_cifar10", pretrained=True)
  #don't want to accidentally update params during training
  for p in resnet18.parameters():
    p.requires_grad = False
  resnet18.to(device)
  return resnet18

### Define eval and train functions

In [5]:
#accuracy of classification
def eval_model(model, testloader, device):
  model.eval()
  correct = 0
  total = 0
  with torch.no_grad():
    for inputs, labels in testloader:
      inputs, labels = inputs.to(device), labels.to(device)
      outputs = model(inputs)
      preds = outputs.argmax(dim=-1)
      correct += preds.eq(labels).sum()
      total += inputs.shape[0]
  return correct / float(total)

In [6]:
def train(
  model,
  trainloader,
  testloader,
  device,
  learning_rate,
  weight_decay,
  num_epochs,
  teacher_model=None,
  temperature=0.5,
  alpha=0.5,
  premature_stop=None):
  optimizer = optim.AdamW(params=model.parameters(), lr=learning_rate, weight_decay=weight_decay)
  ce_loss_criterion = nn.CrossEntropyLoss()
  kld_loss_criterion = nn.KLDivLoss(reduction='batchmean')
  scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=num_epochs)

  train_epoch_losses, train_epoch_acc, test_epoch_acc = [], [], []
  epoch_values, iters, train_losses, train_acc = [], [], [], []

  for epoch in range(num_epochs):
    total_loss = 0.0
    correct = 0
    total = 0
    model.train()

    pbar = tqdm(trainloader, desc=f'Epoch {epoch+1}/{num_epochs}')
    for batch_idx, (inputs, labels) in enumerate(pbar):
      #forward and backward pass
      inputs, labels = inputs.to(device), labels.to(device)
      outputs = model(inputs)
      if teacher_model:
        with torch.no_grad():
          soft_labels = F.softmax(teacher_model(inputs) / temperature, dim=1) #get soft labels from teacher model
        soft_preds = F.log_softmax(outputs / temperature, dim=1)
        kld_loss = kld_loss_criterion(soft_preds, soft_labels) * (temperature ** 2)
        ce_loss = ce_loss_criterion(outputs, labels)
        loss = alpha * kld_loss + (1 - alpha) * ce_loss #combine CELoss and KL-Divergence Loss
      else:
        loss = ce_loss_criterion(outputs, labels)
      optimizer.zero_grad()
      loss.backward()
      optimizer.step()

      #update running statistics
      total_loss += loss.item()
      preds = outputs.argmax(dim=-1)
      correct += preds.eq(labels).sum()
      total += labels.size(0)
      if premature_stop:
        if batch_idx >= premature_stop: return total_loss, correct / float(total)

      #update progress bar
      pbar.set_postfix({
        'Loss': f'{loss.item():.4f}',
        'Acc': f'{100.*correct/total:.2f}%'
      })
      epoch_values.append(epoch)
      iters.append(batch_idx)
      train_losses.append(loss.item())
      train_acc.append(correct / float(total))

    # update learning rate
    scheduler.step()

    # Calculate epoch metrics
    epoch_loss = total_loss / len(trainloader)
    epoch_acc = 100. * correct / total
    test_acc = eval_model(model, testloader, device)

    train_epoch_losses.append(epoch_loss)
    train_epoch_acc.append(epoch_acc)
    test_epoch_acc.append(test_acc)

    print(f'Epoch {epoch+1}: Train CELoss: {epoch_loss:.4f}, Train Acc: {epoch_acc:.2f}%, Test Acc: {test_acc:.2f}%')

  results = {
    'train_epoch_losses': train_epoch_losses,
    'train_epoch_acc': train_epoch_acc,
    'test_epoch_acc': test_epoch_acc,
    'epoch_values': epoch_values,
    'iters': iters,
    'train_losses': train_losses,
    'train_acc': train_acc
  }
  return results

In [7]:
def save_stats(results, output_file_prefix, ts):
  #save batch-level stats
  with open(f'/content/drive/My Drive/ML + Robotics/{output_file_prefix}_batchstats_{ts}.csv', 'w') as f:
    for i in range(len(results['train_losses'])):
      row = f'''{results['epoch_values'][i]},{results['iters'][i]},{results['train_losses'][i]},{results['train_acc'][i]}\n'''
      f.write(row)

  #save epoch-level stats
  with open(f'/content/drive/My Drive/ML + Robotics/{output_file_prefix}_epochstats_{ts}.csv', 'w') as f:
    for i in range(len(results['train_epoch_losses'])):
      row = f'''{i},{results['train_epoch_losses'][i]},{results['train_epoch_acc'][i]},{results['test_epoch_acc'][i]}\n'''
      f.write(row)

### Training Code

In [8]:
#define compute device
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
print('running using', device)

#get data
trainset, testset, trainloader, testloader = get_cifar10_data(128)

#get teacher model
resnet_teacher = get_teacher_model(device)

running using cuda


100%|██████████| 170M/170M [00:03<00:00, 43.3MB/s]
Downloading: "https://huggingface.co/edadaltocg/resnet18_cifar10/resolve/main/pytorch_model.bin" to /root/.cache/torch/hub/checkpoints/resnet18_cifar10.pth
100%|██████████| 42.7M/42.7M [00:00<00:00, 168MB/s]


In [None]:
#find best initial learning rate for vit_normal
lr_tests = []
for lr in (torch.rand(20) * 1e-2):
  model = init_model(device)
  loss, acc = train(model, trainloader, testloader,
                    device, learning_rate=lr.item(), weight_decay=1e-2,
                    num_epochs=1, premature_stop=50)
  lr_tests.append((lr.item(), loss, acc))

Epoch 1/10:  13%|█▎        | 50/391 [00:24<02:48,  2.02it/s, Loss=2.1464, Acc=18.22%]
Epoch 1/10:  13%|█▎        | 50/391 [00:25<02:53,  1.96it/s, Loss=2.1424, Acc=19.28%]
Epoch 1/10:  13%|█▎        | 50/391 [00:24<02:45,  2.06it/s, Loss=2.1786, Acc=15.70%]
Epoch 1/10:  13%|█▎        | 50/391 [00:24<02:46,  2.05it/s, Loss=2.0698, Acc=19.06%]
Epoch 1/10:  13%|█▎        | 50/391 [00:24<02:43,  2.08it/s, Loss=2.1889, Acc=18.56%]
Epoch 1/10:  13%|█▎        | 50/391 [00:24<02:47,  2.04it/s, Loss=2.0553, Acc=19.00%]
Epoch 1/10:  13%|█▎        | 50/391 [00:24<02:45,  2.06it/s, Loss=2.0185, Acc=17.66%]
Epoch 1/10:  13%|█▎        | 50/391 [00:24<02:49,  2.01it/s, Loss=2.2569, Acc=15.70%]
Epoch 1/10:  13%|█▎        | 50/391 [00:24<02:45,  2.06it/s, Loss=2.1306, Acc=18.20%]
Epoch 1/10:  13%|█▎        | 50/391 [00:24<02:45,  2.06it/s, Loss=2.0313, Acc=19.72%]
Epoch 1/10:  13%|█▎        | 50/391 [00:24<02:46,  2.05it/s, Loss=2.1607, Acc=17.38%]
Epoch 1/10:  13%|█▎        | 50/391 [00:24<02:46,  2.0

In [None]:
lr_tests

[(0.005193037446588278, 110.90308237075806, tensor(0.1814, device='cuda:0')),
 (0.00609348900616169, 109.90362119674683, tensor(0.1932, device='cuda:0')),
 (0.008939031511545181, 116.20408749580383, tensor(0.1572, device='cuda:0')),
 (0.0035522496327757835, 110.400723695755, tensor(0.1903, device='cuda:0')),
 (0.004252366255968809, 110.75607967376709, tensor(0.1873, device='cuda:0')),
 (0.0046906135976314545, 110.44552218914032, tensor(0.1890, device='cuda:0')),
 (0.005321366712450981, 110.75724804401398, tensor(0.1768, device='cuda:0')),
 (0.009965403936803341, 115.35032510757446, tensor(0.1573, device='cuda:0')),
 (0.0027509015053510666, 111.67822575569153, tensor(0.1820, device='cuda:0')),
 (0.0017359828343614936, 109.7705659866333, tensor(0.1972, device='cuda:0')),
 (0.0054545593447983265, 111.92306506633759, tensor(0.1728, device='cuda:0')),
 (0.004323772620409727, 110.27522850036621, tensor(0.1838, device='cuda:0')),
 (0.0034549450501799583, 109.8525071144104, tensor(0.1898, devi

In [None]:
#best normal LR: 4e-3

0.003

In [None]:
torch.rand(1).item()

0.45901137590408325

In [None]:
#find best initial learning rate, temp, and alpha for vit_student
lr_tests = []
for i in range(20):
  model = init_model(device)
  lr = torch.rand(1).item() * 1e-2
  temp, alpha = torch.rand(1).item(), torch.rand(1).item()
  loss, acc = train(model, trainloader, testloader,
                    device, learning_rate=lr, weight_decay=1e-2,
                    num_epochs=1, teacher_model=resnet_teacher,
                    temperature=temp, alpha=alpha, premature_stop=50)
  lr_tests.append((lr, temp, alpha, loss, acc))

lr_tests

Epoch 1/1:  13%|█▎        | 50/391 [01:39<11:21,  2.00s/it, Loss=1.6324, Acc=15.91%]
Epoch 1/1:  13%|█▎        | 50/391 [01:41<11:31,  2.03s/it, Loss=0.9067, Acc=18.11%]
Epoch 1/1:  13%|█▎        | 50/391 [01:41<11:28,  2.02s/it, Loss=0.8176, Acc=23.42%]
Epoch 1/1:  13%|█▎        | 50/391 [01:40<11:28,  2.02s/it, Loss=0.6902, Acc=12.86%]
Epoch 1/1:  13%|█▎        | 50/391 [01:41<11:29,  2.02s/it, Loss=1.8412, Acc=14.67%]
Epoch 1/1:  13%|█▎        | 50/391 [01:40<11:28,  2.02s/it, Loss=2.1729, Acc=17.14%]
Epoch 1/1:  13%|█▎        | 50/391 [01:41<11:31,  2.03s/it, Loss=1.6873, Acc=16.53%]
Epoch 1/1:  13%|█▎        | 50/391 [01:40<11:28,  2.02s/it, Loss=1.7559, Acc=11.17%]
Epoch 1/1:  13%|█▎        | 50/391 [01:41<11:29,  2.02s/it, Loss=1.5145, Acc=19.47%]
Epoch 1/1:  13%|█▎        | 50/391 [01:40<11:28,  2.02s/it, Loss=1.2250, Acc=19.44%]
Epoch 1/1:  13%|█▎        | 50/391 [01:40<11:28,  2.02s/it, Loss=1.2611, Acc=15.00%]
Epoch 1/1:  13%|█▎        | 50/391 [01:40<11:28,  2.02s/it, Loss=

[(0.009176899194717408,
  0.1093798279762268,
  0.2829812169075012,
  83.22928214073181,
  tensor(0.1602, device='cuda:0')),
 (0.0023594468832015994,
  0.9462930560112,
  0.7444887757301331,
  48.420181930065155,
  tensor(0.1818, device='cuda:0')),
 (0.0001073598861694336,
  0.2125471830368042,
  0.6206677556037903,
  42.446245431900024,
  tensor(0.2339, device='cuda:0')),
 (0.007556940913200379,
  0.8694360256195068,
  0.9084545969963074,
  40.27882248163223,
  tensor(0.1288, device='cuda:0')),
 (0.008993620276451111,
  0.4444413185119629,
  0.19185703992843628,
  98.52953815460205,
  tensor(0.1478, device='cuda:0')),
 (0.00610435426235199,
  0.1768862009048462,
  0.06488102674484253,
  105.27413606643677,
  tensor(0.1717, device='cuda:0')),
 (0.008500646948814392,
  0.6295808553695679,
  0.293637216091156,
  88.76784038543701,
  tensor(0.1659, device='cuda:0')),
 (0.00948402762413025,
  0.7624981999397278,
  0.30410802364349365,
  94.28015768527985,
  tensor(0.1112, device='cuda:0'))

In [None]:
'''
best: (0.0001073598861694336,
  0.2125471830368042,
  0.6206677556037903,
  42.446245431900024,
  tensor(0.2339, device='cuda:0')),
'''

In [21]:
#initialize models
vit_normal = init_model(device)
vit_student = init_model(device)
resnet_teacher = get_teacher_model(device)

vit_param_count = sum([p.numel() for p in vit_normal.parameters()])
resnet_param_count = sum([p.numel() for p in resnet_teacher.parameters()])
print(f'vit_normal and vit_student parameter count: {vit_param_count:,}')
print(f'resnet18 teacher parameter count: {resnet_param_count:,}')
print('\n')

vit_normal and vit_student parameter count: 5,526,346
resnet18 teacher parameter count: 11,173,962




In [14]:
#show initial performance of model before training
print('evaluating accuracy of default parameters on CIFAR10 test set...')
acc = eval_model(vit_normal, testloader, device)
print(f'initial accuracy {acc:,}')

evaluating accuracy of default parameters on CIFAR10 test set...
initial accuracy 0.11169999837875366


In [20]:
#define hyperparams/training constants
WEIGHT_DECAY = 1e-2
NUM_EPOCHS = 5

In [18]:
#learning_rate = 4e-3
learning_rate = 1e-3

#train vit_normal, store results
print('training vit using normal pre-training approach')
vit_normal_results = train(
  vit_normal,
  trainloader,
  testloader,
  device,
  learning_rate=learning_rate,
  weight_decay=WEIGHT_DECAY,
  num_epochs=NUM_EPOCHS)
ts = datetime.datetime.now().strftime('%Y%m%d_%H%M%S')
save_stats(vit_normal_results, 'vit_normal', ts)
torch.save(vit_normal.state_dict(), f'/content/drive/My Drive/ML + Robotics/vit_normal_{ts}.pt')


training vit using normal pre-training approach


Epoch 1/10: 100%|██████████| 391/391 [03:06<00:00,  2.10it/s, Loss=1.5480, Acc=30.55%]


Epoch 1: Train CELoss: 1.8496, Train Acc: 30.55%, Test Acc: 0.38%


Epoch 2/10: 100%|██████████| 391/391 [03:09<00:00,  2.06it/s, Loss=1.4876, Acc=42.53%]


Epoch 2: Train CELoss: 1.5649, Train Acc: 42.53%, Test Acc: 0.45%


Epoch 3/10: 100%|██████████| 391/391 [03:09<00:00,  2.06it/s, Loss=1.3978, Acc=48.14%]


Epoch 3: Train CELoss: 1.4213, Train Acc: 48.14%, Test Acc: 0.50%


Epoch 4/10: 100%|██████████| 391/391 [03:09<00:00,  2.06it/s, Loss=1.1354, Acc=52.10%]


Epoch 4: Train CELoss: 1.3242, Train Acc: 52.10%, Test Acc: 0.52%


Epoch 5/10: 100%|██████████| 391/391 [03:09<00:00,  2.06it/s, Loss=1.2215, Acc=54.95%]


Epoch 5: Train CELoss: 1.2436, Train Acc: 54.95%, Test Acc: 0.53%


Epoch 6/10: 100%|██████████| 391/391 [03:09<00:00,  2.06it/s, Loss=1.3412, Acc=57.88%]


Epoch 6: Train CELoss: 1.1639, Train Acc: 57.88%, Test Acc: 0.57%


Epoch 7/10: 100%|██████████| 391/391 [03:09<00:00,  2.06it/s, Loss=1.2831, Acc=61.15%]


Epoch 7: Train CELoss: 1.0798, Train Acc: 61.15%, Test Acc: 0.58%


Epoch 8/10: 100%|██████████| 391/391 [03:09<00:00,  2.07it/s, Loss=1.0087, Acc=64.20%]


Epoch 8: Train CELoss: 0.9962, Train Acc: 64.20%, Test Acc: 0.60%


Epoch 9/10: 100%|██████████| 391/391 [03:09<00:00,  2.06it/s, Loss=1.0264, Acc=67.51%]


Epoch 9: Train CELoss: 0.9134, Train Acc: 67.51%, Test Acc: 0.61%


Epoch 10/10: 100%|██████████| 391/391 [03:08<00:00,  2.07it/s, Loss=0.6833, Acc=69.92%]


Epoch 10: Train CELoss: 0.8489, Train Acc: 69.92%, Test Acc: 0.62%


In [22]:
'''(0.0001073598861694336,
  0.2125471830368042,
  0.6206677556037903,
  42.446245431900024,
  tensor(0.2339, device='cuda:0')),'''

student_learning_rate = 1e-4
temperature = 0.212
alpha=0.62

#use train vit_student with teacher model
print('training vit using distillation/student-teacher approach')
vit_student_results = train(
  vit_normal,
  trainloader,
  testloader,
  device,
  learning_rate=student_learning_rate,
  weight_decay=WEIGHT_DECAY,
  num_epochs=NUM_EPOCHS,
  teacher_model=resnet_teacher,
  temperature=temperature,
  alpha=alpha)
ts = datetime.datetime.now().strftime('%Y%m%d_%H%M%S')
save_stats(vit_student_results, 'vit_student', ts)
torch.save(vit_student.state_dict(), f'/content/drive/My Drive/ML + Robotics/vit_normal_{ts}.pt')

training vit using distillation/student-teacher approach


Epoch 1/5: 100%|██████████| 391/391 [13:00<00:00,  2.00s/it, Loss=0.7125, Acc=32.40%]


Epoch 1: Train CELoss: 0.7524, Train Acc: 32.40%, Test Acc: 0.38%


Epoch 2/5: 100%|██████████| 391/391 [12:59<00:00,  1.99s/it, Loss=0.6433, Acc=43.66%]


Epoch 2: Train CELoss: 0.6698, Train Acc: 43.66%, Test Acc: 0.46%


Epoch 3/5: 100%|██████████| 391/391 [12:59<00:00,  1.99s/it, Loss=0.6622, Acc=50.65%]


Epoch 3: Train CELoss: 0.6234, Train Acc: 50.65%, Test Acc: 0.52%


Epoch 4/5: 100%|██████████| 391/391 [12:59<00:00,  1.99s/it, Loss=0.6161, Acc=55.14%]


Epoch 4: Train CELoss: 0.5888, Train Acc: 55.14%, Test Acc: 0.56%


Epoch 5/5: 100%|██████████| 391/391 [12:58<00:00,  1.99s/it, Loss=0.5758, Acc=59.14%]


Epoch 5: Train CELoss: 0.5618, Train Acc: 59.14%, Test Acc: 0.58%


### Comparing Performance metrics