<div style="background-color:#F0F8FF; border-radius:15px; border:1px solid #4682B4; padding: 25px;">
    <div style="text-align:center;">
        <h1 style="color:#2E8B57; font-family:sans-serif; font-weight:800; font-size:32px;">üåø CSIRO Biomass: SOTA Training (Leakage Fixed)</h1>
        <h3 style="color:#4682B4; font-family:sans-serif; font-weight:600;">ViT-Large (DINOv2) + High-Res Tiling + Multi-Task Learning</h3>
    </div>
    <hr style="border-top: 1px solid #4682B4; margin: 20px 0;">
    <div style="font-family:sans-serif; color:#333; line-height:1.6;">
        <p><b>üèÜ Strategy Update:</b> Previous iterations scored high locally (0.82) but low on LB (0.40) due to <b>Data Leakage</b> (using ground-truth Height as input). This notebook fixes that logic.</p>
        <p><b>üîë Key Methodologies:</b></p>
        <ul>
            <li><b>Backbone:</b> <code>vit_large_patch14_dinov2</code> (State-of-the-art texture encoder).</li>
            <li><b>Resolution Tiling:</b> Images are resized to <b>1036x1036</b> and split into four <b>518x518</b> tiles to match DINOv2's native resolution.</li>
            <li><b>Multi-Task Learning:</b> We predict <b>Biomass</b> (Main Target) AND <b>Height/NDVI</b> (Auxiliary Targets). Predicting Height forces the model to learn 3D geometry from pixels alone.</li>
            <li><b>No Leakage:</b> Metadata is <b>NOT</b> fed as input. The model relies 100% on the image.</li>
        </ul>
    </div>
</div>

## Configuration & Imports

In [None]:
# Libraries
import os
import gc
import sys
import random
import warnings
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from pathlib import Path
from tqdm.auto import tqdm
from PIL import Image

# Image Processing
import cv2
import albumentations as A
from albumentations.pytorch import ToTensorV2

# Deep Learning
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torch.optim.lr_scheduler import OneCycleLR

# PyTorch Lightning
import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping, LearningRateMonitor
from pytorch_lightning.loggers import CSVLogger

# Models & Metrics
import timm
from sklearn.model_selection import StratifiedKFold
from sklearn.preprocessing import RobustScaler, MinMaxScaler
from sklearn.metrics import r2_score

# Config
class CFG:
    # Paths
    TRAIN_DIR = "/kaggle/input/csiro-biomass/train"
    TEST_DIR = "/kaggle/input/csiro-biomass/test"
    TRAIN_CSV = "/kaggle/input/csiro-biomass/train.csv"
    TEST_CSV = "/kaggle/input/csiro-biomass/test.csv"
    
    # Model Architecture
    MODEL_NAME = "vit_large_patch14_dinov2.lvd142m" 
    
    # Tiling Strategy
    # DINOv2 Native Res is 518px. We split image into 4 tiles of 518px (total effective resolution: 1036x1036)
    TILE_SIZE = 518       
    TILES_PER_IMG = 4   
    IMG_LOAD_SIZE = 1036
    
    # Hyperparameters
    BATCH_SIZE = 8       # Fits in 16GB VRAM
    NUM_WORKERS = 4
    EPOCHS = 15          
    LR = 1e-4            
    WEIGHT_DECAY = 1e-4  
    GRAD_CLIP = 1.0
    DROPOUT = 0.1
    SEED = 42
    N_FOLDS = 5
    
    # Targets
    TARGET_COLS = ['Dry_Total_g', 'GDM_g', 'Dry_Green_g', 'Dry_Dead_g', 'Dry_Clover_g'] # weighted R2
    TARGET_WEIGHTS = np.array([0.5, 0.2, 0.1, 0.1, 0.1])

    AUX_COLS = ['Height_Ave_cm', 'Pre_GSHH_NDVI'] 

warnings.filterwarnings('ignore')
torch.set_float32_matmul_precision('medium')
pl.seed_everything(CFG.SEED)
print(f"Setup Complete. Device: {torch.cuda.get_device_name(0) if torch.cuda.is_available() else 'CPU'}")

## Robust Data Loading
This section auto-detects file extensions (.jpg vs .png) and verifies files exist before training.

In [None]:
def check_folder_content(folder_path):
    """
    Check the folder for image files and return the file extension of the first found image.
    Returns None if the folder is empty or does not exist.
    """
    if not os.path.exists(folder_path): return None
    files = [f for f in os.listdir(folder_path) if not f.startswith('.')]
    if not files: return None
    return os.path.splitext(files[0])[1]

def load_data(csv_path, img_dir, is_train=True):
    """
    Load and preprocess data from CSV and image directory.
    For training data, pivot to wide format and fill missing auxiliary columns.
    For test data, return unique image paths.
    """
    print(f"Processing {csv_path}...")
    df = pd.read_csv(csv_path)
    
    # Detect Extension
    ext = check_folder_content(img_dir)
    if not ext: raise FileNotFoundError(f"No images found in {img_dir}")

    # Create Paths
    df['image_id'] = df['sample_id'].str.split('__').str[0]
    df['full_path'] = df['image_id'].apply(lambda x: os.path.join(img_dir, f"{x}{ext}"))
    
    # Ensure Aux columns exist (fill with NaN if missing in Test)
    for col in CFG.AUX_COLS:
        if col not in df.columns: df[col] = np.nan

    if is_train:
        # Pivot to Wide Format
        meta_cols = ['image_path', 'Sampling_Date', 'State', 'Species', 'full_path'] + CFG.AUX_COLS
        df_wide = df.pivot_table(
            index=['image_id'] + meta_cols,
            columns='target_name',
            values='target',
            aggfunc='first'
        ).reset_index()
        
        # Stratify Group
        df_wide['stratify_group'] = df_wide['State'].astype(str) + "_" + df_wide['Species'].astype(str)
        
        # Fill missing Aux data with median
        for col in CFG.AUX_COLS:
            df_wide[col] = df_wide[col].fillna(df_wide[col].median())
            
        return df_wide
    
    else:
        # Test Data
        df_unique = df[['image_id', 'full_path']].drop_duplicates().reset_index(drop=True)
        return df_unique

# Load Data
train_df = load_data(CFG.TRAIN_CSV, CFG.TRAIN_DIR, is_train=True)
test_df = load_data(CFG.TEST_CSV, CFG.TEST_DIR, is_train=False)

print(f"\nTrain Shape: {train_df.shape} | Test Shape: {test_df.shape}")

## Scaling & Augmentations
*   **Main Targets (Biomass):** `RobustScaler` (Handles outliers/heavy tails).
*   **Aux Targets (Height/NDVI):** `MinMaxScaler` (Bounded ranges).

In [None]:
class DualProcessor:
    """
    Handles scaling of main target columns and auxiliary columns separately.
    """
    def __init__(self):
        self.main_scaler = RobustScaler()
        self.aux_scaler = MinMaxScaler()
        
    def fit(self, df):
        self.main_scaler.fit(df[CFG.TARGET_COLS])
        self.aux_scaler.fit(df[CFG.AUX_COLS])
        
    def transform(self, df):
        if CFG.TARGET_COLS[0] in df.columns:
            y_main = self.main_scaler.transform(df[CFG.TARGET_COLS]).astype(np.float32)
            y_aux = self.aux_scaler.transform(df[CFG.AUX_COLS]).astype(np.float32)
            return y_main, y_aux
        return None, None
    
    def inverse_main(self, preds):
        return np.maximum(0, self.main_scaler.inverse_transform(preds))

processor = DualProcessor()
processor.fit(train_df)

def get_transforms(mode="train"):
    """
    Returns the appropriate Albumentations transformations for training or validation.
    """
    if mode == "train":
        return A.Compose([
            A.Resize(CFG.TILE_SIZE, CFG.TILE_SIZE), # Safety resize
            A.HorizontalFlip(p=0.5),
            A.VerticalFlip(p=0.5),
            A.Rotate(limit=30, p=0.5),
            A.RandomBrightnessContrast(p=0.2),
            A.CoarseDropout(max_holes=8, max_height=32, max_width=32, p=0.2),
            A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
            ToTensorV2(),
        ])
    else:
        return A.Compose([
            A.Resize(CFG.TILE_SIZE, CFG.TILE_SIZE),
            A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
            ToTensorV2(),
        ])

## Tiled Dataset (No Metadata Inputs)
This dataset handles the high-res tiling but **does not return metadata inputs**, preventing leakage.

In [None]:
class TiledDataset(Dataset):
    def __init__(self, df, transforms=None, processor=None, is_train=True):
        self.df = df.reset_index(drop=True)
        self.paths = df['full_path'].values
        self.transforms = transforms
        self.is_train = is_train
        
        if is_train:
            self.y_main, self.y_aux = processor.transform(df)

    def __len__(self):
        return len(self.df)
    
    def tile_image(self, image):
        # Resize to total load size (1036x1036)
        H, W = CFG.IMG_LOAD_SIZE, CFG.IMG_LOAD_SIZE
        image = cv2.resize(image, (W, H))
        
        # Split into 4 tiles of 518x518
        t_sz = CFG.TILE_SIZE
        tiles = []
        for i in range(2):
            for j in range(2):
                tile = image[i*t_sz:(i+1)*t_sz, j*t_sz:(j+1)*t_sz, :]
                tiles.append(tile)
        return tiles

    def __getitem__(self, idx):
        path = self.paths[idx]
        try:
            image = np.array(Image.open(path).convert("RGB"))
        except:
            image = np.zeros((CFG.IMG_LOAD_SIZE, CFG.IMG_LOAD_SIZE, 3), dtype=np.uint8)
            
        tiles = self.tile_image(image)
        
        # Apply Transforms
        processed_tiles = []
        for t in tiles:
            if self.transforms:
                t = self.transforms(image=t)['image']
            processed_tiles.append(t)
            
        img_tensor = torch.stack(processed_tiles)

        if self.is_train:
            return img_tensor, (self.y_main[idx], self.y_aux[idx])
        return img_tensor

## Multi-Task DINOv2 Architecture

In [None]:
class TiledDINO(pl.LightningModule):
    """
    PyTorch Lightning Module for Tiled DINOv2 with Dual Heads for Biomass Prediction.
    """
    def __init__(self):
        super().__init__()
        self.save_hyperparameters()
        
        # Load ViT-Large Backbone
        self.backbone = timm.create_model(
            CFG.MODEL_NAME, 
            pretrained=True, 
            num_classes=0,
            img_size=CFG.TILE_SIZE 
        )
        
        # Unfreeze Strategy (Last 4 blocks for better feature adaptation)
        for param in self.backbone.parameters():
            param.requires_grad = False
        if hasattr(self.backbone, 'blocks'):
            for param in self.backbone.blocks[-4:].parameters():
                param.requires_grad = True
            
        embed_dim = self.backbone.num_features
        
        # Main Head: Predict 5 Biomass Targets
        self.head_main = nn.Sequential(
            nn.Linear(embed_dim, 512),
            nn.BatchNorm1d(512),
            nn.SiLU(),
            nn.Dropout(CFG.DROPOUT),
            nn.Linear(512, 5)
        )
        
        # Auxiliary Head: Predict 2 Aux Targets
        self.head_aux = nn.Sequential(
            nn.Linear(embed_dim, 256),
            nn.BatchNorm1d(256),
            nn.SiLU(),
            nn.Linear(256, 2)
        )
        
        self.loss_fn = nn.HuberLoss(delta=1.0)

    def forward(self, x):
        # x shape: [Batch, 4, 3, 518, 518]
        b, n_tiles, c, h, w = x.shape
        x_flat = x.view(b * n_tiles, c, h, w)
        
        # Extract Features
        feats = self.backbone(x_flat) 
        feats = feats.view(b, n_tiles, -1)
        
        # Average Pooling across tiles
        global_feat = torch.mean(feats, dim=1) 
        
        return self.head_main(global_feat), self.head_aux(global_feat)

    def training_step(self, batch, batch_idx):
        img, (y_main, y_aux) = batch
        p_main, p_aux = self(img)
        
        loss_main = self.loss_fn(p_main, y_main)
        loss_aux = self.loss_fn(p_aux, y_aux)
        
        # 30% weight to aux tasks
        total_loss = loss_main + (0.3 * loss_aux)
        self.log("train_loss", total_loss, prog_bar=True)
        return total_loss

    def validation_step(self, batch, batch_idx):
        img, (y_main, y_aux) = batch
        p_main, _ = self(img)
        loss = self.loss_fn(p_main, y_main)
        
        # Store outputs for epoch-end metrics
        self.validation_step_outputs = getattr(self, "validation_step_outputs", [])
        self.validation_step_outputs.append((p_main, y_main))
        self.log("val_loss", loss, prog_bar=True)
        return loss

    def on_validation_epoch_end(self):
        outputs = self.validation_step_outputs
        preds = torch.cat([x[0] for x in outputs]).cpu().numpy()
        targets = torch.cat([x[1] for x in outputs]).cpu().numpy()
        
        r2s = []
        for i in range(5):
            r2s.append(r2_score(targets[:, i], preds[:, i]))
        
        w_r2 = np.average(r2s, weights=CFG.TARGET_WEIGHTS)
        self.log("val_w_r2", w_r2, prog_bar=True)
        self.validation_step_outputs.clear()
        
    def configure_optimizers(self):
        optimizer = torch.optim.AdamW(filter(lambda p: p.requires_grad, self.parameters()), 
                                      lr=CFG.LR, weight_decay=CFG.WEIGHT_DECAY)
        scheduler = OneCycleLR(optimizer, max_lr=CFG.LR, total_steps=self.trainer.estimated_stepping_batches)
        return {"optimizer": optimizer, "lr_scheduler": {"scheduler": scheduler, "interval": "step"}}

## Training Loop (Stratified K-Fold)

We perform Stratified K-Fold CV to ensure we aren't biased by one specific geographic state or crop species.

In [None]:
# Force download model weights first to avoid progress bar glitches
print("‚¨áDownloading Model Weights...")
timm.create_model(CFG.MODEL_NAME, pretrained=True, num_classes=0)
print("Ready to train.")

kf = StratifiedKFold(n_splits=CFG.N_FOLDS, shuffle=True, random_state=CFG.SEED)
oof_preds = np.zeros((len(train_df), 5))
oof_targets = np.zeros((len(train_df), 5))

# Training Loop
for fold, (train_idx, val_idx) in enumerate(kf.split(train_df, train_df['stratify_group'])):
    print(f"\n{'='*15} FOLD {fold+1}/{CFG.N_FOLDS} {'='*15}")
    
    train_sub = train_df.iloc[train_idx].reset_index(drop=True)
    val_sub = train_df.iloc[val_idx].reset_index(drop=True)
    
    train_loader = DataLoader(
        TiledDataset(train_sub, get_transforms('train'), processor, is_train=True),
        batch_size=CFG.BATCH_SIZE, shuffle=True, num_workers=CFG.NUM_WORKERS
    )
    val_loader = DataLoader(
        TiledDataset(val_sub, get_transforms('valid'), processor, is_train=True),
        batch_size=CFG.BATCH_SIZE, shuffle=False, num_workers=CFG.NUM_WORKERS
    )
    
    model = TiledDINO()
    
    checkpoint_cb = ModelCheckpoint(
        dirpath=f"models/fold_{fold}", filename="{epoch}-{val_w_r2:.4f}",
        monitor="val_w_r2", mode="max", save_top_k=1, save_weights_only=True
    )
    early_stop_cb = EarlyStopping(monitor="val_w_r2", patience=5, mode="max")
    
    trainer = pl.Trainer(
        max_epochs=CFG.EPOCHS,
        accelerator="auto", devices=1,
        precision="16-mixed",
        callbacks=[checkpoint_cb, early_stop_cb],
        logger=CSVLogger("logs"),
        enable_progress_bar=True,
        gradient_clip_val=CFG.GRAD_CLIP
    )
    
    trainer.fit(model, train_loader, val_loader)
    
    # Reload best model for OOF
    model = TiledDINO.load_from_checkpoint(checkpoint_cb.best_model_path)
    model.eval().cuda()
    
    # OOF Predictions
    preds, trues = [], []
    with torch.no_grad():
        for img, (ym, _) in tqdm(val_loader, desc="Validating"):
            img = img.cuda()
            p_main, _ = model(img)
            preds.append(p_main.cpu().numpy())
            trues.append(ym.cpu().numpy())
            
    oof_preds[val_idx] = processor.inverse_main(np.vstack(preds))
    oof_targets[val_idx] = processor.inverse_main(np.vstack(trues))
    
    del model, trainer
    torch.cuda.empty_cache()
    gc.collect()

## Evaluation & Submission

In [None]:
def get_weighted_r2(y_true, y_pred):
    """
    Calculate weighted R2 score across multiple targets.
    """
    r2_scores = []
    for i in range(5):
        r2_scores.append(r2_score(y_true[:, i], y_pred[:, i]))
    weights = CFG.TARGET_WEIGHTS / CFG.TARGET_WEIGHTS.sum()
    return np.average(r2_scores, weights=weights), r2_scores

# Evaluate OOF Results
score, ind_scores = get_weighted_r2(oof_targets, oof_preds)
print(f"\nOverall Weighted R2: {score:.5f}")
for n, s in zip(CFG.TARGET_COLS, ind_scores):
    print(f"  üîπ {n}: {s:.4f}")

# Plot Results
fig, axes = plt.subplots(1, 5, figsize=(20, 4))
for i, ax in enumerate(axes):
    ax.scatter(oof_targets[:, i], oof_preds[:, i], alpha=0.5, s=10)
    limit = max(oof_targets[:, i].max(), oof_preds[:, i].max())
    ax.plot([0, limit], [0, limit], 'r--')
    ax.set_title(f"{CFG.TARGET_COLS[i]}\nR2: {ind_scores[i]:.2f}")
plt.tight_layout()
plt.show()