In [1]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision.models as models
import torchvision
import numpy as np

import matplotlib.pyplot as plt

device = torch.device('xpu')

In [2]:
resnet18 = models.resnet18(weights=models.ResNet18_Weights.DEFAULT)
resnet18.to(device)
for param in resnet18.parameters(): #set requires grad to false: not training
    param.requires_grad = False

In [3]:
resnet18

ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
  

In [4]:
#create dataset with pytorch datset and dataloaders
transform = torchvision.transforms.Compose(
  torchvision.transforms.ToTensor(),
  torchvision.transforms.Resize(224, 224),
  torchvision.transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])  # ImageNet normalization
)
TRAIN_BATCH_SIZE = 128

trainset = torchvision.datasets.CIFAR10(
    root='./data', 
    train=True,
    download=True, 
    transform=transform,
)
trainloader = torch.utils.data.DataLoader(
    trainset, 
    batch_size=TRAIN_BATCH_SIZE,
    shuffle=True, 
    num_workers=2
)

testset = torchvision.datasets.CIFAR10(
    root='./data', 
    train=False,
    download=True, 
    transform=transform,
)
testloader = torch.utils.data.DataLoader(
    testset, 
    batch_size=200,
    shuffle=False, 
    num_workers=2
)

KeyError: 224

In [None]:
#accuracy of classification
def eval_model(model, testloader, device):
  acc_list = []
  denom = 0
  for i, data in enumerate(testloader, 0):
    inputs, labels = data
    inputs, labels = inputs.to(device), labels.to(device)
    out = model(inputs)
    preds = out.argmax(dim=-1)
    acc = preds.eq(labels).sum()
    denom += inputs.shape[0]
    acc_list.append(acc)
  return sum(acc_list) / denom

In [None]:
#replace final output layer of resnet18 with a new linear layer, out_features=10 for CIFAR10

in_features = resnet18.fc.in_features
resnet18.fc = nn.Linear(in_features, 10, bias=False) #new classification projection
resnet18.fc.weight.requires_grad = True #ensure requires grad is set to true

In [None]:
#test the initial model's perf on cifar10 test set
eval_model(resnet18, testloader, device)

tensor(0.3781, device='xpu:0')

In [None]:
optimizer = optim.Adam(params=resnet18.fc.parameters(), lr=0.002, weight_decay=0.001)
criterion = nn.CrossEntropyLoss()

In [None]:
for epoch in range(5):
  resnet18.train()
  for i, data in enumerate(trainloader, 0):
    inputs, labels = data
    inputs, labels = inputs.to(device), labels.to(device)
    outputs = resnet18(inputs)
    optimizer.zero_grad()
    loss = criterion(outputs, labels)
    loss.backward()
    optimizer.step()

    if i % 200 == 99:
      print(f'epoch {epoch}, batch {i}: loss = {loss.item()}')

  #eval model at end of epoch
  resnet18.eval()
  acc = eval_model(resnet18, testloader, device).item()
  print('test accuracy: ', acc, '\n')

epoch 0, batch 99: loss = 1.6583102941513062
epoch 0, batch 299: loss = 1.508906364440918
test accuracy:  0.4423999786376953 

epoch 1, batch 99: loss = 1.5073816776275635
epoch 1, batch 299: loss = 1.6249878406524658
test accuracy:  0.45319998264312744 

epoch 2, batch 99: loss = 1.5966134071350098
epoch 2, batch 299: loss = 1.5739390850067139
test accuracy:  0.4372999966144562 

epoch 3, batch 99: loss = 1.6781584024429321
epoch 3, batch 299: loss = 1.7415069341659546
test accuracy:  0.4412999749183655 

epoch 4, batch 99: loss = 1.503389835357666
epoch 4, batch 299: loss = 1.4510746002197266
test accuracy:  0.4481000006198883 

