---

# Methodology: CSIRO Pasture Biomass Prediction

## 1. Core Strategy: Predicting Key Components

The primary goal is to predict five biomass targets. Based on exploratory data analysis (EDA), we identified linear dependencies:
* `Dry_Total_g` $\approx$ `Dry_Green_g` + `Dry_Dead_g` + `Dry_Clover_g`
* `GDM_g` $\approx$ `Dry_Green_g` + `Dry_Clover_g`

To avoid redundancy, the model is trained to predict only the **three most visually distinct and/or highest-weighted targets**:
* `Dry_Total_g` (50% of the score)
* `GDM_g` (20% of the score)
* `Dry_Green_g` (10% of the score)

The remaining two targets (`Dry_Dead_g` and `Dry_Clover_g`) are then **calculated during validation and inference** using subtraction (e.g., `pred_Clover = max(0, pred_GDM - pred_Green)`).

---

## 2. Data Handling & K-Fold Strategy

* **Image Input:** All source images are high-resolution (`2000x1000` pixels).
* **Two-Stream Processing:** To preserve fine-grained details (like clover leaves) that would be lost by resizing the entire image, the `Dataset` class crops each image into two `1000x1000` patches (a "left" and "right" half).
* **High-Resolution Input:** Each `1000x1000` patch is then resized to **`768x768`**, maintaining a high level of detail.
* **K-Fold Strategy:** We use a **5-Fold Cross-Validation** strategy due to the small dataset (357 images).
* **Robust Splitting (GroupKFold):** To prevent data leakage (where similar images from the same day are in both train and validation), we use `GroupKFold` grouped by `Sampling_Date`. This ensures the model is validated on dates it has never seen.

---

## 3. Model Architecture: Two-Stream, Multi-Head

The model uses a "Two-Stream, Multi-Head" architecture.
* **Shared Backbone:** A single `timm` backbone (e.g., `convnext_tiny`) with pre-trained ImageNet weights is used.
* **Two-Stream Input:**
    * `img_left` $\rightarrow$ `backbone` $\rightarrow$ `features_left`
    * `img_right` $\rightarrow$ (same) `backbone` $\rightarrow$ `features_right`
* **Fusion:** The two feature vectors are concatenated: `combined_features = torch.cat([features_left, features_right])`.
* **Multi-Head Output:** This combined vector is fed into **three separate, specialized MLP heads** (one for each target: `head_total`, `head_gdm`, `head_green`) to allow for task specialization.

---

## 4. Data Augmentation

To compensate for the small dataset, augmentations are applied **independently** to the `img_left` and `img_right` patches.
* `HorizontalFlip (p=0.5)`
* `VerticalFlip (p=0.5)`
* `RandomRotate90 (p=0.5)` (Only 90-degree rotations)
* `ColorJitter`

This independent application creates a much larger variety of training combinations.

---

## 5. Loss Function: Weighted SmoothL1Loss

The model is optimized using a custom weighted loss function that aligns with the competition's scoring metric.
* **Base Loss:** `nn.SmoothL1Loss` (Huber Loss) is used instead of `MSELoss` to make training more stable and less sensitive to outliers.
* **Weighted Sum:** The final loss is a weighted sum of the individual losses, using the competition's scoring weights:
    $$Loss = (0.5 \cdot Loss_{Total}) + (0.2 \cdot Loss_{GDM}) + (0.1 \cdot Loss_{Green})$$

---

## 6. Training Strategy: Two-Stage Fine-Tuning

A two-stage "Freeze/Unfreeze" strategy is used to stabilize training on the small dataset.
* **Stage 1 (Freeze):**
    * **Epochs:** 1-5
    * **Action:** The entire `backbone` is frozen. Only the three MLP heads are trained.
    * **LR:** `1e-4`
* **Stage 2 (Unfreeze/Fine-Tuning):**
    * **Epochs:** 6-20
    * **Action:** The `backbone` is "unfrozen," and the entire model is trained.
    * **LR:** A very low learning rate (`1e-5`) is used to slowly adapt the backbone features.
* **Model Saving:** A `ModelCheckpoint` saves the model based on the **highest `Score (R^2)`** on the validation set, *not* the lowest loss. This is critical for capturing the model's peak performance (like the `R^2=0.64` spike at Epoch 11) and ignoring the unstable, overfitted epochs.

In [1]:
# ============================================================
# CSIRO Image2Biomass ‚Äî Dual-Ensemble Inference (A√ó0.88 + B√ó0.12)
# - ÊØèÂ•ó‰∏∫ 5-fold ensemble + ÁÆÄÂçï TTAÔºàÂéüÂõæ/Ê∞¥Âπ≥ÁøªËΩ¨/ÂûÇÁõ¥ÁøªËΩ¨Ôºâ
# - Ëá™Âä®‰ªéËÆ≠ÁªÉÊùÉÈáç‰∏≠Ââ•Á¶ª student.* Â≠êÊ†ëÔºàÂéªÂâçÁºÄÔºâÔºå‰∏¢ÂºÉËÆ≠ÁªÉÊúüÂ§öÊ®°ÊÄÅÈÉ®ÂàÜ
# - ‰æùÊçÆÊùÉÈáçÊòØÂê¶ÂåÖÂê´ film_left/right Ëá™Âä®Âà§Êñ≠Âèò‰ΩìÔºàtiled_film / tiled / plainÔºâ
# - Á©∑‰∏æ DINO ‰∏ªÂπ≤ÂêçÔºà‰∏éËÆ≠ÁªÉ‰∏ÄËá¥ÔºâÔºåstrict ÂåπÈÖçÊàêÂäüÊâç‰ΩøÁî®
# - final = 0.88*final_A + 0.12*final_B -> submission.csv
# ============================================================

import os
import gc
import cv2
import numpy as np
import pandas as pd
from tqdm import tqdm

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

import albumentations as A
from albumentations.pytorch import ToTensorV2
import timm

# =============================== ÈÖçÁΩÆ =========================================
class CFG:
    # Êï∞ÊçÆË∑ØÂæÑÔºàKaggle ÈªòËÆ§Ôºâ
    BASE_PATH = "/kaggle/input/csiro-biomass"
    TEST_CSV = os.path.join(BASE_PATH, "test.csv")
    TEST_IMAGE_DIR = os.path.join(BASE_PATH, "test")

    # ‚Äî‚Äî ÈõÜÂêà AÔºàÁ¨¨‰∏Ä‰ªΩ‰ª£Á†ÅÔºâ‚Äî‚Äî
    MODEL_DIR_A = "/kaggle/input/m/gothamjocker/csiro/pytorch/default/6"   # ‚Üê ÊîπÊàê‰Ω†ÁöÑÊï∞ÊçÆÈõÜË∑ØÂæÑ
    CKPTS_A = [
        os.path.join(MODEL_DIR_A, "tiled_film_best_model_fold1.pth"),
        os.path.join(MODEL_DIR_A, "tiled_film_best_model_fold2.pth"),
        os.path.join(MODEL_DIR_A, "tiled_film_best_model_fold3.pth"),
        os.path.join(MODEL_DIR_A, "tiled_film_best_model_fold4.pth"),
        os.path.join(MODEL_DIR_A, "tiled_film_best_model_fold5.pth"),
    ]

    # ‚Äî‚Äî ÈõÜÂêà BÔºàÁ¨¨‰∫å‰ªΩ‰ª£Á†ÅÔºâ‚Äî‚Äî
    MODEL_DIR_B = "/kaggle/input/m/gothamjocker/csiro/pytorch/default/11"  # ‚Üê ÊîπÊàê‰Ω†ÁöÑÊï∞ÊçÆÈõÜË∑ØÂæÑ
    CKPTS_B = [
        os.path.join(MODEL_DIR_B, "tiled_film_best_model_fold1.pth"),
        os.path.join(MODEL_DIR_B, "tiled_film_best_model_fold2.pth"),
        os.path.join(MODEL_DIR_B, "tiled_film_best_model_fold3.pth"),
        os.path.join(MODEL_DIR_B, "tiled_film_best_model_fold4.pth"),
        os.path.join(MODEL_DIR_B, "tiled_film_best_model_fold5.pth"),
    ]

    # ËûçÂêàÊùÉÈáç
    W_A = 0.965
    W_B = 0.035

    SUBMISSION_FILE = "submission.csv"
    DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    BATCH_SIZE = 1
    NUM_WORKERS = 0

    # ‰∏éËÆ≠ÁªÉ‰∏ÄËá¥
    DROPOUT = 0.30
    HIDDEN_RATIO = 0.25
    GRID = (2, 2)  # tiled / tiled_film ÁöÑÂàáÂùóÁΩëÊ†ºÔºàËã•ÊùÉÈáç‰∏çÊòØ tiled Á±ªÔºå‰πü‰ºöËá™Âä®ËØÜÂà´Ôºâ

    # ËæìÂá∫ÂàóÈ°∫Â∫èÔºà‰∏éËÆ≠ÁªÉ‰∏ÄËá¥Ôºâ
    ALL_TARGET_COLS = ["Dry_Green_g", "Dry_Dead_g", "Dry_Clover_g", "GDM_g", "Dry_Total_g"]

    # ËÆ≠ÁªÉÊó∂ÁöÑ DINO ÂÄôÈÄâÔºà‰ºòÂÖàÁ∫ß‰ªéÈ´òÂà∞‰ΩéÔºâ
    DINO_CANDIDATES = [
        "vit_base_patch14_dinov3",
        "vit_base_patch14_reg4_dinov3",
        "vit_small_patch14_dinov3",
        "vit_base_patch14_reg4_dinov2",
        "vit_base_patch14_dinov2",
        "vit_small_patch14_dinov2",
    ]

print(f"Device: {CFG.DEVICE}")

# =============================== Êï∞ÊçÆÈõÜÔºàÂ∑¶Âè≥‰∏§Ë∑ØÔºå‰∏éËÆ≠ÁªÉ‰∏ÄËá¥Ôºâ =================
class TestBiomassDataset(Dataset):
    def __init__(self, df, transform, image_dir):
        self.df = df.reset_index(drop=True)
        self.transform = transform
        self.image_dir = image_dir
        self.paths = self.df["image_path"].values

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        filename = os.path.basename(self.paths[idx])
        full_path = os.path.join(self.image_dir, filename)
        img = cv2.imread(full_path)
        if img is None:
            # ÂÆπÈîôÔºöËã•ËØªÂõæÂ§±Ë¥•ÔºåÁî®ÈªëÂõæÂç†‰Ωç
            img = np.zeros((1000, 2000, 3), dtype=np.uint8)
        else:
            img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)

        # Â∑¶Âè≥ÂàáÂçäÔºàËÆ≠ÁªÉ/È™åËØÅÁªü‰∏ÄÁ≠ñÁï•Ôºâ
        h, w, _ = img.shape
        mid = w // 2
        left = img[:, :mid]
        right = img[:, mid:]

        left_t = self.transform(image=left)["image"]
        right_t = self.transform(image=right)["image"]
        return left_t, right_t

# =============================== ÊûÑÂª∫ DINO ‰∏ªÂπ≤ÔºàÊé®ÁêÜÁ´ØÔºâ ======================
def _infer_input_res(m) -> int:
    if hasattr(m, "patch_embed") and hasattr(m.patch_embed, "img_size"):
        isz = m.patch_embed.img_size
        return int(isz if isinstance(isz, (int, float)) else isz[0])
    if hasattr(m, "img_size"):
        isz = m.img_size
        return int(isz if isinstance(isz, (int, float)) else isz[0])
    dc = getattr(m, "default_cfg", {}) or {}
    ins = dc.get("input_size", None)
    if ins:
        if isinstance(ins, (tuple, list)) and len(ins) >= 2:
            return int(ins[1])
        return int(ins if isinstance(ins, (int, float)) else 224)
    name = getattr(m, "default_cfg", {}).get("architecture", "") or str(type(m))
    return 518 if ("dinov2" in name.lower()) else 224

def _build_dino_by_name(name: str):
    # Êé®ÁêÜÁ´Ø‰∏çÈúÄË¶Å‰∏ãËΩΩÈ¢ÑËÆ≠ÁªÉÔºåÁõ¥Êé• pretrained=FalseÔºà‰ºöÁî®ÊùÉÈáçÂÆåÂÖ®Ë¶ÜÁõñÔºâ
    m = timm.create_model(name, pretrained=False, num_classes=0)
    feat = m.num_features
    input_res = _infer_input_res(m)
    return m, feat, input_res

# =============================== Ê®°ÂûãÔºà‰∏éËÆ≠ÁªÉ‰∏•Ê†ºÂØπÈΩêÔºâ ========================
class TwoStreamDINOBase(nn.Module):
    def __init__(self, backbone_name: str, dropout: float = 0.3, hidden_ratio: float = 0.25):
        super().__init__()
        self.backbone, feat, input_res = _build_dino_by_name(backbone_name)
        self.used_backbone_name = backbone_name
        self.input_res = int(input_res)
        self.feat_dim = feat
        self.combined = feat * 2

        hidden = max(8, int(self.combined * hidden_ratio))

        def head():
            return nn.Sequential(
                nn.Linear(self.combined, hidden),
                nn.ReLU(inplace=True),
                nn.Dropout(dropout),
                nn.Linear(hidden, 1),
            )

        self.head_green = head()
        self.head_clover = head()
        self.head_dead = head()
        self.softplus = nn.Softplus(beta=1.0)

    def _merge_heads(self, f_l: torch.Tensor, f_r: torch.Tensor):
        f = torch.cat([f_l, f_r], dim=1)
        green_pos = self.softplus(self.head_green(f))
        clover_pos = self.softplus(self.head_clover(f))
        dead_pos = self.softplus(self.head_dead(f))
        gdm = green_pos + clover_pos
        total = gdm + dead_pos
        return total, gdm, green_pos

class TwoStreamDINOPlain(TwoStreamDINOBase):
    def forward(self, x_left: torch.Tensor, x_right: torch.Tensor):
        f_l = self.backbone(x_left)
        f_r = self.backbone(x_right)
        return self._merge_heads(f_l, f_r)

def _make_edges(L: int, parts: int):
    step = L // parts
    edges = []
    start = 0
    for _ in range(parts - 1):
        edges.append((start, start + step))
        start += step
    edges.append((start, L))
    return edges

class TwoStreamDINOTiled(TwoStreamDINOBase):
    def __init__(self, backbone_name: str, grid=(2, 2), **kwargs):
        super().__init__(backbone_name, **kwargs)
        self.grid = tuple(grid)

    def _encode_tiles(self, x: torch.Tensor) -> torch.Tensor:
        B, C, H, W = x.shape
        r, c = self.grid
        rows = _make_edges(H, r)
        cols = _make_edges(W, c)
        feats = []
        for (rs, re) in rows:
            for (cs, ce) in cols:
                xt = x[:, :, rs:re, cs:ce]
                if xt.shape[-2:] != (self.input_res, self.input_res):
                    xt = F.interpolate(xt, size=(self.input_res, self.input_res), mode="bilinear", align_corners=False)
                ft = self.backbone(xt)
                feats.append(ft)
        feats = torch.stack(feats, dim=0).permute(1, 0, 2)  # (B, T, F)
        feat_stream = feats.mean(dim=1)  # (B, F)
        return feat_stream

    def forward(self, x_left: torch.Tensor, x_right: torch.Tensor):
        f_l = self._encode_tiles(x_left)
        f_r = self._encode_tiles(x_right)
        return self._merge_heads(f_l, f_r)

class FiLM(nn.Module):
    def __init__(self, in_dim: int):
        super().__init__()
        hid = max(32, in_dim // 2)
        self.mlp = nn.Sequential(nn.Linear(in_dim, hid), nn.ReLU(inplace=True), nn.Linear(hid, in_dim * 2))

    def forward(self, context: torch.Tensor):
        gb = self.mlp(context)
        gamma, beta = torch.chunk(gb, 2, dim=1)
        return gamma, beta

class TwoStreamDINOTiledFiLM(TwoStreamDINOBase):
    def __init__(self, backbone_name: str, grid=(2, 2), **kwargs):
        super().__init__(backbone_name, **kwargs)
        self.grid = tuple(grid)
        self.film_left = FiLM(self.feat_dim)
        self.film_right = FiLM(self.feat_dim)

    def _tiles_backbone(self, x: torch.Tensor) -> torch.Tensor:
        B, C, H, W = x.shape
        r, c = self.grid
        rows = _make_edges(H, r)
        cols = _make_edges(W, c)
        feats = []
        for (rs, re) in rows:
            for (cs, ce) in cols:
                xt = x[:, :, rs:re, cs:ce]
                if xt.shape[-2:] != (self.input_res, self.input_res):
                    xt = F.interpolate(xt, size=(self.input_res, self.input_res), mode="bilinear", align_corners=False)
                ft = self.backbone(xt)
                feats.append(ft)
        feats = torch.stack(feats, dim=0).permute(1, 0, 2)  # (B, T, F)
        return feats

    def _encode_stream(self, x: torch.Tensor, film: FiLM) -> torch.Tensor:
        tiles = self._tiles_backbone(x)  # (B, T, F)
        context = tiles.mean(dim=1)      # (B, F)
        gamma, beta = film(context)      # (B, F)
        tiles = tiles * (1 + gamma.unsqueeze(1)) + beta.unsqueeze(1)
        feat_stream = tiles.mean(dim=1)  # (B, F)
        return feat_stream

    def forward(self, x_left: torch.Tensor, x_right: torch.Tensor):
        f_l = self._encode_stream(x_left, self.film_left)
        f_r = self._encode_stream(x_right, self.film_right)
        return self._merge_heads(f_l, f_r)

# =============================== ÊùÉÈáçÊ∏ÖÊ¥ó‰∏éÂä†ËΩΩ =================================
def _strip_module_prefix(sd: dict):
    if not sd:
        return sd
    keys = list(sd.keys())
    if all(k.startswith("module.") for k in keys):
        return {k[len("module.") :]: v for k, v in sd.items()}
    return sd

def _extract_student_substate(sd: dict) -> dict:
    """
    ËÆ≠ÁªÉ‰øùÂ≠òÁöÑÊòØ MultiModalStudentTeacher ÁöÑ state_dict„ÄÇ
    ËøôÈáåÂâ•Á¶ªÔºö
      - ÂÖàÂéª 'module.' ÂâçÁºÄ
      - Ëã•Â≠òÂú® 'student.' ÂâçÁºÄÔºå‰ªÖ‰øùÁïôËØ•Â≠êÊ†ëÂπ∂ÂéªÊéâ 'student.' ÂâçÁºÄ
      - ‰∏¢ÂºÉ‰ªÖËÆ≠ÁªÉÊúüÁöÑÂ§öÊ®°ÊÄÅ/Ëí∏È¶èÁõ∏ÂÖ≥Â±ÇÔºàtxt_enc.*, img_proj.*, txt_film_left.*, txt_film_right.*Ôºâ
    """
    sd = _strip_module_prefix(sd)
    has_student = any(k.startswith("student.") for k in sd.keys())
    if has_student:
        sd = {k[len("student.") :]: v for k, v in sd.items() if k.startswith("student.")}
    drop_prefixes = ("txt_enc.", "img_proj.", "txt_film_left.", "txt_film_right.")
    sd = {k: v for k, v in sd.items() if not k.startswith(drop_prefixes)}
    return sd

def _has_film(sd_keys: set) -> bool:
    return any(k.startswith("film_left.mlp.") for k in sd_keys) or any(k.startswith("film_right.mlp.") for k in sd_keys)

def _try_build_and_load(sd: dict, backbone_name: str, variant: str, grid=(2, 2)):
    if variant == "tiled_film":
        model = TwoStreamDINOTiledFiLM(backbone_name, grid=grid, dropout=CFG.DROPOUT, hidden_ratio=CFG.HIDDEN_RATIO)
    elif variant == "tiled":
        model = TwoStreamDINOTiled(backbone_name, grid=grid, dropout=CFG.DROPOUT, hidden_ratio=CFG.HIDDEN_RATIO)
    else:
        model = TwoStreamDINOPlain(backbone_name, dropout=CFG.DROPOUT, hidden_ratio=CFG.HIDDEN_RATIO)

    # ‰∏•Ê†ºÂåπÈÖçÔºömissing/unexpected ÈÉΩ‰∏∫ 0
    result = model.load_state_dict(sd, strict=False)
    missing = getattr(result, "missing_keys", [])
    unexpected = getattr(result, "unexpected_keys", [])
    if len(missing) == 0 and len(unexpected) == 0:
        model.to(CFG.DEVICE)
        model.eval()
        return model
    return None

def load_fold_model_auto(path: str, grid=(2, 2)):
    if not os.path.exists(path):
        raise FileNotFoundError(path)
    # ÂÖºÂÆπ PyTorch 2.6 ÁöÑ weights_only ÂèòÊõ¥
    try:
        raw_sd = torch.load(path, map_location=CFG.DEVICE, weights_only=True)
    except TypeError:
        raw_sd = torch.load(path, map_location=CFG.DEVICE)
    sd = _extract_student_substate(raw_sd)
    keys = set(sd.keys())

    # Âà§ÂÆöÊòØÂê¶ FiLM
    is_film = _has_film(keys)
    variant_order = ["tiled_film"] if is_film else ["tiled", "plain"]

    # Á©∑‰∏æ‰∏ªÂπ≤ + Âèò‰ΩìÔºåÁõ¥Âà∞‰∏•Ê†ºÂåπÈÖç
    for variant in variant_order:
        for backbone in CFG.DINO_CANDIDATES:
            try:
                m = _try_build_and_load(sd, backbone, variant, grid=grid)
                if m is not None:
                    return m, variant, backbone
            except Exception:
                continue

    # Âà∞Ê≠§‰ªçÂ§±Ë¥•ÔºåÊèê‰æõËØäÊñ≠‰ø°ÊÅØ
    raise RuntimeError(
        f"Êó†Ê≥ï‰∏∫ {os.path.basename(path)} ÊâæÂà∞ÂåπÈÖçÁöÑÂèò‰Ωì/‰∏ªÂπ≤„ÄÇ"
        f" Ê£ÄÊµãÂà∞ {'tiled_film' if is_film else 'non-film'} ÊùÉÈáçÔºåËØ∑Ê£ÄÊü•ËÆ≠ÁªÉ-Êé®ÁêÜ‰∏ÄËá¥ÊÄß„ÄÇ"
    )

# =============================== TTA ÂèòÊç¢ÔºàÂä®ÊÄÅÊåâÊ®°Âûã input_resÔºâ ===============
def get_tta_transforms(img_size: int):
    base = [A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), ToTensorV2()]
    original = A.Compose([A.Resize(img_size, img_size, interpolation=cv2.INTER_AREA), *base])
    hflip = A.Compose([A.HorizontalFlip(p=1.0), A.Resize(img_size, img_size, interpolation=cv2.INTER_AREA), *base])
    vflip = A.Compose([A.VerticalFlip(p=1.0), A.Resize(img_size, img_size, interpolation=cv2.INTER_AREA), *base])
    return [original, hflip, vflip]

# =============================== Êé®ÁêÜÔºàÁªü‰∏Ä‰∫ßÂá∫ 5 ÁõÆÊ†áÔºâ =======================
@torch.no_grad()
def predict_one_view(models, loader):
    out_list = []
    amp_dtype = "cuda" if CFG.DEVICE.type == "cuda" else "cpu"

    for (xl, xr) in tqdm(loader, desc="  Predicting View", leave=False):
        xl = xl.to(CFG.DEVICE, non_blocking=True)
        xr = xr.to(CFG.DEVICE, non_blocking=True)

        per_model_preds = []
        with torch.amp.autocast(amp_dtype, enabled=(CFG.DEVICE.type == "cuda")):
            for m in models:
                total, gdm, green = m(xl, xr)
                dead = total - gdm
                clover = gdm - green
                five = torch.cat([green, dead, clover, gdm, total], dim=1)
                five = torch.clamp(five, min=0.0)  # ÈùûË¥üÁ∫¶Êùü
                per_model_preds.append(five.float().cpu())

        stacked = torch.mean(torch.stack(per_model_preds, dim=0), dim=0)
        out_list.append(stacked.numpy())

    return np.concatenate(out_list, axis=0)

def run_inference_for_ckpts(ckpt_list, test_unique, image_dir):
    print("\n================= Âä†ËΩΩÊ®°Âûã (5 Êäò) =================")
    models = []

    # ÂÖàÁî®Á¨¨‰∏ÄÊäòÁ°ÆÂÆöËæìÂÖ•ÂàÜËæ®Áéá
    m1, v1, b1 = load_fold_model_auto(ckpt_list[0], grid=CFG.GRID)
    models.append(m1)
    backbone_res = int(getattr(m1, "input_res", 518))
    print(f"fold1 => variant={v1}, backbone={b1}, input_res={backbone_res}")

    for p in ckpt_list[1:]:
        m, v, b = load_fold_model_auto(p, grid=CFG.GRID)
        print(f"{os.path.basename(p)} => variant={v}, backbone={b}, input_res={getattr(m, 'input_res', '?')}")
        models.append(m)

    # ÂáÜÂ§á TTA ËßÜËßí
    tta_trans = get_tta_transforms(backbone_res)
    per_view_preds = []
    for i, t in enumerate(tta_trans):
        print(f"\n--- TTA ËßÜËßí {i+1}/{len(tta_trans)} (resize={backbone_res}) ---")
        ds = TestBiomassDataset(test_unique, t, image_dir)
        dl = DataLoader(ds, batch_size=CFG.BATCH_SIZE, shuffle=False, num_workers=CFG.NUM_WORKERS, pin_memory=True)
        view_5 = predict_one_view(models, dl)  # [N,5]
        per_view_preds.append(view_5)

    # TTA Âπ≥Âùá
    final_5 = np.mean(per_view_preds, axis=0)  # [N,5]: [green, dead, clover, gdm, total]
    return final_5

def run_dual_ensembles_and_fuse():
    print("\n================= Âä†ËΩΩÊµãËØïÊï∞ÊçÆ =================")
    test_long = pd.read_csv(CFG.TEST_CSV)
    test_unique = test_long.drop_duplicates(subset=["image_path"]).reset_index(drop=True)
    print(f"ÊâæÂà∞ {len(test_unique)} Âº†Áã¨Á´ãÊµãËØïÂõæÂÉè„ÄÇ")

    print("\n================= ÈõÜÂêà A Êé®ÁêÜ =================")
    final_a = run_inference_for_ckpts(CFG.CKPTS_A, test_unique, CFG.TEST_IMAGE_DIR)

    print("\n================= ÈõÜÂêà B Êé®ÁêÜ =================")
    final_b = run_inference_for_ckpts(CFG.CKPTS_B, test_unique, CFG.TEST_IMAGE_DIR)

    assert final_a.shape == final_b.shape, "A/B ‰∏§‰∏™ÁªìÊûúÂΩ¢Áä∂‰∏ç‰∏ÄËá¥ÔºåÊó†Ê≥ïÂä†ÊùÉËûçÂêà„ÄÇ"
    final = CFG.W_A * final_a + CFG.W_B * final_b
    return final, test_long, test_unique

# =============================== ÁîüÊàêÊèê‰∫§ ======================================
def create_submission(final_5, test_long, test_unique):
    green = final_5[:, 0]
    dead = final_5[:, 1]
    clover = final_5[:, 2]
    gdm = final_5[:, 3]
    total = final_5[:, 4]

    # ÊúÄÁªàÂÜçÂÅö‰∏ÄÊ¨°ÈùûË¥üË£ÅÂâ™‰∏é NaN/Inf Â§ÑÁêÜ
    def nnz(x):
        x = np.nan_to_num(x, nan=0.0, posinf=0.0, neginf=0.0)
        return np.maximum(0, x)

    green, dead, clover, gdm, total = map(nnz, [green, dead, clover, gdm, total])

    wide = pd.DataFrame(
        {
            "image_path": test_unique["image_path"],
            "Dry_Green_g": green,
            "Dry_Dead_g": dead,
            "Dry_Clover_g": clover,
            "GDM_g": gdm,
            "Dry_Total_g": total,
        }
    )

    long_preds = wide.melt(
        id_vars=["image_path"],
        value_vars=CFG.ALL_TARGET_COLS,
        var_name="target_name",
        value_name="target",
    )

    sub = pd.merge(
        test_long[["sample_id", "image_path", "target_name"]],
        long_preds,
        on=["image_path", "target_name"],
        how="left",
    )[["sample_id", "target"]]

    sub["target"] = np.nan_to_num(sub["target"], nan=0.0, posinf=0.0, neginf=0.0)
    sub.to_csv(CFG.SUBMISSION_FILE, index=False)
    print(f"\nüéâ Â∑≤ÁîüÊàêÊèê‰∫§Êñá‰ª∂: {CFG.SUBMISSION_FILE}")
    print(sub.head())
    return sub

# =============================== ÂÖ•Âè£ =========================================
if __name__ == "__main__":
    final_5, df_long, df_unique = run_dual_ensembles_and_fuse()
    _ = create_submission(final_5, df_long, df_unique)
    gc.collect()
    if torch.cuda.is_available():
        torch.cuda.empty_cache()


  data = fetch_version_info()


Device: cuda

ÊâæÂà∞ 1 Âº†Áã¨Á´ãÊµãËØïÂõæÂÉè„ÄÇ


fold1 => variant=tiled_film, backbone=vit_base_patch14_reg4_dinov2, input_res=518
tiled_film_best_model_fold2.pth => variant=tiled_film, backbone=vit_base_patch14_reg4_dinov2, input_res=518
tiled_film_best_model_fold3.pth => variant=tiled_film, backbone=vit_base_patch14_reg4_dinov2, input_res=518
tiled_film_best_model_fold4.pth => variant=tiled_film, backbone=vit_base_patch14_reg4_dinov2, input_res=518
tiled_film_best_model_fold5.pth => variant=tiled_film, backbone=vit_base_patch14_reg4_dinov2, input_res=518

--- TTA ËßÜËßí 1/3 (resize=518) ---


                                                                


--- TTA ËßÜËßí 2/3 (resize=518) ---


                                                                


--- TTA ËßÜËßí 3/3 (resize=518) ---


                                                                



fold1 => variant=tiled_film, backbone=vit_base_patch14_reg4_dinov2, input_res=518
tiled_film_best_model_fold2.pth => variant=tiled_film, backbone=vit_base_patch14_reg4_dinov2, input_res=518
tiled_film_best_model_fold3.pth => variant=tiled_film, backbone=vit_base_patch14_reg4_dinov2, input_res=518
tiled_film_best_model_fold4.pth => variant=tiled_film, backbone=vit_base_patch14_reg4_dinov2, input_res=518
tiled_film_best_model_fold5.pth => variant=tiled_film, backbone=vit_base_patch14_reg4_dinov2, input_res=518

--- TTA ËßÜËßí 1/3 (resize=518) ---


                                                                


--- TTA ËßÜËßí 2/3 (resize=518) ---


                                                                


--- TTA ËßÜËßí 3/3 (resize=518) ---


                                                                


üéâ Â∑≤ÁîüÊàêÊèê‰∫§Êñá‰ª∂: submission.csv
                    sample_id     target
0  ID1001187975__Dry_Clover_g   1.782589
1    ID1001187975__Dry_Dead_g  26.787958
2   ID1001187975__Dry_Green_g  32.905235
3   ID1001187975__Dry_Total_g  61.475784
4         ID1001187975__GDM_g  34.687828
