# TTT Sweep: Setup & Imports

In [2]:
import os
import sys
import torch
import pandas as pd

CSIRO_CODE_DIR = "/notebooks/CSIRO"
sys.path.insert(0, CSIRO_CODE_DIR)
os.environ["DEFAULT_DINO_REPO_DIR"]="/notebooks/dinov3"
os.environ["DEFAULT_DATA_ROOT"]="/notebooks/kaggle/csiro"
os.environ["DINO_WEIGHTS_PATH"] = "/notebooks/kaggle/csiro/weights/dinov3/dinov3_vitb16_pretrain.pth"
os.environ["DINO_B_WEIGHTS_PATH"] = "/notebooks/kaggle/csiro/weights/dinov3/dinov3_vitb16_pretrain.pth"
os.environ["DINO_L_WEIGHTS_PATH"] = "/notebooks/kaggle/csiro/weights/dinov3/dinov3_vitl16_pretrain.pth"

from csiro.config import DEFAULTS, DEFAULT_DATA_ROOT, DEFAULT_DINO_REPO_DIR, DINO_WEIGHTS_PATH, dino_hub_name
from csiro.data import load_train_wide, BiomassTiledCached, TiledTransformView
from csiro.eval import ttt_sweep_cv
from csiro.transforms import post_tfms

# --- paths / config ---
TRAIN_CSV = "/notebooks/kaggle/csiro/train.csv"  # e.g. "/notebooks/kaggle/csiro/train.csv"
DATA_ROOT = DEFAULT_DATA_ROOT
DINO_REPO = DEFAULT_DINO_REPO_DIR
DINO_WEIGHTS = DINO_WEIGHTS_PATH
PT_PATHS = [
     "/notebooks/kaggle/csiro/output/v7_rdrop001head_drop02_0f5.pt",
]
CV_PARAMS = dict(mode="gkf", cv_seed=126015, n_splits=5)

IMG_SIZE = int(DEFAULTS.get("img_size", 512))
CACHE_IMAGES = True


# TTT Sweep: Load Data & Model

In [3]:
# --- data ---
wide_df = load_train_wide(TRAIN_CSV, root=DATA_ROOT)
base_ds = BiomassTiledCached(wide_df, img_size=IMG_SIZE, cache_images=CACHE_IMAGES)
dataset = TiledTransformView(
    base_ds,
    post_tfms(),
    tile_swap=False,
)

# --- backbone ---
backbone = torch.hub.load(
    DINO_REPO,
    dino_hub_name(model_size=str(DEFAULTS.get("backbone_size", "b")), plus=str(DEFAULTS.get("plus", ""))),
    source="local",
    weights=DINO_WEIGHTS,
)


# TTT Sweep: Define Tasks + Run

In [5]:
class RDropPredMSE(torch.nn.Module):
    # Example task: MSE between two dropout passes on final preds.
    def __init__(self, name: str = "rdrop_pred_mse"):
        super().__init__()
        self.name = name

    def forward(self, model, x, ctx):
        if hasattr(model, "set_train"):
            model.set_train(True)
        model.train()
        p1 = model(x)
        p2 = model(x)
        return ((p1.float() - p2.float()) ** 2).mean()


TASKS = [
    RDropPredMSE(),
]

# Param specs: strings can be module names ("head", "neck", "norm") or exact parameter names.
PARAM_SPECS = [
    ["head"],
]

SWEEPS = [
    dict(
        name="rdrop_head",
        task=TASKS[0],
        params=PARAM_SPECS[0],
        steps=4,
        lr=1e-3,
        beta=0.0,
        batch_size=1,
    ),
]

results = ttt_sweep_cv(
    dataset=dataset,
    wide_df=wide_df,
    backbone=backbone,
    pt_paths=PT_PATHS,
    cv_params=CV_PARAMS,
    sweeps=SWEEPS,
    batch_size=32,
    num_workers=DEFAULTS.get("num_workers", 4),
    device=DEFAULTS.get("device", "cuda"),
    inner_agg="mean",
    outer_agg="mean",
)

results

TTT sweeps:   0%|          | 0/1 [00:00<?, ?it/s]

rdrop_head folds:   0%|          | 0/5 [00:00<?, ?it/s]

[{'name': 'rdrop_head',
  'steps': 4,
  'lr': 0.001,
  'beta': 0.0,
  'batch_size': 1,
  'inner_agg': 'mean',
  'outer_agg': 'mean',
  'fold_base': [0.7874088287353516,
   0.7932122945785522,
   0.7990223169326782,
   0.7873789668083191,
   0.6804995536804199],
  'fold_ttt': [0.7874916195869446,
   0.7926286458969116,
   0.798236072063446,
   0.7873626351356506,
   0.6795005798339844],
  'fold_delta': [8.279085159301758e-05,
   -0.000583648681640625,
   -0.0007862448692321777,
   -1.633167266845703e-05,
   -0.0009989738464355469],
  'mean_base': 0.7695043921470642,
  'mean_ttt': 0.7690439105033875,
  'mean_delta': -0.00046048164367675783}]