# ResNet-50 Training on Critter Capture

Train a ResNet-50 classifier on the cached observation images and CSV metadata.


## Notebook Overview

- Load the observations CSV and align it with images cached in `data/raw/images`.
- Prepare stratified train/validation/test splits and construct PyTorch datasets.
- Configure a ResNet-50 model, optimizer, scheduler, and class-balanced loss.
- Run the training loop with mixed precision and evaluate on the hold-out split.


In [7]:
import math
import random
from collections import Counter
from pathlib import Path

import numpy as np
import pandas as pd
from PIL import Image
from sklearn.metrics import precision_recall_fscore_support
from sklearn.model_selection import train_test_split
import torch
from torch import nn
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms
from torchvision.models import ResNet50_Weights, resnet50
from tqdm.auto import tqdm


In [8]:
# Project paths and reproducibility helpers
PROJECT_ROOT = Path.cwd().resolve().parent
DATA_DIR = PROJECT_ROOT / "data"
CSV_PATH = DATA_DIR / "data.csv"
IMAGE_DIR = DATA_DIR / "raw" / "images"

if not CSV_PATH.exists():
    raise FileNotFoundError(f"CSV file not found at {CSV_PATH}")
if not IMAGE_DIR.exists():
    raise FileNotFoundError(
        f"Expected image directory at {IMAGE_DIR}. Run the data ingestion step first."
    )

SEED = 42


def set_seed(seed: int) -> None:
    # Seed all common random number generators for reproducibility.
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)


set_seed(SEED)

device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")
device


device(type='cuda')

In [9]:
config = {
    "data": {
        "validation_size": 0.15,
        "test_size": 0.15,
        "image_size": 224,
        "normalize_mean": [0.485, 0.456, 0.406],
        "normalize_std": [0.229, 0.224, 0.225],
        "num_workers": 4,
    },
    "training": {
        "epochs": 20,
        "batch_size": 32,
        "learning_rate": 5e-4,
        "weight_decay": 0.01,
        "gradient_clip_norm": 1.0,
        "early_stopping_patience": 5,
        "amp": True,
        "scheduler": {
            "t_max": 10,
            "min_lr": 1e-6,
        },
    },
}
config


{'data': {'validation_size': 0.15,
  'test_size': 0.15,
  'image_size': 224,
  'normalize_mean': [0.485, 0.456, 0.406],
  'normalize_std': [0.229, 0.224, 0.225],
  'num_workers': 4},
 'training': {'epochs': 20,
  'batch_size': 32,
  'learning_rate': 0.0005,
  'weight_decay': 0.0001,
  'gradient_clip_norm': 1.0,
  'early_stopping_patience': 5,
  'amp': True,
  'scheduler': {'t_max': 10, 'min_lr': 1e-06}}}

In [None]:
df = pd.read_csv(CSV_PATH)
required_cols = {"uuid", "taxon_id", "common_name"}
missing_cols = required_cols - set(df.columns)
if missing_cols:
    raise ValueError(f"CSV is missing columns: {missing_cols}")

df = df.dropna(subset=["uuid", "taxon_id", "common_name"])
df["uuid"] = df["uuid"].astype(str)
df["taxon_id"] = df["taxon_id"].astype(int)

val_test_size = config["data"]["validation_size"] + config["data"]["test_size"]
min_class_frequency = 2
if val_test_size > 0:
    min_class_frequency = max(min_class_frequency, int(math.ceil(2.0 / val_test_size)))
class_counts = df["taxon_id"].value_counts()
rare_taxa = class_counts[class_counts < min_class_frequency].index.tolist()
if rare_taxa:
    print(
        f"Filtering out {len(rare_taxa)} classes with fewer than {min_class_frequency} samples."
    )
    df = df[~df["taxon_id"].isin(rare_taxa)].reset_index(drop=True)

if df.empty:
    raise RuntimeError(
        "No observations remain after filtering low-frequency classes. Adjust the threshold."
    )

label_lookup = (
    df[["taxon_id", "common_name"]]
    .drop_duplicates(subset=["taxon_id"])
    .set_index("taxon_id")["common_name"]
    .to_dict()
)

sorted_label_ids = sorted(label_lookup.keys())
id_to_index = {label_id: idx for idx, label_id in enumerate(sorted_label_ids)}
index_to_name = [label_lookup[label_id] for label_id in sorted_label_ids]

df["label_index"] = df["taxon_id"].map(id_to_index)

df["image_path"] = df["uuid"].apply(lambda uuid: IMAGE_DIR / f"{uuid}.jpg")
available_mask = df["image_path"].apply(lambda path: path.exists())

missing_count = (~available_mask).sum()
if missing_count:
    print(
        f"Warning: {missing_count} images are missing from {IMAGE_DIR}. They will be skipped."
    )

df = df[available_mask].reset_index(drop=True)

if df.empty:
    raise RuntimeError(
        "No images were found in the cache. Populate data/raw/images before training."
    )

print(
    f"Loaded {len(df)} records across {df['label_index'].nunique()} classes after filtering."
)


In [None]:
val_size = config["data"]["validation_size"]
test_size = config["data"]["test_size"]
val_test_size = val_size + test_size
if val_test_size <= 0 or val_test_size >= 1:
    raise ValueError("Validation and test sizes must sum to a value between 0 and 1.")

train_df, temp_df = train_test_split(
    df,
    test_size=val_test_size,
    stratify=df["label_index"],
    random_state=SEED,
)

temp_counts = temp_df["label_index"].value_counts()
too_small = temp_counts[temp_counts < 2].index.tolist()
if too_small:
    print(
        "Reassigning hold-out samples for classes with fewer than 2 instances to the training split."
    )
    move_mask = temp_df["label_index"].isin(too_small)
    train_df = pd.concat([train_df, temp_df[move_mask]], ignore_index=True)
    temp_df = temp_df[~move_mask].reset_index(drop=True)

if temp_df.empty:
    raise RuntimeError(
        "Hold-out split is empty after rebalancing. Reduce validation/test sizes or relax filtering."
    )

test_fraction = test_size / val_test_size
val_df, test_df = train_test_split(
    temp_df,
    test_size=test_fraction,
    stratify=temp_df["label_index"],
    random_state=SEED,
)

print(f"Train: {len(train_df)} | Val: {len(val_df)} | Test: {len(test_df)}")


In [None]:
class LocalObservationDataset(Dataset):
    def __init__(self, frame: pd.DataFrame, transform, label_names):
        self.frame = frame.reset_index(drop=True)
        self.transform = transform
        self.label_names = label_names

    def __len__(self) -> int:
        return len(self.frame)

    def __getitem__(self, idx: int):
        row = self.frame.iloc[idx]
        image_path = row["image_path"]
        try:
            image = Image.open(image_path).convert("RGB")
        except FileNotFoundError as exc:
            raise FileNotFoundError(f"Image missing at {image_path}") from exc

        if self.transform is not None:
            image = self.transform(image)

        target = torch.tensor(row["label_index"], dtype=torch.long)
        return {
            "image": image,
            "target": target,
            "uuid": row["uuid"],
        }


In [None]:
image_size = config["data"]["image_size"]
normalize_mean = config["data"]["normalize_mean"]
normalize_std = config["data"]["normalize_std"]

train_transforms = transforms.Compose(
    [
        transforms.RandomResizedCrop(image_size, scale=(0.8, 1.0)),
        transforms.RandomHorizontalFlip(),
        transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.05),
        transforms.ToTensor(),
        transforms.Normalize(mean=normalize_mean, std=normalize_std),
    ]
)

eval_transforms = transforms.Compose(
    [
        transforms.Resize(int(image_size * 1.14)),
        transforms.CenterCrop(image_size),
        transforms.ToTensor(),
        transforms.Normalize(mean=normalize_mean, std=normalize_std),
    ]
)


In [None]:
train_dataset = LocalObservationDataset(train_df, train_transforms, index_to_name)
val_dataset = LocalObservationDataset(val_df, eval_transforms, index_to_name)
test_dataset = LocalObservationDataset(test_df, eval_transforms, index_to_name)


def compute_class_weights(frame: pd.DataFrame, num_classes: int) -> torch.Tensor:
    counts = Counter(frame["label_index"].tolist())
    total = sum(counts.values())
    weights = torch.ones(num_classes, dtype=torch.float32)
    for idx, count in counts.items():
        weights[idx] = total / (num_classes * count)
    weights = weights / weights.mean()
    return weights


class_weights = compute_class_weights(train_df, len(index_to_name))

num_workers = config["data"]["num_workers"]
pin_memory = device.type == "cuda"

train_loader = DataLoader(
    train_dataset,
    batch_size=config["training"]["batch_size"],
    shuffle=True,
    num_workers=num_workers,
    pin_memory=pin_memory,
    persistent_workers=num_workers > 0,
)

val_loader = DataLoader(
    val_dataset,
    batch_size=config["training"]["batch_size"],
    shuffle=False,
    num_workers=num_workers,
    pin_memory=pin_memory,
    persistent_workers=num_workers > 0,
)

test_loader = DataLoader(
    test_dataset,
    batch_size=config["training"]["batch_size"],
    shuffle=False,
    num_workers=num_workers,
    pin_memory=pin_memory,
    persistent_workers=num_workers > 0,
)

len(train_dataset), len(val_dataset), len(test_dataset), class_weights


In [None]:
weights = ResNet50_Weights.IMAGENET1K_V2
model = resnet50(weights=weights)
model.fc = nn.Linear(model.fc.in_features, len(index_to_name))
model = model.to(device)

criterion = nn.CrossEntropyLoss(weight=class_weights.to(device))
optimizer = torch.optim.AdamW(
    model.parameters(),
    lr=config["training"]["learning_rate"],
    weight_decay=config["training"]["weight_decay"],
)

scheduler = torch.optim.lr_scheduler.OneCycleLR(
    optimizer,
    "max_lr=config["training"]["learning_rate"],
",
    "total_steps=config["training"]["epochs"] * len(train_loader),
",
    "pct_start=0.2,
",
    "div_factor=10,
",
    "final_div_factor=100,
",
)


In [None]:
def train_one_epoch(
    model, dataloader, optimizer, criterion, device, scaler, grad_clip, use_amp
):
    model.train()
    running_loss = 0.0
    total = 0
    correct = 0
    all_preds = []
    all_targets = []

    progress = tqdm(dataloader, desc="train", leave=False)
    for batch in progress:
        images = batch["image"].to(device)
        targets = batch["target"].to(device)

        optimizer.zero_grad(set_to_none=True)

        with torch.cuda.amp.autocast(enabled=use_amp and device.type == "cuda"):
            outputs = model(images)
            loss = criterion(outputs, targets)

        scaler.scale(loss).backward()

        if grad_clip is not None:
            scaler.unscale_(optimizer)
            torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip)

        scaler.step(optimizer)
        scaler.update()

        running_loss += loss.item() * targets.size(0)
        preds = outputs.argmax(dim=1)
        correct += (preds == targets).sum().item()
        total += targets.size(0)

        all_preds.append(preds.detach().cpu())
        all_targets.append(targets.detach().cpu())

        progress.set_postfix(
            loss=running_loss / max(total, 1),
            acc=correct / max(total, 1),
        )

    history = {
        "loss": running_loss / max(total, 1),
        "accuracy": correct / max(total, 1),
    }

    y_true = torch.cat(all_targets).numpy()
    y_pred = torch.cat(all_preds).numpy()
    precision, recall, f1, _ = precision_recall_fscore_support(
        y_true, y_pred, average="macro", zero_division=0
    )
    history.update({"precision": precision, "recall": recall, "f1": f1})
    return history


@torch.no_grad()
def evaluate(model, dataloader, criterion, device):
    model.eval()
    running_loss = 0.0
    total = 0
    correct = 0
    all_preds = []
    all_targets = []

    for batch in tqdm(dataloader, desc="eval", leave=False):
        images = batch["image"].to(device)
        targets = batch["target"].to(device)

        outputs = model(images)
        loss = criterion(outputs, targets)

        running_loss += loss.item() * targets.size(0)
        preds = outputs.argmax(dim=1)
        correct += (preds == targets).sum().item()
        total += targets.size(0)

        all_preds.append(preds.detach().cpu())
        all_targets.append(targets.detach().cpu())

    history = {
        "loss": running_loss / max(total, 1),
        "accuracy": correct / max(total, 1),
    }

    y_true = torch.cat(all_targets).numpy()
    y_pred = torch.cat(all_preds).numpy()
    precision, recall, f1, _ = precision_recall_fscore_support(
        y_true, y_pred, average="macro", zero_division=0
    )
    history.update({"precision": precision, "recall": recall, "f1": f1})
    return history


In [None]:
def train_model(
    model, train_loader, val_loader, optimizer, scheduler, criterion, config, device
):
    num_epochs = config["training"]["epochs"]
    patience = config["training"]["early_stopping_patience"]
    use_amp = config["training"]["amp"] and device.type == "cuda"
    grad_clip = config["training"]["gradient_clip_norm"]

    scaler = torch.amp.GradScaler(enabled=use_amp)
    best_state = None
    best_metric = -math.inf
    patience_counter = 0
    history = []

    for epoch in range(1, num_epochs + 1):
        print(f"Epoch {epoch}/{num_epochs}")
        train_metrics = train_one_epoch(
            model,
            train_loader,
            optimizer,
            criterion,
            device,
            scaler,
            grad_clip,
            use_amp,
        )
        val_metrics = evaluate(model, val_loader, criterion, device)

        if scheduler is not None:
            scheduler.step()

        history.append(
            {
                "epoch": epoch,
                "train": train_metrics,
                "val": val_metrics,
            }
        )

        val_score = val_metrics["f1"]
        if val_score > best_metric:
            best_metric = val_score
            best_state = {
                "model": model.state_dict(),
                "optimizer": optimizer.state_dict(),
                "scheduler": scheduler.state_dict() if scheduler is not None else None,
                "epoch": epoch,
                "config": config,
                "label_names": index_to_name,
            }
            patience_counter = 0
            print(f"New best validation macro F1: {best_metric:.4f}")
        else:
            patience_counter += 1
            print(
                f"Validation macro F1 did not improve. Patience {patience_counter}/{patience}."
            )
            if patience_counter >= patience:
                print("Early stopping triggered.")
                break

    return best_state, history


In [None]:
best_checkpoint, training_history = train_model(
    model,
    train_loader,
    val_loader,
    optimizer,
    scheduler,
    criterion,
    config,
    device,
)


In [None]:
if best_checkpoint is not None:
    model.load_state_dict(best_checkpoint["model"])
    test_metrics = evaluate(model, test_loader, criterion, device)
    print("Test metrics:", test_metrics)
else:
    print("Training did not produce a checkpoint. Check earlier logs for issues.")


In [None]:
if best_checkpoint is not None:
    history_records = []
    for item in training_history:
        epoch_record = {"epoch": item["epoch"]}
        epoch_record.update({f"train_{k}": v for k, v in item["train"].items()})
        epoch_record.update({f"val_{k}": v for k, v in item["val"].items()})
        history_records.append(epoch_record)
    history_df = pd.DataFrame(history_records)
    history_df


In [None]:
if best_checkpoint is not None:
    output_dir = PROJECT_ROOT / "outputs"
    output_dir.mkdir(parents=True, exist_ok=True)
    checkpoint_path = output_dir / "resnet50_best.pth"
    torch.save(best_checkpoint, checkpoint_path)
    checkpoint_path
