# CSIRO Kaggle Submission Notebook

This notebook assumes offline submission (no internet). Edit the flags below to match your Kaggle inputs.


In [None]:
import os
import sys
import glob

INPUT_ROOT = "/kaggle/input"

def find_path(patterns, must_exist=True):
    """Return first match under /kaggle/input for any glob pattern in patterns."""
    for pat in patterns:
        hits = glob.glob(os.path.join(INPUT_ROOT, pat))
        if hits:
            return hits[0]
    if must_exist:
        raise FileNotFoundError(f"No matches for {patterns} under {INPUT_ROOT}")
    return None

DETECTED_CSIRO_INPUT = find_path([
    "*csiro*code*", "*csiro*repo*", "*biomass*code*", "*vikstrand*ai*", "*csiro*"
], must_exist=False)

DETECTED_DINO_REPO = find_path([
    "*dinov3*", "*dino*v3*", "*dino*repo*"
], must_exist=False)

print("DETECTED_CSIRO_INPUT:", DETECTED_CSIRO_INPUT)
print("DETECTED_DINO_REPO:", DETECTED_DINO_REPO)
print("INPUT DIRS:", os.listdir(INPUT_ROOT)[:30])


In [None]:
# Global flags (edit these)
CSIRO_INPUT = DETECTED_CSIRO_INPUT
DINO_REPO = DETECTED_DINO_REPO or "/kaggle/input/dinov3"
TEST_CSV = "/kaggle/input/<dataset>/test.csv"
IMAGE_ROOT = "/kaggle/input/<dataset>/test_images"
IMAGE_PATH_COL = "image_path"
IMAGE_ID_COL = "image_id"
WEIGHTS_PATH = "/kaggle/input/<weights-dataset>/ensemble_states.pt"
DINO_WEIGHTS = "/kaggle/input/dinov3/weights/dinov3_vitb16_pretrain.pth"
MODEL_SIZE = "b"
PLUS = ""
IMG_SIZE = 512
BATCH_SIZE = 64
NUM_WORKERS = 2
DEVICE = "cuda"
TTA_ROT90 = True
TTA_AGG = "mean"
ENS_AGG = "mean"
OUTPUT_PATH = "/kaggle/working/submission.csv"

if IMAGE_ROOT == "":
    IMAGE_ROOT = None

DETECTED_COMP = find_path([
    "*csiro*image2biomass*", "*image2biomass*", "*biomass*"
], must_exist=False)
if DETECTED_COMP:
    cand_test_csv = os.path.join(DETECTED_COMP, "test.csv")
    cand_test_img = os.path.join(DETECTED_COMP, "test_images")
    if os.path.exists(cand_test_csv):
        TEST_CSV = cand_test_csv
    if os.path.exists(cand_test_img):
        IMAGE_ROOT = cand_test_img

DETECTED_W = find_path([
    "*weights*", "*ensemble*", "*ckpt*", "*checkpoints*"
], must_exist=False)
if DETECTED_W:
    hits = glob.glob(os.path.join(DETECTED_W, "**", "ensemble_states.pt"), recursive=True)
    if hits:
        WEIGHTS_PATH = hits[0]

if DINO_REPO is not None and (DINO_WEIGHTS is None or not os.path.exists(DINO_WEIGHTS)):
    hits = glob.glob(os.path.join(DINO_REPO, "**", "*.pth"), recursive=True)
    if len(hits) == 1:
        DINO_WEIGHTS = hits[0]
    elif len(hits) > 1:
        print("Found DINO weights candidates:", hits[:10])


In [None]:
import os

if DINO_REPO is None:
    raise RuntimeError("DINO_REPO not found; attach the dinov3 dataset input.")
if not os.path.exists(DINO_REPO):
    raise FileNotFoundError(DINO_REPO)
if not os.path.exists(TEST_CSV):
    raise FileNotFoundError(TEST_CSV)
if not os.path.exists(WEIGHTS_PATH):
    raise FileNotFoundError(WEIGHTS_PATH)
if DINO_WEIGHTS is None or not os.path.exists(DINO_WEIGHTS):
    raise FileNotFoundError(str(DINO_WEIGHTS))
if IMAGE_ROOT is not None and not os.path.exists(IMAGE_ROOT):
    raise FileNotFoundError(IMAGE_ROOT)


In [None]:
import os
import sys
import torch
import pandas as pd
import torchvision.transforms as T
from PIL import Image

if CSIRO_INPUT is not None:
    sys.path.insert(0, CSIRO_INPUT)
if DINO_REPO is not None:
    sys.path.insert(0, DINO_REPO)

try:
    import csiro
    print("csiro import OK from:", csiro.__file__)
except Exception as e:
    raise RuntimeError(
        "Could not import csiro package. Attach the dataset containing the csiro/ folder."
    ) from e

from csiro.config import TARGETS, dino_hub_name
from csiro.transforms import PadToSquare, post_tfms
from csiro.train import predict_ensemble

if DEVICE.startswith("cuda") and not torch.cuda.is_available():
    DEVICE = "cpu"


In [None]:
ckpt = torch.load(WEIGHTS_PATH, map_location="cpu")
states = ckpt["states"] if isinstance(ckpt, dict) and "states" in ckpt else ckpt

backbone = torch.hub.load(
    DINO_REPO,
    dino_hub_name(model_size=str(MODEL_SIZE), plus=str(PLUS)),
    source="local",
    weights=DINO_WEIGHTS,
)


In [None]:
class TestDataset(torch.utils.data.Dataset):
    def __init__(self, df, root, img_col, img_size):
        self.df = df.reset_index(drop=True)
        self.root = root
        self.img_col = img_col
        self.pre = T.Compose([
            T.Lambda(lambda im: im.convert("RGB")),
            PadToSquare(fill=0),
            T.Resize((int(img_size), int(img_size)), antialias=True),
        ])
        self.post = post_tfms()

    def __len__(self):
        return len(self.df)

    def __getitem__(self, i):
        p = self.df.loc[i, self.img_col]
        if self.root and not os.path.isabs(p):
            p = os.path.join(self.root, p)
        with Image.open(p) as im:
            im = self.pre(im).copy()
        x = self.post(im)
        return x

df = pd.read_csv(TEST_CSV)
if IMAGE_PATH_COL not in df.columns:
    raise KeyError(f"Missing column: {IMAGE_PATH_COL}")
if IMAGE_ID_COL not in df.columns:
    df[IMAGE_ID_COL] = df[IMAGE_PATH_COL].apply(
        lambda p: os.path.splitext(os.path.basename(p))[0]
    )

print("test_df columns:", list(df.columns))
print(df.head(3))

dataset = TestDataset(df, IMAGE_ROOT, IMAGE_PATH_COL, IMG_SIZE)
preds = predict_ensemble(
    dataset,
    states,
    backbone,
    batch_size=BATCH_SIZE,
    num_workers=NUM_WORKERS,
    device=DEVICE,
    tta_rot90=TTA_ROT90,
    tta_agg=TTA_AGG,
    ens_agg=ENS_AGG,
)
print("preds grams min/max:", float(preds.min()), float(preds.max()))


In [None]:
preds_np = preds.detach().float().cpu().numpy()
preds_np = preds_np.clip(min=0)
assert preds_np.shape[0] == len(df), (preds_np.shape, len(df))
assert preds_np.shape[1] == len(TARGETS), (preds_np.shape, len(TARGETS))

SAMPLE_SUB = os.path.join(os.path.dirname(TEST_CSV), "sample_submission.csv")
if os.path.exists(SAMPLE_SUB):
    sub = pd.read_csv(SAMPLE_SUB)
    pred_map = {}
    ids = df[IMAGE_ID_COL].astype(str).tolist()
    for i, img_id in enumerate(ids):
        for t_idx, t_name in enumerate(TARGETS):
            pred_map[f"{img_id}__{t_name}"] = float(preds_np[i, t_idx])
    sub["target"] = sub["sample_id"].map(pred_map).astype(float)
    if sub["target"].isna().any():
        raise ValueError("sample_submission mapping produced NaNs")
else:
    ids = df[IMAGE_ID_COL].astype(str).tolist()
    rows = []
    for i, img_id in enumerate(ids):
        for t_idx, t_name in enumerate(TARGETS):
            rows.append((f"{img_id}__{t_name}", float(preds_np[i, t_idx])))
    sub = pd.DataFrame(rows, columns=["sample_id", "target"])

sub.to_csv(OUTPUT_PATH, index=False)
print(sub.head())
print("rows:", len(sub))

import numpy as np
sub = pd.read_csv(OUTPUT_PATH)
assert list(sub.columns) == ["sample_id", "target"]
assert len(sub) == 5 * len(df)
assert np.isfinite(sub["target"].values).all()
assert not sub["target"].isna().any()
print(sub.head(10))
