In [1]:
import torch, torch.nn as nn, torch.optim as optim
from torch.utils.data import DataLoader, ConcatDataset
import torchvision.transforms.v2 as v2
from torchvision.models import resnet50, ResNet50_Weights
import matplotlib.pyplot as plt
from datasets import load_dataset

  from .autonotebook import tqdm as notebook_tqdm


Load data as well as split into loaders

In [4]:
ds = load_dataset("tanganke/stanford_cars")        # {train, test}

split = ds["train"].train_test_split(test_size=0.10,
                                     stratify_by_column="label",
                                     seed=42)
ds["train"], ds["val"] = split["train"], split["test"]

# base preprocessing
base = v2.Compose([
    v2.Resize((224, 224), antialias=True),
    v2.ToImage(),
    v2.ConvertImageDtype(torch.float32),
    v2.Normalize([0.485,0.456,0.406], [0.229,0.224,0.225]),
])

# corruption pool (always applied)
corruption_pool = v2.RandomChoice([
    v2.GaussianBlur(kernel_size=(15,15), sigma=(8,8)),
    v2.RandomAdjustSharpness(sharpness_factor=2.5),
    v2.RandomAutocontrast(),
    v2.RandomEqualize(),
    v2.RandomPerspective(distortion_scale=0.6, p=1.0),
    v2.RandomInvert(p=1.0),
    v2.RandomSolarize(threshold=0.3, p=1.0),
])

train_transform      = v2.Compose([base, corruption_pool])   # always corrupt
val_test_transform   = base                                  # keep clean

def apply_tf(example, tf):
    example["pixel_values"] = tf(example["image"].convert("RGB"))
    return example

#MATERIALISE 


ds["train_clean"]   = ds["train"].map(apply_tf, fn_kwargs={"tf": val_test_transform},
                                      desc="clean train → tensors")
ds["train_corrupt"] = ds["train"].map(apply_tf, fn_kwargs={"tf": train_transform},
                                      desc="corrupt train → tensors")
ds["val"]  = ds["val"].map(apply_tf,  fn_kwargs={"tf": val_test_transform},
                           desc="val → tensors")
ds["test"] = ds["test"].map(apply_tf, fn_kwargs={"tf": val_test_transform},
                           desc="test → tensors")

# keep only needed columns
for split in ("train_clean","train_corrupt","val","test"):
    ds[split].set_format(type="torch", columns=["pixel_values","label"])

#CONCAT clean+corrupt for training
train_all = ConcatDataset([ds["train_clean"], ds["train_corrupt"]])

#  DATALOADERS
BATCH = 32
train_loader = DataLoader(train_all, batch_size=BATCH, shuffle=True,
                          num_workers=4, pin_memory=True)
val_loader   = DataLoader(ds["val"],   batch_size=BATCH, shuffle=True,
                          num_workers=4, pin_memory=True)
test_loader  = DataLoader(ds["test"],  batch_size=BATCH, shuffle=True,
                          num_workers=4, pin_memory=True)

clean train → tensors: 100%|██████████| 7329/7329 [08:58<00:00, 13.62 examples/s]  
corrupt train → tensors: 100%|██████████| 7329/7329 [08:54<00:00, 13.71 examples/s]  
val → tensors: 100%|██████████| 815/815 [01:12<00:00, 11.18 examples/s] 
test → tensors: 100%|██████████| 8041/8041 [11:21<00:00, 11.81 examples/s]  


Set up model

In [5]:
torch.backends.cudnn.benchmark = True
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = resnet50(weights=ResNet50_Weights.IMAGENET1K_V1)   # pretrained=True
model.fc = nn.Linear(model.fc.in_features, 196)            # 196 classes
model.to(device)

criterion  = nn.CrossEntropyLoss()
optimizer  = optim.Adam(model.parameters(), lr=1e-4)

Training and Validation

In [6]:
EPOCHS = 10
tr_loss_hist, tr_acc_hist = [], []
va_loss_hist, va_acc_hist = [], []

for ep in range(1, EPOCHS + 1):
    # training 
    model.train()
    run_loss = correct = total = 0
    for batch in train_loader:
        x = batch["pixel_values"].to(device, non_blocking=True)
        y = batch["label"].to(device, non_blocking=True)

        optimizer.zero_grad(set_to_none=True)
        out   = model(x)
        loss  = criterion(out, y)
        loss.backward()
        optimizer.step()

        run_loss += loss.item()
        total    += y.size(0)
        correct  += out.argmax(1).eq(y).sum().item()

    tr_loss = run_loss / len(train_loader)
    tr_acc  = 100 * correct / total
    tr_loss_hist.append(tr_loss)
    tr_acc_hist.append(tr_acc)

    #Validation
    model.eval()
    val_loss = val_correct = val_total = 0
    with torch.no_grad():
        for batch in val_loader:
            x = batch["pixel_values"].to(device, non_blocking=True)
            y = batch["label"].to(device, non_blocking=True)
            out = model(x)
            val_loss   += criterion(out, y).item()
            val_total  += y.size(0)
            val_correct += out.argmax(1).eq(y).sum().item()

    va_loss = val_loss / len(val_loader)
    va_acc  = 100 * val_correct / val_total
    va_loss_hist.append(va_loss)
    va_acc_hist.append(va_acc)

    print(f"Epoch {ep:2d}/{EPOCHS} ─ "
          f"Train L {tr_loss:.4f}  Acc {tr_acc:5.2f}%   "
          f"Val L {va_loss:.4f}  Acc {va_acc:5.2f}%")

Epoch  1/10 ─ Train L 3.6161  Acc 25.59%   Val L 1.9263  Acc 49.57%
Epoch  2/10 ─ Train L 1.3296  Acc 70.94%   Val L 1.0712  Acc 71.90%
Epoch  3/10 ─ Train L 0.4818  Acc 90.35%   Val L 0.8995  Acc 74.97%
Epoch  4/10 ─ Train L 0.1894  Acc 96.83%   Val L 0.8608  Acc 77.30%
Epoch  5/10 ─ Train L 0.1071  Acc 98.40%   Val L 0.8288  Acc 76.32%
Epoch  6/10 ─ Train L 0.0782  Acc 98.81%   Val L 0.8159  Acc 77.55%
Epoch  7/10 ─ Train L 0.0933  Acc 98.29%   Val L 1.1577  Acc 72.02%
Epoch  8/10 ─ Train L 0.1285  Acc 97.28%   Val L 0.9978  Acc 74.72%
Epoch  9/10 ─ Train L 0.0810  Acc 98.24%   Val L 0.9749  Acc 76.20%
Epoch 10/10 ─ Train L 0.0771  Acc 98.38%   Val L 1.0069  Acc 75.34%


Plotting

In [None]:
plt.figure()
plt.plot(range(1, EPOCHS + 1), tr_acc_hist, label="Train Acc")
plt.plot(range(1, EPOCHS + 1), va_acc_hist, label="Val Acc")
plt.xlabel("Epoch"); plt.ylabel("Accuracy (%)"); plt.title("Accuracy Curves")
plt.legend(); plt.grid(True)
plt.gca().invert_yaxis()
plt.show()

NameError: name 'plt' is not defined

Testing

In [8]:
model.eval()
test_correct = test_total = 0
with torch.no_grad():
    for batch in test_loader:
        x = batch["pixel_values"].to(device, non_blocking=True)
        y = batch["label"].to(device, non_blocking=True)
        out = model(x)
        test_total   += y.size(0)
        test_correct += out.argmax(1).eq(y).sum().item()

print(f"\nTest Accuracy: {100 * test_correct / test_total:.2f}%")


Test Accuracy: 74.44%


In [9]:
from pathlib import Path

run_name  = "resnet_aug3"                 # pick any tag
save_dir  = Path("runs") / run_name
save_dir.mkdir(parents=True, exist_ok=True)

torch.save({
    "epoch": EPOCHS,
    "model_state": model.state_dict(),
    "optimizer_state": optimizer.state_dict(),   # optional
    "val_acc": va_acc_hist[-1],                  # optional
}, save_dir / "checkpoint.pt")
print("Saved to", save_dir / "checkpoint.pt")

Saved to runs\resnet_aug3\checkpoint.pt
