# Intro:
This is an implementation of this: https://arxiv.org/pdf/2212.10717 paper on a Resnet34 model using CIFAR100

## Load Data

In [None]:
import torch as t
import torch.nn as nn
import torch.nn.functional as F
import torchvision as tv
import torchvision.transforms as transforms

# Define the transform to normalize the data
transform = transforms.Compose(
  [transforms.ToTensor(),
   transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

def get_dataset():
  # Load the CIFAR-100 training and test datasets
  trainset = tv.datasets.CIFAR100(root='data', train=True,
                      download=True, transform=transform)


  testset = tv.datasets.CIFAR100(root='data', train=False,
                      download=True, transform=transform)
  return trainset, testset

trainset, testset = get_dataset()


classes = trainset.classes

In [None]:
import numpy as np
import matplotlib.pyplot as plt

def imshow(img: t.Tensor):
  img = img / 2 + 0.5  # unnormalize
  npimg = img.numpy()
  plt.imshow(np.transpose(npimg, (1, 2, 0)))
  plt.show()

# Get a random image from the trainset
image, label = trainset[3]

# Display the image
imshow(image)
print(classes[label])

## Train Loop

In [None]:
from dataclasses import dataclass
from tqdm import tqdm

@dataclass
class ResNetTrainingArgs:
  batch_size: int = 64
  num_workers: int = 2
  epochs: int = 20
  learning_rate: float = 1e-3
  n_classes: int = 100


def train_and_validate(args: ResNetTrainingArgs, model: nn.Module, device: t.device) -> tuple[list[float], list[float], nn.Module]:
  """
  Performs feature extraction on ResNet, returning the model & lists of loss and accuracy.
  """
  # YOUR CODE HERE - write your train function for feature extraction
  optimizer = t.optim.AdamW(model.parameters(), lr=args.learning_rate)

  trainset, testset = get_dataset()
  trainloader = t.utils.data.DataLoader(trainset, batch_size=args.batch_size,
                       shuffle=True, num_workers=args.num_workers)
  testloader = t.utils.data.DataLoader(testset, batch_size=args.batch_size,
                      shuffle=False, num_workers=args.num_workers)

  loss_list = []
  accuracy_list = []

  for epoch in range(args.epochs):
    pbar_train = tqdm(trainloader)

    model.train()
    for imgs, labels in pbar_train:
      # Move data to device, perform forward pass
      imgs, labels = imgs.to(device), labels.to(device)
      logits = model(imgs)

      # Calculate loss, perform backward pass
      loss = F.cross_entropy(logits, labels)
      loss.backward()
      optimizer.step()
      optimizer.zero_grad()

      # Update logs & progress bar
      loss_list.append(loss.item())
      pbar_train.set_postfix(epoch=f"{epoch+1}/{args.epochs}", loss=f"{loss:.3f}")


    correct_results = 0
    model.eval()
    for imgs, labels in testloader:
      # Move data to device, perform forward pass
      imgs, labels = imgs.to(device), labels.to(device)

      with t.inference_mode():
        logits: t.Tensor = model(imgs)

        predicted = t.argmax(logits, dim=-1)

        correct_results += (predicted == labels).sum().item()

    # Update logs
    accuracy = correct_results/len(testset)
    accuracy_list.append(accuracy)

    print(f"Epoch Accuracy: {accuracy}")

  return loss_list, accuracy_list, model

In [None]:
from model import ResNet34

device = t.device("mps" if t.backends.mps.is_available() else "cuda" if t.cuda.is_available() else "cpu")
training_args = ResNetTrainingArgs(batch_size=128, num_workers=2, epochs=10, learning_rate=0.001, n_classes=100)

model = ResNet34(n_classes=100).to(device)

device

In [None]:
loss_list, accuracy_list, model = train_and_validate(training_args, model, device)