## Flowers102 Hyperparameter Tuning and Model Comparison Overview

This document, building upon the baseline and few-shot models presented earlier, specifically focuses on systematic hyperparameter tuning and result analysis for the Flowers102 classification task. It includes the following core objectives:
- **Unified Training Interface**:
    - Lightweight encapsulation of the original training function, enabling reuse of the same training and evaluation logic across different tuning strategies (greedy search / Optuna).
- **Hyperparameter Search**:
    - Using **staged greedy search (Greedy Tuning)**, adjusting the learning rate, weight decay, cosine classifier scaling factor, label smoothing factor, and batch size in a fixed order.
    - Using **Optuna (TPE)** for joint hyperparameter optimization, automatically sampling and evaluating candidate configurations in both continuous and discrete spaces.
- **Final Model Selection and Test Set Evaluation**:
    - Automatically inferring and loading the corresponding optimal checkpoint based on the greedy or Optuna search results.
    - Evaluate the performance difference between the "best model after hyperparameter tuning" and the "baseline model" on the TEST set, and save the comparison results as JSON for easy experiment reproduction and report writing.

This document allows for experimentation and cross-sectional comparison of various hyperparameter tuning strategies while maintaining a consistent training code structure, providing a reliable basis for the final model selection in the Flowers102 project.

# Baseline Flowers102 pipeline setup
 - Initialize the baseline Flowers102 classification experiment using shared utilities from `flowers_common`.
 - Import helpers for seeding, device selection, data loading, model construction, training, evaluation, and inference.
 - Fix random seeds to ensure deterministic behavior across Python, NumPy, and PyTorch.
 - Select the compute device (CPU, CUDA, or MPS) via `get_device_config`.
 - Build deterministic train/validation/test `DataLoader` objects for the Flowers102 dataset.

In [None]:
from flowers_common import seed_all, get_device_config, get_dataloaders
from flowers_common import build_resnet50_cosine
from flowers_common import train_model
from flowers_common import build_resnet50_cosine

seed_all(1029)
dc = get_device_config()
device = dc.device
train_loader, val_loader, test_loader = get_dataloaders(root="data", batch_size=32, img_size=224)

In [None]:
import os, json
import torch, torch.nn as nn, torch.nn.functional as F
from tqdm import tqdm
LOAD_WEIGHTS = True
CKPT_EXISTING = "ckpt/resnet50_imagenet_finetuned_v1.pth"
CKPT_NEW      = "ckpt/best_model_new.pth"

## 3.1.1: Extended train_model guard for compatibility

This section dynamically wraps an extended version of the imported train_model when it lacks the signature required for hyperparameter tuning:

- Checks if the current train_model already includes parameters such as weight_decay using inspect.signature; if it meets the requirements, it returns directly.

- If not, it defines train_model_extended as an alternative implementation, using a training loop of Adam + ReduceLROnPlateau + AMP.

- Supports weight_decay and label_smoothing, and calculates training/validation set loss and accuracy at each epoch.

- Uses an early stopping strategy and validation set performance (loss + accuracy) to track the optimal model, saving the best checkpoint as performance improves.

- Records the metrics and trial_id for each epoch to an external log via the epoch_log_cb callback for subsequent hyperparameter tuning analysis.

Finally, the global train_model is set to train_model_extended, achieving backward compatibility with the old public training function.

In [None]:
import math
from typing import Callable, Optional, Dict, Any


def _extended_train_model_guard():
    import inspect
    sig = inspect.signature(train_model)
    if "weight_decay" in sig.parameters:
        return
    orig_train_model = train_model

    def train_model_extended(
        model,
        train_loader,
        val_loader,
        epochs=30,
        lr=1e-4,
        patience=5,
        ckpt_path="ckpt/best.pth",
        *,
        weight_decay: float = 0.0,
        label_smoothing: float = 0.0,
        epoch_log_cb: Optional[Callable[[Dict[str, Any]], None]] = None,
        trial_id: Optional[str] = None,
    ):
        import os, time
        from torch.optim import Adam
        from torch.optim.lr_scheduler import ReduceLROnPlateau
        from torch.cuda.amp import autocast, GradScaler
        import torch.nn as nn
        import torch

        os.makedirs("ckpt", exist_ok=True)
        model = model.to(device)

        optimizer = Adam(model.parameters(), lr=lr, weight_decay=weight_decay)
        criterion = nn.CrossEntropyLoss(label_smoothing=label_smoothing)
        scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=2)
        scaler = GradScaler(enabled=True)

        best_val_loss = float("inf")
        best_epoch = -1
        best_val_acc = -1.0
        patience_counter = 0

        for epoch in range(1, epochs + 1):
            # ---- Train ----
            model.train()
            total_loss = 0.0
            correct, total = 0, 0
            t0 = time.time()

            for imgs, labels in tqdm(train_loader, desc=f"Epoch {epoch}/{epochs}"):
                imgs = imgs.to(device, non_blocking=True)
                labels = labels.to(device, non_blocking=True)

                optimizer.zero_grad(set_to_none=True)
                with autocast(enabled=True):
                    logits = model(imgs)
                    loss = criterion(logits, labels)
                scaler.scale(loss).backward()
                scaler.step(optimizer)
                scaler.update()

                total_loss += loss.item()
                preds = logits.argmax(1)
                correct += (preds == labels).sum().item()
                total += labels.size(0)

            train_loss = total_loss / max(1, len(train_loader))
            train_acc = correct / max(1, total)

            # ---- Val ----
            model.eval()
            val_loss, val_correct, val_total = 0.0, 0, 0
            with torch.no_grad():
                for imgs, labels in val_loader:
                    imgs = imgs.to(device, non_blocking=True)
                    labels = labels.to(device, non_blocking=True)
                    with autocast(enabled=True):
                        logits = model(imgs)
                        loss = criterion(logits, labels)
                    val_loss += loss.item()
                    preds = logits.argmax(1)
                    val_correct += (preds == labels).sum().item()
                    val_total += labels.size(0)

            val_loss /= max(1, len(val_loader))
            val_acc = val_correct / max(1, val_total)
            scheduler.step(val_loss)

            print(f"Epoch {epoch:02d} | Train Loss: {train_loss:.4f} | Train Acc: {train_acc:.4f} | "
                  f"Val Loss: {val_loss:.4f} | Val Acc: {val_acc:.4f}")

            # ---- Early stopping & best tracking ----
            improved = val_loss < best_val_loss or (math.isclose(val_loss, best_val_loss, rel_tol=0.0, abs_tol=1e-12) and val_acc > best_val_acc)
            if improved:
                best_val_loss = val_loss
                best_val_acc = val_acc
                best_epoch = epoch
                patience_counter = 0
                torch.save(model.state_dict(), ckpt_path)
                print(f"Saved best model to {ckpt_path}")
            else:
                patience_counter += 1
                print(f" EarlyStopping counter {patience_counter}/{patience}")
                if patience_counter >= patience:
                    print(" Early stopping triggered.")
                    break

            # ---- Optional epoch log callback ----
            if epoch_log_cb is not None:
                epoch_log_cb({
                    "trial_id": trial_id,
                    "epoch": epoch,
                    "train_loss": train_loss,
                    "train_acc": train_acc,
                    "val_loss": val_loss,
                    "val_acc": val_acc,
                    "lr": float(optimizer.param_groups[0]["lr"]),
                    "elapsed_sec": time.time() - t0,
                })

        return {
            "best_epoch": best_epoch,
            "best_val_loss": float(best_val_loss),
            "best_val_acc": float(best_val_acc),
            "ckpt_path": ckpt_path,
        }

    globals()["train_model"] = train_model_extended

_extended_train_model_guard()


## 3.1.2: Tuning utilities

This section implements reusable tools for the hyperparameter tuning phase:

- JSONLLogger: Appends training metrics for each epoch/trial to a JSON Lines log file for easy visualization and analysis later.

- set_cosine_scale: Sets the scaling factor 's' for CosineClassifier uniformly, ensuring consistent decision boundaries across different experiments.

- build_model_with_safe_load: Rebuilds the ResNet50+CosineClassifier model and safely loads backbone/FC parameters when fine-tuning weights are available.

- ensure_included: Ensures baseline configurations are included when constructing the hyperparameter search space, facilitating direct performance comparison with "untuned" models.

In [None]:
# =============================================================================
# 3.1.2. Tuning utilities
# =============================================================================

class JSONLLogger:
    """Append JSON objects per line for later analysis."""
    def __init__(self, path: str):
        self.path = path
        os.makedirs(os.path.dirname(self.path), exist_ok=True)
    def log(self, obj: Dict[str, Any]):
        with open(self.path, "a", encoding="utf-8") as f:
            f.write(json.dumps(obj, ensure_ascii=False) + "\n")

def set_cosine_scale(model: nn.Module, s: float) -> None:
    """Set the CosineClassifier scale if present."""
    try:
        model[1].s = float(s)
    except Exception:
        pass

def build_model_with_safe_load() -> nn.Module:
    """Rebuild a fresh model and reuse the same safe-load logic as baseline."""
    m = build_resnet50_cosine(num_classes=102, pretrained=True)
    # Reuse the transplant logic if the file exists.
    if 'LOAD_WEIGHTS' in globals() and LOAD_WEIGHTS and 'CKPT_EXISTING' in globals() and os.path.exists(CKPT_EXISTING):
        try:
            state = torch.load(CKPT_EXISTING, map_location=device)
            if isinstance(state, dict) and "state_dict" in state:
                state = state["state_dict"]
            m[0].load_state_dict(state, strict=False)
            # Attempt FC to cosine transplant
            def first_present(d, keys):
                for k in keys:
                    if k in d:
                        return k
                return None
            w_key = first_present(state, ["fc.1.weight", "fc.weight"])
            W = state[w_key] if w_key else None
            if W is not None and W.shape == m[1].weight.shape:
                with torch.no_grad():
                    m[1].weight.copy_(F.normalize(W, dim=1))
                    m[1].s = 30.0
        except Exception:
            pass
    return m

def ensure_included(seq, base):
    """Ensure baseline value is included while preserving readable order."""
    if base in seq:
        return list(seq)
    return [base] + list(seq)


## 3.1.3: Greedy tuner

This section implements a phased greedy hyperparameter searcher:

- The search order is fixed: learning rate (lr) → weight decay (weight_decay) → cosine head scaling (s) → label smoothing (label_smoothing) → batch_size.

- Each phase, based on the optimal configuration of the previous phase, only changes the hyperparameters to be tuned, training on candidate values ​​one by one and evaluating on the validation set.

- The evaluation metric is primarily validation set accuracy; when accuracies are equal, the lower validation set loss is prioritized.

- The training process is written to logs/train_process.jsonl via JSONLLogger, and the summary results of each trial are written to logs/greedy_results.json.

- This module ultimately provides a set of "optimal hyperparameter configurations" obtained through the greedy search, providing input for subsequent visualization and testing evaluation.

In [None]:
# =============================================================================
# 3.1.3. Greedy tuner
# =============================================================================
def greedy_tune():
    """
    Greedy search over discrete grids in fixed order:
    lr -> weight_decay -> s -> label_smoothing -> batch_size.
    Use validation accuracy as primary metric and validation loss as tie-breaker.
    Output files:
      - logs/train_process.jsonl : per-epoch records across trials
      - logs/greedy_results.json : trial summaries and final best config
    """
    # 1) Define search spaces (baseline included)
    BASELINE = {"lr": 1e-4, "weight_decay": 0.0, "s": 30.0, "label_smoothing": 0.0, "batch_size": 32}
    SPACE = {
        "lr":               ensure_included([3e-5, 2e-4, 3e-4, 5e-4], BASELINE["lr"]),
        "weight_decay":     ensure_included([1e-6, 1e-5, 5e-5, 1e-4, 3e-4, 0], BASELINE["weight_decay"]),
        "s":                ensure_included([10, 16, 24, 40, 64], BASELINE["s"]),
        "label_smoothing":  ensure_included([0.05, 0.1, 0.15], BASELINE["label_smoothing"]),
        "batch_size":       ensure_included([16, 48, 64], BASELINE["batch_size"]),
    }
    ORDER = ["lr", "weight_decay", "s", "label_smoothing", "batch_size"]

    # 2) Common training knobs
    EPOCHS = 30
    PATIENCE = 5
    IMG_SIZE = 224
    NUM_WORKERS = 0
    PIN_MEMORY = False
    AUGMENT = True
    SEED = 1029

    # 3) Initialize loggers
    process_logger = JSONLLogger("logs/train_process.jsonl")
    results: Dict[str, Any] = {
        "baseline": BASELINE,
        "order": ORDER,
        "search_space": SPACE,
        "trials": [],
    }

    # 4) Current best config (start with baseline)
    best_cfg = dict(BASELINE)
    best_of_stage: Dict[str, Any] = {}

    # Prepare initial loaders; reuse when batch size is unchanged
    train_loader_g, val_loader_g, test_loader_g = get_dataloaders(
        batch_size=best_cfg["batch_size"], img_size=IMG_SIZE,
        num_workers=NUM_WORKERS, pin_memory=PIN_MEMORY,
        augment=AUGMENT, seed=SEED
    )

    def build_loaders_for(bs: int):
        nonlocal train_loader_g, val_loader_g, test_loader_g
        if bs == train_loader_g.batch_size:
            return train_loader_g, val_loader_g, test_loader_g
        # Rebuild deterministically for a new batch size
        return get_dataloaders(
            batch_size=bs, img_size=IMG_SIZE,
            num_workers=NUM_WORKERS, pin_memory=PIN_MEMORY,
            augment=AUGMENT, seed=SEED
        )

    trial_counter = 0

    for stage_param in ORDER:
        stage_best = {"val_acc": -1.0, "val_loss": float("inf"), "value": None, "summary": None}
        print(f"\n[GREEDY] Stage: {stage_param} | Candidates: {SPACE[stage_param]}")
        for cand in SPACE[stage_param]:
            cfg = dict(best_cfg)
            cfg[stage_param] = cand
            trial_counter += 1
            trial_id = f"stage_{stage_param}_trial_{trial_counter}_val_{str(cand)}"

            # Build loaders (only if batch size changes)
            try:
                train_loader, val_loader, _ = build_loaders_for(cfg["batch_size"])
            except RuntimeError as e:
                # e.g., DataLoader worker init issues; record and continue
                results["trials"].append({
                    "trial_id": trial_id, "status": "failed", "reason": str(e), "config": cfg
                })
                continue

            # Fresh model per trial; same initialization and weight-loading policy as baseline
            model_t = build_model_with_safe_load()
            set_cosine_scale(model_t, cfg["s"])

            ckpt_dir = os.path.join("ckpt", "greedy", stage_param, str(cand))
            os.makedirs(ckpt_dir, exist_ok=True)
            ckpt_path = os.path.join(ckpt_dir, "best.pth")

            print(f"[GREEDY] Trial {trial_id} | cfg={cfg}")
            status = "ok"
            summary = None
            try:
                summary = train_model(
                    model_t,
                    train_loader,
                    val_loader,
                    epochs=EPOCHS,
                    lr=cfg["lr"],
                    patience=PATIENCE,
                    ckpt_path=ckpt_path,
                    weight_decay=cfg["weight_decay"],
                    label_smoothing=cfg["label_smoothing"],
                    epoch_log_cb=lambda row: process_logger.log({
                        **row,
                        "param": stage_param,
                        "value": cand,
                        "config": cfg,
                    }),
                    trial_id=trial_id,
                )
                # Track stage best: prioritize higher val_acc; tie-break on lower val_loss
                cur_va, cur_vl = summary["best_val_acc"], summary["best_val_loss"]
                better = (cur_va > stage_best["val_acc"]) or (
                    math.isclose(cur_va, stage_best["val_acc"], rel_tol=0.0, abs_tol=1e-12) and cur_vl < stage_best["val_loss"]
                )
                if better:
                    stage_best.update({"val_acc": cur_va, "val_loss": cur_vl, "value": cand, "summary": summary})
                results["trials"].append({
                    "trial_id": trial_id, "status": status, "config": cfg,
                    "best_epoch": summary["best_epoch"],
                    "best_val_acc": summary["best_val_acc"],
                    "best_val_loss": summary["best_val_loss"],
                    "ckpt_path": summary["ckpt_path"],
                })
            except RuntimeError as e:
                # Handle OOM gracefully for large batch sizes, etc.
                status = "failed_oom" if "out of memory" in str(e).lower() else "failed"
                print(f"[WARN] Trial {trial_id} failed: {e}")
                results["trials"].append({
                    "trial_id": trial_id, "status": status, "reason": str(e), "config": cfg
                })
            finally:
                # Free CUDA memory between trials
                try:
                    del model_t
                    torch.cuda.empty_cache()
                except Exception:
                    pass

        # Commit stage winner
        if stage_best["value"] is not None:
            best_cfg[stage_param] = stage_best["value"]
            best_of_stage[stage_param] = stage_best
        print(f"[GREEDY] Stage {stage_param} best: {stage_best['value']} "
              f"(val_acc={stage_best['val_acc']:.4f}, val_loss={stage_best['val_loss']:.4f})")

    results["best_config"] = best_cfg
    results["stage_best"] = best_of_stage

    os.makedirs("logs", exist_ok=True)
    with open("logs/greedy_results.json", "w", encoding="utf-8") as f:
        json.dump(results, f, ensure_ascii=False, indent=2)

    print("\n[GREEDY] Done.")
    print(f"[GREEDY] Final best config: {best_cfg}")
    print('[GREEDY] Files written: "logs/train_process.jsonl", "logs/greedy_results.json"')


RUN_GREEDY = True  # set to False to disable
if RUN_GREEDY:
    greedy_tune()

## 3.1.4: Optuna-based Hyperparameter Optimization

This section uses Optuna for joint hyperparameter optimization as a supplement to the greedy search:

- Introducing Optuna and TPESampler to automatically sample hyperparameter combinations in continuous/discrete spaces.

- `run_optuna_tuning`:

    - Defines the same or compatible hyperparameter search space as the greedy search and explicitly includes baseline configurations for fair comparison.

    - Rebuilds the model and data loader for each trial, trains using an extended version of `train_model`, and logs the metrics for each epoch to `logs/optuna_train_process.json`.

    - Writes the best epoch, validation set performance, and ckpt path for each trial to `logs/optuna_results.json`.

    - This section shares the same log format as the previous greedy hyperparameter tuning for unified visualization and post-processing.

    - Controls whether the HPO is actually run via additional entry switches (such as `RUN_OPTUNA_TUNING` or the `main` function) to avoid unnecessary long searches in the Notebook.

In [None]:
# =============================================================================
# 3.1.4. Optuna-based hyperparameter optimization
# =============================================================================
import optuna
from optuna.samplers import TPESampler

def run_optuna_tuning():
    """
    Joint HPO with Optuna (TPE). Logging schema matches the greedy version:
      - Per-epoch: logs/optuna_train_process.jsonl (JSON Lines)
        Keys: trial_id, epoch, train_loss, train_acc, val_loss, val_acc, lr, elapsed_sec,
               and for schema parity: param='optuna', value=None, config=dict
      - Per-trial summary: logs/optuna_results.json (JSON)
        Keys: trial_id, status, config, best_epoch, best_val_acc, best_val_loss, ckpt_path
    """
    assert optuna is not None, "Optuna not installed."

    # ---- (1) Search space with baseline included ----
    BASELINE = {"lr": 1e-4, "weight_decay": 0.0, "s": 30.0, "label_smoothing": 0.0, "batch_size": 32}
    SPACE = {
        "lr":               {"type": "log_float", "low": 3e-5, "high": 5e-4},
        "weight_decay":     {"type": "categorical", "choices": [0.0, 1e-6, 1e-5, 5e-5, 1e-4, 3e-4]},
        "s":                {"type": "categorical", "choices": [10, 16, 24, 30, 40, 64]},
        "label_smoothing":  {"type": "categorical", "choices": [0.0, 0.05, 0.1, 0.15]},
        "batch_size":       {"type": "categorical", "choices": [16, 32, 48, 64]},
    }

    # ---- (2) Common knobs: keep identical to greedy search ----
    EPOCHS = 30
    PATIENCE = 5
    IMG_SIZE = 224
    NUM_WORKERS = 0
    PIN_MEMORY = False
    AUGMENT = True
    SEED = 1029

    # ---- (3) Loggers and results collector ----
    process_logger = JSONLLogger("logs/optuna_train_process.jsonl")
    os.makedirs("logs", exist_ok=True)
    results = {
        "baseline": BASELINE,
        "search_space": SPACE,
        "sampler": "TPE",
        "order": ["joint"],
        "trials": [],
    }

    # ---- (4) Objective definition ----
    def objective(trial: "optuna.trial.Trial") -> float: # type: ignore
        # Sample params (include baseline later via enqueue_trial)
        if SPACE["lr"]["type"] == "log_float":
            lr = trial.suggest_float("lr", SPACE["lr"]["low"], SPACE["lr"]["high"], log=True)
        else:
            lr = trial.suggest_float("lr", SPACE["lr"]["low"], SPACE["lr"]["high"])

        weight_decay = trial.suggest_categorical("weight_decay", SPACE["weight_decay"]["choices"])
        s = trial.suggest_categorical("s", SPACE["s"]["choices"])
        label_smoothing = trial.suggest_categorical("label_smoothing", SPACE["label_smoothing"]["choices"])
        batch_size = trial.suggest_categorical("batch_size", SPACE["batch_size"]["choices"])

        cfg = {
            "lr": float(lr),
            "weight_decay": float(weight_decay),
            "s": float(s),
            "label_smoothing": float(label_smoothing),
            "batch_size": int(batch_size),
        }

        # Data loaders (deterministic build)
        try:
            train_loader, val_loader, _ = get_dataloaders(
                batch_size=cfg["batch_size"], img_size=IMG_SIZE,
                num_workers=NUM_WORKERS, pin_memory=PIN_MEMORY,
                augment=AUGMENT, seed=SEED
            )
        except RuntimeError as e:
            # Worker init or other loader errors; mark failed trial
            results["trials"].append({
                "trial_id": f"optuna_trial_{trial.number}",
                "status": "failed",
                "reason": str(e),
                "config": cfg
            })
            return 0.0

        # Fresh model per trial; reuse baseline safe-load/transplant policy
        model_t = build_model_with_safe_load()
        set_cosine_scale(model_t, cfg["s"])

        ckpt_dir = os.path.join("ckpt", "optuna", f"trial_{trial.number}")
        os.makedirs(ckpt_dir, exist_ok=True)
        ckpt_path = os.path.join(ckpt_dir, "best.pth")

        trial_id = f"optuna_trial_{trial.number}"

        # Per-epoch logger to match schema; no mid-epoch pruning to keep training identical
        def _epoch_cb(row: Dict[str, Any]):
            # Keep keys consistent with greedy logger
            process_logger.log({
                **row,
                "param": "optuna",
                "value": None,
                "config": cfg,
            })
            # Optional: report to Optuna for visualization; no pruning to keep parity
            try:
                trial.report(float(row["val_acc"]), step=int(row["epoch"]))
            except Exception:
                pass

        status = "ok"
        try:
            summary = train_model(
                model_t,
                train_loader,
                val_loader,
                epochs=EPOCHS,
                lr=cfg["lr"],
                patience=PATIENCE,
                ckpt_path=ckpt_path,
                weight_decay=cfg["weight_decay"],
                label_smoothing=cfg["label_smoothing"],
                epoch_log_cb=_epoch_cb,
                trial_id=trial_id,
            )
            results["trials"].append({
                "trial_id": trial_id,
                "status": status,
                "config": cfg,
                "best_epoch": summary["best_epoch"],
                "best_val_acc": float(summary["best_val_acc"]),
                "best_val_loss": float(summary["best_val_loss"]),
                "ckpt_path": summary["ckpt_path"],
            })
            return float(summary["best_val_acc"])
        except RuntimeError as e:
            status = "failed_oom" if "out of memory" in str(e).lower() else "failed"
            print(f"[WARN] Trial {trial_id} failed: {e}")
            results["trials"].append({
                "trial_id": trial_id,
                "status": status,
                "reason": str(e),
                "config": cfg
            })
            return 0.0
        finally:
            try:
                del model_t
                torch.cuda.empty_cache()
            except Exception:
                pass

    # ---- (5) Create study and run ----
    study = optuna.create_study(direction="maximize", sampler=TPESampler(seed=SEED))
    # Ensure baseline is evaluated
    study.enqueue_trial(BASELINE)

    OPTUNA_TRIALS = 16  # adjust as needed; includes the enqueued baseline
    study.optimize(objective, n_trials=OPTUNA_TRIALS, n_jobs=1)

    # ---- (6) Persist study-level results with same schema keys ----
    results.update({
        "study_best_value": float(study.best_value) if study.best_trial is not None else None,
        "best_config": dict(study.best_params) if study.best_trial is not None else None,
        "n_trials": len(study.trials),
    })
    with open("logs/optuna_results.json", "w", encoding="utf-8") as f:
        json.dump(results, f, ensure_ascii=False, indent=2)

    print("\n[OPTUNA] Done.")
    if study.best_trial is not None:
        print(f"[OPTUNA] Best val_acc={study.best_value:.4f} with params={study.best_params}")
    print('[OPTUNA] Files written: "logs/optuna_train_process.jsonl", "logs/optuna_results.json"')

# Entry point switch
RUN_OPTUNA = True  # set to False to disable Optuna run
if optuna is not None and RUN_OPTUNA:
    run_optuna_tuning()



## 3.1.5: Display best configurations and compare on test set

This section summarizes and compares the final performance of the baseline, greedy, and optuna configurations on the TEST set:

- Use `_safe_load_json` to read `logs/greedy_results.json` and `logs/optuna_results.json` to obtain the best configuration and summary information for each trial;

- Use `_cfg_equal` and `_find_trial_ckpt` to find checkpoints matching the target configuration in the results. If a strict match fails, fall back to the trial with the highest validation accuracy;

- `_load_model_from_ckpt` reconstructs the model based on the configuration, sets the scaling factor `s` of the `CosineClassifier`, and loads the corresponding checkpoint weights;

- `_build_test_loader` reuses the construction logic of Flowers102, constructing a TEST DataLoader according to the given batch size, maintaining the same preprocessing strategy as the training/parameter tuning phase;

- `_evaluate_on_test` calculates the accuracy and... macro/weighted precision/recall/F1 is used for quantitative comparison of model performance;

- compare_best_and_baseline evaluates baseline, greedy_best, and optuna_best sequentially, prints key results, and writes the complete comparison to logs/final_compare.json;

- RUN_FINAL_COMPARE acts as a switch to control whether this final comparison process is automatically executed when the script or Notebook starts.

In [None]:
# =============================================================================
# 3.1.5. Display best configs and compare on test set
# =============================================================================
# This block reads best configs from greedy/optuna result files, loads their checkpoints,
# evaluates on the test set, compares with baseline, and writes logs/final_compare.json.

from typing import Dict, Any, Optional
from copy import deepcopy

def _safe_load_json(path: str) -> Optional[Dict[str, Any]]:
    try:
        with open(path, "r", encoding="utf-8") as f:
            return json.load(f)
    except FileNotFoundError:
        print(f"[INFO] File not found: {path}")
        return None

def _cfg_equal(a: Dict[str, Any], b: Dict[str, Any], eps: float = 1e-8) -> bool:
    """Float-tolerant equality for configs."""
    if a.keys() != b.keys():
        return False
    for k in a:
        va, vb = a[k], b[k]
        if isinstance(va, float) or isinstance(vb, float):
            if not (abs(float(va) - float(vb)) <= eps):
                return False
        else:
            if va != vb:
                return False
    return True

def _find_trial_ckpt(results_json: Dict[str, Any], target_cfg: Dict[str, Any]) -> Optional[str]:
    """Find ckpt_path for the exact target config in a results JSON (greedy/optuna)."""
    trials = results_json.get("trials", [])
    for t in trials:
        cfg = t.get("config")
        if not cfg or t.get("status") not in ("ok",):
            continue
        if _cfg_equal(cfg, target_cfg):
            return t.get("ckpt_path")
    # Fallback: best across trials, if exact match not found
    best = None
    for t in trials:
        if t.get("status") != "ok":
            continue
        va = float(t.get("best_val_acc", 0.0))
        if best is None or va > best[0]:
            best = (va, t.get("ckpt_path"))
    return best[1] if best else None

def _load_model_from_ckpt(cfg: Dict[str, Any], ckpt_path: str):
    """Rebuild model, set cosine scale, and load weights."""
    m = build_model_with_safe_load()
    set_cosine_scale(m, float(cfg.get("s", 30.0)))
    assert os.path.exists(ckpt_path), f"Checkpoint not found: {ckpt_path}"
    state = torch.load(ckpt_path, map_location=device)
    m.load_state_dict(state, strict=True)
    return m.to(device)

def _evaluate_on_test(model, test_loader) -> Dict[str, Any]:
    """Evaluate on test set and return accuracy + macro/weighted averages."""
    from sklearn.metrics import classification_report, accuracy_score
    model.eval()
    all_preds, all_labels = [], []
    with torch.no_grad():
        for imgs, labels in tqdm(test_loader, desc="Testing(best)"):
            imgs = imgs.to(device, non_blocking=True)
            logits = model(imgs)
            preds = logits.argmax(1).detach().cpu().numpy().tolist()
            all_preds.extend(preds)
            all_labels.extend(labels.cpu().numpy().tolist())
    acc = float(accuracy_score(all_labels, all_preds))
    rep = classification_report(all_labels, all_preds, output_dict=True, zero_division=0)
    return {
        "accuracy": acc,
        "macro_avg": {
            "precision": float(rep.get("macro avg", {}).get("precision", 0.0)),
            "recall": float(rep.get("macro avg", {}).get("recall", 0.0)),
            "f1": float(rep.get("macro avg", {}).get("f1-score", 0.0)),
        },
        "weighted_avg": {
            "precision": float(rep.get("weighted avg", {}).get("precision", 0.0)),
            "recall": float(rep.get("weighted avg", {}).get("recall", 0.0)),
            "f1": float(rep.get("weighted avg", {}).get("f1-score", 0.0)),
        },
    }

def _build_test_loader(batch_size: int):
    """Reuse the deterministic loader builder to get the test loader."""
    _, _, test_loader = get_dataloaders(
        batch_size=batch_size,
        img_size=224,
        num_workers=0,
        pin_memory=False,
        augment=True,   # keep identical preprocessing policy
        seed=1029
    )
    return test_loader

def compare_best_and_baseline():
    os.makedirs("logs", exist_ok=True)
    out = {"baseline": None, "greedy_best": None, "optuna_best": None}

    # ---- Baseline config and checkpoint ----
    baseline_cfg = {"lr": 1e-4, "weight_decay": 0.0, "s": 30.0, "label_smoothing": 0.0, "batch_size": 32}
    baseline_ckpt = "ckpt/best_cosine_cuda.pth"  # set earlier in baseline training
    if os.path.exists(baseline_ckpt):
        try:
            test_loader = _build_test_loader(baseline_cfg["batch_size"])
            model_b = build_model_with_safe_load()
            set_cosine_scale(model_b, baseline_cfg["s"])
            state = torch.load(baseline_ckpt, map_location=device)
            model_b.load_state_dict(state, strict=True)
            metrics_b = _evaluate_on_test(model_b, test_loader)
            out["baseline"] = {
                "config": deepcopy(baseline_cfg),
                "ckpt_path": baseline_ckpt,
                "test_metrics": metrics_b,
            }
            print(f"[COMPARE] Baseline test accuracy: {metrics_b['accuracy']:.4f}")
        except Exception as e:
            print(f"[WARN] Baseline evaluation failed: {e}")
    else:
        print(f"[INFO] Baseline checkpoint not found: {baseline_ckpt}")

    # ---- Greedy best ----
    greedy_json = _safe_load_json("logs/greedy_results.json")
    if greedy_json and "best_config" in greedy_json:
        g_cfg = greedy_json["best_config"]
        g_ckpt = _find_trial_ckpt(greedy_json, g_cfg)
        if g_ckpt and os.path.exists(g_ckpt):
            try:
                test_loader = _build_test_loader(int(g_cfg.get("batch_size", 32)))
                model_g = _load_model_from_ckpt(g_cfg, g_ckpt)
                metrics_g = _evaluate_on_test(model_g, test_loader)
                out["greedy_best"] = {
                    "config": deepcopy(g_cfg),
                    "ckpt_path": g_ckpt,
                    "test_metrics": metrics_g,
                }
                print(f"[COMPARE] Greedy-best test accuracy: {metrics_g['accuracy']:.4f} | cfg={g_cfg}")
            except Exception as e:
                print(f"[WARN] Greedy-best evaluation failed: {e}")
        else:
            print("[INFO] Greedy-best checkpoint not found or missing exact match.")
    else:
        print("[INFO] Greedy results JSON not found or missing best_config.")

    # ---- Optuna best ----
    optuna_json = _safe_load_json("logs/optuna_results.json")
    if optuna_json and "best_config" in optuna_json:
        o_cfg = optuna_json["best_config"]
        # Ensure full schema with defaults in case sampler omitted some keys
        for k, v in {"weight_decay": 0.0, "s": 30.0, "label_smoothing": 0.0, "batch_size": 32}.items():
            o_cfg.setdefault(k, v)
        o_ckpt = _find_trial_ckpt(optuna_json, o_cfg)
        if o_ckpt and os.path.exists(o_ckpt):
            try:
                test_loader = _build_test_loader(int(o_cfg.get("batch_size", 32)))
                model_o = _load_model_from_ckpt(o_cfg, o_ckpt)
                metrics_o = _evaluate_on_test(model_o, test_loader)
                out["optuna_best"] = {
                    "config": deepcopy(o_cfg),
                    "ckpt_path": o_ckpt,
                    "test_metrics": metrics_o,
                }
                print(f"[COMPARE] Optuna-best test accuracy: {metrics_o['accuracy']:.4f} | cfg={o_cfg}")
            except Exception as e:
                print(f"[WARN] Optuna-best evaluation failed: {e}")
        else:
            print("[INFO] Optuna-best checkpoint not found or missing exact match.")
    else:
        print("[INFO] Optuna results JSON not found or missing best_config.")

    # ---- Write comparison file ----
    with open("logs/final_compare.json", "w", encoding="utf-8") as f:
        json.dump(out, f, ensure_ascii=False, indent=2)
    print('[COMPARE] Written "logs/final_compare.json".')

# Entry point switch
RUN_FINAL_COMPARE = True  # set to False to skip the final comparison
if RUN_FINAL_COMPARE:
    compare_best_and_baseline()

