# CIFAR-10 Image Classification with ConvNeXt and Optuna

This notebook demonstrates training a ConvNeXt classifier on the CIFAR-10 dataset using Optuna for hyperparameter tuning.  The dataset is exported from the original CIFAR-10 binary batches into a folder structure with separate `train`, `val`, and `test` splits.  Throughout the tuning and final training runs, detailed metrics (loss, accuracy, precision, recall, F1, and specificity) are logged for each epoch.  Curves and confusion matrices are automatically generated and stored in the `./artifacts` directory.


## Install Dependencies

Install necessary Python packages.  This cell uses `pip` via the `%uv` magic to ensure dependencies are available for the rest of the notebook.


In [1]:
%uv pip install timm optuna scikit-learn torchmetrics jiwer opencv-python tqdm matplotlib

[2mUsing Python 3.12.6 environment at: /usr/local[0m
[37m⠋[0m [2mResolving dependencies...                                                     [0m[2K[37m⠙[0m [2mResolving dependencies...                                                     [0m[2K[37m⠋[0m [2mResolving dependencies...                                                     [0m[2K[37m⠙[0m [2mResolving dependencies...                                                     [0m[2K[37m⠙[0m [2mtimm==1.0.21                                                                  [0m[2K[37m⠙[0m [2moptuna==4.5.0                                                                 [0m[2K[37m⠙[0m [2mscikit-learn==1.7.1                                                           [0m[2K[37m⠙[0m [2mtorchmetrics==1.8.2                                                           [0m[2K[37m⠙[0m [2mjiwer==4.0.0                                                                  [0m[2K[37m⠙[0m [2mopencv-pyth

## Imports and Basic Utilities

Import required libraries and define helper functions for seeding and device selection.


In [1]:
import os
import time
import json
import csv
from pathlib import Path
from typing import Dict, List, Tuple

import numpy as np
import optuna
import timm
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, confusion_matrix

# Mixed precision scaler (PyTorch 2.5+)
from torch import amp
SCALER = amp.GradScaler('cuda', enabled=True)

def seed_everything(seed: int = 42) -> None:
    """Seed all random number generators for reproducibility."""
    import random
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)


def device_auto() -> torch.device:
    """Return the available device: CUDA if present, else CPU."""
    return torch.device("cuda" if torch.cuda.is_available() else "cpu")


## Dataset and Transform Classes

Define a simple classification transform and a dataset class that reads from a folder structure organized as `<root>/<split>/<class>/<images>`.


In [2]:
import cv2
from PIL import Image

class ClassificationTransform:
    """Apply random augmentation and normalisation for classification."""
    def __init__(self, size: int = 224, train: bool = True) -> None:
        self.size = size
        self.train = train

    def __call__(self, img_bgr: np.ndarray) -> torch.Tensor:
        # Convert BGR (OpenCV) to RGB (PIL)
        img = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB)
        if self.train:
            h, w = img.shape[:2]
            # Random scale crop
            scale = np.random.uniform(0.7, 1.0)
            nh, nw = int(h * scale), int(w * scale)
            y0 = np.random.randint(0, max(1, h - nh + 1))
            x0 = np.random.randint(0, max(1, w - nw + 1))
            img = img[y0:y0 + nh, x0:x0 + nw]
            # Random brightness/contrast
            if np.random.rand() < 0.5:
                alpha = np.random.uniform(0.8, 1.2)  # contrast
                beta = np.random.randint(-20, 20)    # brightness
                img = cv2.convertScaleAbs(img, alpha=alpha, beta=beta)
            # Random horizontal flip
            if np.random.rand() < 0.5:
                img = img[:, ::-1, :]
        # Resize to target size
        img = cv2.resize(img, (self.size, self.size), interpolation=cv2.INTER_LINEAR)
        # Convert to tensor and normalise
        t = torch.tensor(img.transpose(2, 0, 1), dtype=torch.float32) / 255.0
        # Normalise using ImageNet mean and std
        mean = torch.tensor([0.485, 0.456, 0.406]).view(3, 1, 1)
        std  = torch.tensor([0.229, 0.224, 0.225]).view(3, 1, 1)
        t = (t - mean) / std
        return t

class ClassificationDataset(Dataset):
    """A dataset for image classification organised as <root>/<split>/<class>/<images>."""
    def __init__(self, root: str, split: str, transform: ClassificationTransform, class_to_idx: Dict[str, int] = None) -> None:
        self.root = Path(root)
        self.split = split
        self.transform = transform
        self.samples: List[Tuple[str, int]] = []
        # Build class mapping
        if class_to_idx is None:
            classes = sorted([p.name for p in (self.root / split).iterdir() if p.is_dir()])
            self.class_to_idx = {c: i for i, c in enumerate(classes)}
        else:
            self.class_to_idx = dict(class_to_idx)
            # Ensure directory exists for every class
            for c in self.class_to_idx.keys():
                (self.root / split / c).mkdir(parents=True, exist_ok=True)
        # Collect samples (image path, label)
        for cls, idx in self.class_to_idx.items():
            class_dir = self.root / split / cls
            if not class_dir.exists():
                continue
            for imgp in class_dir.glob("*.*"):
                if imgp.suffix.lower() in [".jpg", ".jpeg", ".png", ".bmp", ".webp"]:
                    self.samples.append((str(imgp), idx))

    def __len__(self) -> int:
        return len(self.samples)

    def __getitem__(self, idx: int) -> Tuple[torch.Tensor, int]:
        path, label = self.samples[idx]
        img_bgr = cv2.imread(path, cv2.IMREAD_COLOR)
        if img_bgr is None:
            raise FileNotFoundError(path)
        x = self.transform(img_bgr)
        return x, label


## Training and Evaluation Functions

Define functions to train the model for one epoch and to evaluate it on a dataset, computing common metrics.


In [3]:
@torch.no_grad()
def eval_cls_metrics(model: nn.Module, loader: DataLoader, criterion: nn.Module, device: torch.device) -> Dict[str, float]:
    """Evaluate a classification model and compute common metrics."""
    model.eval()
    all_preds = []
    all_labels = []
    losses = 0.0
    nobs = 0
    for x, y in loader:
        x = x.to(device, non_blocking=True)
        y = torch.as_tensor(y, dtype=torch.long, device=device)
        with amp.autocast('cuda', dtype=torch.float16, enabled=True):
            logits = model(x)
            loss = criterion(logits, y)
        losses += float(loss.item()) * x.size(0)
        nobs += x.size(0)
        all_preds.append(logits.argmax(1).detach().cpu().numpy())
        all_labels.append(y.detach().cpu().numpy())
    y_pred = np.concatenate(all_preds) if all_preds else np.array([])
    y_true = np.concatenate(all_labels) if all_labels else np.array([])
    val_loss = losses / max(1, nobs)

    if y_true.size == 0:
        return {"loss": val_loss, "acc": 0.0, "precision": 0.0, "recall": 0.0, "f1": 0.0, "specificity": None}

    acc  = accuracy_score(y_true, y_pred)
    prec = precision_score(y_true, y_pred, average="macro", zero_division=0)
    rec  = recall_score(y_true, y_pred, average="macro", zero_division=0)
    f1   = f1_score(y_true, y_pred, average="macro", zero_division=0)
    cm   = confusion_matrix(y_true, y_pred)
    spec = None
    if cm.shape == (2, 2):
        tn, fp, fn, tp = cm.ravel()
        spec = tn / (tn + fp + 1e-12)
    return {"loss": val_loss, "acc": acc, "precision": prec, "recall": rec, "f1": f1, "specificity": spec}


def train_one_epoch_cls(model: nn.Module, loader: DataLoader, opt: torch.optim.Optimizer, criterion: nn.Module,
                        device: torch.device, epoch: int, epochs: int) -> Tuple[float, float]:
    """Train the model for a single epoch and return average loss and accuracy."""
    model.train()
    total = 0
    correct = 0
    loss_sum = 0.0
    for x, y in loader:
        x = x.to(device, non_blocking=True)
        y = torch.as_tensor(y, dtype=torch.long, device=device)
        opt.zero_grad(set_to_none=True)
        with amp.autocast('cuda', dtype=torch.float16, enabled=True):
            logits = model(x)
            loss = criterion(logits, y)
        SCALER.scale(loss).backward()
        SCALER.step(opt)
        SCALER.update()
        loss_sum += float(loss.item()) * x.size(0)
        preds = logits.argmax(1)
        correct += int((preds == y).sum().item())
        total += x.size(0)
    avg_loss = loss_sum / max(1, total)
    acc = correct / max(1, total)
    return avg_loss, acc


## Build ConvNeXt Model

Select a ConvNeXt variant and define a helper to instantiate it with a custom dropout rate and number of classes.


In [4]:
# Choose a ConvNeXt variant: 'convnext_tiny', 'convnext_small', or 'convnext_base'
MANUAL_CONVNEXT_VARIANT = 'convnext_tiny'

def build_convnext_classifier(variant: str, num_classes: int, dropout: float) -> nn.Module:
    """Create a ConvNeXt model for classification using timm."""
    model = timm.create_model(variant, pretrained=True, num_classes=num_classes, drop_rate=dropout)
    return model


## Export CIFAR-10 to Train/Val/Test Folders

The CIFAR-10 dataset is originally provided in a binary format.  This cell exports the dataset into a folder structure with separate `train`, `val`, and `test` splits.  The validation set is created by splitting the original test set in half.  Set `overwrite=False` to skip regeneration if the export already exists.


In [5]:
# CIFAR-10 export settings
from pathlib import Path
from tqdm import tqdm
from torchvision import datasets
from PIL import Image
import shutil
import numpy as np

# Paths
batches_dir = Path("/root/data/cifar10/cifar-10-batches-py")  # CIFAR-10 binary batches location
export_root = Path("/root/data/cifar10_extracted")
overwrite = False  # set True to re-export
val_fraction = 0.5  # fraction of test set to use for validation

# Ensure parent folder exists
batches_dir.parent.mkdir(parents=True, exist_ok=True)

# ✅ If CIFAR-10 not found locally, download it first
if not batches_dir.exists():
    print("[INFO] CIFAR-10 not found locally. Downloading...")
    datasets.CIFAR10(root=str(batches_dir.parent), train=True, download=True)
    datasets.CIFAR10(root=str(batches_dir.parent), train=False, download=True)
else:
    print("[OK] CIFAR-10 already exists locally.")

root_for_torchvision = batches_dir.parent  # torchvision will find 'cifar-10-batches-py' under this root

# Export if needed
if export_root.exists() and not overwrite:
    print(f"[SKIP] {export_root} already exists. Set overwrite=True to re-export.")
else:
    # Remove existing and recreate
    if export_root.exists():
        shutil.rmtree(export_root)
    (export_root / "train").mkdir(parents=True, exist_ok=True)
    (export_root / "val").mkdir(parents=True, exist_ok=True)
    (export_root / "test").mkdir(parents=True, exist_ok=True)

    # Load CIFAR-10 from existing or freshly downloaded batches
    ds_train = datasets.CIFAR10(root=str(root_for_torchvision), train=True, download=False)
    ds_test  = datasets.CIFAR10(root=str(root_for_torchvision), train=False, download=False)
    classes = ds_train.classes
    print("CIFAR-10 classes:", classes)

    # Split test dataset into val/test
    indices = np.random.permutation(len(ds_test))
    val_size = int(len(ds_test) * val_fraction)
    val_indices = indices[:val_size]
    test_indices = indices[val_size:]

    # Helper to export a portion of a dataset
    def export_dataset(dataset, indices, split_name: str):
        for c in classes:
            (export_root / split_name / c).mkdir(parents=True, exist_ok=True)
        for idx in tqdm(indices, desc=f"Exporting {split_name}"):
            img, label = dataset[idx]
            cls = classes[label]
            out_path = export_root / split_name / cls / f"{split_name}_{cls}_{idx:05d}.png"
            img.save(out_path, format="PNG", optimize=True)

    # Export train
    for c in classes:
        (export_root / "train" / c).mkdir(parents=True, exist_ok=True)
    for idx in tqdm(range(len(ds_train)), desc="Exporting train"):
        img, label = ds_train[idx]
        cls = classes[label]
        out_path = export_root / "train" / cls / f"train_{cls}_{idx:05d}.png"
        img.save(out_path, format="PNG", optimize=True)

    # Export validation and test splits
    export_dataset(ds_test, val_indices, "val")
    export_dataset(ds_test, test_indices, "test")

print(f"[DONE] CIFAR-10 exported to: {export_root.resolve()}")


[INFO] CIFAR-10 not found locally. Downloading...


  0%|                                                                               | 0.00/170M [00:00<?, ?B/s]  0%|▏                                                                      | 459k/170M [00:00<00:37, 4.51MB/s]  3%|██▎                                                                   | 5.64M/170M [00:00<00:05, 32.1MB/s]  6%|███▉                                                                  | 9.67M/170M [00:00<00:04, 35.8MB/s]  8%|█████▋                                                                | 13.9M/170M [00:00<00:04, 38.3MB/s] 10%|███████▎                                                              | 17.8M/170M [00:00<00:04, 35.6MB/s] 13%|████████▉                                                             | 21.8M/170M [00:00<00:04, 37.1MB/s] 15%|██████████▋                                                           | 26.0M/170M [00:00<00:03, 38.5MB/s] 18%|████████████▎                                                         | 29.9M/170M [00:00<00:03, 3

CIFAR-10 classes: ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']


Exporting train:   0%|                                                               | 0/50000 [00:00<?, ?it/s]Exporting train:   0%|▏                                                  | 145/50000 [00:00<00:34, 1442.18it/s]Exporting train:   1%|▎                                                  | 310/50000 [00:00<00:31, 1560.08it/s]Exporting train:   1%|▍                                                  | 467/50000 [00:00<00:32, 1513.21it/s]Exporting train:   1%|▋                                                  | 619/50000 [00:00<00:32, 1502.98it/s]Exporting train:   2%|▊                                                  | 779/50000 [00:00<00:32, 1534.46it/s]Exporting train:   2%|▉                                                  | 948/50000 [00:00<00:30, 1583.83it/s]Exporting train:   2%|█                                                 | 1117/50000 [00:00<00:30, 1617.40it/s]Exporting train:   3%|█▎                                                | 1280/50000 [00:00<00:30, 1620

[DONE] CIFAR-10 exported to: /root/data/cifar10_extracted





## Prepare DataLoaders

Define the dataset root and load the training, validation, and test datasets using the custom `ClassificationDataset` class.  Set the number of workers depending on whether a GPU is available.


In [6]:
# Path to your dataset root directory. This folder contains 'train', 'val', and 'test' subfolders.
data_root = str(Path("/root/data/cifar10_extracted").resolve())

# Image size used for resizing. ConvNeXt models typically expect 224×224 inputs.
img_size = 224

# Number of worker processes for data loading. Adjust based on your CPU cores and GPU. A fallback of 0 for CPU-only.
num_workers = 4 if torch.cuda.is_available() else 0

# Define transforms for training and validation/test using our custom transform
transform_train = ClassificationTransform(size=img_size, train=True)
transform_val   = ClassificationTransform(size=img_size, train=False)

# Load training dataset and determine class mapping
ds_train = ClassificationDataset(data_root, 'train', transform_train)
class_to_idx = ds_train.class_to_idx
num_classes = len(class_to_idx)
if num_classes < 2:
    raise RuntimeError('Need at least two classes in the training data.')

# Choose which split to use for validation: we explicitly have a 'val' folder
val_split_name = 'val'

# Load validation and test datasets with the same class mapping
ds_val  = ClassificationDataset(data_root, val_split_name, transform_val, class_to_idx=class_to_idx)
ds_test = ClassificationDataset(data_root, 'test', transform_val, class_to_idx=class_to_idx)

print(f"Classes ({num_classes}): {list(class_to_idx.keys())}")
print(f"Train: {len(ds_train)}, Val: {len(ds_val)}, Test: {len(ds_test)}")


Classes (10): ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']
Train: 50000, Val: 5000, Test: 5000


## Logging Helpers

Utilities to save per-epoch metrics into CSV files and to generate line plots and confusion matrices.  All artifacts are written under the `./artifacts` directory.


In [7]:
# Directory to store artifacts (CSV logs, curves, checkpoints)
ARTIFACT_ROOT = Path("./artifacts/cifar10_optuna")
ARTIFACT_ROOT.mkdir(parents=True, exist_ok=True)

# Flag to evaluate test set each epoch (be cautious of leakage).  Set True to log test metrics per epoch.
EVAL_TEST_EACH_EPOCH = True

def _now():
    return time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())


def save_epoch_logcsv(
    out_dir: Path,
    trial_num: int,
    split: str,
    epoch: int,
    metrics: Dict[str, float],
    extra: Dict[str, str] = None
) -> None:
    """
    Append an epoch line into out_dir/'epoch_logs.csv' with columns:
    time, trial, split, epoch, loss, acc, precision, recall, f1, specificity, <extras...>
    """
    out_dir.mkdir(parents=True, exist_ok=True)
    fpath = out_dir / "epoch_logs.csv"

    row = {
        "time": _now(),
        "trial": trial_num,
        "split": split,
        "epoch": epoch,
        "loss": metrics.get("loss", None),
        "acc": metrics.get("acc", None),
        "precision": metrics.get("precision", None),
        "recall": metrics.get("recall", None),
        "f1": metrics.get("f1", None),
        "specificity": metrics.get("specificity", None),
    }
    if extra:
        row.update(extra)

    write_header = not fpath.exists()
    with open(fpath, "a", newline="", encoding="utf-8") as f:
        w = csv.DictWriter(f, fieldnames=list(row.keys()))
        if write_header:
            w.writeheader()
        w.writerow(row)


def plot_lines(line_dict: Dict[str, List[float]], title: str, x_label: str, y_label: str, out_path: Path):
    import matplotlib.pyplot as plt
    if not line_dict:
        return
    plt.figure(figsize=(8, 5))
    max_len = max(len(v) for v in line_dict.values())
    epochs = range(1, max_len + 1)
    for name, y in line_dict.items():
        plt.plot(epochs[:len(y)], y, label=name, linewidth=2)
    plt.title(title)
    plt.xlabel(x_label)
    plt.ylabel(y_label)
    plt.legend()
    plt.tight_layout()
    plt.savefig(out_path)
    plt.close()


def plot_confusion_matrix_matplotlib(cm: np.ndarray, class_names: List[str], out_path: Path, title: str = "Confusion Matrix"):
    import matplotlib.pyplot as plt
    if cm.size == 0:
        return
    plt.figure(figsize=(6, 5))
    plt.imshow(cm, interpolation='nearest')
    plt.title(title)
    plt.colorbar()
    tick_marks = np.arange(len(class_names))
    plt.xticks(tick_marks, class_names, rotation=45)
    plt.yticks(tick_marks, class_names)
    thresh = cm.max() / 2.0 if cm.max() > 0 else 0
    for i in range(cm.shape[0]):
        for j in range(cm.shape[1]):
            plt.text(j, i, format(int(cm[i, j])), ha="center", va="center",
                     color="white" if cm[i, j] > thresh else "black")
    plt.ylabel('True label')
    plt.xlabel('Predicted label')
    plt.tight_layout()
    plt.savefig(out_path)
    plt.close()


## Optuna Objective Function

Define the objective function for Optuna.  For each trial, the model is trained for a small number of epochs with sampled hyperparameters.  Metrics for train, validation, and test (if enabled) are logged to CSV files, and loss/accuracy curves along with confusion matrices are generated for each trial.


In [8]:
def objective(trial: optuna.Trial) -> float:
    # ===== hyperparameters =====
    lr        = trial.suggest_float('lr', 1e-5, 5e-3, log=True)
    wd        = trial.suggest_float('weight_decay', 1e-6, 5e-3, log=True)
    dropout   = trial.suggest_float('dropout', 0.0, 0.4)
    opt_name  = trial.suggest_categorical('optimizer', ['adamw', 'sgd'])
    batch_size= trial.suggest_categorical('batch_size', [8, 16])
    # Force epochs to a fixed small integer for tuning
    epochs    = trial.suggest_int('epochs', 3, 3)  # small for tuning

    # Print the sampled hyperparameters for this trial
    print(f"Starting Trial {trial.number}: lr={lr:.5f}, wd={wd:.6f}, dropout={dropout:.2f}, optimizer={opt_name}, batch_size={batch_size}, epochs={epochs}", flush=True)

    device = device_auto()

    # ===== DataLoaders (uses global num_workers) =====
    dl_train = DataLoader(ds_train, batch_size=batch_size, shuffle=True,  num_workers=num_workers, pin_memory=True)
    dl_val   = DataLoader(ds_val,   batch_size=batch_size, shuffle=False, num_workers=num_workers, pin_memory=True)
    dl_test  = DataLoader(ds_test,  batch_size=batch_size, shuffle=False, num_workers=num_workers, pin_memory=True)

    # ===== Model & Optimizer =====
    model = build_convnext_classifier(MANUAL_CONVNEXT_VARIANT, num_classes, dropout).to(device)
    optimizer = (torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=wd)
                 if opt_name == 'adamw'
                 else torch.optim.SGD(model.parameters(), lr=lr, weight_decay=wd, momentum=0.9, nesterov=True))
    criterion = nn.CrossEntropyLoss()

    # ===== Logs for curves =====
    tr_losses, tr_accs = [], []
    va_losses, va_accs = [], []
    te_losses, te_accs = [], []

    # Directory for this trial
    out_trial_dir = ARTIFACT_ROOT / f"trial_{trial.number:03d}"
    out_trial_dir.mkdir(parents=True, exist_ok=True)

    best_acc = -1.0
    best_ckpt = out_trial_dir / "best_model.pth"

    for epoch in range(1, epochs+1):
        # ---- train ----
        tr_loss, tr_acc = train_one_epoch_cls(model, dl_train, optimizer, criterion, device, epoch-1, epochs)
        tr_losses.append(tr_loss); tr_accs.append(tr_acc)
        save_epoch_logcsv(out_trial_dir, trial.number, "train", epoch, {
            "loss": tr_loss, "acc": tr_acc, "precision": None, "recall": None, "f1": None, "specificity": None
        })

        # ---- val ----
        val_metrics = eval_cls_metrics(model, dl_val, criterion, device)
        va_losses.append(val_metrics['loss']); va_accs.append(val_metrics['acc'])
        save_epoch_logcsv(out_trial_dir, trial.number, "val", epoch, val_metrics)

        # Print train and validation metrics for this epoch
        print(f"Trial {trial.number} Epoch {epoch}/{epochs} - Train loss: {tr_loss:.4f}, acc: {tr_acc:.4f}; Val loss: {val_metrics['loss']:.4f}, acc: {val_metrics['acc']:.4f}", flush=True)

        # ---- test (optional) ----
        if EVAL_TEST_EACH_EPOCH:
            test_metrics = eval_cls_metrics(model, dl_test, criterion, device)
            te_losses.append(test_metrics['loss']); te_accs.append(test_metrics['acc'])
            save_epoch_logcsv(out_trial_dir, trial.number, "test", epoch, test_metrics)
            # Print test metrics for this epoch
            print(f"Trial {trial.number} Epoch {epoch}/{epochs} - Test loss: {test_metrics['loss']:.4f}, acc: {test_metrics['acc']:.4f}", flush=True)

        # Report to Optuna (minimize 1 - val_acc)
        trial.report(1.0 - val_metrics['acc'], step=epoch)
        if trial.should_prune():
            raise optuna.TrialPruned()

        # Save best by val acc
        if val_metrics['acc'] > best_acc:
            best_acc = val_metrics['acc']
            torch.save({
                "model": model.state_dict(),
                "variant": MANUAL_CONVNEXT_VARIANT,
                "num_classes": num_classes,
                "classes": list(class_to_idx.keys())
            }, best_ckpt)

    # curves
    plot_lines({"Train Loss": tr_losses, "Val Loss": va_losses},
               "Train vs Val Loss", "Epoch", "Loss", out_trial_dir / "loss_curve.png")
    plot_lines({"Train Acc": tr_accs, "Val Acc": va_accs},
               "Train vs Val Accuracy", "Epoch", "Accuracy", out_trial_dir / "acc_curve.png")
    if EVAL_TEST_EACH_EPOCH and te_losses:
        plot_lines({"Test Loss": te_losses}, "Test Loss", "Epoch", "Loss", out_trial_dir / "test_loss_curve.png")
    if EVAL_TEST_EACH_EPOCH and te_accs:
        plot_lines({"Test Acc": te_accs}, "Test Accuracy", "Epoch", "Accuracy", out_trial_dir / "test_acc_curve.png")

    # final confusion matrix on VAL and TEST of final epoch
    with torch.no_grad():
        # VAL
        all_preds_cm, all_labels_cm = [], []
        for x_val, y_val in dl_val:
            x_val = x_val.to(device, non_blocking=True)
            y_val = y_val.to(device, non_blocking=True)
            logits_val = model(x_val)
            all_preds_cm.append(logits_val.argmax(1).detach().cpu().numpy())
            all_labels_cm.append(y_val.detach().cpu().numpy())
        if all_labels_cm:
            y_true_cm = np.concatenate(all_labels_cm)
            y_pred_cm = np.concatenate(all_preds_cm)
            cm_val = confusion_matrix(y_true_cm, y_pred_cm)
            plot_confusion_matrix_matplotlib(cm_val, list(class_to_idx.keys()), out_trial_dir / "val_confusion_matrix.png", "Validation Confusion Matrix")
        # TEST
        if EVAL_TEST_EACH_EPOCH:
            all_preds_cm, all_labels_cm = [], []
            for x_te, y_te in dl_test:
                x_te = x_te.to(device, non_blocking=True)
                y_te = y_te.to(device, non_blocking=True)
                logits_te = model(x_te)
                all_preds_cm.append(logits_te.argmax(1).detach().cpu().numpy())
                all_labels_cm.append(y_te.detach().cpu().numpy())
            if all_labels_cm:
                y_true_cm = np.concatenate(all_labels_cm)
                y_pred_cm = np.concatenate(all_preds_cm)
                cm_test = confusion_matrix(y_true_cm, y_pred_cm)
                plot_confusion_matrix_matplotlib(cm_test, list(class_to_idx.keys()), out_trial_dir / "test_confusion_matrix.png", "Test Confusion Matrix")

    # Print final result of trial
    print(f"Finished Trial {trial.number}. Best validation accuracy: {best_acc:.4f}", flush=True)

    # objective: minimize 1 - best val acc
    return 1.0 - float(best_acc)


## Hyperparameter Optimization with Optuna

Create an Optuna study, run a few trials, and report the best hyperparameters found.  The number of trials is set small by default to keep runtime reasonable.


In [9]:
# Reset Optuna logging level to show information
optuna.logging.set_verbosity(optuna.logging.INFO)

# Callback to print results at the end of each trial

def print_callback(study, trial):
    print(f"[Optuna] Trial {trial.number} finished with value: {trial.value:.4f}, params: {trial.params}", flush=True)
    print(f"[Optuna]     Best value so far: {study.best_value:.4f}", flush=True)

seed_everything(42)
study = optuna.create_study(direction='minimize', sampler=optuna.samplers.TPESampler(seed=42))
study.optimize(objective, n_trials=1, timeout=None, callbacks=[print_callback])

print('Number of finished trials:', len(study.trials))
best_trial = study.best_trial

print('Best trial:')
print(f"Validation accuracy: {1.0 - best_trial.value:.4f}")
print('Hyperparameters:')
for key, value in best_trial.params.items():
    print(f"    {key}: {value}")


[I 2025-10-25 02:24:33,433] A new study created in memory with name: no-name-0cc273d4-4ae7-4b77-b666-689035e63f97


Starting Trial 0: lr=0.00010, wd=0.003286, dropout=0.29, optimizer=adamw, batch_size=8, epochs=3


model.safetensors:   0%|          | 0.00/114M [00:00<?, ?B/s]

Trial 0 Epoch 1/3 - Train loss: 0.6321, acc: 0.7788; Val loss: 0.2591, acc: 0.9102
Trial 0 Epoch 1/3 - Test loss: 0.2839, acc: 0.8996
Trial 0 Epoch 2/3 - Train loss: 0.2731, acc: 0.9069; Val loss: 0.2488, acc: 0.9158
Trial 0 Epoch 2/3 - Test loss: 0.2458, acc: 0.9164
Trial 0 Epoch 3/3 - Train loss: 0.2122, acc: 0.9278; Val loss: 0.1794, acc: 0.9396
Trial 0 Epoch 3/3 - Test loss: 0.1853, acc: 0.9392
Finished Trial 0. Best validation accuracy: 0.9396


[I 2025-10-25 02:42:14,768] Trial 0 finished with value: 0.06040000000000001 and parameters: {'lr': 0.0001025350969016849, 'weight_decay': 0.0032859708169642424, 'dropout': 0.292797576724562, 'optimizer': 'adamw', 'batch_size': 8, 'epochs': 3}. Best is trial 0 with value: 0.06040000000000001.


[Optuna] Trial 0 finished with value: 0.0604, params: {'lr': 0.0001025350969016849, 'weight_decay': 0.0032859708169642424, 'dropout': 0.292797576724562, 'optimizer': 'adamw', 'batch_size': 8, 'epochs': 3}
[Optuna]     Best value so far: 0.0604
Number of finished trials: 1
Best trial:
Validation accuracy: 0.9396
Hyperparameters:
    lr: 0.0001025350969016849
    weight_decay: 0.0032859708169642424
    dropout: 0.292797576724562
    optimizer: adamw
    batch_size: 8
    epochs: 3


## Final Training with Best Hyperparameters

Train a fresh model using the best hyperparameters discovered by Optuna.  Record metrics for train, validation, and test sets at each epoch, produce curves, and generate confusion matrices for the final epoch.


In [11]:
# Retrieve the best hyperparameters
best_params = study.best_trial.params
# Number of epochs for the final training run
final_epochs    = int(best_params.get('epochs', 5))
final_batchsize = int(best_params.get('batch_size', 16))
final_lr        = float(best_params['lr'])
final_wd        = float(best_params['weight_decay'])
final_dropout   = float(best_params['dropout'])
final_opt_name  = best_params['optimizer']

# Set up directory for final run
final_dir = ARTIFACT_ROOT / "final_best_run"
final_dir.mkdir(parents=True, exist_ok=True)

# DataLoaders
device = device_auto()
final_dl_train = DataLoader(ds_train, batch_size=final_batchsize, shuffle=True,  num_workers=num_workers, pin_memory=True)
final_dl_val   = DataLoader(ds_val,   batch_size=final_batchsize, shuffle=False, num_workers=num_workers, pin_memory=True)
final_dl_test  = DataLoader(ds_test,  batch_size=final_batchsize, shuffle=False, num_workers=num_workers, pin_memory=True)

# Model & optimizer
final_model = build_convnext_classifier(MANUAL_CONVNEXT_VARIANT, num_classes, final_dropout).to(device)
final_optimizer = (torch.optim.AdamW(final_model.parameters(), lr=final_lr, weight_decay=final_wd)
                   if final_opt_name == 'adamw'
                   else torch.optim.SGD(final_model.parameters(), lr=final_lr, weight_decay=final_wd, momentum=0.9, nesterov=True))
criterion = nn.CrossEntropyLoss()

# Containers to store metrics
train_losses, train_accs = [], []
val_losses, val_accs     = [], []
test_losses, test_accs   = [], []

best_val_acc = -1.0
best_ckpt = final_dir / "best_model.pth"

for epoch in range(1, final_epochs+1):
    # Train
    tr_loss, tr_acc = train_one_epoch_cls(final_model, final_dl_train, final_optimizer, criterion, device, epoch-1, final_epochs)
    train_losses.append(tr_loss); train_accs.append(tr_acc)
    save_epoch_logcsv(final_dir, -1, "train", epoch, {
        "loss": tr_loss, "acc": tr_acc, "precision": None, "recall": None, "f1": None, "specificity": None
    })

    # Validate
    val_metrics = eval_cls_metrics(final_model, final_dl_val, criterion, device)
    val_losses.append(val_metrics['loss']); val_accs.append(val_metrics['acc'])
    save_epoch_logcsv(final_dir, -1, "val", epoch, val_metrics)

    # Test
    test_metrics = None
    if EVAL_TEST_EACH_EPOCH:
        test_metrics = eval_cls_metrics(final_model, final_dl_test, criterion, device)
        test_losses.append(test_metrics['loss']); test_accs.append(test_metrics['acc'])
        save_epoch_logcsv(final_dir, -1, "test", epoch, test_metrics)

    # Print metrics for this epoch
    msg = f"Epoch {epoch}/{final_epochs} - Train loss: {tr_loss:.4f}, acc: {tr_acc:.4f}; Val loss: {val_metrics['loss']:.4f}, acc: {val_metrics['acc']:.4f}"
    if test_metrics is not None:
        msg += f"; Test loss: {test_metrics['loss']:.4f}, acc: {test_metrics['acc']:.4f}"
    print(msg, flush=True)

    # Save best based on validation accuracy
    if val_metrics['acc'] > best_val_acc:
        best_val_acc = val_metrics['acc']
        torch.save({
            "model": final_model.state_dict(),
            "variant": MANUAL_CONVNEXT_VARIANT,
            "num_classes": num_classes,
            "classes": list(class_to_idx.keys())
        }, best_ckpt)

# Generate curves for the final run
plot_lines({"Train Loss": train_losses, "Val Loss": val_losses}, "Final: Train vs Val Loss", "Epoch", "Loss", final_dir/"loss_curve.png")
plot_lines({"Train Acc": train_accs, "Val Acc": val_accs}, "Final: Train vs Val Accuracy", "Epoch", "Accuracy", final_dir/"acc_curve.png")
if EVAL_TEST_EACH_EPOCH and test_losses:
    plot_lines({"Test Loss": test_losses}, "Final: Test Loss", "Epoch", "Loss", final_dir/"test_loss_curve.png")
if EVAL_TEST_EACH_EPOCH and test_accs:
    plot_lines({"Test Acc": test_accs}, "Final: Test Accuracy", "Epoch", "Accuracy", final_dir/"test_acc_curve.png")

# Confusion matrices for final epoch (val and test)
with torch.no_grad():
    # Validation
    preds, labels = [], []
    for x_val, y_val in final_dl_val:
        x_val = x_val.to(device, non_blocking=True)
        y_val = y_val.to(device, non_blocking=True)
        logits_val = final_model(x_val)
        preds.append(logits_val.argmax(1).detach().cpu().numpy())
        labels.append(y_val.detach().cpu().numpy())
    if labels:
        y_true = np.concatenate(labels)
        y_pred = np.concatenate(preds)
        cm_val = confusion_matrix(y_true, y_pred)
        plot_confusion_matrix_matplotlib(cm_val, list(class_to_idx.keys()), final_dir/"val_confusion_matrix.png", "Final Validation Confusion Matrix")

    # Test
    if EVAL_TEST_EACH_EPOCH:
        preds, labels = [], []
        for x_te, y_te in final_dl_test:
            x_te = x_te.to(device, non_blocking=True)
            y_te = y_te.to(device, non_blocking=True)
            logits_te = final_model(x_te)
            preds.append(logits_te.argmax(1).detach().cpu().numpy())
            labels.append(y_te.detach().cpu().numpy())
        if labels:
            y_true = np.concatenate(labels)
            y_pred = np.concatenate(preds)
            cm_test = confusion_matrix(y_true, y_pred)
            plot_confusion_matrix_matplotlib(cm_test, list(class_to_idx.keys()), final_dir/"test_confusion_matrix.png", "Final Test Confusion Matrix")

print(f"[DONE] Final training complete. Logs and plots are stored in {final_dir}")


Epoch 1/3 - Train loss: 0.4552, acc: 0.8492; Val loss: 0.2387, acc: 0.9202; Test loss: 0.2589, acc: 0.9176
Epoch 2/3 - Train loss: 0.2710, acc: 0.9093; Val loss: 0.2354, acc: 0.9170; Test loss: 0.2327, acc: 0.9238
Epoch 3/3 - Train loss: 0.2115, acc: 0.9289; Val loss: 0.1774, acc: 0.9424; Test loss: 0.1947, acc: 0.9382
[DONE] Final training complete. Logs and plots are stored in artifacts/cifar10_optuna/final_best_run
