<div style="background-color:#F0F8FF; border-radius:15px; border:1px solid #4682B4; padding: 25px;">
    <div style="text-align:center;">
        <h1 style="color:#2E8B57; font-family:sans-serif; font-weight:800; font-size:32px;">CSIRO Inference: Pure Vision SOTA</h1>
        <h3 style="color:#4682B4; font-family:sans-serif; font-weight:600;">ViT-Large + Tiling + TTA (No Metadata Leakage)</h3>
    </div>
</div>

In [None]:
# Libraries
import sys
import os
import gc
import numpy as np
import pandas as pd
from pathlib import Path
from tqdm.auto import tqdm
from PIL import Image
import cv2

# Deep Learning
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import pytorch_lightning as pl
import albumentations as A
from albumentations.pytorch import ToTensorV2
import timm
from sklearn.preprocessing import RobustScaler, MinMaxScaler

# CONFIG
class CFG:
    # IMPORTANT: Update this path to where you uploaded your .ckpt files
    MODEL_DIR = "/kaggle/input/csiro-models-film-tiling" 
    
    # Competition Data Paths
    TEST_CSV = "/kaggle/input/csiro-biomass/test.csv"
    TEST_DIR = "/kaggle/input/csiro-biomass/test"
    TRAIN_CSV = "/kaggle/input/csiro-biomass/train.csv"
    
    # Model Architecture
    MODEL_NAME = "vit_large_patch14_dinov2.lvd142m"
    
    # Tiling Settings
    TILE_SIZE = 518       
    IMG_LOAD_SIZE = 1036
    
    BATCH_SIZE = 8
    NUM_WORKERS = 2
    
    # Targets
    TARGET_COLS = ['Dry_Total_g', 'GDM_g', 'Dry_Green_g', 'Dry_Dead_g', 'Dry_Clover_g']
    AUX_COLS = ['Height_Ave_cm', 'Pre_GSHH_NDVI'] 

## Model Classes

In [None]:
class TiledDINO(pl.LightningModule):
    """
    PyTorch Lightning Module for Tiled DINO model.
    """
    def __init__(self):
        super().__init__()
        # Load Backbone (Pretrained=False for Offline)
        self.backbone = timm.create_model(
            CFG.MODEL_NAME, 
            pretrained=False, 
            num_classes=0,
            img_size=CFG.TILE_SIZE 
        )
        
        embed_dim = self.backbone.num_features
        
        # Main Head
        self.head_main = nn.Sequential(
            nn.Linear(embed_dim, 512),
            nn.BatchNorm1d(512),
            nn.SiLU(),
            nn.Dropout(0.1),
            nn.Linear(512, 5)
        )
        
        # Auxiliary Head
        self.head_aux = nn.Sequential(
            nn.Linear(embed_dim, 256),  
            nn.BatchNorm1d(256),        
            nn.SiLU(),                  
            nn.Linear(256, 2)           
        )
    
    # Forward Pass
    def forward(self, x):
        # x: [Batch, 4, 3, 518, 518]
        b, n_tiles, c, h, w = x.shape
        x_flat = x.view(b * n_tiles, c, h, w)
        
        feats = self.backbone(x_flat) 
        feats = feats.view(b, n_tiles, -1)
        global_feat = torch.mean(feats, dim=1) 
        
        return self.head_main(global_feat), self.head_aux(global_feat)

# TTA Prediction Function
def predict_tta(model_list, img_batch, meta_batch):
    # Pass meta_batch to the forward function combining the predictions across all models (ensemble)
    return np.mean([m(img_batch, meta_batch)[0].cpu().numpy() for m in model_list], axis=0)

## Data Handling & Processing

In [None]:
class TiledDataset(Dataset):
    """
    Dataset for loading images and creating tiles.
    """
    def __init__(self, df, transforms=None):
        self.df = df.reset_index(drop=True)
        self.paths = df['full_path'].values
        self.transforms = transforms

    def __len__(self):
        return len(self.df)
    
    # Create 4 tiles from the image
    def tile_image(self, image):
        H, W = CFG.IMG_LOAD_SIZE, CFG.IMG_LOAD_SIZE
        image = cv2.resize(image, (W, H))
        
        t_sz = CFG.TILE_SIZE
        tiles = []
        for i in range(2):
            for j in range(2):
                tile = image[i*t_sz:(i+1)*t_sz, j*t_sz:(j+1)*t_sz, :]
                tiles.append(tile)
        return tiles

    def __getitem__(self, idx):
        path = self.paths[idx]
        try:
            image = np.array(Image.open(path).convert("RGB"))
        except:
            image = np.zeros((CFG.IMG_LOAD_SIZE, CFG.IMG_LOAD_SIZE, 3), dtype=np.uint8)
            
        tiles = self.tile_image(image)
        
        processed_tiles = []
        for t in tiles:
            if self.transforms:
                t = self.transforms(image=t)['image']
            processed_tiles.append(t)
            
        img_tensor = torch.stack(processed_tiles)
        return img_tensor

# Data Transforms
def get_transforms():
    return A.Compose([
        A.Resize(CFG.TILE_SIZE, CFG.TILE_SIZE),
        A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ToTensorV2(),
    ])

# Fit Scaler on Training Data
def get_processor():
    print("Fitting Scaler on Training Data...")
    train_df = pd.read_csv(CFG.TRAIN_CSV)
    train_df['image_id'] = train_df['sample_id'].str.split('__').str[0]
    
    # Pivot logic to match training
    meta_cols = ['image_path', 'Sampling_Date']
    df_wide = train_df.pivot_table(
        index=['image_id'] + meta_cols,
        columns='target_name',
        values='target',
        aggfunc='first'
    ).reset_index()
    
    main_scaler = RobustScaler()
    main_scaler.fit(df_wide[CFG.TARGET_COLS])
    print("Scaler fitted.")
    return main_scaler

# TTA Prediction Function without meta_batch
def predict_tta(model_list, img_batch):
    return np.mean([m(img_batch)[0].cpu().numpy() for m in model_list], axis=0)

## Main Inference Pipeline
Includes **Metadata Encoding** setup and **TTA Loop**.

In [None]:
print("Starting Pure Vision Inference...")

# Load Test Data
test_df = pd.read_csv(CFG.TEST_CSV)
test_df['image_id'] = test_df['sample_id'].str.split('__').str[0]

# Detect Extension
try:
    sample_file = [f for f in os.listdir(CFG.TEST_DIR) if not f.startswith('.')][0]
    ext = os.path.splitext(sample_file)[1]
except:
    ext = ".png"

test_df['full_path'] = test_df['image_id'].apply(lambda x: os.path.join(CFG.TEST_DIR, f"{x}{ext}"))

# Unique images only
df_unique = test_df[['image_id', 'full_path']].drop_duplicates(subset=['image_id']).reset_index(drop=True)
test_ds = TiledDataset(df_unique, get_transforms())
loader = DataLoader(test_ds, batch_size=CFG.BATCH_SIZE, shuffle=False, num_workers=CFG.NUM_WORKERS)

# Load Models
models = []
ckpt_files = list(Path(CFG.MODEL_DIR).rglob("*.ckpt"))
print(f"Found {len(ckpt_files)} checkpoints.")
for p in ckpt_files:
    try:
        model = TiledDINO.load_from_checkpoint(str(p))
        model.eval().cuda()
        models.append(model)
        print(f"Loaded {p.name}")
    except Exception as e:
        print(f"Skipping {p.name}: {e}")
if not models:
    raise RuntimeError("No models loaded! Check CFG.MODEL_DIR")



# Inference with TTA
print("Running Inference (x5 TTA)...")
processor = get_processor()
all_preds = []

with torch.no_grad():
    for i, img in tqdm(enumerate(loader), desc="Predicting"):
        img = img.cuda()

        # TTA 1: Original
        p1 = predict_tta(models, img)
        
        # TTA 2: Horizontal Flip
        img_h = torch.flip(img, dims=[4]) 
        p2 = predict_tta(models, img_h)
        
        # TTA 3: Vertical Flip
        img_v = torch.flip(img, dims=[3]) 
        p3 = predict_tta(models, img_v)
        
        # TTA 4: Rotate 90 degrees
        img_r1 = torch.rot90(img, 1, [3, 4])
        p4 = predict_tta(models, img_r1)
        
        # TTA 5: Rotate 270 degrees
        img_r3 = torch.rot90(img, 3, [3, 4])
        p5 = predict_tta(models, img_r3)

        # Average all 5 TTA predictions
        batch_preds = (p1 + p2 + p3 + p4 + p5) / 5.0
        all_preds.append(batch_preds)

# Post-Processing
all_preds = np.vstack(all_preds)
all_preds = processor.inverse_transform(all_preds) # Scale back to grams
all_preds = np.maximum(0, all_preds) # Clip negatives

# Write Submission
print("Writing Submission...")
pred_map = {img_id: preds for img_id, preds in zip(df_unique['image_id'].values, all_preds)}
target_idx = {name: i for i, name in enumerate(CFG.TARGET_COLS)}

final_values = []
sub_template = pd.read_csv(CFG.TEST_CSV)
for _, row in sub_template.iterrows():
    img_id = row['sample_id'].split('__')[0]
    t_name = row['target_name']
    
    if img_id in pred_map:
        val = pred_map[img_id][target_idx[t_name]]
        final_values.append(val)
    else:
        final_values.append(0.0)

sub_template['target'] = final_values
sub_template[['sample_id', 'target']].to_csv("submission.csv", index=False)
print("Done! submission.csv created successfully.")