# SmallNORB

In [None]:
!pip install datasets
!pip install transformers

Collecting datasets
  Downloading datasets-2.15.0-py3-none-any.whl (521 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m521.2/521.2 kB[0m [31m7.0 MB/s[0m eta [36m0:00:00[0m
Collecting pyarrow-hotfix (from datasets)
  Downloading pyarrow_hotfix-0.5-py3-none-any.whl (7.8 kB)
Collecting dill<0.3.8,>=0.3.0 (from datasets)
  Downloading dill-0.3.7-py3-none-any.whl (115 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m115.3/115.3 kB[0m [31m17.4 MB/s[0m eta [36m0:00:00[0m
Collecting multiprocess (from datasets)
  Downloading multiprocess-0.70.15-py310-none-any.whl (134 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m134.8/134.8 kB[0m [31m12.8 MB/s[0m eta [36m0:00:00[0m
Collecting huggingface-hub>=0.18.0 (from datasets)
  Downloading huggingface_hub-0.19.4-py3-none-any.whl (311 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m311.7/311.7 kB[0m [31m35.2 MB/s[0m eta [36m0:00:00[0m
Installing collect

In [None]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.init as init
import torchvision as tv
import matplotlib.pyplot as plt
import seaborn as sns

from tqdm import tqdm
from torchvision import transforms
from torchvision.transforms import v2
from datasets import load_dataset
from transformers import ConvNextV2ForImageClassification
from transformers.models.convnextv2.modeling_convnextv2 import ConvNextV2Embeddings

sns.set_theme()

In [None]:
!nvidia-smi

Fri Nov 17 00:59:10 2023       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 525.105.17   Driver Version: 525.105.17   CUDA Version: 12.0     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  Tesla T4            Off  | 00000000:00:04.0 Off |                    0 |
| N/A   42C    P8     9W /  70W |      0MiB / 15360MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Proces

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

device(type='cuda')

In [None]:
smallnorb = load_dataset("Ramos-Ramos/smallnorb")

Downloading readme:   0%|          | 0.00/5.43k [00:00<?, ?B/s]

Downloading data files:   0%|          | 0/2 [00:00<?, ?it/s]

Downloading data:   0%|          | 0.00/118M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/118M [00:00<?, ?B/s]

Extracting data files:   0%|          | 0/2 [00:00<?, ?it/s]

Generating train split:   0%|          | 0/24300 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/24300 [00:00<?, ? examples/s]

In [None]:
class SmallNORBDataset(torch.utils.data.Dataset):
  def __init__(self, hf_dataset, subset="train", transform=None):
    self.hf_dataset = hf_dataset
    self.subset = subset
    self.transform = transform

  def __len__(self):
    return len(self.hf_dataset[self.subset])

  def __getitem__(self, idx):
    sample = self.hf_dataset[self.subset][idx]
    image = sample["image_lt"]
    label = sample["category"]

    if self.transform:
      image = self.transform(image)

    return image, label

## VGG16

In [None]:
transforms_train = transforms.Compose([
    transforms.RandomResizedCrop(size=(224,224), scale=(0.8, 1.0)),
    transforms.RandomRotation(degrees=30),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize([0.485], [0.229])
])

transforms_val = transforms.Compose([
    transforms.Resize(size=(224,224)),
    transforms.ToTensor(),
    transforms.Normalize([0.485], [0.229])
])

transforms_test = transforms.Compose([
    transforms.Resize(size=(224,224)),
    transforms.ToTensor(),
    transforms.Normalize([0.485], [0.229])
])

In [None]:
train_ds = SmallNORBDataset(smallnorb, "train", transform=transforms_train)
val_ds = SmallNORBDataset(smallnorb, "train", transform=transforms_val)

train_size = int(0.8 * len(train_ds))
val_size = len(train_ds) - train_size

train_ds, _ = torch.utils.data.random_split(train_ds, [train_size, val_size])
_, val_ds = torch.utils.data.random_split(val_ds, [train_size, val_size])

test_ds = SmallNORBDataset(smallnorb, "test", transform=transforms_test)

In [None]:
IN_CHANNELS = 1
N_CLASSES = 5

vgg16_model = tv.models.vgg16(weights="IMAGENET1K_V1")

Downloading: "https://download.pytorch.org/models/vgg16-397923af.pth" to /root/.cache/torch/hub/checkpoints/vgg16-397923af.pth
100%|██████████| 528M/528M [00:07<00:00, 75.1MB/s]


In [None]:
for param in vgg16_model.parameters():
    param.requires_grad = False

In [None]:
# 1-channel inputs
vgg16_model.features[0] = nn.Conv2d(IN_CHANNELS, 64, kernel_size=(3,3), stride=(1,1), padding=(1,1))

# Add on classifier
n_inputs = vgg16_model.classifier[6].in_features
vgg16_model.classifier[6] = nn.Sequential(
    nn.Linear(n_inputs, 256), nn.ReLU(), nn.Dropout(0.6),
    nn.Linear(256, N_CLASSES), nn.LogSoftmax(dim=1))

In [None]:
vgg16_model = vgg16_model.to(device)

In [None]:
total_params = sum(p.numel() for p in vgg16_model.parameters())
print(f"{total_params:,} total parameters.")
total_trainable_params = sum(
    p.numel() for p in vgg16_model.parameters() if p.requires_grad)
print(f"{total_trainable_params:,} trainable parameters.")

135,309,509 total parameters.
1,050,757 trainable parameters.


In [None]:
BATCH_SIZE = 16
EPOCHS = 35
LEARNING_RATE = 1e-2
MOMENTUM = 0.9
WEIGHT_DECAY = 5e-4

In [None]:
train_dl = torch.utils.data.DataLoader(dataset=train_ds, batch_size=BATCH_SIZE, shuffle=True)
val_dl = torch.utils.data.DataLoader(dataset=val_ds, batch_size=BATCH_SIZE)
test_dl = torch.utils.data.DataLoader(dataset=test_ds, batch_size=BATCH_SIZE)

In [None]:
criterion = nn.NLLLoss()
optimizer = torch.optim.SGD(vgg16_model.parameters(), lr=LEARNING_RATE, momentum=MOMENTUM, weight_decay=WEIGHT_DECAY)
lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
    optimizer=optimizer, mode="max", factor=0.1, patience=11, verbose=True
)

In [None]:
best_val_loss = 1e7
train_losses = list()
val_losses = list()

for epoch in range(EPOCHS):
  # Train
  vgg16_model.train()
  train_loss = 0.0
  for batch in tqdm(train_dl, desc=f"Epoch {epoch + 1}/{EPOCHS}", ncols=100):
    inputs, labels = batch
    inputs = inputs.to(device)
    labels = labels.to(device)
    optimizer.zero_grad()
    outputs = vgg16_model(inputs)
    loss = criterion(outputs, labels)
    loss.backward()
    optimizer.step()
    train_loss += loss.item()

  # Validate
  vgg16_model.eval()
  val_loss = 0.0
  correct_predictions = 0
  total_predictions = 0
  with torch.no_grad():
    for batch in tqdm(val_dl, desc="Validation", ncols=100):
      inputs, labels = batch
      inputs = inputs.to(device)
      labels = labels.to(device)
      outputs = vgg16_model(inputs)
      loss = criterion(outputs, labels)
      val_loss += loss.item()
      _, predicted = torch.max(outputs, 1)
      total_predictions += labels.size(0)
      correct_predictions += (predicted == labels).sum().item()
  lr_scheduler.step(val_loss)

  if val_loss < best_val_loss:
    best_val_loss = val_loss
    torch.save(vgg16_model.state_dict(), "vgg16-smallnorb.pt")

  train_losses.append(train_loss / len(train_dl))
  val_losses.append(val_loss / len(val_dl))

  accuracy = 100 * correct_predictions / total_predictions
  print(f"Epoch {epoch + 1}/{EPOCHS}, Train Loss: {train_loss / len(train_dl):.4f}, Validation Loss: {val_loss / len(val_dl):.4f}, Validation Accuracy: {accuracy:.2f}%")


In [None]:
# Predict on Test set
vgg16_model.eval()
test_loss = 0.0
correct_predictions = 0
total_predictions = 0
with torch.no_grad():
    for batch in tqdm(test_dl, desc="Testing", ncols=100):
      inputs, labels = batch
      inputs = inputs.to(device)
      labels = labels.to(device)
      outputs = vgg16_model(inputs)
      loss = criterion(outputs, labels)
      test_loss += loss.item()
      _, predicted = torch.max(outputs, 1)
      total_predictions += labels.size(0)
      correct_predictions += (predicted == labels).sum().item()

print(f"\nTest set accuracy = {(100 * correct_predictions / total_predictions):.4f}%")

Testing: 100%|██████████████████████████████████████████████████| 1519/1519 [01:37<00:00, 15.59it/s]


Test set accuracy = 91.2305%





## ConvNeXt

In [None]:
transforms_train = transforms.Compose([
    transforms.Resize(size=(224,224)),
    transforms.RandAugment(num_ops=2, magnitude=9),
    transforms.ToTensor(),
    transforms.Normalize([0.485], [0.229]),
    transforms.RandomErasing(p=0.25)
])

transforms_val = transforms.Compose([
    transforms.Resize(size=(224,224)),
    transforms.ToTensor(),
    transforms.Normalize([0.485], [0.229]),
])

transforms_test = transforms.Compose([
    transforms.Resize(size=(224,224)),
    transforms.ToTensor(),
    transforms.Normalize([0.485], [0.229]),
])

In [None]:
train_ds = SmallNORBDataset(smallnorb, "train", transform=transforms_train)
val_ds = SmallNORBDataset(smallnorb, "train", transform=transforms_val)

train_size = int(0.8 * len(train_ds))
val_size = len(train_ds) - train_size

train_ds, _ = torch.utils.data.random_split(train_ds, [train_size, val_size])
_, val_ds = torch.utils.data.random_split(val_ds, [train_size, val_size])

test_ds = SmallNORBDataset(smallnorb, "test", transform=transforms_test)

In [None]:
IN_CHANNELS = 1
N_CLASSES = 5

convnext_model = tv.models.convnext_base(weights="IMAGENET1K_V1")

In [None]:
for param in convnext_model.parameters():
    param.requires_grad = False

In [None]:
# 1-channel inputs
convnext_model.features[0][0] = nn.Conv2d(IN_CHANNELS, 128, kernel_size=(4,4), stride=(4,4))

# Add on classifier
n_inputs = convnext_model.classifier[2].in_features
convnext_model.classifier[2] = nn.Sequential(
    nn.Linear(n_inputs, 256), nn.GELU(), nn.Dropout(0.4),
    nn.Linear(256, N_CLASSES), nn.Softmax(dim=1))

In [None]:
convnext_model = convnext_model.to(device)

In [None]:
total_params = sum(p.numel() for p in convnext_model.parameters())
print(f"{total_params:,} total parameters.")
total_trainable_params = sum(
    p.numel() for p in convnext_model.parameters() if p.requires_grad)
print(f"{total_trainable_params:,} trainable parameters.")

87,826,053 total parameters.
265,861 trainable parameters.


In [None]:
BATCH_SIZE = 16
EPOCHS = 30
LEARNING_RATE = 5e-5
WEIGHT_DECAY = 1e-8

In [None]:
train_dl = torch.utils.data.DataLoader(dataset=train_ds, batch_size=BATCH_SIZE, shuffle=True)
val_dl = torch.utils.data.DataLoader(dataset=val_ds, batch_size=BATCH_SIZE)
test_dl = torch.utils.data.DataLoader(dataset=test_ds, batch_size=BATCH_SIZE)

In [None]:
criterion = nn.CrossEntropyLoss(label_smoothing=0.1)
optimizer = torch.optim.AdamW(convnext_model.parameters(), lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY)
lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer=optimizer, T_max=11)

In [None]:
# Additional ConvNeXt augmentations
cutmix = v2.CutMix(num_classes=N_CLASSES)
mixup = v2.MixUp(num_classes=N_CLASSES)
cutmix_or_mixup = v2.RandomChoice([cutmix, mixup])

In [None]:
best_val_loss = 1e7
train_losses = list()
val_losses = list()

for epoch in range(EPOCHS):
  # Train
  convnext_model.train()
  train_loss = 0.0
  for batch in tqdm(train_dl, desc=f"Epoch {epoch + 1}/{EPOCHS}", ncols=100):
    inputs, labels = batch
    inputs, labels = cutmix_or_mixup(inputs, labels)
    inputs = inputs.to(device)
    labels = labels.to(device)
    optimizer.zero_grad()
    outputs = convnext_model(inputs)
    loss = criterion(outputs, labels)
    loss.backward()
    optimizer.step()
    train_loss += loss.item()

  # Validate
  convnext_model.eval()
  val_loss = 0.0
  correct_predictions = 0
  total_predictions = 0
  with torch.no_grad():
    for batch in tqdm(val_dl, desc="Validation", ncols=100):
      inputs, labels = batch
      inputs = inputs.to(device)
      labels = labels.to(device)
      outputs = convnext_model(inputs)
      loss = criterion(outputs, labels)
      val_loss += loss.item()
      _, predicted = torch.max(outputs, 1)
      total_predictions += labels.size(0)
      correct_predictions += (predicted == labels).sum().item()
  lr_scheduler.step()

  if val_loss < best_val_loss:
    best_val_loss = val_loss
    torch.save(convnext_model.state_dict(), "convnext-smallnorb.pt")

  train_losses.append(train_loss / len(train_dl))
  val_losses.append(val_loss / len(val_dl))

  accuracy = 100 * correct_predictions / total_predictions
  print(f"Epoch {epoch + 1}/{EPOCHS}, Train Loss: {train_loss / len(train_dl):.4f}, Validation Loss: {val_loss / len(val_dl):.4f}, Validation Accuracy: {accuracy:.2f}%")


In [None]:
# Predict on Test set
convnext_model.eval()
test_loss = 0.0
correct_predictions = 0
total_predictions = 0
with torch.no_grad():
    for batch in tqdm(test_dl, desc="Testing", ncols=100):
      inputs, labels = batch
      inputs = inputs.to(device)
      labels = labels.to(device)
      outputs = convnext_model(inputs)
      loss = criterion(outputs, labels)
      test_loss += loss.item()
      _, predicted = torch.max(outputs, 1)
      total_predictions += labels.size(0)
      correct_predictions += (predicted == labels).sum().item()

print(f"\nTest set accuracy = {(100 * correct_predictions / total_predictions):.4f}%")

Testing: 100%|██████████████████████████████████████████████████| 1519/1519 [02:43<00:00,  9.28it/s]


Test set accuracy = 92.4979%





## ConvNeXt V2

In [None]:
transforms_train = transforms.Compose([
    transforms.Resize(size=(224,224)),
    transforms.RandAugment(num_ops=2, magnitude=9),
    transforms.ToTensor(),
    transforms.Normalize([0.485], [0.229]),
    transforms.RandomErasing(p=0.25)
])

transforms_val = transforms.Compose([
    transforms.Resize(size=(224,224)),
    transforms.ToTensor(),
    transforms.Normalize([0.485], [0.229]),
])

transforms_test = transforms.Compose([
    transforms.Resize(size=(224,224)),
    transforms.ToTensor(),
    transforms.Normalize([0.485], [0.229]),
])

In [None]:
train_ds = SmallNORBDataset(smallnorb, "train", transform=transforms_train)
val_ds = SmallNORBDataset(smallnorb, "train", transform=transforms_val)

train_size = int(0.8 * len(train_ds))
val_size = len(train_ds) - train_size

train_ds, _ = torch.utils.data.random_split(train_ds, [train_size, val_size])
_, val_ds = torch.utils.data.random_split(val_ds, [train_size, val_size])

test_ds = SmallNORBDataset(smallnorb, "test", transform=transforms_test)

In [None]:
IN_CHANNELS = 1
N_CLASSES = 5

convnextv2_model = ConvNextV2ForImageClassification.from_pretrained("facebook/convnextv2-base-1k-224")

In [None]:
for param in convnextv2_model.parameters():
    param.requires_grad = False

In [None]:
# 1-channel inputs
convnextv2_model.convnextv2.embeddings.patch_embeddings = nn.Conv2d(IN_CHANNELS, 128, kernel_size=(4,4), stride=(4,4))
convnextv2_model.convnextv2.embeddings.num_channels = IN_CHANNELS

# Add on classifier
n_inputs = convnextv2_model.classifier.in_features
convnextv2_model.classifier = nn.Sequential(
    nn.Linear(n_inputs, 256), nn.GELU(), nn.Dropout(0.4),
    nn.Linear(256, N_CLASSES), nn.Softmax(dim=1))

In [None]:
convnextv2_model = convnextv2_model.to(device)

In [None]:
total_params = sum(p.numel() for p in convnextv2_model.parameters())
print(f"{total_params:,} total parameters.")
total_trainable_params = sum(
    p.numel() for p in convnextv2_model.parameters() if p.requires_grad)
print(f"{total_trainable_params:,} trainable parameters.")

87,952,389 total parameters.
265,861 trainable parameters.


In [None]:
BATCH_SIZE = 16
EPOCHS = 50
LEARNING_RATE = 6.25e-3
WEIGHT_DECAY = 0.05

In [None]:
train_dl = torch.utils.data.DataLoader(dataset=train_ds, batch_size=BATCH_SIZE, shuffle=True)
val_dl = torch.utils.data.DataLoader(dataset=val_ds, batch_size=BATCH_SIZE)
test_dl = torch.utils.data.DataLoader(dataset=test_ds, batch_size=BATCH_SIZE)

In [None]:
criterion = nn.CrossEntropyLoss(label_smoothing=0.1)
optimizer = torch.optim.AdamW(convnextv2_model.parameters(), lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY)
lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer=optimizer, T_max=11)

In [None]:
# Additional ConvNeXt augmentations
cutmix = v2.CutMix(num_classes=N_CLASSES)
mixup = v2.MixUp(alpha=0.8, num_classes=N_CLASSES)
cutmix_or_mixup = v2.RandomChoice([cutmix, mixup])

In [None]:
best_val_loss = 1e7
train_losses = list()
val_losses = list()

for epoch in range(EPOCHS):
  # Train
  convnextv2_model.train()
  train_loss = 0.0
  for batch in tqdm(train_dl, desc=f"Epoch {epoch + 1}/{EPOCHS}", ncols=100):
    inputs, labels = batch
    inputs, labels = cutmix_or_mixup(inputs, labels)
    inputs = inputs.to(device)
    labels = labels.to(device)
    optimizer.zero_grad()
    outputs = convnextv2_model(inputs)
    loss = criterion(outputs[0], labels)
    loss.backward()
    optimizer.step()
    train_loss += loss.item()

  # Validate
  convnextv2_model.eval()
  val_loss = 0.0
  correct_predictions = 0
  total_predictions = 0
  with torch.no_grad():
    for batch in tqdm(val_dl, desc="Validation", ncols=100):
      inputs, labels = batch
      inputs = inputs.to(device)
      labels = labels.to(device)
      outputs = convnextv2_model(inputs)
      loss = criterion(outputs[0], labels)
      val_loss += loss.item()
      _, predicted = torch.max(outputs[0], 1)
      total_predictions += labels.size(0)
      correct_predictions += (predicted == labels).sum().item()
  lr_scheduler.step()

  if val_loss < best_val_loss:
    best_val_loss = val_loss
    torch.save(convnextv2_model.state_dict(), "convnextv2-smallnorb.pt")

  train_losses.append(train_loss / len(train_dl))
  val_losses.append(val_loss / len(val_dl))

  accuracy = 100 * correct_predictions / total_predictions
  print(f"Epoch {epoch + 1}/{EPOCHS}, Train Loss: {train_loss / len(train_dl):.4f}, Validation Loss: {val_loss / len(val_dl):.4f}, Validation Accuracy: {accuracy:.2f}%")


In [None]:
# Predict on Test set
convnextv2_model.eval()
test_loss = 0.0
correct_predictions = 0
total_predictions = 0
with torch.no_grad():
    for batch in tqdm(test_dl, desc="Testing", ncols=100):
      inputs, labels = batch
      inputs = inputs.to(device)
      labels = labels.to(device)
      outputs = convnextv2_model(inputs)
      loss = criterion(outputs[0], labels)
      test_loss += loss.item()
      _, predicted = torch.max(outputs[0], 1)
      total_predictions += labels.size(0)
      correct_predictions += (predicted == labels).sum().item()

print(f"\nTest set accuracy = {(100 * correct_predictions / total_predictions):.4f}%")

Testing: 100%|██████████████████████████████████████████████████| 1519/1519 [03:43<00:00,  6.80it/s]


Test set accuracy = 92.5267%





## Improved ResNet50

In [None]:
transforms_train = transforms.Compose([
    transforms.RandomResizedCrop(size=(224,224), scale=(0.8, 1.0)),
    transforms.RandomRotation(degrees=30),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize([0.485], [0.229])
])

transforms_val = transforms.Compose([
    transforms.Resize(size=(224,224)),
    transforms.ToTensor(),
    transforms.Normalize([0.485], [0.229])
])

transforms_test = transforms.Compose([
    transforms.Resize(size=(224,224)),
    transforms.ToTensor(),
    transforms.Normalize([0.485], [0.229])
])

In [None]:
train_ds = SmallNORBDataset(smallnorb, "train", transform=transforms_train)
val_ds = SmallNORBDataset(smallnorb, "train", transform=transforms_val)

train_size = int(0.8 * len(train_ds))
val_size = len(train_ds) - train_size

train_ds, _ = torch.utils.data.random_split(train_ds, [train_size, val_size])
_, val_ds = torch.utils.data.random_split(val_ds, [train_size, val_size])

test_ds = SmallNORBDataset(smallnorb, "test", transform=transforms_test)

In [None]:
IN_CHANNELS = 1
N_CLASSES = 5

resnet50_model = tv.models.resnet50(weights="IMAGENET1K_V1")

Downloading: "https://download.pytorch.org/models/resnet50-0676ba61.pth" to /root/.cache/torch/hub/checkpoints/resnet50-0676ba61.pth
100%|██████████| 97.8M/97.8M [00:01<00:00, 76.8MB/s]


In [None]:
for param in resnet50_model.parameters():
    param.requires_grad = False

In [None]:
# 1-channel inputs
resnet50_model.conv1 = nn.Conv2d(IN_CHANNELS, 64, kernel_size=(7,7), stride=(2,2), padding=(3,3), bias=False)

# Add on classifier
n_inputs = resnet50_model.fc.in_features
resnet50_model.fc = nn.Sequential(
    nn.Linear(n_inputs, 256), nn.ReLU(), nn.Dropout(0.6),
    nn.Linear(256, N_CLASSES), nn.LogSoftmax(dim=1))

In [None]:
resnet50_model = resnet50_model.to(device)

In [None]:
total_params = sum(p.numel() for p in resnet50_model.parameters())
print(f"{total_params:,} total parameters.")
total_trainable_params = sum(
    p.numel() for p in resnet50_model.parameters() if p.requires_grad)
print(f"{total_trainable_params:,} trainable parameters.")


24,027,589 total parameters.
528,965 trainable parameters.


In [None]:
BATCH_SIZE = 16
EPOCHS = 100
LEARNING_RATE = 5e-3
MOMENTUM = 0.9
WEIGHT_DECAY = 1e-2

In [None]:
train_dl = torch.utils.data.DataLoader(dataset=train_ds, batch_size=BATCH_SIZE, shuffle=True)
val_dl = torch.utils.data.DataLoader(dataset=val_ds, batch_size=BATCH_SIZE)
test_dl = torch.utils.data.DataLoader(dataset=test_ds, batch_size=BATCH_SIZE)

In [None]:
criterion = nn.NLLLoss()
optimizer = torch.optim.SGD(resnet50_model.parameters(), lr=LEARNING_RATE, momentum=MOMENTUM, weight_decay=WEIGHT_DECAY)
lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
    optimizer=optimizer, T_max=35
)

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
best_val_loss = 1e7
train_losses = list()
val_losses = list()

for epoch in range(EPOCHS):
  # Train
  resnet50_model.train()
  train_loss = 0.0
  for batch in tqdm(train_dl, desc=f"Epoch {epoch + 1}/{EPOCHS}", ncols=100):
    inputs, labels = batch
    inputs = inputs.to(device)
    labels = labels.to(device)
    optimizer.zero_grad()
    outputs = resnet50_model(inputs)
    loss = criterion(outputs, labels)
    loss.backward()
    optimizer.step()
    train_loss += loss.item()

  # Validate
  resnet50_model.eval()
  val_loss = 0.0
  correct_predictions = 0
  total_predictions = 0
  with torch.no_grad():
    for batch in tqdm(val_dl, desc="Validation", ncols=100):
      inputs, labels = batch
      inputs = inputs.to(device)
      labels = labels.to(device)
      outputs = resnet50_model(inputs)
      loss = criterion(outputs, labels)
      val_loss += loss.item()
      _, predicted = torch.max(outputs, 1)
      total_predictions += labels.size(0)
      correct_predictions += (predicted == labels).sum().item()
  lr_scheduler.step(val_loss)

  if val_loss < best_val_loss:
    best_val_loss = val_loss
    torch.save(resnet50_model.state_dict(), "/content/drive/MyDrive/UM_Project/checkpoints/resnet50-smallnorb.pt")

  train_losses.append(train_loss / len(train_dl))
  val_losses.append(val_loss / len(val_dl))

  accuracy = 100 * correct_predictions / total_predictions
  print(f"Epoch {epoch + 1}/{EPOCHS}, Train Loss: {train_loss / len(train_dl):.4f}, Validation Loss: {val_loss / len(val_dl):.4f}, Validation Accuracy: {accuracy:.2f}%")

Epoch 1/100: 100%|██████████████████████████████████████████████| 1215/1215 [07:39<00:00,  2.65it/s]
Validation: 100%|█████████████████████████████████████████████████| 304/304 [00:51<00:00,  5.87it/s]


Epoch 1/100, Train Loss: 0.8298, Validation Loss: 0.5647, Validation Accuracy: 76.44%


Epoch 2/100: 100%|██████████████████████████████████████████████| 1215/1215 [07:54<00:00,  2.56it/s]
Validation: 100%|█████████████████████████████████████████████████| 304/304 [00:52<00:00,  5.74it/s]


Epoch 2/100, Train Loss: 0.4987, Validation Loss: 0.2384, Validation Accuracy: 93.70%


Epoch 3/100: 100%|██████████████████████████████████████████████| 1215/1215 [07:39<00:00,  2.64it/s]
Validation: 100%|█████████████████████████████████████████████████| 304/304 [00:51<00:00,  5.92it/s]


Epoch 3/100, Train Loss: 0.6384, Validation Loss: 0.3047, Validation Accuracy: 88.62%


Epoch 4/100: 100%|██████████████████████████████████████████████| 1215/1215 [07:25<00:00,  2.72it/s]
Validation: 100%|█████████████████████████████████████████████████| 304/304 [00:51<00:00,  5.96it/s]


Epoch 4/100, Train Loss: 0.4083, Validation Loss: 0.2239, Validation Accuracy: 92.90%


Epoch 5/100: 100%|██████████████████████████████████████████████| 1215/1215 [07:13<00:00,  2.80it/s]
Validation: 100%|█████████████████████████████████████████████████| 304/304 [00:51<00:00,  5.88it/s]


Epoch 5/100, Train Loss: 0.6297, Validation Loss: 0.4736, Validation Accuracy: 83.35%


Epoch 6/100: 100%|██████████████████████████████████████████████| 1215/1215 [07:26<00:00,  2.72it/s]
Validation: 100%|█████████████████████████████████████████████████| 304/304 [00:51<00:00,  5.89it/s]


Epoch 6/100, Train Loss: 0.6775, Validation Loss: 0.3386, Validation Accuracy: 89.44%


Epoch 7/100: 100%|██████████████████████████████████████████████| 1215/1215 [07:22<00:00,  2.74it/s]
Validation: 100%|█████████████████████████████████████████████████| 304/304 [00:47<00:00,  6.38it/s]


Epoch 7/100, Train Loss: 0.5800, Validation Loss: 0.2450, Validation Accuracy: 93.93%


Epoch 8/100: 100%|██████████████████████████████████████████████| 1215/1215 [07:52<00:00,  2.57it/s]
Validation: 100%|█████████████████████████████████████████████████| 304/304 [00:50<00:00,  6.02it/s]


Epoch 8/100, Train Loss: 0.6826, Validation Loss: 0.3156, Validation Accuracy: 89.63%


Epoch 9/100: 100%|██████████████████████████████████████████████| 1215/1215 [07:35<00:00,  2.67it/s]
Validation: 100%|█████████████████████████████████████████████████| 304/304 [00:48<00:00,  6.22it/s]


Epoch 9/100, Train Loss: 0.4940, Validation Loss: 0.1502, Validation Accuracy: 96.36%


Epoch 10/100: 100%|█████████████████████████████████████████████| 1215/1215 [07:28<00:00,  2.71it/s]
Validation: 100%|█████████████████████████████████████████████████| 304/304 [00:51<00:00,  5.91it/s]


Epoch 10/100, Train Loss: 0.4484, Validation Loss: 0.1324, Validation Accuracy: 96.56%


Epoch 11/100: 100%|█████████████████████████████████████████████| 1215/1215 [07:34<00:00,  2.67it/s]
Validation: 100%|█████████████████████████████████████████████████| 304/304 [00:50<00:00,  6.06it/s]


Epoch 11/100, Train Loss: 0.4076, Validation Loss: 0.1122, Validation Accuracy: 96.93%


Epoch 12/100: 100%|█████████████████████████████████████████████| 1215/1215 [07:33<00:00,  2.68it/s]
Validation: 100%|█████████████████████████████████████████████████| 304/304 [00:50<00:00,  6.08it/s]


Epoch 12/100, Train Loss: 0.4027, Validation Loss: 0.1111, Validation Accuracy: 97.04%


Epoch 13/100: 100%|█████████████████████████████████████████████| 1215/1215 [07:29<00:00,  2.70it/s]
Validation: 100%|█████████████████████████████████████████████████| 304/304 [00:53<00:00,  5.68it/s]


Epoch 13/100, Train Loss: 0.3972, Validation Loss: 0.1039, Validation Accuracy: 97.55%


Epoch 14/100: 100%|█████████████████████████████████████████████| 1215/1215 [07:25<00:00,  2.73it/s]
Validation: 100%|█████████████████████████████████████████████████| 304/304 [00:48<00:00,  6.32it/s]


Epoch 14/100, Train Loss: 0.3957, Validation Loss: 0.1108, Validation Accuracy: 97.08%


Epoch 15/100: 100%|█████████████████████████████████████████████| 1215/1215 [07:29<00:00,  2.70it/s]
Validation: 100%|█████████████████████████████████████████████████| 304/304 [00:50<00:00,  6.00it/s]


Epoch 15/100, Train Loss: 0.3827, Validation Loss: 0.1070, Validation Accuracy: 97.16%


Epoch 16/100: 100%|█████████████████████████████████████████████| 1215/1215 [07:33<00:00,  2.68it/s]
Validation: 100%|█████████████████████████████████████████████████| 304/304 [00:52<00:00,  5.75it/s]


Epoch 16/100, Train Loss: 0.3827, Validation Loss: 0.1028, Validation Accuracy: 97.26%


Epoch 17/100: 100%|█████████████████████████████████████████████| 1215/1215 [07:33<00:00,  2.68it/s]
Validation: 100%|█████████████████████████████████████████████████| 304/304 [00:48<00:00,  6.23it/s]


Epoch 17/100, Train Loss: 0.3968, Validation Loss: 0.1030, Validation Accuracy: 97.26%


Epoch 18/100: 100%|█████████████████████████████████████████████| 1215/1215 [07:22<00:00,  2.75it/s]
Validation: 100%|█████████████████████████████████████████████████| 304/304 [00:51<00:00,  5.94it/s]


Epoch 18/100, Train Loss: 0.3790, Validation Loss: 0.0952, Validation Accuracy: 97.55%


Epoch 19/100: 100%|█████████████████████████████████████████████| 1215/1215 [07:34<00:00,  2.67it/s]
Validation: 100%|█████████████████████████████████████████████████| 304/304 [00:51<00:00,  5.88it/s]


Epoch 19/100, Train Loss: 0.3940, Validation Loss: 0.1149, Validation Accuracy: 96.54%


Epoch 20/100: 100%|█████████████████████████████████████████████| 1215/1215 [07:30<00:00,  2.70it/s]
Validation: 100%|█████████████████████████████████████████████████| 304/304 [00:51<00:00,  5.87it/s]


Epoch 20/100, Train Loss: 0.3853, Validation Loss: 0.1052, Validation Accuracy: 97.22%


Epoch 21/100: 100%|█████████████████████████████████████████████| 1215/1215 [07:39<00:00,  2.65it/s]
Validation: 100%|█████████████████████████████████████████████████| 304/304 [00:47<00:00,  6.35it/s]


Epoch 21/100, Train Loss: 0.3670, Validation Loss: 0.0941, Validation Accuracy: 97.41%


Epoch 22/100: 100%|█████████████████████████████████████████████| 1215/1215 [07:31<00:00,  2.69it/s]
Validation: 100%|█████████████████████████████████████████████████| 304/304 [00:49<00:00,  6.14it/s]


Epoch 22/100, Train Loss: 0.3866, Validation Loss: 0.1307, Validation Accuracy: 95.66%


Epoch 23/100: 100%|█████████████████████████████████████████████| 1215/1215 [07:21<00:00,  2.75it/s]
Validation: 100%|█████████████████████████████████████████████████| 304/304 [00:55<00:00,  5.53it/s]


Epoch 23/100, Train Loss: 0.3623, Validation Loss: 0.0923, Validation Accuracy: 97.43%


Epoch 24/100: 100%|█████████████████████████████████████████████| 1215/1215 [07:38<00:00,  2.65it/s]
Validation: 100%|█████████████████████████████████████████████████| 304/304 [00:53<00:00,  5.65it/s]


Epoch 24/100, Train Loss: 0.3712, Validation Loss: 0.0984, Validation Accuracy: 97.35%


Epoch 25/100: 100%|█████████████████████████████████████████████| 1215/1215 [07:29<00:00,  2.70it/s]
Validation: 100%|█████████████████████████████████████████████████| 304/304 [00:52<00:00,  5.76it/s]


Epoch 25/100, Train Loss: 0.3409, Validation Loss: 0.0992, Validation Accuracy: 97.06%


Epoch 26/100: 100%|█████████████████████████████████████████████| 1215/1215 [07:30<00:00,  2.70it/s]
Validation: 100%|█████████████████████████████████████████████████| 304/304 [00:52<00:00,  5.83it/s]


Epoch 26/100, Train Loss: 0.3278, Validation Loss: 0.1082, Validation Accuracy: 96.69%


Epoch 27/100: 100%|█████████████████████████████████████████████| 1215/1215 [07:43<00:00,  2.62it/s]
Validation: 100%|█████████████████████████████████████████████████| 304/304 [00:51<00:00,  5.85it/s]


Epoch 27/100, Train Loss: 0.3048, Validation Loss: 0.0764, Validation Accuracy: 97.65%


Epoch 28/100: 100%|█████████████████████████████████████████████| 1215/1215 [07:24<00:00,  2.73it/s]
Validation: 100%|█████████████████████████████████████████████████| 304/304 [00:47<00:00,  6.39it/s]


Epoch 28/100, Train Loss: 0.3938, Validation Loss: 0.2038, Validation Accuracy: 93.02%


Epoch 29/100: 100%|█████████████████████████████████████████████| 1215/1215 [07:33<00:00,  2.68it/s]
Validation: 100%|█████████████████████████████████████████████████| 304/304 [00:58<00:00,  5.20it/s]


Epoch 29/100, Train Loss: 0.7841, Validation Loss: 0.4648, Validation Accuracy: 82.14%


Epoch 30/100: 100%|█████████████████████████████████████████████| 1215/1215 [07:27<00:00,  2.72it/s]
Validation: 100%|█████████████████████████████████████████████████| 304/304 [00:56<00:00,  5.39it/s]


Epoch 30/100, Train Loss: 0.7312, Validation Loss: 0.4731, Validation Accuracy: 81.77%


Epoch 31/100: 100%|█████████████████████████████████████████████| 1215/1215 [07:26<00:00,  2.72it/s]
Validation: 100%|█████████████████████████████████████████████████| 304/304 [00:49<00:00,  6.08it/s]


Epoch 31/100, Train Loss: 0.8245, Validation Loss: 0.5369, Validation Accuracy: 78.07%


Epoch 32/100: 100%|█████████████████████████████████████████████| 1215/1215 [07:24<00:00,  2.73it/s]
Validation: 100%|█████████████████████████████████████████████████| 304/304 [00:52<00:00,  5.81it/s]


Epoch 32/100, Train Loss: 0.6346, Validation Loss: 0.2710, Validation Accuracy: 92.90%


Epoch 33/100: 100%|█████████████████████████████████████████████| 1215/1215 [07:42<00:00,  2.63it/s]
Validation: 100%|█████████████████████████████████████████████████| 304/304 [00:51<00:00,  5.85it/s]


Epoch 33/100, Train Loss: 0.7100, Validation Loss: 0.3661, Validation Accuracy: 89.05%


Epoch 34/100: 100%|█████████████████████████████████████████████| 1215/1215 [07:25<00:00,  2.73it/s]
Validation: 100%|█████████████████████████████████████████████████| 304/304 [00:50<00:00,  6.07it/s]


Epoch 34/100, Train Loss: 0.5972, Validation Loss: 0.2664, Validation Accuracy: 91.75%


Epoch 35/100: 100%|█████████████████████████████████████████████| 1215/1215 [03:34<00:00,  5.67it/s]
Validation: 100%|█████████████████████████████████████████████████| 304/304 [00:20<00:00, 14.81it/s]


Epoch 35/100, Train Loss: 0.7303, Validation Loss: 0.4639, Validation Accuracy: 84.01%


Epoch 36/100: 100%|█████████████████████████████████████████████| 1215/1215 [02:54<00:00,  6.98it/s]
Validation: 100%|█████████████████████████████████████████████████| 304/304 [00:20<00:00, 14.63it/s]


Epoch 36/100, Train Loss: 0.7998, Validation Loss: 0.3527, Validation Accuracy: 92.02%


Epoch 37/100: 100%|█████████████████████████████████████████████| 1215/1215 [02:53<00:00,  6.99it/s]
Validation: 100%|█████████████████████████████████████████████████| 304/304 [00:20<00:00, 14.76it/s]


Epoch 37/100, Train Loss: 0.6584, Validation Loss: 0.2609, Validation Accuracy: 94.47%


Epoch 38/100: 100%|█████████████████████████████████████████████| 1215/1215 [02:54<00:00,  6.98it/s]
Validation: 100%|█████████████████████████████████████████████████| 304/304 [00:20<00:00, 14.94it/s]


Epoch 38/100, Train Loss: 0.7204, Validation Loss: 0.4202, Validation Accuracy: 88.05%


Epoch 39/100: 100%|█████████████████████████████████████████████| 1215/1215 [02:54<00:00,  6.96it/s]
Validation: 100%|█████████████████████████████████████████████████| 304/304 [00:20<00:00, 15.13it/s]


Epoch 39/100, Train Loss: 0.6237, Validation Loss: 0.3505, Validation Accuracy: 90.31%


Epoch 40/100: 100%|█████████████████████████████████████████████| 1215/1215 [02:55<00:00,  6.94it/s]
Validation: 100%|█████████████████████████████████████████████████| 304/304 [00:20<00:00, 15.14it/s]


Epoch 40/100, Train Loss: 0.5376, Validation Loss: 0.2622, Validation Accuracy: 93.00%


Epoch 41/100: 100%|█████████████████████████████████████████████| 1215/1215 [02:54<00:00,  6.95it/s]
Validation: 100%|█████████████████████████████████████████████████| 304/304 [00:20<00:00, 15.13it/s]


Epoch 41/100, Train Loss: 0.6650, Validation Loss: 0.4082, Validation Accuracy: 85.08%


Epoch 42/100: 100%|█████████████████████████████████████████████| 1215/1215 [02:54<00:00,  6.96it/s]
Validation: 100%|█████████████████████████████████████████████████| 304/304 [00:20<00:00, 15.20it/s]


Epoch 42/100, Train Loss: 0.6064, Validation Loss: 0.4327, Validation Accuracy: 83.58%


Epoch 43/100: 100%|█████████████████████████████████████████████| 1215/1215 [02:54<00:00,  6.97it/s]
Validation: 100%|█████████████████████████████████████████████████| 304/304 [00:20<00:00, 15.09it/s]


Epoch 43/100, Train Loss: 0.6503, Validation Loss: 0.3921, Validation Accuracy: 85.33%


Epoch 44/100: 100%|█████████████████████████████████████████████| 1215/1215 [02:54<00:00,  6.97it/s]
Validation: 100%|█████████████████████████████████████████████████| 304/304 [00:20<00:00, 15.05it/s]


Epoch 44/100, Train Loss: 0.5688, Validation Loss: 0.1743, Validation Accuracy: 95.74%


Epoch 45/100: 100%|█████████████████████████████████████████████| 1215/1215 [02:53<00:00,  7.00it/s]
Validation: 100%|█████████████████████████████████████████████████| 304/304 [00:20<00:00, 15.04it/s]


Epoch 45/100, Train Loss: 0.5754, Validation Loss: 6.9181, Validation Accuracy: 38.89%


Epoch 46/100: 100%|█████████████████████████████████████████████| 1215/1215 [02:54<00:00,  6.98it/s]
Validation: 100%|█████████████████████████████████████████████████| 304/304 [00:20<00:00, 14.89it/s]


Epoch 46/100, Train Loss: 0.7494, Validation Loss: 0.3658, Validation Accuracy: 88.33%


Epoch 47/100: 100%|█████████████████████████████████████████████| 1215/1215 [02:53<00:00,  6.99it/s]
Validation: 100%|█████████████████████████████████████████████████| 304/304 [00:20<00:00, 14.90it/s]


Epoch 47/100, Train Loss: 0.5989, Validation Loss: 0.1682, Validation Accuracy: 97.02%


Epoch 48/100: 100%|█████████████████████████████████████████████| 1215/1215 [02:53<00:00,  7.01it/s]
Validation: 100%|█████████████████████████████████████████████████| 304/304 [00:20<00:00, 14.84it/s]


Epoch 48/100, Train Loss: 0.6117, Validation Loss: 0.1870, Validation Accuracy: 96.05%


Epoch 49/100: 100%|█████████████████████████████████████████████| 1215/1215 [02:53<00:00,  7.00it/s]
Validation: 100%|█████████████████████████████████████████████████| 304/304 [00:20<00:00, 14.88it/s]


Epoch 49/100, Train Loss: 0.6730, Validation Loss: 0.2585, Validation Accuracy: 91.15%


Epoch 50/100: 100%|█████████████████████████████████████████████| 1215/1215 [02:53<00:00,  7.00it/s]
Validation: 100%|█████████████████████████████████████████████████| 304/304 [00:20<00:00, 14.96it/s]


Epoch 50/100, Train Loss: 0.7233, Validation Loss: 0.6901, Validation Accuracy: 76.15%


Epoch 51/100: 100%|█████████████████████████████████████████████| 1215/1215 [02:53<00:00,  6.99it/s]
Validation: 100%|█████████████████████████████████████████████████| 304/304 [00:20<00:00, 15.16it/s]


Epoch 51/100, Train Loss: 0.7645, Validation Loss: 0.4320, Validation Accuracy: 87.26%


Epoch 52/100: 100%|█████████████████████████████████████████████| 1215/1215 [02:54<00:00,  6.98it/s]
Validation: 100%|█████████████████████████████████████████████████| 304/304 [00:20<00:00, 15.16it/s]


Epoch 52/100, Train Loss: 0.6762, Validation Loss: 0.2805, Validation Accuracy: 91.58%


Epoch 53/100: 100%|█████████████████████████████████████████████| 1215/1215 [02:54<00:00,  6.98it/s]
Validation: 100%|█████████████████████████████████████████████████| 304/304 [00:20<00:00, 15.20it/s]


Epoch 53/100, Train Loss: 0.5928, Validation Loss: 0.3645, Validation Accuracy: 90.08%


Epoch 54/100: 100%|█████████████████████████████████████████████| 1215/1215 [02:54<00:00,  6.98it/s]
Validation: 100%|█████████████████████████████████████████████████| 304/304 [00:20<00:00, 15.20it/s]


Epoch 54/100, Train Loss: 0.4966, Validation Loss: 0.1755, Validation Accuracy: 95.12%


Epoch 55/100: 100%|█████████████████████████████████████████████| 1215/1215 [02:54<00:00,  6.98it/s]
Validation: 100%|█████████████████████████████████████████████████| 304/304 [00:20<00:00, 15.20it/s]


Epoch 55/100, Train Loss: 0.5443, Validation Loss: 0.2462, Validation Accuracy: 91.11%


Epoch 56/100: 100%|█████████████████████████████████████████████| 1215/1215 [02:53<00:00,  6.99it/s]
Validation: 100%|█████████████████████████████████████████████████| 304/304 [00:20<00:00, 14.99it/s]


Epoch 56/100, Train Loss: 0.6872, Validation Loss: 0.3324, Validation Accuracy: 90.14%


Epoch 57/100: 100%|█████████████████████████████████████████████| 1215/1215 [02:53<00:00,  7.00it/s]
Validation: 100%|█████████████████████████████████████████████████| 304/304 [00:20<00:00, 14.85it/s]


Epoch 57/100, Train Loss: 0.5568, Validation Loss: 0.1602, Validation Accuracy: 97.22%


Epoch 58/100: 100%|█████████████████████████████████████████████| 1215/1215 [02:53<00:00,  7.00it/s]
Validation: 100%|█████████████████████████████████████████████████| 304/304 [00:20<00:00, 14.78it/s]


Epoch 58/100, Train Loss: 0.4953, Validation Loss: 0.1290, Validation Accuracy: 96.26%


Epoch 59/100: 100%|█████████████████████████████████████████████| 1215/1215 [02:53<00:00,  6.99it/s]
Validation: 100%|█████████████████████████████████████████████████| 304/304 [00:20<00:00, 14.78it/s]


Epoch 59/100, Train Loss: 0.4310, Validation Loss: 0.1067, Validation Accuracy: 97.24%


Epoch 60/100: 100%|█████████████████████████████████████████████| 1215/1215 [02:53<00:00,  7.00it/s]
Validation: 100%|█████████████████████████████████████████████████| 304/304 [00:20<00:00, 15.10it/s]


Epoch 60/100, Train Loss: 0.4077, Validation Loss: 0.0931, Validation Accuracy: 97.70%


Epoch 61/100:  27%|████████████▎                                 | 325/1215 [00:46<02:02,  7.26it/s]

In [None]:
checkpoint = torch.load("drive/MyDrive/UM_Project/checkpoints/resnet50-smallnorb.pt")
resnet50_model.load_state_dict(checkpoint)

<All keys matched successfully>

In [None]:
# Predict on Test set
resnet50_model.eval()
test_loss = 0.0
correct_predictions = 0
total_predictions = 0
with torch.no_grad():
    for batch in tqdm(test_dl, desc="Testing", ncols=100):
      inputs, labels = batch
      inputs = inputs.to(device)
      labels = labels.to(device)
      outputs = resnet50_model(inputs)
      loss = criterion(outputs, labels)
      test_loss += loss.item()
      _, predicted = torch.max(outputs, 1)
      total_predictions += labels.size(0)
      correct_predictions += (predicted == labels).sum().item()

print(f"\nTest set accuracy = {(100 * correct_predictions / total_predictions):.4f}%")

Testing: 100%|██████████████████████████████████████████████████| 1519/1519 [01:50<00:00, 13.81it/s]


Test set accuracy = 89.1975%





## Swin Transformer

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
transforms_train = transforms.Compose([
    transforms.Resize(size=(224,224)),
    transforms.RandAugment(num_ops=2, magnitude=9),
    transforms.ToTensor(),
    transforms.Normalize([0.485], [0.229]),
    transforms.RandomErasing(p=0.25)
])

transforms_val = transforms.Compose([
    transforms.Resize(size=(224,224)),
    transforms.ToTensor(),
    transforms.Normalize([0.485], [0.229]),
])

transforms_test = transforms.Compose([
    transforms.Resize(size=(224,224)),
    transforms.ToTensor(),
    transforms.Normalize([0.485], [0.229]),
])

In [None]:
train_ds = SmallNORBDataset(smallnorb, "train", transform=transforms_train)
val_ds = SmallNORBDataset(smallnorb, "train", transform=transforms_val)

train_size = int(0.8 * len(train_ds))
val_size = len(train_ds) - train_size

train_ds, _ = torch.utils.data.random_split(train_ds, [train_size, val_size])
_, val_ds = torch.utils.data.random_split(val_ds, [train_size, val_size])

test_ds = SmallNORBDataset(smallnorb, "test", transform=transforms_test)

In [None]:
IN_CHANNELS = 1
N_CLASSES = 5

swin_model = tv.models.swin_b(weights="IMAGENET1K_V1")

Downloading: "https://download.pytorch.org/models/swin_b-68c6b09e.pth" to /root/.cache/torch/hub/checkpoints/swin_b-68c6b09e.pth
100%|██████████| 335M/335M [00:23<00:00, 15.0MB/s]


In [None]:
for param in swin_model.parameters():
    param.requires_grad = False

In [None]:
# 1-channel inputs
swin_model.features[0][0] = nn.Conv2d(IN_CHANNELS, 128, kernel_size=(4,4), stride=(4,4))

# Add on classifier
n_inputs = swin_model.head.in_features
swin_model.head = nn.Sequential(
    nn.Linear(n_inputs, 256), nn.GELU(), nn.Dropout(0.6),
    nn.Linear(256, N_CLASSES), nn.LogSoftmax(dim=1))

In [None]:
swin_model = swin_model.to(device)

In [None]:
total_params = sum(p.numel() for p in swin_model.parameters())
print(f"{total_params:,} total parameters.")
total_trainable_params = sum(
    p.numel() for p in swin_model.parameters() if p.requires_grad)
print(f"{total_trainable_params:,} trainable parameters.")

87,002,813 total parameters.
265,861 trainable parameters.


In [None]:
BATCH_SIZE = 16
EPOCHS = 80
LEARNING_RATE = 5e-3
MOMENTUM = 0.9
WEIGHT_DECAY = 5e-2

In [None]:
train_dl = torch.utils.data.DataLoader(dataset=train_ds, batch_size=BATCH_SIZE, shuffle=True)
val_dl = torch.utils.data.DataLoader(dataset=val_ds, batch_size=BATCH_SIZE)
test_dl = torch.utils.data.DataLoader(dataset=test_ds, batch_size=BATCH_SIZE)

In [None]:
criterion = nn.NLLLoss()
optimizer = torch.optim.SGD(swin_model.parameters(), lr=LEARNING_RATE, momentum=MOMENTUM, weight_decay=WEIGHT_DECAY)
lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
    optimizer=optimizer, T_max=35
)

In [None]:
best_val_loss = 1e7
train_losses = list()
val_losses = list()

for epoch in range(EPOCHS):
  # Train
  swin_model.train()
  train_loss = 0.0
  for batch in tqdm(train_dl, desc=f"Epoch {epoch + 1}/{EPOCHS}", ncols=100):
    inputs, labels = batch
    inputs = inputs.to(device)
    labels = labels.to(device)
    optimizer.zero_grad()
    outputs = swin_model(inputs)
    loss = criterion(outputs, labels)
    loss.backward()
    optimizer.step()
    train_loss += loss.item()

  # Validate
  swin_model.eval()
  val_loss = 0.0
  correct_predictions = 0
  total_predictions = 0
  with torch.no_grad():
    for batch in tqdm(val_dl, desc="Validation", ncols=100):
      inputs, labels = batch
      inputs = inputs.to(device)
      labels = labels.to(device)
      outputs = swin_model(inputs)
      loss = criterion(outputs, labels)
      val_loss += loss.item()
      _, predicted = torch.max(outputs, 1)
      total_predictions += labels.size(0)
      correct_predictions += (predicted == labels).sum().item()
  lr_scheduler.step(val_loss)

  if val_loss < best_val_loss:
    best_val_loss = val_loss
    torch.save(swin_model.state_dict(), "/content/drive/MyDrive/UM_Project/checkpoints/swin-smallnorb.pt")

  train_losses.append(train_loss / len(train_dl))
  val_losses.append(val_loss / len(val_dl))

  accuracy = 100 * correct_predictions / total_predictions
  print(f"Epoch {epoch + 1}/{EPOCHS}, Train Loss: {train_loss / len(train_dl):.4f}, Validation Loss: {val_loss / len(val_dl):.4f}, Validation Accuracy: {accuracy:.2f}%")


Epoch 1/80: 100%|███████████████████████████████████████████████| 1215/1215 [09:11<00:00,  2.20it/s]
Validation: 100%|█████████████████████████████████████████████████| 304/304 [01:04<00:00,  4.75it/s]


Epoch 1/80, Train Loss: 0.9026, Validation Loss: 0.1791, Validation Accuracy: 95.19%


Epoch 2/80: 100%|███████████████████████████████████████████████| 1215/1215 [09:02<00:00,  2.24it/s]
Validation: 100%|█████████████████████████████████████████████████| 304/304 [01:04<00:00,  4.74it/s]


Epoch 2/80, Train Loss: 0.4143, Validation Loss: 0.1090, Validation Accuracy: 97.16%


Epoch 3/80: 100%|███████████████████████████████████████████████| 1215/1215 [09:03<00:00,  2.24it/s]
Validation: 100%|█████████████████████████████████████████████████| 304/304 [01:04<00:00,  4.73it/s]


Epoch 3/80, Train Loss: 0.3065, Validation Loss: 0.0911, Validation Accuracy: 98.00%


Epoch 4/80: 100%|███████████████████████████████████████████████| 1215/1215 [09:03<00:00,  2.24it/s]
Validation: 100%|█████████████████████████████████████████████████| 304/304 [01:04<00:00,  4.72it/s]


Epoch 4/80, Train Loss: 0.2968, Validation Loss: 0.0876, Validation Accuracy: 97.74%


Epoch 5/80: 100%|███████████████████████████████████████████████| 1215/1215 [09:02<00:00,  2.24it/s]
Validation: 100%|█████████████████████████████████████████████████| 304/304 [01:04<00:00,  4.73it/s]


Epoch 5/80, Train Loss: 0.2957, Validation Loss: 0.0808, Validation Accuracy: 97.98%


Epoch 6/80: 100%|███████████████████████████████████████████████| 1215/1215 [09:03<00:00,  2.23it/s]
Validation: 100%|█████████████████████████████████████████████████| 304/304 [01:03<00:00,  4.75it/s]


Epoch 6/80, Train Loss: 0.3241, Validation Loss: 0.0867, Validation Accuracy: 97.98%


Epoch 7/80: 100%|███████████████████████████████████████████████| 1215/1215 [09:03<00:00,  2.24it/s]
Validation: 100%|█████████████████████████████████████████████████| 304/304 [01:04<00:00,  4.74it/s]


Epoch 7/80, Train Loss: 0.3007, Validation Loss: 0.0873, Validation Accuracy: 98.40%


Epoch 8/80:  23%|██████████▉                                     | 278/1215 [02:04<07:05,  2.20it/s]

In [None]:
checkpoint = torch.load("/content/drive/MyDrive/UM_Project/checkpoints/swin-smallnorb.pt")
swin_model.load_state_dict(checkpoint)

<All keys matched successfully>

In [None]:
# Predict on Test set
swin_model.eval()
test_loss = 0.0
correct_predictions = 0
total_predictions = 0
with torch.no_grad():
    for batch in tqdm(test_dl, desc="Testing", ncols=100):
      inputs, labels = batch
      inputs = inputs.to(device)
      labels = labels.to(device)
      outputs = swin_model(inputs)
      loss = criterion(outputs, labels)
      test_loss += loss.item()
      _, predicted = torch.max(outputs, 1)
      total_predictions += labels.size(0)
      correct_predictions += (predicted == labels).sum().item()

print(f"\nTest set accuracy = {(100 * correct_predictions / total_predictions):.4f}%")

Testing: 100%|██████████████████████████████████████████████████| 1519/1519 [10:57<00:00,  2.31it/s]


Test set accuracy = 91.7407%



