# Kaggle CSIRO Biomass Submission

Self-contained inference notebook that mirrors the repository's two-stream DINO logic.

In [None]:
# ================================
# USER-EDITABLE CONFIG
# (Edit these paths only)
# ================================
TEST_CSV_PATH = "/kaggle/input/csiro-biomass/test.csv"
TEST_IMAGE_DIR = "/kaggle/input/csiro-biomass/test"

# List of checkpoint paths (Kaggle dataset paths)
CHECKPOINTS = [
    # e.g. "/kaggle/input/my-csiro-weights/fold1.pth",
    #      "/kaggle/input/my-csiro-weights/fold2.pth",
    #      ...
]

# (Optional) if I use multiple ensembles:
ENSEMBLE_GROUPS = [
    {
        "name": "ens1",
        "weights": [1.0],  # list aligned with CHECKPOINTS subset
        "ckpts": [
            # "/kaggle/input/my-csiro-weights/fold1.pth",
            # "/kaggle/input/my-csiro-weights/fold2.pth",
        ],
    },
    # I can add more groups later if needed
]

SUBMISSION_FILE = "submission.csv"
BATCH_SIZE = 1
NUM_WORKERS = 0

# Grid used during training for tiled models (e.g. (2, 2))
GRID = (2, 2)


In [None]:
# Optional: install packages if missing on local runs (Kaggle images already include these)
# !pip install -q timm albumentations

import os, gc, math, types
import cv2
import numpy as np
import pandas as pd
from pathlib import Path
from typing import List, Tuple, Dict, Any
from tqdm import tqdm

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

import albumentations as A
from albumentations.pytorch import ToTensorV2
import timm

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Device:", DEVICE)


In [None]:
# Dataset utilities --------------------------------------------------
class TestBiomassDataset(Dataset):
    def __init__(self, df: pd.DataFrame, image_dir: str, transform: A.BasicTransform, input_res: int) -> None:
        self.df = df.reset_index(drop=True)
        self.image_dir = image_dir
        self.transform = transform
        self.input_res = input_res

    def __len__(self) -> int:
        return len(self.df)

    def __getitem__(self, idx: int):
        row = self.df.iloc[idx]
        img_path = Path(self.image_dir) / row["image_path"]
        image = cv2.imread(str(img_path))
        if image is None:
            raise FileNotFoundError(f"Image not found: {img_path}")
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

        # Split into left/right halves
        h, w, _ = image.shape
        mid = w // 2
        left = image[:, :mid]
        right = image[:, mid:]

        left = self.transform(image=left)["image"]
        right = self.transform(image=right)["image"]
        return left, right


In [None]:
# Model & backbone builder -------------------------------------------
def _infer_input_res(backbone_name: str) -> int:
    try:
        cfg = timm.create_model(backbone_name, pretrained=False).default_cfg
        size = cfg.get("input_size", None)
        if size is not None and len(size) == 3:
            return int(size[1])
    except Exception:
        pass
    return 224


def _build_dino_by_name(backbone_name: str, input_res: int):
    model = timm.create_model(
        backbone_name,
        pretrained=False,
        num_classes=0,
        global_pool="avg",
    )
    if hasattr(model, "patch_embed") and hasattr(model.patch_embed, "img_size"):
        try:
            model.patch_embed.img_size = (input_res, input_res)
        except Exception:
            pass
    return model


class FiLM(nn.Module):
    def __init__(self, in_dim: int, hidden: int = 512) -> None:
        super().__init__()
        self.gamma = nn.Sequential(
            nn.Linear(in_dim, hidden),
            nn.GELU(),
            nn.Linear(hidden, in_dim),
        )
        self.beta = nn.Sequential(
            nn.Linear(in_dim, hidden),
            nn.GELU(),
            nn.Linear(hidden, in_dim),
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return x * (1 + self.gamma(x)) + self.beta(x)


class TwoStreamDINOBase(nn.Module):
    def __init__(self, backbone_name: str, input_res: int, grid: Tuple[int, int] = (1, 1)) -> None:
        super().__init__()
        self.backbone_name = backbone_name
        self.backbone = _build_dino_by_name(backbone_name, input_res)
        self.grid = grid
        self.feature_dim = getattr(self.backbone, "num_features", 1024)
        hidden = self.feature_dim * 2
        self.head_total = nn.Sequential(nn.Linear(hidden, hidden), nn.GELU(), nn.Linear(hidden, 1))
        self.head_gdm = nn.Sequential(nn.Linear(hidden, hidden), nn.GELU(), nn.Linear(hidden, 1))
        self.head_green = nn.Sequential(nn.Linear(hidden, hidden), nn.GELU(), nn.Linear(hidden, 1))
        self.softplus = nn.Softplus()

    def _encode_backbone(self, x: torch.Tensor) -> torch.Tensor:
        feats = self.backbone(x)
        if isinstance(feats, (list, tuple)):
            feats = feats[-1]
        return feats

    def _tile_image(self, x: torch.Tensor) -> torch.Tensor:
        # x: (B, C, H, W)
        rows, cols = self.grid
        b, c, h, w = x.shape
        h_step, w_step = h // rows, w // cols
        tiles = []
        for r in range(rows):
            for cidx in range(cols):
                tiles.append(x[:, :, r * h_step : (r + 1) * h_step, cidx * w_step : (cidx + 1) * w_step])
        return torch.stack(tiles, dim=1)  # (B, T, C, h_step, w_step)

    def forward(self, x_left: torch.Tensor, x_right: torch.Tensor):
        raise NotImplementedError


class TwoStreamDINOPlain(TwoStreamDINOBase):
    def forward(self, x_left: torch.Tensor, x_right: torch.Tensor):
        f_left = self._encode_backbone(x_left)
        f_right = self._encode_backbone(x_right)
        fused = torch.cat([f_left, f_right], dim=1)
        total = self.softplus(self.head_total(fused))
        gdm = self.softplus(self.head_gdm(fused))
        green = self.softplus(self.head_green(fused))
        return total, gdm, green


class TwoStreamDINOTiled(TwoStreamDINOBase):
    def forward(self, x_left: torch.Tensor, x_right: torch.Tensor):
        left_tiles = self._tile_image(x_left)
        right_tiles = self._tile_image(x_right)
        b, t, c, h, w = left_tiles.shape
        left_tiles = left_tiles.view(b * t, c, h, w)
        right_tiles = right_tiles.view(b * t, c, h, w)
        f_left = self._encode_backbone(left_tiles).view(b, t, -1).mean(dim=1)
        f_right = self._encode_backbone(right_tiles).view(b, t, -1).mean(dim=1)
        fused = torch.cat([f_left, f_right], dim=1)
        total = self.softplus(self.head_total(fused))
        gdm = self.softplus(self.head_gdm(fused))
        green = self.softplus(self.head_green(fused))
        return total, gdm, green


class TwoStreamDINOTiledFiLM(TwoStreamDINOBase):
    def __init__(self, backbone_name: str, input_res: int, grid: Tuple[int, int]) -> None:
        super().__init__(backbone_name, input_res, grid)
        self.film_left = FiLM(self.feature_dim)
        self.film_right = FiLM(self.feature_dim)

    def forward(self, x_left: torch.Tensor, x_right: torch.Tensor):
        left_tiles = self._tile_image(x_left)
        right_tiles = self._tile_image(x_right)
        b, t, c, h, w = left_tiles.shape
        left_tiles = left_tiles.view(b * t, c, h, w)
        right_tiles = right_tiles.view(b * t, c, h, w)
        f_left = self._encode_backbone(left_tiles).view(b, t, -1)
        f_right = self._encode_backbone(right_tiles).view(b, t, -1)
        f_left = self.film_left(f_left)
        f_right = self.film_right(f_right)
        f_left_mean = f_left.mean(dim=1)
        f_right_mean = f_right.mean(dim=1)
        fused = torch.cat([f_left_mean, f_right_mean], dim=1)
        total = self.softplus(self.head_total(fused))
        gdm = self.softplus(self.head_gdm(fused))
        green = self.softplus(self.head_green(fused))
        return total, gdm, green


In [None]:
# Weight loading & auto-detection ------------------------------------

def _clean_state_dict(state: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
    if "state_dict" in state:
        state = state["state_dict"]
    if "model" in state and isinstance(state["model"], dict):
        state = state["model"]

    new_state = {}
    for k, v in state.items():
        if k.startswith("module."):
            k = k[len("module."):]
        if k.startswith("student."):
            k = k[len("student."):]
        if any(skip in k for skip in ["txt_enc", "text_encoder", "text_model", "img_proj", "txt_proj", "clip", "language"]):
            continue
        new_state[k] = v
    return new_state


def _detect_variant(clean_keys: List[str]) -> str:
    if any("film_left" in k or "film_right" in k for k in clean_keys):
        return "tiled_film"
    if any("_tile" in k or "grid" in k for k in clean_keys):
        return "tiled"
    return "plain"


def _attempt_load(model: nn.Module, state: Dict[str, torch.Tensor]) -> bool:
    missing, unexpected = model.load_state_dict(state, strict=False)
    return len(missing) < len(state)


def load_fold_model_auto(path: str, grid: Tuple[int, int] = GRID):
    candidates = [
        "vit_large_patch14_dinov2.lvd142m",
        "vit_base_patch14_dinov2.lvd142m",
        "vit_small_patch14_dinov2.lvd142m",
        "convnextv2_base.fcmae_ft_in22k_in1k",
    ]

    try:
        ckpt = torch.load(path, map_location=DEVICE, weights_only=True)
    except TypeError:
        ckpt = torch.load(path, map_location=DEVICE)
    clean_state = _clean_state_dict(ckpt)
    keys = list(clean_state.keys())
    variant = _detect_variant(keys)

    for backbone_name in candidates:
        input_res = _infer_input_res(backbone_name)
        if variant == "plain":
            model = TwoStreamDINOPlain(backbone_name, input_res)
        elif variant == "tiled_film":
            model = TwoStreamDINOTiledFiLM(backbone_name, input_res, grid)
        else:
            model = TwoStreamDINOTiled(backbone_name, input_res, grid)

        model.to(DEVICE)
        ok = _attempt_load(model, clean_state)
        if ok:
            model.eval()
            return model, variant, backbone_name, input_res
    raise RuntimeError(f"Failed to load checkpoint {path}")


In [None]:
# TTA transforms ------------------------------------------------------

def get_tta_transforms(img_size: int):
    base = [
        A.Resize(img_size, img_size),
        A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
        ToTensorV2(),
    ]
    tta = []
    tta.append(A.Compose(base))
    tta.append(A.Compose([A.HorizontalFlip(p=1.0)] + base))
    tta.append(A.Compose([A.VerticalFlip(p=1.0)] + base))
    return tta


In [None]:
# Inference helpers ---------------------------------------------------
@torch.no_grad()
def predict_one_view(models: List[nn.Module], loader: DataLoader):
    all_preds = []
    for batch in tqdm(loader, desc="Predict", leave=False):
        left, right = batch
        left = left.to(DEVICE)
        right = right.to(DEVICE)
        model_preds = []
        for model in models:
            total, gdm, green = model(left, right)
            dead = total - gdm
            clover = gdm - green
            stack = torch.cat([green, dead, clover, gdm, total], dim=1)
            model_preds.append(stack)
        stacked = torch.stack(model_preds, dim=0).mean(dim=0)
        stacked = torch.clamp(stacked, min=0.0)
        all_preds.append(stacked.cpu())
    return torch.cat(all_preds, dim=0).numpy()


def run_ensemble_prediction(ckpt_paths: List[str], weights: List[float], grid: Tuple[int, int], test_df: pd.DataFrame):
    models = []
    input_res = None
    for path in ckpt_paths:
        model, variant, backbone, input_res = load_fold_model_auto(path, grid)
        print(f"Loaded {path} as {variant} with {backbone} @ {input_res}")
        models.append(model)

    assert input_res is not None
    transforms = get_tta_transforms(input_res)
    tta_preds = []
    for transform in transforms:
        dataset = TestBiomassDataset(test_df, TEST_IMAGE_DIR, transform, input_res)
        loader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS, pin_memory=True)
        preds = predict_one_view(models, loader)
        tta_preds.append(preds)
    tta_stack = np.stack(tta_preds, axis=0).mean(axis=0)

    weights = np.array(weights, dtype=np.float32)
    weights = weights / weights.sum()
    weighted_models = tta_stack if len(models) == 1 else None
    if weighted_models is None:
        # Re-run per-model to apply weights
        per_model_preds = []
        for model in models:
            per_tta = []
            for transform in transforms:
                dataset = TestBiomassDataset(test_df, TEST_IMAGE_DIR, transform, input_res)
                loader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS, pin_memory=True)
                preds = predict_one_view([model], loader)
                per_tta.append(preds)
            per_model_preds.append(np.stack(per_tta).mean(axis=0))
        weighted_models = np.average(np.stack(per_model_preds, axis=0), axis=0, weights=weights)
    return weighted_models


In [None]:
# Submission creation -------------------------------------------------

def create_submission(final_5: np.ndarray, test_long: pd.DataFrame, test_unique: pd.DataFrame, submission_path: str):
    green = final_5[:, 0]
    dead = final_5[:, 1]
    clover = final_5[:, 2]
    gdm = final_5[:, 3]
    total = final_5[:, 4]

    def nnz(x):
        x = np.nan_to_num(x, nan=0.0, posinf=0.0, neginf=0.0)
        return np.maximum(0.0, x)

    green, dead, clover, gdm, total = map(nnz, [green, dead, clover, gdm, total])

    wide = pd.DataFrame({
        "image_path": test_unique["image_path"],
        "Dry_Green_g": green,
        "Dry_Dead_g": dead,
        "Dry_Clover_g": clover,
        "GDM_g": gdm,
        "Dry_Total_g": total,
    })

    long_preds = wide.melt(
        id_vars=["image_path"],
        value_vars=["Dry_Green_g", "Dry_Dead_g", "Dry_Clover_g", "GDM_g", "Dry_Total_g"],
        var_name="target_name",
        value_name="target",
    )

    sub = pd.merge(
        test_long[["sample_id", "image_path", "target_name"]],
        long_preds,
        on=["image_path", "target_name"],
        how="left",
    )[["sample_id", "target"]]

    sub["target"] = np.nan_to_num(sub["target"], nan=0.0, posinf=0.0, neginf=0.0)
    sub.to_csv(submission_path, index=False)
    print(f"Saved submission to: {submission_path}")
    display(sub.head())
    return sub


In [None]:
# Main execution -----------------------------------------------------
if __name__ == "__main__":
    test_long = pd.read_csv(TEST_CSV_PATH)
    test_unique = test_long.drop_duplicates(subset=["image_path"]).reset_index(drop=True)

    ensembles = ENSEMBLE_GROUPS if ENSEMBLE_GROUPS else [
        {"name": "default", "weights": [1.0] * len(CHECKPOINTS), "ckpts": CHECKPOINTS}
    ]

    final_preds = None
    total_weight = 0.0
    for group in ensembles:
        ckpts = group.get("ckpts", [])
        if not ckpts:
            ckpts = CHECKPOINTS
        if not ckpts:
            raise ValueError("No checkpoints provided. Please update CHECKPOINTS or ENSEMBLE_GROUPS.")
        weights = group.get("weights", [1.0] * len(ckpts))
        group_pred = run_ensemble_prediction(ckpts, weights, GRID, test_unique)
        w = group.get("weight", 1.0)
        if final_preds is None:
            final_preds = group_pred * w
        else:
            final_preds += group_pred * w
        total_weight += w

    final_preds = final_preds / max(total_weight, 1e-6)
    _ = create_submission(final_preds, test_long, test_unique, SUBMISSION_FILE)

    gc.collect()
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
