In [1]:
# -------------------------
# 0) CONFIG (edit these)
# -------------------------
import os, sys


# Project code (must contain the csiro package)
CSIRO_CODE_DIR = "/notebooks/CSIRO"

# DINOv3 repo/code dir (must contain hubconf.py)
DINO_REPO = "/notebooks/dinov3"

# DINOv3 pretrained backbone weights (.pth)
DINO_WEIGHTS = "/notebooks/dinov3/weights/dinov3_vitb16_pretrain.pth"

# Ensemble checkpoint produced by CV (dict with key 'states')
WEIGHTS_PATH = "/notebooks/kaggle/csiro/weights/cv5_v1_f0a.pt"

# Competition data
COMP_ROOT = "/notebooks/kaggle/csiro/"
TEST_CSV = f"{COMP_ROOT}/test.csv"
# IMPORTANT: test.csv image_path values look like "test/IDxxxx.jpg", so IMAGE_ROOT should be COMP_ROOT
IMAGE_ROOT = COMP_ROOT

# Inference params
IMG_SIZE = 512
BATCH_SIZE = 64
NUM_WORKERS = 2
DEVICE = "cuda"  # or "cpu"

# TTA / ensemble knobs
TTA_ROT90 = True
TTA_AGG = "mean"
ENS_AGG = "mean"

OUTPUT_PATH = "/notebooks/kaggle/csiro/sub/submission.csv"

# --- Env vars expected by csiro.config (no defaults) ---
os.environ["TORCH_HOME"] = "notebooks/kaggle/working/torch_home"
os.environ["DINO_WB"] = "https://dinov3.llamameta.net/dinov3_vitb16/dinov3_vitb16_pretrain_lvd1689m-73cec8be.pth?Policy=eyJTdGF0ZW1lbnQiOlt7InVuaXF1ZV9oYXNoIjoidW84aXJvdGQyeThwcGpuNXFveGthZTE4IiwiUmVzb3VyY2UiOiJodHRwczpcL1wvZGlub3YzLmxsYW1hbWV0YS5uZXRcLyoiLCJDb25kaXRpb24iOnsiRGF0ZUxlc3NUaGFuIjp7IkFXUzpFcG9jaFRpbWUiOjE3NjU5NzI4MTd9fX1dfQ__&Signature=H5H5kLVc6V83i-s2euNHx6t9KlVeG27QKX6qtkXNiLwEzuCshJD4RfwUbQv8oBJOZXPezAVJZPRkYRdsb4jh-LQ72DZtEuNkjNKHf7Pn57wzee0bjEYjWdJmOqK4waaSe9TQqELM%7EPgzdAT4LCSHYcFQ%7EleRnHGGGJiHBmTd6e1xZYhvUCfkvVD1TG-zM7R0-P%7EMLetHMvWl%7EUapCMYthsWqZctsYAQKUQxsLrly8Y4EaM8hm5nowpArPZC4myNO1iiXld5Hc3t9CVLEdYT7LIct0x6cf3-B-6WOgxGb7LdLPCcZPPfoGgX3KGtTAgNQYOpGFs-hgILFHRKVOJ7T3A__&Key-Pair-Id=K15QRJLYKIFSLZ&Download-Request-ID=1893388161261111"
os.environ["DINO_WL"] = "https://dinov3.llamameta.net/dinov3_vitl16/dinov3_vitl16_pretrain_lvd1689m-8aa4cbdd.pth?Policy=eyJTdGF0ZW1lbnQiOlt7InVuaXF1ZV9oYXNoIjoidW84aXJvdGQyeThwcGpuNXFveGthZTE4IiwiUmVzb3VyY2UiOiJodHRwczpcL1wvZGlub3YzLmxsYW1hbWV0YS5uZXRcLyoiLCJDb25kaXRpb24iOnsiRGF0ZUxlc3NUaGFuIjp7IkFXUzpFcG9jaFRpbWUiOjE3NjU5NzI4MTd9fX1dfQ__&Signature=H5H5kLVc6V83i-s2euNHx6t9KlVeG27QKX6qtkXNiLwEzuCshJD4RfwUbQv8oBJOZXPezAVJZPRkYRdsb4jh-LQ72DZtEuNkjNKHf7Pn57wzee0bjEYjWdJmOqK4waaSe9TQqELM%7EPgzdAT4LCSHYcFQ%7EleRnHGGGJiHBmTd6e1xZYhvUCfkvVD1TG-zM7R0-P%7EMLetHMvWl%7EUapCMYthsWqZctsYAQKUQxsLrly8Y4EaM8hm5nowpArPZC4myNO1iiXld5Hc3t9CVLEdYT7LIct0x6cf3-B-6WOgxGb7LdLPCcZPPfoGgX3KGtTAgNQYOpGFs-hgILFHRKVOJ7T3A__&Key-Pair-Id=K15QRJLYKIFSLZ&Download-Request-ID=1893388161261111"
os.environ["DINO_WL_plus"] = "https://dinov3.llamameta.net/dinov3_vith16plus/dinov3_vith16plus_pretrain_lvd1689m-7c1da9a5.pth?Policy=eyJTdGF0ZW1lbnQiOlt7InVuaXF1ZV9oYXNoIjoidW84aXJvdGQyeThwcGpuNXFveGthZTE4IiwiUmVzb3VyY2UiOiJodHRwczpcL1wvZGlub3YzLmxsYW1hbWV0YS5uZXRcLyoiLCJDb25kaXRpb24iOnsiRGF0ZUxlc3NUaGFuIjp7IkFXUzpFcG9jaFRpbWUiOjE3NjU5NzI4MTd9fX1dfQ__&Signature=H5H5kLVc6V83i-s2euNHx6t9KlVeG27QKX6qtkXNiLwEzuCshJD4RfwUbQv8oBJOZXPezAVJZPRkYRdsb4jh-LQ72DZtEuNkjNKHf7Pn57wzee0bjEYjWdJmOqK4waaSe9TQqELM%7EPgzdAT4LCSHYcFQ%7EleRnHGGGJiHBmTd6e1xZYhvUCfkvVD1TG-zM7R0-P%7EMLetHMvWl%7EUapCMYthsWqZctsYAQKUQxsLrly8Y4EaM8hm5nowpArPZC4myNO1iiXld5Hc3t9CVLEdYT7LIct0x6cf3-B-6WOgxGb7LdLPCcZPPfoGgX3KGtTAgNQYOpGFs-hgILFHRKVOJ7T3A__&Key-Pair-Id=K15QRJLYKIFSLZ&Download-Request-ID=1893388161261111"
os.environ["DEFAULT_DINO_REPO_DIR"] = DINO_REPO
os.environ["DEFAULT_DATA_ROOT"] = COMP_ROOT

sys.path.insert(0, CSIRO_CODE_DIR)


In [2]:
# -------------------------
# 1) Imports
# -------------------------
import pandas as pd
import numpy as np
import torch
import torchvision.transforms as T
from torch.utils.data import Dataset
from PIL import Image

import csiro
from csiro.config import TARGETS, dino_hub_name, DEFAULT_MODEL_SIZE, DEFAULT_PLUS
from csiro.transforms import PadToSquare, post_tfms
from csiro.train import predict_ensemble

print("TARGETS:", TARGETS)

TARGETS: ['Dry_Green_g', 'Dry_Clover_g', 'Dry_Dead_g', 'GDM_g', 'Dry_Total_g']


In [3]:
from pathlib import Path
# Create checkpoints dir
ckpt_dir = Path(torch.hub.get_dir()) / "checkpoints"
ckpt_dir.mkdir(parents=True, exist_ok=True)

# Symlink the weight file from /kaggle/input (read-only) into the hub cache
src = Path(DINO_WEIGHTS)                  # /kaggle/input/.../dinov3_vitb16_pretrain.pth
dst = ckpt_dir / src.name                 # /kaggle/working/torch_hub/checkpoints/dinov3_vitb16_pretrain.pth

if not dst.exists():
    dst.symlink_to(src)

In [4]:
# -------------------------
# 2) Load checkpoint + backbone
# -------------------------
ckpt = torch.load(WEIGHTS_PATH, map_location="cpu", weights_only=False)
if isinstance(ckpt, dict) and "states" in ckpt:
    states = ckpt["states"]
else:
    states = ckpt

sys.path.insert(0, DINO_REPO)
backbone = torch.hub.load(
    DINO_REPO,
    dino_hub_name(model_size=str(DEFAULT_MODEL_SIZE), plus=str(DEFAULT_PLUS)),
    source="local",
    weights=DINO_WEIGHTS,
)

print("Loaded states", type(states))


Loaded states <class 'list'>


In [5]:
# -------------------------
# 3) Read test.csv (long format)
# -------------------------
df = pd.read_csv(TEST_CSV)
print("test.csv columns:", list(df.columns))
print(df.head(3))

IMAGE_PATH_COL = "image_path"
TARGET_NAME_COL = "target_name"
SAMPLE_ID_COL = "sample_id"

if IMAGE_PATH_COL not in df.columns:
    raise ValueError(f"Expected column {IMAGE_PATH_COL} in test.csv")
if TARGET_NAME_COL not in df.columns:
    raise ValueError(f"Expected column {TARGET_NAME_COL} in test.csv")

df_img = df.drop_duplicates(subset=[IMAGE_PATH_COL]).reset_index(drop=True)
print("rows (long):", len(df), "unique images:", len(df_img))


test.csv columns: ['sample_id', 'image_path', 'target_name']
                    sample_id             image_path   target_name
0  ID1001187975__Dry_Clover_g  test/ID1001187975.jpg  Dry_Clover_g
1    ID1001187975__Dry_Dead_g  test/ID1001187975.jpg    Dry_Dead_g
2   ID1001187975__Dry_Green_g  test/ID1001187975.jpg   Dry_Green_g
rows (long): 5 unique images: 1


In [6]:
# -------------------------
# 4) Dataset + inference
# -------------------------
class TestDataset(Dataset):
    def __init__(self, df, root, img_col, img_size):
        self.df = df.reset_index(drop=True)
        self.root = root
        self.img_col = img_col
        self.pre = T.Compose([
            T.Lambda(lambda im: im.convert("RGB")),
            PadToSquare(fill=0),
            T.Resize((int(img_size), int(img_size)), antialias=True),
            post_tfms(),
        ])

    def __len__(self):
        return len(self.df)

    def __getitem__(self, i):
        rel = self.df.loc[i, self.img_col]
        p = os.path.join(self.root, rel) if self.root else rel
        with Image.open(p) as im:
            x = self.pre(im)
        return x

ds = TestDataset(df_img, IMAGE_ROOT, IMAGE_PATH_COL, IMG_SIZE)
preds = predict_ensemble(
    ds,
    states=states,
    backbone=backbone,
    batch_size=BATCH_SIZE,
    num_workers=NUM_WORKERS,
    device=DEVICE,
    tta_rot90=TTA_ROT90,
    tta_agg=TTA_AGG,
    ens_agg=ENS_AGG,
)

preds = preds.detach().cpu()
print("preds shape:", tuple(preds.shape))


preds shape: (1, 5)


In [7]:
# -------------------------
# 5) Build submission (long format)
# -------------------------
target_to_idx = {t: i for i, t in enumerate(TARGETS)}
preds_np = preds.numpy()

pred_by_path = {
    df_img.loc[i, IMAGE_PATH_COL]: preds_np[i]
    for i in range(len(df_img))
}

if SAMPLE_ID_COL in df.columns:
    sample_ids = df[SAMPLE_ID_COL].astype(str)
else:
    image_ids = df[IMAGE_PATH_COL].apply(lambda p: os.path.splitext(os.path.basename(p))[0])
    sample_ids = image_ids + "__" + df[TARGET_NAME_COL].astype(str)

targets = []
for i, row in df.iterrows():
    p = row[IMAGE_PATH_COL]
    t_name = row[TARGET_NAME_COL]
    if t_name not in target_to_idx:
        raise ValueError(f"Unknown target_name: {t_name}")
    vec = pred_by_path[p]
    targets.append(float(vec[target_to_idx[t_name]]))

sub = pd.DataFrame({
    "sample_id": sample_ids.values,
    "target": np.asarray(targets, dtype=np.float32),
})

assert len(sub) == len(df)
assert list(sub.columns) == ["sample_id", "target"]
assert np.isfinite(sub["target"].values).all()

print("Wrote", OUTPUT_PATH)
print(sub.head(10))


Wrote /notebooks/kaggle/csiro/sub/submission.csv
                    sample_id     target
0  ID1001187975__Dry_Clover_g   0.422129
1    ID1001187975__Dry_Dead_g  26.266594
2   ID1001187975__Dry_Green_g  28.523952
3   ID1001187975__Dry_Total_g  59.529186
4         ID1001187975__GDM_g  27.927700
