# CSIRO Image2Biomass v3 — Flattened Kaggle Submission (Inference Only)


In [None]:

# Cell 1: imports + paths/config (edit WEIGHTS_ROOT)
import glob
import os
from typing import List

import numpy as np
import pandas as pd
import timm
import torch
from PIL import Image, ImageOps
from torch.utils.data import DataLoader
import torchvision.transforms as T
from tqdm import tqdm

try:
    import cv2
except ImportError:
    cv2 = None

# === Mandatory path settings for offline Kaggle inference ===
DATA_ROOT = "/kaggle/input/csiro-biomass"
# !!! Update this to point to the attached Kaggle Dataset that contains fold*_best.pth files
WEIGHTS_ROOT = "/kaggle/input/<your-weights-dataset>/v3_weights"

RUN_NAME = os.environ.get("RUN_NAME", "v3_flat_inference")
OUTPUT_ROOT = "/kaggle/working/outputs"
RUN_DIR = os.path.join(OUTPUT_ROOT, RUN_NAME)
SUBMISSION_PATH = os.path.join(RUN_DIR, "submission", "submission.csv")
WORKING_SUBMISSION = "/kaggle/working/submission.csv"

# Preprocessing + split metadata for reproducibility
CROP_BOTTOM = 0.1  # crop away bottom 10% of image height; set to 0.0 to disable
USE_CLAHE = True   # contrast normalization; falls back to PIL equalize if OpenCV is unavailable
CV_SPLIT_STRATEGY = "group_date_state"  # training-time CV grouping; shown here for reference

os.makedirs(os.path.join(RUN_DIR, "submission"), exist_ok=True)

TARGET_COLUMNS = ["Dry_Green_g", "Dry_Clover_g", "Dry_Dead_g"]
ALL_TARGET_COLUMNS = TARGET_COLUMNS + ["GDM_g", "Dry_Total_g"]

cfg = {
    "backbone": "efficientnet_b2",  # same as v3/src config
    "image_size": 456,
    "batch_size": 32,
    "num_workers": 2,
    "device": "cuda" if torch.cuda.is_available() else "cpu",
}

device = torch.device(cfg["device"])
print(f"Using device: {device}")
print(f"DATA_ROOT: {DATA_ROOT}")
print(f"WEIGHTS_ROOT: {WEIGHTS_ROOT}")
print("cv2 available:", cv2 is not None)


def expand_targets(primary: np.ndarray) -> np.ndarray:
    dry_green = primary[:, 0]
    dry_clover = primary[:, 1]
    dry_dead = primary[:, 2]
    gdm = dry_green + dry_clover
    dry_total = gdm + dry_dead
    full = np.stack([dry_green, dry_dead, dry_clover, gdm, dry_total], axis=1)
    return full


In [None]:

# Cell 2: quick file/column sanity check

def load_long_dataframe(csv_path: str) -> pd.DataFrame:
    df = pd.read_csv(csv_path)
    required = {"sample_id", "image_path"}
    missing = required - set(df.columns)
    if missing:
        raise ValueError(f"Missing columns in {csv_path}: {sorted(missing)}")
    df["sample_id_prefix"] = df["sample_id"].str.split("__").str[0]
    return df


def to_wide(df: pd.DataFrame, include_targets: bool = True) -> pd.DataFrame:
    index_cols = [col for col in ["sample_id_prefix", "image_path"] if col in df.columns]
    missing_index = [col for col in ["sample_id_prefix", "image_path"] if col not in index_cols]
    if missing_index:
        raise ValueError(f"Missing required aggregation columns: {missing_index}")

    if not include_targets:
        return df[index_cols].drop_duplicates().reset_index(drop=True)

    if "target_name" not in df.columns or "target" not in df.columns:
        raise ValueError("target_name/target columns are required when include_targets=True")

    wide = df.pivot_table(index=index_cols, columns="target_name", values="target", aggfunc="first").reset_index()
    missing = [c for c in TARGET_COLUMNS if c not in wide.columns]
    if missing:
        raise ValueError(f"Missing target columns after pivot: {missing}")
    return wide


test_csv = os.path.join(DATA_ROOT, "test.csv")
sample_submission_csv = os.path.join(DATA_ROOT, "sample_submission.csv")

print("Test CSV exists:", os.path.exists(test_csv), "-", test_csv)
print("sample_submission.csv exists:", os.path.exists(sample_submission_csv), "-", sample_submission_csv)

if not os.path.exists(test_csv):
    raise FileNotFoundError(f"test.csv not found at {test_csv}")

# Load test set and derive wide format (no targets)
test_long = load_long_dataframe(test_csv)
print(f"Test rows: {len(test_long)}, columns: {test_long.columns.tolist()}")

test_wide = to_wide(test_long, include_targets=False)
print("test_wide shape:", test_wide.shape)
print(test_wide.head())


In [None]:

# Cell 3: dataset & transforms (inference-only)

class RegressionDataset(torch.utils.data.Dataset):
    def __init__(self, df: pd.DataFrame, image_root: str, image_size: int, crop_bottom: float = 0.0, use_clahe: bool = False):
        self.df = df.reset_index(drop=True)
        self.image_root = image_root
        self.crop_bottom = crop_bottom
        self.use_clahe = use_clahe
        self.transform = T.Compose([
            T.Resize((image_size, image_size)),
            T.ToTensor(),
            T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ])

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx: int):
        row = self.df.iloc[idx]
        img_path = os.path.normpath(os.path.join(self.image_root, row["image_path"]))
        with Image.open(img_path) as img:
            image = img.convert("RGB")

        if self.crop_bottom > 0:
            width, height = image.size
            keep_height = max(1, int(round(height * (1 - self.crop_bottom))))
            image = image.crop((0, 0, width, keep_height))

        if self.use_clahe:
            image = self._apply_clahe(image)

        image = self.transform(image)
        return image, torch.zeros(len(TARGET_COLUMNS), dtype=torch.float32), row["sample_id_prefix"]

    def _apply_clahe(self, image: Image.Image) -> Image.Image:
        if cv2 is None:
            return ImageOps.equalize(image)

        arr = np.array(image)
        lab = cv2.cvtColor(arr, cv2.COLOR_RGB2LAB)
        clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8))
        l, a, b = cv2.split(lab)
        l_eq = clahe.apply(l)
        lab_eq = cv2.merge((l_eq, a, b))
        rgb = cv2.cvtColor(lab_eq, cv2.COLOR_LAB2RGB)
        return Image.fromarray(rgb)


def get_inference_loader(test_df: pd.DataFrame) -> DataLoader:
    ds = RegressionDataset(test_df, DATA_ROOT, cfg["image_size"], crop_bottom=CROP_BOTTOM, use_clahe=USE_CLAHE)
    return DataLoader(
        ds,
        batch_size=cfg["batch_size"],
        shuffle=False,
        num_workers=cfg["num_workers"],
        pin_memory=True,
    )


inference_loader = get_inference_loader(test_wide)
print("Inference batches:", len(inference_loader))


In [None]:

# Cell 4: model definition (pretrained=False, weights loaded from dataset)

def build_model(backbone: str, num_outputs: int = len(TARGET_COLUMNS)) -> torch.nn.Module:
    # pretrained=False to avoid internet fetch in offline Kaggle
    model = timm.create_model(backbone, pretrained=False, num_classes=num_outputs)
    return model


In [None]:

# Cell 5: checkpoint load + inference across folds (mean ensemble)

def list_checkpoints(weights_root: str) -> List[str]:
    ckpts = sorted(glob.glob(os.path.join(weights_root, "fold*_best.pth")))
    if not ckpts:
        raise FileNotFoundError(f"No *_best.pth checkpoints found under {weights_root}")
    if len(ckpts) != 5:
        print(f"Warning: expected 5 fold checkpoints, found {len(ckpts)}")
    return ckpts


def load_model(checkpoint_path: str) -> torch.nn.Module:
    model = build_model(cfg["backbone"], num_outputs=len(TARGET_COLUMNS))
    state = torch.load(checkpoint_path, map_location=device)
    model.load_state_dict(state)
    model.to(device)
    model.eval()
    return model


def predict_wide(loader: DataLoader) -> np.ndarray:
    checkpoints = list_checkpoints(WEIGHTS_ROOT)
    preds_stack: List[np.ndarray] = []

    for ckpt_path in checkpoints:
        model = load_model(ckpt_path)
        fold_preds = []
        with torch.no_grad():
            for images, _, _ in tqdm(loader, desc=f"Predict {os.path.basename(ckpt_path)}"):
                images = images.to(device)
                outputs = model(images)
                fold_preds.append(outputs.cpu().numpy())
        preds_stack.append(np.concatenate(fold_preds))

    preds_mean = np.mean(preds_stack, axis=0)
    print("Preds shape (mean ensemble):", preds_mean.shape)
    return preds_mean


preds_primary = predict_wide(inference_loader)


In [None]:

# Cell 6: submission build + validation + save

def build_submission(test_long_df: pd.DataFrame, test_wide_df: pd.DataFrame, preds: np.ndarray, run_dir: str) -> str:
    full_preds = expand_targets(preds)
    pred_df = pd.DataFrame(full_preds, columns=["Dry_Green_g", "Dry_Dead_g", "Dry_Clover_g", "GDM_g", "Dry_Total_g"])
    pred_df["sample_id_prefix"] = test_wide_df["sample_id_prefix"].values

    pred_long = pred_df.melt(id_vars="sample_id_prefix", var_name="target_name", value_name="target")
    pred_long["sample_id"] = pred_long["sample_id_prefix"].astype(str) + "__" + pred_long["target_name"].astype(str)

    merged = test_long_df.merge(
        pred_long[["sample_id_prefix", "target_name", "target"]],
        on=["sample_id_prefix", "target_name"],
        how="left",
    )

    submission = merged[["sample_id", "target"]].copy()
    os.makedirs(os.path.join(run_dir, "submission"), exist_ok=True)
    submission.to_csv(SUBMISSION_PATH, index=False)
    submission.to_csv(WORKING_SUBMISSION, index=False)
    return submission


submission_df = build_submission(test_long, test_wide, preds_primary, RUN_DIR)
print("Submission shape:", submission_df.shape)
print("NaN present:", submission_df["target"].isna().any())

if os.path.exists(sample_submission_csv):
    sample_sub = pd.read_csv(sample_submission_csv)
    sample_ids = sample_sub["sample_id"].tolist()
    submission_ids = submission_df["sample_id"].tolist()
    print("Matches sample_submission length:", len(sample_ids) == len(submission_ids))
    print("Sample ID order identical:", sample_ids == submission_ids)
else:
    print("sample_submission.csv not found at", sample_submission_csv)

print("Saved submission to:", SUBMISSION_PATH)
print("Copied submission to:", WORKING_SUBMISSION)
submission_df.head()
