# ðŸŒ¾ CSIRO Image2Biomass: Data Augmentation and Training Notebook

**Competition:** [CSIRO Image2Biomass Prediction](https://www.kaggle.com/competitions/csiro-biomass)

**Thank you CigarCat for the useful EDA** [NB link](https://www.kaggle.com/code/takahitomizunobyts/csiro-customizable-eda)

**Thank you Zhuang Jia for some of the albumentations** [NB link](https://www.kaggle.com/code/jiazhuang/csiro-simple)

**Thank you Sagar Nagpure for the DINO baseline model**[NB link](https://www.kaggle.com/code/sagarnagpure1310/csiro-dinov2-tiled-lb-0-65)

In [None]:
# Imports
import os
import warnings
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from PIL import Image
from pathlib import Path

# PyTorch & Lightning
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping

# Scikit-Learn
from sklearn.model_selection import KFold
from sklearn.preprocessing import StandardScaler, OneHotEncoder, RobustScaler, PowerTransformer
import timm
from sklearn.metrics import r2_score, mean_absolute_error

# Suppress warnings
warnings.filterwarnings('ignore')
pl.seed_everything(42)

In [None]:
class CFG:
    # Data Paths
    # CSV_PATH_TRAIN = "csiro-biomass\\train.csv"
    # IMAGE_DIR_TRAIN = "csiro-biomass\\train"

    # CSV_PATH_TEST = "csiro-biomass\\test.csv"
    # IMAGE_DIR_TEST = "csiro-biomass\\test"      # This contains a singular image! The rest of them are semiprivate

    # Kaggle Data Paths (toggle on Kaggle)
    CSV_PATH_TRAIN = "/kaggle/input/csiro-biomass/train.csv"
    IMAGE_DIR_TRAIN = "/kaggle/input/csiro-biomass/train"

    CSV_PATH_TEST = "/kaggle/input/csiro-biomass/test.csv"
    IMAGE_DIR_TEST = "/kaggle/input/csiro-biomass/test"
    
    # Model Settings
    # 'vit_base_patch14_dinov2.lvd142m' is good, 'vit_small...' is faster
    MODEL_NAME = "vit_base_patch14_dinov2.lvd142m"
    IMG_SIZE = 518                 # DINOv2 uses patches of 14 (224/14 = 16)
    
    # Hyperparameters
    BATCH_SIZE = 16                # Lower if OOM
    EPOCHS = 20
    LR = 1e-4
    WEIGHT_DECAY = 1e-2
    N_FOLDS = 5
    SEED = 42
    
    # Features & Targets
    TARGET_COLS = ['Dry_Total_g', 'GDM_g', 'Dry_Green_g', 'Dry_Dead_g', 'Dry_Clover_g']
    
    NUM_FEATS = ["Pre_GSHH_NDVI", "Height_Ave_cm"]
    CAT_FEATS = ["State", "Species", "season"]

    # Weights (specified in competition overview)
    TARGET_WEIGHTS = {"Dry_Green_g": 0.1, "Dry_Dead_g": 0.1, "Dry_Clover_g": 0.1, "GDM_g": 0.2 , "Dry_Total_g": 0.5}

In [None]:
def load_and_preprocess_data(csv_path):
    print("Loading and formatting data...")
    df_long = pd.read_csv(csv_path)
    
    # Extract image ID
    df_long['image_id'] = df_long['sample_id'].str.split('__').str[0]
    
    # Pivot to Wide Format
    meta_cols = ['image_path', 'Sampling_Date', 'State', 'Species', 
                 'Pre_GSHH_NDVI', 'Height_Ave_cm']
    
    df_wide = df_long.pivot_table(
        index=['image_id'] + meta_cols,
        columns='target_name',
        values='target',
        aggfunc='first'
    ).reset_index()
    
    # Temporal Features
    df_wide['date'] = pd.to_datetime(df_wide['Sampling_Date'], format='%Y/%m/%d')
    df_wide['month'] = df_wide['date'].dt.month
    
    # Season (Southern Hemisphere)
    def get_season(month):
        if month in [9, 10, 11]: return "Spring"
        elif month in [12, 1, 2]: return "Summer"
        elif month in [3, 4, 5]: return "Autumn"
        else: return "Winter"
        
    df_wide['season'] = df_wide['month'].apply(get_season)
    
    # Construct full image path
    # Adjust extension (.png/.jpg) based on your actual data
    df_wide['full_image_path'] = df_wide['image_id'].apply(lambda x: os.path.join(CFG.IMAGE_DIR_TRAIN, f"{x}.png"))
    
    print(f"Loaded {len(df_wide)} unique samples.")
    return df_wide

In [None]:
class BiomassPreprocessor:
    """
    Handles scaling of inputs and Log-Transform of targets.
    Fits ONLY on train data to prevent leakage.
    """
    def __init__(self):
        self.pt = PowerTransformer(method="yeo-johnson")
        self.scaler = RobustScaler()
        self.encoder = OneHotEncoder(sparse_output=False, handle_unknown="ignore")
        self.cat_cols = CFG.CAT_FEATS
        self.num_cols = CFG.NUM_FEATS

    def fit_transform(self, df):
        # 1. Numeric: Power Transform -> Robust Scale
        X_num = self.pt.fit_transform(df[self.num_cols])
        X_num = self.scaler.fit_transform(X_num)

        # 2. Categorical: One Hot
        X_cat = self.encoder.fit_transform(df[self.cat_cols])

        # Combine
        return np.hstack([X_num, X_cat])

    def transform(self, df):
        # Use fitted scalers
        X_num = self.scaler.transform(self.pt.transform(df[self.num_cols]))
        X_cat = self.encoder.transform(df[self.cat_cols])
        return np.hstack([X_num, X_cat])

    def transform_targets(self, df):
        # Log1p transform to handle skew/outliers in biomass
        return np.log1p(df[CFG.TARGET_COLS].values.astype(np.float32))

    def inverse_transform_targets(self, y_pred):
        # Inverse of log1p is expm1
        return np.expm1(y_pred)

In [None]:
# ==========================================
# 4. DATASET & TRANSFORMS
# ==========================================
train_transforms = transforms.Compose([
    transforms.RandomResizedCrop(
        CFG.IMG_SIZE,              # output size
        scale=(0.8, 1.0),          # random area of the image to crop
        ratio=(0.75, 1.33)         # aspect ratio range
    ),
    transforms.RandomHorizontalFlip(),
    transforms.RandomVerticalFlip(),
    transforms.RandomRotation(15),
    transforms.ColorJitter(brightness=0.2, contrast=0.2),
    transforms.ToTensor(),
    transforms.Normalize(
        mean=[0.485, 0.456, 0.406], 
        std=[0.229, 0.224, 0.225]
    )
])

val_transforms = transforms.Compose([
    transforms.Resize((CFG.IMG_SIZE, CFG.IMG_SIZE)),
    transforms.ToTensor(),
    transforms.Normalize(
        mean=[0.485, 0.456, 0.406], 
        std=[0.229, 0.224, 0.225]
    )
])

class BiomassDataset(Dataset):
    def __init__(self, X_tab, image_paths, y, transform=None):
        self.X_tab = torch.tensor(X_tab, dtype=torch.float32)
        self.image_paths = image_paths
        self.y = torch.tensor(y, dtype=torch.float32)
        self.transform = transform

    def __len__(self):
        return len(self.y)

    def __getitem__(self, idx):
        # Tabular
        x_tab = self.X_tab[idx]
        
        # Image
        path = self.image_paths[idx]
        try:
            image = Image.open(path).convert("RGB")
        except Exception as e:
            # Fallback for missing images
            image = Image.new('RGB', (CFG.IMG_SIZE, CFG.IMG_SIZE), (0, 0, 0))

        if self.transform:
            image = self.transform(image)
            
        return (x_tab, image), self.y[idx]

In [None]:
# ==========================================
# 5. MODEL ARCHITECTURE
# ==========================================
class BiomassModel(pl.LightningModule):
    def __init__(self, tabular_input_dim, target_dim=3):
        super().__init__()
        self.save_hyperparameters()

        # 1. Vision Backbone (DINOv2)
        self.dino = timm.create_model(CFG.MODEL_NAME, pretrained=True, num_classes=0, img_size=CFG.IMG_SIZE)
        
        # FREEZE Backbone initially to prevent overfitting/catastrophic forgetting
        for param in self.dino.parameters():
            param.requires_grad = False
            
        dino_out_dim = self.dino.num_features

        # 2. Tabular Branch
        self.tabular_branch = nn.Sequential(
            nn.Linear(tabular_input_dim, 128),
            nn.BatchNorm1d(128),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(128, 64),
            nn.BatchNorm1d(64),
            nn.ReLU(),
        )

        # 3. Fusion Head
        self.head = nn.Sequential(
            nn.Linear(dino_out_dim + 64, 256),
            nn.BatchNorm1d(256),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(256, target_dim)
        )
        
        # Huber Loss is robust to outliers
        self.criterion = nn.HuberLoss(delta=1.0)

    def forward(self, x_tab, x_img):
        # DINO inference (no grad required for backbone)
        img_feat = self.dino(x_img)
        
        tab_feat = self.tabular_branch(x_tab)
        combined = torch.cat([img_feat, tab_feat], dim=1)
        return self.head(combined)

    def training_step(self, batch, batch_idx):
        (x_tab, x_img), y = batch
        y_pred = self(x_tab, x_img)
        loss = self.criterion(y_pred, y)
        self.log('train_loss', loss, prog_bar=True)
        return loss

    def validation_step(self, batch, batch_idx):
        (x_tab, x_img), y = batch
        y_pred = self(x_tab, x_img)
        loss = self.criterion(y_pred, y)
        self.log('val_loss', loss, prog_bar=True)
        return loss

    def configure_optimizers(self):
        optimizer = torch.optim.AdamW(self.parameters(), lr=CFG.LR, weight_decay=CFG.WEIGHT_DECAY)
        scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
            optimizer, mode='min', factor=0.5, patience=3
        )
        return {
            "optimizer": optimizer,
            "lr_scheduler": {"scheduler": scheduler, "monitor": "val_loss"}
        }

In [None]:
#========================================= 
# # 5.1 COMPUTING WEIGHTED R2 CALLBACK 
# #========================================= 
class WeightedR2Callback(pl.Callback): 
    def __init__(self, val_loader, processor, target_names): 
        super().__init__() 
        self.val_loader = val_loader 
        self.processor = processor 
        self.target_names = target_names 

    def compute_weighted_r2(self, y_true, y_pred, target_names): 
        r2s = [] 
        weights = [] 
        for i, name in enumerate(target_names): 
            r2 = r2_score(y_true[:, i], y_pred[:, i]) 
            r2s.append(r2) 
            weights.append(CFG.TARGET_WEIGHTS.get(name, 0)) 
        r2s = np.array(r2s) 
        weights = np.array(weights) 
        return np.average(r2s, weights=weights) 

    def on_validation_epoch_end(self, trainer, pl_module): 
        pl_module.eval() 
        preds, targets = [], [] 
        with torch.no_grad(): 
            for (x_tab, x_img), y in self.val_loader: 
                x_tab = x_tab.to(pl_module.device) 
                x_img = x_img.to(pl_module.device) 
                y = y.to(pl_module.device) 
                y_pred = pl_module(x_tab, x_img) 
                preds.append(y_pred.cpu().numpy()) 
                targets.append(y.cpu().numpy()) 
        y_pred = np.concatenate(preds, axis=0) 
        y_true = np.concatenate(targets, axis=0) 
        y_pred_orig = self.processor.inverse_transform_targets(y_pred) 
        y_true_orig = self.processor.inverse_transform_targets(y_true) 
        w_r2 = self.compute_weighted_r2(y_true_orig, y_pred_orig, self.target_names) 
        trainer.log_metrics({"w_r2": w_r2}, step=trainer.global_step) 
        trainer.callback_metrics["w_r2"] = torch.tensor(w_r2)

In [None]:
# ==========================================
# 6. MAIN TRAINING LOOP (K-FOLD)
# ==========================================
if __name__ == "__main__":
    # Load Data
    df = load_and_preprocess_data(CFG.CSV_PATH_TRAIN)

    # Setup K-Fold
    kfold = KFold(n_splits=CFG.N_FOLDS, shuffle=True, random_state=CFG.SEED)

    print(f"\nStarting training with {CFG.MODEL_NAME}...")

    for fold, (train_idx, val_idx) in enumerate(kfold.split(df)):
        print(f"\n{'='*20} FOLD {fold+1}/{CFG.N_FOLDS} {'='*20}")

        # 1. Split Data
        train_df = df.iloc[train_idx].reset_index(drop=True)
        val_df = df.iloc[val_idx].reset_index(drop=True)

        # 2. Preprocessing (Fit on Train, Transform Val)
        processor = BiomassPreprocessor()
        X_train_tab = processor.fit_transform(train_df)
        X_val_tab = processor.transform(val_df)
        y_train = processor.transform_targets(train_df)
        y_val = processor.transform_targets(val_df)

        # 3. Create Datasets & Loaders
        train_ds = BiomassDataset(X_train_tab, train_df['full_image_path'].values, y_train, transform=train_transforms)
        val_ds = BiomassDataset(X_val_tab, val_df['full_image_path'].values, y_val, transform=val_transforms)

        train_loader = DataLoader(train_ds, batch_size=CFG.BATCH_SIZE, shuffle=True, num_workers=4, pin_memory=True)
        val_loader = DataLoader(val_ds, batch_size=CFG.BATCH_SIZE, shuffle=False, num_workers=4, pin_memory=True)

        # 4. Initialize Model
        model = BiomassModel(tabular_input_dim=X_train_tab.shape[1], target_dim=len(CFG.TARGET_COLS))

        # 5. Setup Trainer
        checkpoint_callback = ModelCheckpoint(
            dirpath=f'checkpoints/fold_{fold}',
            filename='best-{epoch:02d}-{w_r2:.4f}',
            monitor='w_r2',   # monitor weighted RÂ²
            mode='max',
            save_top_k=1
        )
        early_stop = EarlyStopping(monitor='w_r2', patience=5, mode='max')

        weighted_r2_cb = WeightedR2Callback(val_loader, processor, CFG.TARGET_COLS)

        trainer = pl.Trainer(
            max_epochs=CFG.EPOCHS,
            accelerator="auto",
            devices=1,
            precision="16-mixed",
            callbacks=[checkpoint_callback, early_stop, weighted_r2_cb],
            logger=False,
            enable_progress_bar=True
        )

        # 6. Train
        trainer.fit(model, train_loader, val_loader)

        # 7. Compute Weighted R^2 Score on Validation Set
        val_preds = []
        val_targets = []
        model.eval()
        with torch.no_grad():
            for (x_tab, x_img), y in val_loader:
                y_pred = model(x_tab.to(model.device), x_img.to(model.device))
                val_preds.append(y_pred.cpu().numpy())
                val_targets.append(y.cpu().numpy())

        val_preds = np.concatenate(val_preds, axis=0)
        val_targets = np.concatenate(val_targets, axis=0)

        # Inverse transform back to original biomass units
        val_preds_orig = processor.inverse_transform_targets(val_preds)
        val_targets_orig = processor.inverse_transform_targets(val_targets)

        # Weighted RÂ²
        w_r2 = WeightedR2Callback.compute_weighted_r2(val_targets_orig, val_preds_orig, CFG.TARGET_COLS)
        print(f"Fold {fold+1} Complete. Best Weighted RÂ²: {w_r2:.4f}")

    print("\nCross-Validation Completed Successfully!")
