# TTT Sweep: Setup & Imports

In [None]:
import os
import torch
import pandas as pd

from csiro.config import DEFAULTS, DEFAULT_DATA_ROOT, DEFAULT_DINO_REPO_DIR, DINO_WEIGHTS_PATH, dino_hub_name
from csiro.data import load_train_wide, BiomassTiledCached, TiledTTADataset
from csiro.eval import ttt_sweep_cv

# TTT Sweep: Load Data & Model

In [None]:
# --- paths / config ---
TRAIN_CSV = None  # e.g. "/notebooks/kaggle/csiro/train.csv"
DATA_ROOT = DEFAULT_DATA_ROOT
DINO_REPO = DEFAULT_DINO_REPO_DIR
DINO_WEIGHTS = DINO_WEIGHTS_PATH
PT_PATHS = [
    # "/notebooks/kaggle/csiro/output/run.pt",
]
CV_PARAMS = dict(mode="gkf", cv_seed=0, n_splits=5)

IMG_SIZE = int(DEFAULTS.get("img_size", 512))
CACHE_IMAGES = True

# TTA for sweep (should match your validation/inference pipeline)
TTA_N = 4
TTA_BCS = 0.0
TTA_HUE = 0.0

# --- data ---
wide_df = load_train_wide(TRAIN_CSV, root=DATA_ROOT)
base_ds = BiomassTiledCached(wide_df, img_size=IMG_SIZE, cache_images=CACHE_IMAGES)
dataset = TiledTTADataset(
    base_ds,
    tta_n=int(TTA_N),
    bcs_val=float(TTA_BCS),
    hue_val=float(TTA_HUE),
    apply_post_tfms=True,
)

# --- backbone ---
backbone = torch.hub.load(
    DINO_REPO,
    dino_hub_name(model_size=str(DEFAULTS.get("backbone_size", "b")), plus=str(DEFAULTS.get("plus", ""))),
    source="local",
    weights=DINO_WEIGHTS,
)

# TTT Sweep: Define Tasks + Run

In [None]:
class RDropPredMSE(torch.nn.Module):
    # Example task: MSE between two dropout passes on final preds.
    def __init__(self, name: str = "rdrop_pred_mse"):
        super().__init__()
        self.name = name

    def forward(self, model, x, ctx):
        if hasattr(model, "set_train"):
            model.set_train(True)
        model.train()
        p1 = model(x)
        p2 = model(x)
        return ((p1.float() - p2.float()) ** 2).mean()


TASKS = [
    RDropPredMSE(),
]

# Param specs: strings can be module names ("head", "neck", "norm") or exact parameter names.
PARAM_SPECS = [
    ["head"],
]

SWEEPS = [
    dict(
        name="rdrop_head",
        task=TASKS[0],
        params=PARAM_SPECS[0],
        steps=1,
        lr=1e-5,
        beta=0.0,
        batch_size=16,
    ),
]

results = ttt_sweep_cv(
    dataset=dataset,
    wide_df=wide_df,
    backbone=backbone,
    pt_paths=PT_PATHS,
    cv_params=CV_PARAMS,
    sweeps=SWEEPS,
    batch_size=32,
    num_workers=DEFAULTS.get("num_workers", 4),
    device=DEFAULTS.get("device", "cuda"),
    inner_agg="mean",
    outer_agg="mean",
)

results