# TTT Sweep: Setup & Imports

In [1]:
import os
import sys
import torch
import pandas as pd

CSIRO_CODE_DIR = "/notebooks/CSIRO"
sys.path.insert(0, CSIRO_CODE_DIR)
os.environ["DEFAULT_DINO_REPO_DIR"]="/notebooks/dinov3"
os.environ["DEFAULT_DATA_ROOT"]="/notebooks/kaggle/csiro"
os.environ["DINO_WEIGHTS_PATH"] = "/notebooks/kaggle/csiro/weights/dinov3/dinov3_vitb16_pretrain.pth"
os.environ["DINO_B_WEIGHTS_PATH"] = "/notebooks/kaggle/csiro/weights/dinov3/dinov3_vitb16_pretrain.pth"
os.environ["DINO_L_WEIGHTS_PATH"] = "/notebooks/kaggle/csiro/weights/dinov3/dinov3_vitl16_pretrain.pth"

from csiro.config import DEFAULTS, DEFAULT_DATA_ROOT, DEFAULT_DINO_REPO_DIR, DINO_WEIGHTS_PATH, dino_hub_name
from csiro.data import load_train_wide, BiomassTiledCached, TiledTransformView
from csiro.eval import ttt_sweep_cv
from csiro.transforms import post_tfms

# --- paths / config ---
TRAIN_CSV = "/notebooks/kaggle/csiro/train.csv"  # e.g. "/notebooks/kaggle/csiro/train.csv"
DATA_ROOT = DEFAULT_DATA_ROOT
DINO_REPO = DEFAULT_DINO_REPO_DIR
DINO_WEIGHTS = DINO_WEIGHTS_PATH
PT_PATHS = [
     "/notebooks/kaggle/csiro/output/v7_n_models2_f2c.pt",
]
CV_PARAMS = dict(mode="gkf", cv_seed=126015, n_splits=5)

IMG_SIZE = int(DEFAULTS.get("img_size", 512))
CACHE_IMAGES = True


# TTT Sweep: Load Data & Model

In [2]:
# --- data ---
wide_df = load_train_wide(TRAIN_CSV, root=DATA_ROOT)
base_ds = BiomassTiledCached(wide_df, img_size=IMG_SIZE, cache_images=CACHE_IMAGES)
dataset = TiledTransformView(
    base_ds,
    post_tfms(),
    tile_swap=False,
)

# --- backbone ---
backbone = torch.hub.load(
    DINO_REPO,
    dino_hub_name(model_size=str(DEFAULTS.get("backbone_size", "b")), plus=str(DEFAULTS.get("plus", ""))),
    source="local",
    weights=DINO_WEIGHTS,
)


# TTT Sweep: Define Tasks + Run

In [None]:
import torchvision.transforms.functional as TF
from csiro.config import IMAGENET_MEAN, IMAGENET_STD

class RDropPredMSE(torch.nn.Module):
    # Example task: MSE between two dropout passes on final preds.
    def __init__(self, name: str = "rdrop_pred_mse"):
        super().__init__()
        self.name = name

    def forward(self, model, x, ctx):
        if hasattr(model, "set_train"):
            model.set_train(True)
        model.train()
        p1 = model(x)
        p2 = model(x)
        return ((p1.float() - p2.float()) ** 2).mean()


class JitterRotPredMSE(torch.nn.Module):
    # Color jitter + rot90 invariance on final preds.
    def __init__(self, bcs_val: float = 0.2, hue_val: float = 0.02, rot_k: int = 1, name: str = "jitter_rot_pred_mse"):
        super().__init__()
        self.name = name
        self.bcs_val = float(bcs_val)
        self.hue_val = float(hue_val)
        self.rot_k = int(rot_k)

    def _apply_jitter_batch(self, x: torch.Tensor) -> torch.Tensor:
        if self.bcs_val <= 0.0 and self.hue_val <= 0.0:
            return x
        if self.bcs_val > 0.0:
            b = float(self.bcs_val)
            brightness = float(1.0 + (torch.rand((), device=x.device) * 2.0 - 1.0) * b)
            contrast = float(1.0 + (torch.rand((), device=x.device) * 2.0 - 1.0) * b)
            saturation = float(1.0 + (torch.rand((), device=x.device) * 2.0 - 1.0) * b)
            x = TF.adjust_brightness(x, brightness)
            x = TF.adjust_contrast(x, contrast)
            x = TF.adjust_saturation(x, saturation)
        if self.hue_val > 0.0:
            hue = float((torch.rand((), device=x.device) * 2.0 - 1.0) * self.hue_val)
            x = TF.adjust_hue(x, hue)
        return x

    def forward(self, model, x, ctx):
        if hasattr(model, "set_train"):
            model.set_train(True)
        model.train()
        if x.ndim != 5 or x.size(1) != 2:
            raise ValueError(f"Expected tiled input [B,2,C,H,W], got {tuple(x.shape)}")

        x = x.float()
        mean = torch.tensor(IMAGENET_MEAN, device=x.device).view(1, 1, 3, 1, 1)
        std = torch.tensor(IMAGENET_STD, device=x.device).view(1, 1, 3, 1, 1)
        x_unn = (x * std + mean).clamp(0.0, 1.0)

        b, tiles, c, h, w = x_unn.shape
        x_flat = x_unn.view(b * tiles, c, h, w)
        x1 = self._apply_jitter_batch(x_flat)
        x2 = self._apply_jitter_batch(x_flat)
        x2 = torch.rot90(x2, k=self.rot_k, dims=(-2, -1))

        mean_f = mean.view(1, 3, 1, 1)
        std_f = std.view(1, 3, 1, 1)
        x1 = (x1 - mean_f) / std_f
        x2 = (x2 - mean_f) / std_f

        x1 = x1.view(b, tiles, c, h, w)
        x2 = x2.view(b, tiles, c, h, w)
        p1 = model(x1)
        p2 = model(x2)
        return ((p1.float() - p2.float()) ** 2).mean()


TASKS = [
    #RDropPredMSE(),
    JitterRotPredMSE(),
]

# Param specs: strings can be module names ("head", "neck", "norm") or exact parameter names.
PARAM_SPECS = [
    ["head"],
]

SWEEPS = [
    dict(
        name="rdrop_head",
        task=TASKS[0],
        params=PARAM_SPECS[0],
        steps=4,
        lr=1e-4,
        beta=0.0,
        batch_size=1,
    ),
]

results = ttt_sweep_cv(
    dataset=dataset,
    wide_df=wide_df,
    backbone=backbone,
    pt_paths=PT_PATHS,
    cv_params=CV_PARAMS,
    sweeps=SWEEPS,
    batch_size=32,
    num_workers=DEFAULTS.get("num_workers", 4),
    device=DEFAULTS.get("device", "cuda"),
    inner_agg="mean",
    outer_agg="mean",
)

results


TTT sweeps:   0%|          | 0/1 [00:00<?, ?it/s]

rdrop_head folds:   0%|          | 0/5 [00:00<?, ?it/s]

[{'name': 'rdrop_head',
  'steps': 4,
  'lr': 0.001,
  'beta': 0.0,
  'batch_size': 1,
  'inner_agg': 'mean',
  'outer_agg': 'mean',
  'fold_base': [0.7638000249862671,
   0.791303277015686,
   0.8119305372238159,
   0.7919535636901855,
   0.6741432547569275],
  'fold_ttt': [0.765116810798645,
   0.7909208536148071,
   0.8113319277763367,
   0.7917848229408264,
   0.6740978956222534],
  'fold_delta': [0.0013167858123779297,
   -0.00038242340087890625,
   -0.000598609447479248,
   -0.00016874074935913086,
   -4.5359134674072266e-05],
  'mean_base': 0.7666261315345764,
  'mean_ttt': 0.7666504621505738,
  'mean_delta': 2.4330615997314452e-05}]