# Intro:
This is an implementation of this: https://arxiv.org/pdf/2212.10717 paper on a Resnet34 model using CIFAR100

## Load Data

In [None]:
import torch as t
import torch.nn as nn
import torch.nn.functional as F
import torchvision as tv
import torchvision.transforms as transforms

# Define the transform to normalize the data

# CIFAR100 Mean and Std
mean = (0.5071, 0.4867, 0.4408)
std = (0.2675, 0.2565, 0.2761)

transform_train = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(mean, std),
])

transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean, std),
])


def get_dataset():
  # Load the CIFAR-100 training and test datasets
  trainset = tv.datasets.CIFAR100(root='data', train=True,
                      download=True, transform=transform_train)


  testset = tv.datasets.CIFAR100(root='data', train=False,
                      download=True, transform=transform_test)
  return trainset, testset

trainset, testset = get_dataset()


classes = trainset.classes

In [None]:
import numpy as np
import matplotlib.pyplot as plt

def imshow(img: t.Tensor):
  img = img / 2 + 0.5  # unnormalize
  npimg = img.numpy()
  plt.imshow(np.transpose(npimg, (1, 2, 0)))
  plt.show()

# Get a random image from the trainset
image, label = trainset[3]

# Display the image
imshow(image)
print(classes[label])

## Train Loop

In [None]:
from dataclasses import dataclass, field
from tqdm import tqdm

@dataclass
class ResNetTrainingArgs:
  batch_size: int = 64
  epochs: int = 5
  n_classes: int = 100
  optimizer: t.optim.Optimizer = t.optim.SGD
  optimizer_args: dict[str, float] = field(default_factory=lambda: {"lr": 0.1})
  scheduler: t.optim.lr_scheduler = t.optim.lr_scheduler.CosineAnnealingLR
  scheduler_args: dict[str, float] = field(default_factory=lambda: {"T_max": 150})


def train_and_validate(args: ResNetTrainingArgs, model: nn.Module, device: t.device) -> tuple[list[float], list[float], list[float], nn.Module]:
  """
  Trains and validates the ResNet model, returning loss, train accuracy, val accuracy, and the trained model.
  """
  optimizer = args.optimizer(model.parameters(), **args.optimizer_args)
  scheduler = args.scheduler(optimizer, **args.scheduler_args)

  # Load dataset
  trainset, testset = get_dataset()
  trainloader = t.utils.data.DataLoader(trainset, batch_size=args.batch_size, shuffle=True, num_workers=2)
  testloader = t.utils.data.DataLoader(testset, batch_size=args.batch_size, shuffle=False, num_workers=2)

  # Lists to store metrics
  loss_list = []
  train_accuracy_list = []
  val_accuracy_list = []

  for epoch in range(args.epochs):
    model.train()
    pbar_train = tqdm(trainloader, desc=f"Epoch {epoch+1}/{args.epochs}")

    total_train = 0
    correct_train = 0

    for imgs, labels in pbar_train:
      # Move data to device
      imgs, labels = imgs.to(device), labels.to(device)

      # Forward pass
      logits = model(imgs)
      loss = F.cross_entropy(logits, labels)

      # Backward pass and optimization
      loss.backward()
      optimizer.step()
      optimizer.zero_grad()

      # Update loss log
      loss_list.append(loss.item())

      # Calculate Train Accuracy
      predicted = t.argmax(logits, dim=-1)
      correct_train += (predicted == labels).sum().item()
      total_train += labels.size(0)

      # Update progress bar
      train_accuracy = 100 * correct_train / total_train
      pbar_train.set_postfix(loss=f"{loss:.3f}", train_acc=f"{train_accuracy:.2f}%")

    # Record train accuracy for the epoch
    train_accuracy_list.append(train_accuracy)

    # ----------------------
    # Validation Loop
    # ----------------------
    model.eval()
    correct_val = 0
    total_val = 0

    with t.inference_mode():
      for imgs, labels in testloader:
        imgs, labels = imgs.to(device), labels.to(device)
        logits = model(imgs)
        predicted = t.argmax(logits, dim=-1)
        correct_val += (predicted == labels).sum().item()
        total_val += labels.size(0)

    val_accuracy = 100 * correct_val / total_val
    val_accuracy_list.append(val_accuracy)

    print(f"Epoch [{epoch+1}/{args.epochs}] - Train Acc: {train_accuracy:.2f}% | Val Acc: {val_accuracy:.2f}%")

    scheduler.step()

  return loss_list, train_accuracy_list, val_accuracy_list


In [None]:
from widenet import WideNet

device = t.device("mps" if t.backends.mps.is_available() else "cuda" if t.cuda.is_available() else "cpu")
training_args = ResNetTrainingArgs(
  batch_size=128,
  epochs=200,
  n_classes=100,
  optimizer=t.optim.SGD,
  optimizer_args={"lr": 0.1, "momentum": 0.9, "weight_decay": 5e-4},
  scheduler=t.optim.lr_scheduler.MultiStepLR,
  scheduler_args={"milestones": [60, 120, 160], "gamma": 0.2}
  )

widenet28 = WideNet(n_classes=100).to(device)

device

In [None]:
loss_list, train_accuracy_list, val_accuracy_list = train_and_validate(training_args, widenet28, device)