# Kaggle CSIRO Biomass ConvNeXt Submission

Self-contained inference notebook that mirrors the repository's ConvNeXt patch-based pipeline.

In [None]:
# ================================# USER-EDITABLE CONFIG (KAGGLE)# ================================TEST_CSV_PATH = "/kaggle/input/csiro-biomass/test.csv"TEST_IMAGE_DIR = "/kaggle/input/csiro-biomass/test"CHECKPOINTS = [    # Example: "/kaggle/input/my-csiro-weights/fold1.ckpt",    # Example: "/kaggle/input/my-csiro-weights/fold2.ckpt",]SUBMISSION_FILE = "submission.csv"BATCH_SIZE = 8NUM_WORKERS = 2IMAGE_SIZE = 1024PATCH_COUNT = 2BACKBONE = "convnext_large"

In [None]:
import osimport randomfrom typing import Any, Dict, List, Sequenceimport albumentations as Aimport cv2import numpy as npimport pandas as pdimport timmimport torchimport torch.nn as nnfrom albumentations.pytorch import ToTensorV2from torch.utils.data import DataLoader, Datasetfrom tqdm import tqdmdevice = torch.device("cuda" if torch.cuda.is_available() else "cpu")print(f"Using device: {device}")

In [None]:
def set_seed(seed: int = 42):    random.seed(seed)    np.random.seed(seed)    torch.manual_seed(seed)    torch.cuda.manual_seed_all(seed)    torch.backends.cudnn.deterministic = True    torch.backends.cudnn.benchmark = Falseset_seed(42)

In [None]:
PATCH_GRID = {    1: (1, 1),    2: (1, 2),    4: (2, 2),    6: (2, 3),}class BiomassPatchDataset(Dataset):    def __init__(        self,        metadata: pd.DataFrame,        image_dir: str,        patch_count: int,        image_size: int,        augment_cfg: Dict[str, Any] | None = None,        is_train: bool = False,        target_columns: List[str] | None = None,    ) -> None:        if patch_count not in PATCH_GRID:            raise ValueError(f"Unsupported patch_count: {patch_count}")        self.metadata = metadata.reset_index(drop=True)        self.image_dir = image_dir        self.patch_count = patch_count        self.grid = PATCH_GRID[patch_count]        self.image_size = image_size        self.is_train = is_train        self.target_columns = target_columns or ["Dry", "Clover", "Green"]        self.transform = self._build_transform(augment_cfg or {})        if self.is_train:            for col in self.target_columns:                if col not in self.metadata.columns:                    raise ValueError(                        f"Expected target column '{col}' not found in dataframe columns: {list(self.metadata.columns)}"                    )    def _build_transform(self, augment_cfg: Dict[str, Any]) -> A.Compose:        transforms: List[A.BasicTransform] = [A.Resize(self.image_size, self.image_size)]        if augment_cfg.get("color_jitter", False):            transforms.append(A.ColorJitter())        if augment_cfg.get("horizontal_flip", False):            transforms.append(A.HorizontalFlip(p=0.5))        transforms.extend([A.Normalize(), ToTensorV2()])        return A.Compose(transforms)    def __len__(self) -> int:        return len(self.metadata)    def __getitem__(self, idx: int) -> Dict[str, Any]:        row = self.metadata.iloc[idx]        image_path = os.path.join(self.image_dir, row["image_path"])        image = cv2.imread(image_path)        if image is None:            raise FileNotFoundError(f"Image not found at {image_path}")        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)        image = cv2.resize(image, (self.image_size, self.image_size))        patches = self._split_into_patches(image)        augmented_patches = [self.transform(image=patch)["image"] for patch in patches]        item: Dict[str, Any] = {"patches": augmented_patches}        if self.is_train:            target = torch.tensor([row[col] for col in self.target_columns], dtype=torch.float32)            item["target"] = target        return item    def _split_into_patches(self, image: np.ndarray) -> List[np.ndarray]:        rows, cols = self.grid        h, w, _ = image.shape        patch_h, patch_w = h // rows, w // cols        patches = []        for r in range(rows):            for c in range(cols):                y0, y1 = r * patch_h, (r + 1) * patch_h                x0, x1 = c * patch_w, (c + 1) * patch_w                patch = image[y0:y1, x0:x1]                patches.append(patch)        return patchesdef collate_fn(batch: List[Dict[str, Any]]):    patch_count = len(batch[0]["patches"])    patch_batches = []    for idx in range(patch_count):        stacked = torch.stack([sample["patches"][idx] for sample in batch])        patch_batches.append(stacked)    return patch_batches

In [None]:
BACKBONE_MAP = {    "convnext_large": "convnext_large.fb_in22k",    "convnext_base": "convnext_base.fb_in22k",}class RegressionHead(nn.Module):    def __init__(self, in_features: int, dropout: float = 0.3) -> None:        super().__init__()        self.layers = nn.Sequential(            nn.Linear(in_features, in_features // 2),            nn.GELU(),            nn.Dropout(dropout),            nn.Linear(in_features // 2, 3),        )    def forward(self, x: torch.Tensor) -> torch.Tensor:  # type: ignore[override]        return self.layers(x)class PatchFusionModel(nn.Module):    def __init__(        self,        backbone_name: str = "convnext_large",        patch_count: int = 2,        pretrained: bool = True,        dropout: float = 0.3,    ) -> None:        super().__init__()        if backbone_name not in BACKBONE_MAP:            raise ValueError(f"Unsupported backbone: {backbone_name}")        timm_name = BACKBONE_MAP[backbone_name]        self.backbone = timm.create_model(timm_name, pretrained=pretrained, num_classes=0, global_pool="avg")        self.feature_dim = self.backbone.num_features        self.patch_count = patch_count        self.head = RegressionHead(self.feature_dim * patch_count, dropout=dropout)    def forward(self, patches: Sequence[torch.Tensor]) -> torch.Tensor:  # type: ignore[override]        if isinstance(patches, torch.Tensor) and patches.dim() == 5:            batch, patch, c, h, w = patches.shape            patches = [patches[:, i] for i in range(patch)]        if len(patches) != self.patch_count:            raise ValueError(                f"Expected {self.patch_count} patches but received {len(patches)}. "                "Ensure model.patch_count matches dataset patch generation."            )        features: List[torch.Tensor] = []        for patch in patches:            feat = self.backbone(patch)            features.append(feat)        fused = torch.cat(features, dim=1)        return self.head(fused)

In [None]:
def load_model(checkpoint_path: str, backbone: str, patch_count: int) -> PatchFusionModel:    model = PatchFusionModel(backbone_name=backbone, patch_count=patch_count, pretrained=False)    state = torch.load(checkpoint_path, map_location=device)    if isinstance(state, dict) and "model_state_dict" in state:        state = state["model_state_dict"]    model.load_state_dict(state)    model.to(device)    model.eval()    return modeldef predict_ensemble(loader: DataLoader, checkpoints: List[str], backbone: str, patch_count: int) -> np.ndarray:    if len(checkpoints) == 0:        raise ValueError("Please provide at least one checkpoint path in CHECKPOINTS.")    all_preds = []    for ckpt in checkpoints:        model = load_model(ckpt, backbone=backbone, patch_count=patch_count)        preds = []        with torch.no_grad():            for patches in tqdm(loader, desc=f"Infer {os.path.basename(ckpt)}"):                patches = [p.to(device) for p in patches]                outputs = model(patches)                preds.append(outputs.cpu())        ckpt_preds = torch.cat(preds, dim=0).numpy()        all_preds.append(ckpt_preds)    stacked = np.stack(all_preds, axis=0)    return stacked.mean(axis=0)def expand_to_five_targets(preds_3: np.ndarray) -> np.ndarray:    # Model predicts [Dry, Clover, Green] corresponding to [Dry_Dead_g, Dry_Clover_g, Dry_Green_g]    dry_dead = preds_3[:, 0]    dry_clover = preds_3[:, 1]    dry_green = preds_3[:, 2]    gdm = dry_green + dry_clover  # green biomass    dry_total = dry_dead + dry_clover + dry_green    final = np.stack([dry_green, dry_dead, dry_clover, gdm, dry_total], axis=1)    final = np.nan_to_num(final, nan=0.0, posinf=0.0, neginf=0.0)    final = np.maximum(0.0, final)    return final

In [None]:
def create_submission(final_5: np.ndarray, test_long: pd.DataFrame, test_unique: pd.DataFrame, submission_path: str):    wide = pd.DataFrame({        "image_path": test_unique["image_path"],        "Dry_Green_g": final_5[:, 0],        "Dry_Dead_g": final_5[:, 1],        "Dry_Clover_g": final_5[:, 2],        "GDM_g": final_5[:, 3],        "Dry_Total_g": final_5[:, 4],    })    long_preds = wide.melt(        id_vars=["image_path"],        value_vars=["Dry_Green_g", "Dry_Dead_g", "Dry_Clover_g", "GDM_g", "Dry_Total_g"],        var_name="target_name",        value_name="target",    )    submission = pd.merge(        test_long[["sample_id", "image_path", "target_name"]],        long_preds,        on=["image_path", "target_name"],        how="left",    )[["sample_id", "target"]]    submission["target"] = np.nan_to_num(submission["target"], nan=0.0, posinf=0.0, neginf=0.0)    submission.to_csv(submission_path, index=False)    print(f"Saved submission to: {submission_path}")    display(submission.head())    return submission

In [None]:
if __name__ == "__main__":    test_long = pd.read_csv(TEST_CSV_PATH)    required_cols = {"sample_id", "image_path", "target_name"}    missing = required_cols - set(test_long.columns)    if missing:        raise ValueError(f"Missing required columns in test CSV: {missing}")    test_unique = test_long.drop_duplicates(subset=["image_path"]).reset_index(drop=True)    dataset = BiomassPatchDataset(        test_unique,        image_dir=TEST_IMAGE_DIR,        patch_count=PATCH_COUNT,        image_size=IMAGE_SIZE,        augment_cfg={"horizontal_flip": False, "color_jitter": False},        is_train=False,    )    loader = DataLoader(        dataset,        batch_size=BATCH_SIZE,        shuffle=False,        num_workers=NUM_WORKERS,        collate_fn=collate_fn,        pin_memory=True,    )    preds_3 = predict_ensemble(loader, CHECKPOINTS, backbone=BACKBONE, patch_count=PATCH_COUNT)    final_5 = expand_to_five_targets(preds_3)    create_submission(final_5, test_long, test_unique, SUBMISSION_FILE)