# Food-101 Training: Scratch, Transfer (Linear Probe), Fine-tune

This notebook loads the config, builds dataloaders with augmentation and imbalance handling, and trains:
- Scratch: SimpleCNN / ResNet-18 from scratch
- Transfer: ResNet-50 or ResNet-18 frozen backbone (linear probe)
- Fine-tune: Unfreeze last stage(s)

It also computes accuracy, macro-F1, and shows confusion matrices and augmentation impact.


In [None]:
import os, json, math, random
from pathlib import Path
from typing import List, Tuple, Optional, Dict

import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, WeightedRandomSampler
import torchvision.transforms as T
import torchvision.models as tvm
from PIL import Image
from sklearn.metrics import classification_report, confusion_matrix, f1_score
import matplotlib.pyplot as plt
import seaborn as sns

PROJECT_ROOT = Path(r"C:/Users/Name/OneDrive - The University of Texas at Austin/UT Austin-DEVICE/Deep L/Mini Project/Image_Classification_DL") #Change based on your own path
CONFIG_PATH = PROJECT_ROOT / "config.json"

with open(CONFIG_PATH, "r", encoding="utf-8") as f:
    cfg = json.load(f)

FOOD101_ROOT = Path(cfg["dataset_root"])  # has images/ and meta/
CLASS_NAMES: List[str] = cfg["class_names"]
IMAGE_SIZE: int = int(cfg.get("image_size", 128))
SPLIT_PROTOCOL: str = cfg.get("split_protocol", "standard")
SEED: int = int(cfg.get("seed", 42))

print("Loaded config:", cfg)


Loaded config: {'dataset_root': 'C:\\Users\\Name\\OneDrive - The University of Texas at Austin\\UT Austin-DEVICE\\Deep L\\Mini Project\\Image_Classification_DL\\data\\food-101', 'class_names': ['apple_pie', 'beef_carpaccio', 'beef_tartare', 'caesar_salad', 'caprese_salad', 'carrot_cake', 'cheesecake', 'club_sandwich', 'creme_brulee', 'croque_madame', 'cup_cakes', 'donuts', 'escargots', 'hamburger', 'hot_and_sour_soup', 'hummus', 'miso_soup', 'oysters', 'paella', 'pho', 'pork_chop', 'ramen', 'samosa', 'sashimi', 'shrimp_and_grits', 'spaghetti_bolognese', 'strawberry_shortcake', 'tacos', 'takoyaki', 'tiramisu'], 'image_size': 128, 'split_protocol': 'standard', 'seed': 42}


Transformer for working with image pixels

In [5]:
# Reusable transforms with augmentation toggle

def build_transforms(image_size: int, train: bool, use_aug: bool):
    if train:
        aug = []
        if use_aug:
            aug = [
                T.RandomResizedCrop(image_size, scale=(0.7, 1.0)),
                T.RandomHorizontalFlip(p=0.5),
                T.ColorJitter(brightness=0.2, contrast=0.2),
                T.RandomRotation(degrees=10),
            ]
        base = [T.Resize((image_size, image_size))]
        norm = [T.ToTensor(), T.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225])]
        return T.Compose(aug + base + norm)
    else:
        return T.Compose([
            T.Resize((image_size, image_size)),
            T.ToTensor(),
            T.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225]),
        ])


Data standardization and stratified splitting (880 img train, 200 val per class)

In [6]:
# Dataset utilities (standard or stratified split)
class FoodDataset(torch.utils.data.Dataset):
    def __init__(self, items: List[Tuple[str, int]], transform=None, class_names: Optional[List[str]] = None):
        self.items = items
        self.transform = transform
        self.class_names = class_names

    def __len__(self):
        return len(self.items)

    def __getitem__(self, idx):
        path, label = self.items[idx]
        img = Image.open(path).convert("RGB")
        if self.transform:
            img = self.transform(img)
        return img, label


def _read_standard_split_lists(root: Path) -> Dict[str, List[str]]:
    meta = root / "meta"
    with open(meta / "train.txt") as f:
        train_stems = [l.strip() for l in f if l.strip()]
    with open(meta / "test.txt") as f:
        test_stems = [l.strip() for l in f if l.strip()]
    return {"train": train_stems, "test": test_stems}


def build_items_standard(root: Path, class_names: List[str], split: str) -> List[Tuple[str, int]]:
    stems = _read_standard_split_lists(root)["train" if split == "train" else "test"]
    class_to_idx = {c: i for i, c in enumerate(class_names)}
    images_dir = root / "images"
    items: List[Tuple[str, int]] = []
    for s in stems:
        cls, stem = s.split("/")
        if cls in class_to_idx:
            items.append((str(images_dir / cls / f"{stem}.jpg"), class_to_idx[cls]))
    return items


def build_items_stratified(root: Path, class_names: List[str], train_ratio: float = 0.8, seed: int = 42):
    rng = random.Random(seed)
    class_to_idx = {c: i for i, c in enumerate(class_names)}
    images_dir = root / "images"

    train_items, val_items = [], []
    for cls in class_names:
        cls_dir = images_dir / cls
        paths = [str(cls_dir / fn) for fn in os.listdir(cls_dir) if fn.lower().endswith(".jpg")]
        rng.shuffle(paths)
        k = int(len(paths) * train_ratio)
        li = class_to_idx[cls]
        train_items.extend([(p, li) for p in paths[:k]])
        val_items.extend([(p, li) for p in paths[k:]])
    return train_items, val_items


def build_loaders(root: Path, class_names: List[str], image_size: int, batch_size: int, num_workers: int,
                  split_protocol: str, use_aug: bool, use_imbalance_sampler: bool):
    if split_protocol == "standard":
        train_items = build_items_standard(root, class_names, split="train")
        val_items   = build_items_standard(root, class_names, split="test")
    else:
        train_items, val_items = build_items_stratified(root, class_names, train_ratio=0.8, seed=SEED)

    t_train = build_transforms(image_size, train=True, use_aug=use_aug)
    t_val   = build_transforms(image_size, train=False, use_aug=False)
    ds_train = FoodDataset(train_items, transform=t_train, class_names=class_names)
    ds_val   = FoodDataset(val_items,   transform=t_val,   class_names=class_names)

    if use_imbalance_sampler:
        counts = np.zeros(len(class_names), dtype=np.int64)
        for _, y in train_items:
            counts[y] += 1
        class_weights = 1.0 / np.maximum(counts, 1)
        sample_weights = np.array([class_weights[y] for _, y in train_items], dtype=np.float32)
        sampler = WeightedRandomSampler(weights=sample_weights, num_samples=len(sample_weights), replacement=True)
        train_loader = DataLoader(ds_train, batch_size=batch_size, sampler=sampler, num_workers=num_workers, pin_memory=True)
    else:
        train_loader = DataLoader(ds_train, batch_size=batch_size, shuffle=True, num_workers=num_workers, pin_memory=True)

    val_loader = DataLoader(ds_val, batch_size=batch_size, shuffle=False, num_workers=num_workers, pin_memory=True)
    return train_loader, val_loader, len(class_names)


Training and validation functions for all 3 models

In [7]:
# Training utilities

def accuracy(logits, targets):
    return (logits.argmax(1) == targets).float().mean().item()


def train_one_epoch(model, loader, optimizer, device, criterion):
    model.train()
    loss_sum = acc_sum = n = 0
    for xb, yb in loader:
        xb, yb = xb.to(device), yb.to(device)
        optimizer.zero_grad(set_to_none=True)
        logits = model(xb)
        loss = criterion(logits, yb)
        loss.backward()
        optimizer.step()
        b = xb.size(0)
        loss_sum += loss.item() * b
        acc_sum  += accuracy(logits, yb) * b
        n += b
    return {"loss": loss_sum / n, "acc": acc_sum / n}


@torch.no_grad()
def validate(model, loader, device, criterion):
    model.eval()
    loss_sum = acc_sum = n = 0
    all_targets, all_preds = [], []
    for xb, yb in loader:
        xb, yb = xb.to(device), yb.to(device)
        logits = model(xb)
        loss = criterion(logits, yb)
        preds = logits.argmax(1)
        all_targets.extend(yb.detach().cpu().tolist())
        all_preds.extend(preds.detach().cpu().tolist())
        b = xb.size(0)
        loss_sum += loss.item() * b
        acc_sum  += accuracy(logits, yb) * b
        n += b
    macro_f1 = f1_score(all_targets, all_preds, average="macro")
    return {"loss": loss_sum / n, "acc": acc_sum / n, "macro_f1": macro_f1, "targets": all_targets, "preds": all_preds}


def plot_confusion(labels_true, labels_pred, class_names: List[str], title: str):
    cm = confusion_matrix(labels_true, labels_pred, labels=list(range(len(class_names))))
    plt.figure(figsize=(10, 8))
    sns.heatmap(cm, annot=False, cmap="Blues", fmt="d")
    plt.title(title)
    plt.xlabel("Predicted")
    plt.ylabel("True")
    plt.tight_layout()
    plt.show()


In [None]:
# Models: SimpleCNN and ResNet18/50 wrappers
class SimpleCNN(nn.Module):
    def __init__(self, num_classes: int):
        super().__init__()
        self.features = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2),
        )
        self.pool = nn.AdaptiveAvgPool2d((1, 1))
        self.classifier = nn.Linear(64, num_classes)

    def forward(self, x):
        x = self.features(x)
        x = self.pool(x)
        x = torch.flatten(x, 1)
        return self.classifier(x)


class ResNetWrapper(nn.Module):
    def __init__(self, backbone: str, num_classes: int, pretrained: bool, freeze_backbone: bool, unfreeze_from_layer: Optional[str] = None):
        super().__init__()
        if backbone == "resnet18":
            base = tvm.resnet18(weights=tvm.ResNet18_Weights.DEFAULT if pretrained else None)
        elif backbone == "resnet50":
            base = tvm.resnet50(weights=tvm.ResNet50_Weights.DEFAULT if pretrained else None)
        else:
            raise ValueError("backbone must be 'resnet18' or 'resnet50'")
        in_feats = base.fc.in_features
        base.fc = nn.Linear(in_feats, num_classes)
        self.backbone = base

        if freeze_backbone:
            for name, p in self.backbone.named_parameters():
                p.requires_grad = (name.startswith("fc"))

        if unfreeze_from_layer is not None:
            passed = False
            for name, module in self.backbone.named_children():
                if name == unfreeze_from_layer:
                    passed = True
                if passed or name == "fc":
                    for p in module.parameters():
                        p.requires_grad = True

    def forward(self, x):
        return self.backbone(x)


# custom optimizer for the backbone and head
def build_optimizer(model: nn.Module, lr_head: float = 1e-3, lr_backbone: float = 1e-4, weight_decay: float = 1e-4):
    head_params, backbone_params = [], []
    for name, p in model.named_parameters():
        if not p.requires_grad:
            continue
        if "fc" in name or "classifier" in name:
            head_params.append(p)
        else:
            backbone_params.append(p)
    params = [
        {"params": head_params, "lr": lr_head},
        {"params": backbone_params, "lr": lr_backbone},
    ]
    opt = optim.AdamW(params, lr=lr_head, weight_decay=weight_decay)
    return opt


In [9]:
# Train runners for the three regimes

def run_training(model: nn.Module, train_loader, val_loader, epochs: int, lr_head: float, lr_backbone: float, label: str):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = model.to(device)
    criterion = nn.CrossEntropyLoss()
    optimizer = build_optimizer(model, lr_head=lr_head, lr_backbone=lr_backbone)

    history = []
    for ep in range(epochs):
        tr = train_one_epoch(model, train_loader, optimizer, device, criterion)
        va = validate(model, val_loader, device, criterion)
        history.append({"epoch": ep, **{f"train_{k}": v for k, v in tr.items()}, **{f"val_{k}": v for k, v in va.items()}})
        print(f"[{label}] Epoch {ep+1}/{epochs} | train_loss={tr['loss']:.4f} acc={tr['acc']:.3f} | val_loss={va['loss']:.4f} acc={va['acc']:.3f} macroF1={va['macro_f1']:.3f}")
    return model, history


def train_regimes(train_loader, val_loader, num_classes: int, image_size: int):
    results = {}

    # Scratch: SimpleCNN or ResNet18(weights=None)
    scratch_model = SimpleCNN(num_classes)
    results['scratch'] = run_training(scratch_model, train_loader, val_loader, epochs=10, lr_head=1e-3, lr_backbone=1e-3, label="scratch")

    # Transfer (Linear Probe): freeze backbone
    transfer_model = ResNetWrapper(backbone="resnet50", num_classes=num_classes, pretrained=True, freeze_backbone=True)
    results['transfer'] = run_training(transfer_model, train_loader, val_loader, epochs=10, lr_head=1e-3, lr_backbone=1e-5, label="transfer")

    # Fine-tune: unfreeze last block
    finetune_model = ResNetWrapper(backbone="resnet50", num_classes=num_classes, pretrained=True, freeze_backbone=False, unfreeze_from_layer="layer4")
    results['finetune'] = run_training(finetune_model, train_loader, val_loader, epochs=10, lr_head=1e-3, lr_backbone=1e-4, label="finetune")

    return results


In [10]:
# Build loaders for two settings: no-aug vs aug + imbalance handling
BATCH_SIZE = 32
NUM_WORKERS = 2

train_loader_noaug, val_loader_noaug, num_classes = build_loaders(
    FOOD101_ROOT, CLASS_NAMES, IMAGE_SIZE, BATCH_SIZE, NUM_WORKERS, SPLIT_PROTOCOL, use_aug=False, use_imbalance_sampler=False)

train_loader_aug, val_loader_aug, _ = build_loaders(
    FOOD101_ROOT, CLASS_NAMES, IMAGE_SIZE, BATCH_SIZE, NUM_WORKERS, SPLIT_PROTOCOL, use_aug=True, use_imbalance_sampler=True)

print("num_classes:", num_classes)


num_classes: 30


Testing the simpleCNN model and running for 3 epochs

In [None]:
# Quick start: train only SimpleCNN first (few epochs) and plot results
EPOCHS = 3
simple_model = SimpleCNN(num_classes)
simple_model, simple_hist = run_training(
    simple_model,
    train_loader_noaug,  # start without augmentation for speed/determinism
    val_loader_noaug,
    epochs=EPOCHS,
    lr_head=1e-3,
    lr_backbone=1e-3,
    label="simplecnn",
)

last = simple_hist[-1]
print({k: last[k] for k in ["epoch", "train_loss", "train_acc", "val_loss", "val_acc", "val_macro_f1"]})
plot_confusion(last["targets"], last["preds"], CLASS_NAMES, title="Confusion Matrix: SimpleCNN (no aug)")



Running all 3 models together and plotting the results -DANGER TAKES A LONG TIME

In [None]:
# Run: scratch / transfer / finetune for both settings and compare macro-F1
results_noaug = train_regimes(train_loader_noaug, val_loader_noaug, num_classes, IMAGE_SIZE)
results_aug   = train_regimes(train_loader_aug,   val_loader_aug,   num_classes, IMAGE_SIZE)

# Collect macro-F1 from the last epoch of each regime
regimes = ["scratch", "transfer", "finetune"]
macroF1_noaug = {k: results_noaug[k][1][-1]["val_macro_f1"] for k in regimes}
macroF1_aug   = {k: results_aug[k][1][-1]["val_macro_f1"]   for k in regimes}

print("Macro-F1 (no aug):", macroF1_noaug)
#print("Macro-F1 (aug):", macroF1_aug)

# Plot comparison
labels = regimes
x = np.arange(len(labels))
width = 0.35
plt.figure(figsize=(8,4))
plt.bar(x - width/2, [macroF1_noaug[k] for k in labels], width, label='No Aug')
#plt.bar(x + width/2, [macroF1_aug[k] for k in labels], width, label='With Aug')
plt.xticks(x, labels)
plt.ylabel('Macro-F1')
plt.title('Augmentation Impact by Regime')
plt.legend()
plt.tight_layout()
plt.show()

# Confusion matrix for best model (aug + finetune)
va_best = results_aug['finetune'][1][-1]
plot_confusion(va_best["targets"], va_best["preds"], CLASS_NAMES, title="Confusion Matrix: Fine-tune (aug)")


