In [None]:
# ============================================================================
# 03_training_refactored.ipynb - Weather Nowcasting with ConvLSTM
# Implements Lead Architect Fixes: Data Integrity, Speed, Robust Persistence
# ============================================================================

import os, gc, glob, time, logging, shutil, re
from typing import Tuple, List, Optional
from pathlib import Path

import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import IterableDataset, DataLoader
import matplotlib.pyplot as plt
from tqdm.auto import tqdm
import psutil

# ============================================================================
# 1. REPRODUCIBILITY & CONFIGURATION
# ============================================================================
RANDOM_SEED = 42
def seed_everything(seed=42):
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False # Strict reproducibility

seed_everything(RANDOM_SEED)
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
print('✅ Reproducibility seeds set.')


In [None]:
class Config:
    # Paths
    WORK_DIR = Path('/content/weather_nowcasting')
    DATA_DIR = WORK_DIR / 'Dataset'
    DATA_2025_DIR = WORK_DIR / '2025_data'
    BATCHED_DIR = WORK_DIR / 'data' / 'batched'
    
    # Local Checkpoints (High Speed I/O)
    LOCAL_CKPT_DIR = WORK_DIR / 'checkpoints'
    
    # Drive Persistence
    DRIVE_MOUNT = '/content/drive'
    DRIVE_ROOT = Path('/content/drive/MyDrive/WeatherPaper')
    DRIVE_CKPT_DIR = DRIVE_ROOT / 'checkpoints'
    
    # Data
    T_IN = 24
    T_OUT = 6
    STRIDE = 12
    VARIABLES = ['tp', 't2m']
    
    # Model
    IN_CHANNELS = 2
    HIDDEN_DIM = 128
    OUT_CHANNELS = 2
    N_LAYERS = 4
    
    # Training
    BATCH_SIZE = 32
    NUM_WORKERS = 2
    LEARNING_RATE = 1e-4
    WEIGHT_DECAY = 1e-5
    NUM_EPOCHS = 100
    PATIENCE = 20
    MAX_GRAD_NORM = 1.0
    PRECIP_WEIGHT = 2.0
    
CFG = Config()

# Create Directories
for p in [CFG.WORK_DIR, CFG.DATA_DIR, CFG.DATA_2025_DIR, CFG.BATCHED_DIR, CFG.LOCAL_CKPT_DIR]:
    p.mkdir(parents=True, exist_ok=True)


In [None]:
# ============================================================================
# 2. DEVICE & DRIVE SETUP
# ============================================================================
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Device: {device} | GPU: {torch.cuda.get_device_name(0) if device.type=='cuda' else 'N/A'}")

# --- DRIVE MOUNT (ROBUST) ---
# We define the save directory. We default to local (fast).
# We will copy to Drive at the end or asynchronously.
SAVE_DIR = CFG.LOCAL_CKPT_DIR 
DRIVE_AVAILABLE = False

try:
    from google.colab import drive
    print("Mounting Google Drive...")
    drive.mount(CFG.DRIVE_MOUNT, force_remount=True)
    
    # Verify and Create Drive Folders
    if os.path.exists(CFG.DRIVE_MOUNT):
        CFG.DRIVE_CKPT_DIR.mkdir(parents=True, exist_ok=True)
        DRIVE_AVAILABLE = True
        print(f"✅ Drive Mounted. Backups will sync to: {CFG.DRIVE_CKPT_DIR}")
    else:
        print("⚠️ Drive mount point not found.")
except ImportError:
    print("⚠️ Not in Colab. Using local checkpoints.")
except Exception as e:
    print(f"⚠️ Drive Mount Failed: {e}. Training locally (checkpoints will be lost if runtime disconnects).")

print(f"💾 Primary Save Directory (Local): {SAVE_DIR}")


In [None]:
# ============================================================================
# 3. DOWNLOAD DATA (GITHUB)
# ============================================================================
import requests, zipfile, io

def download_github_repo():
    url = 'https://github.com/ui07xWizardOp/paper-weather-nowcasting/archive/refs/heads/main.zip'
    print(f'Downloading repository: {url}')
    r = requests.get(url, stream=True)
    if r.status_code != 200: raise Exception(f"Download failed: {r.status_code}")
    
    z = zipfile.ZipFile(io.BytesIO(r.content))
    print('Extracting NetCDF files...')
    for file in tqdm(z.namelist(), desc='Extracting'):
        if file.endswith('.nc'):
            if '/Dataset/' in file:
                with open(CFG.DATA_DIR / os.path.basename(file), 'wb') as f: f.write(z.read(file))
            elif '/2025 data/' in file:
                with open(CFG.DATA_2025_DIR / os.path.basename(file), 'wb') as f: f.write(z.read(file))
    print('✅ Download complete.')

if not list(CFG.DATA_DIR.glob('*.nc')) or not list(CFG.DATA_2025_DIR.glob('*.nc')):
    download_github_repo()
else:
    print('✅ Dataset exists.')


In [None]:
# ============================================================================
# 4. PREPROCESSING (STREAMING) - FIXED STAT LOADING
# ============================================================================
import xarray as xr

# --- HELPERS ---
def load_single_file(filepath):
    for engine in ['netcdf4', 'h5netcdf', 'scipy', None]:
        try:
            ds = xr.open_dataset(filepath, engine=engine) if engine else xr.open_dataset(filepath)
            break
        except: continue
    else: return None
    
    # Coordinate Standardization
    if 'valid_time' in ds.coords: ds = ds.rename({'valid_time': 'time'})
    if 'expver' in ds.dims: ds = ds.isel(expver=0, drop=True)
    if 'expver' in ds.coords: ds = ds.drop_vars('expver', errors='ignore')
    if 'number' in ds.coords: ds = ds.drop_vars('number', errors='ignore')
    return ds

# --- MAIN PREPROCESSING LOGIC ---
stats_path = CFG.BATCHED_DIR / 'stats.npz'
reprocess = not (CFG.BATCHED_DIR / 'train').exists() or not list((CFG.BATCHED_DIR / 'train').glob('X_batch_*.npy'))

if reprocess:
    print('🚀 Starting Preprocessing...')
    for split in ['train', 'val', 'test']: (CFG.BATCHED_DIR / split).mkdir(parents=True, exist_ok=True)
    
    files = sorted(list(CFG.DATA_DIR.glob('*.nc')) + list(CFG.DATA_2025_DIR.glob('*.nc')))
    
    # PASS 1: Stats
    print('Pass 1: Computing Statistics...')
    train_values = []
    TRAIN_YEARS = range(2015, 2022)
    
    for f in tqdm(files, desc='Stats'):
        ds = load_single_file(f)
        if ds is None: continue
        try:
            year = pd.to_datetime(ds.time.values[0]).year
            if year in TRAIN_YEARS:
                # Extract and Transform
                tp = np.log1p(np.maximum(ds['tp'].values if 'tp' in ds else ds['total_precipitation'].values, 0))
                t2m = ds['t2m'].values if 't2m' in ds else ds['2m_temperature'].values
                data = np.stack([tp, t2m], axis=-1)
                train_values.append(data[::24]) # Subsample
        except Exception as e: print(f"Skip {f.name}: {e}")
        finally: ds.close()
            
    train_sample = np.concatenate(train_values, axis=0)
    mean = np.nanmean(train_sample, axis=(0, 1, 2))
    std = np.nanstd(train_sample, axis=(0, 1, 2))
    std[std < 1e-6] = 1.0
    np.savez(CFG.BATCHED_DIR / 'stats.npz', mean=mean, std=std)
    print(f'Stats — Mean: {mean}, Std: {std}')
    del train_values, train_sample; gc.collect()
    
    # PASS 2: Sequences
    print('Pass 2: Creating Sequences...')
    buffers = {'train': ([], []), 'val': ([], []), 'test': ([], [])}
    counts = {'train': 0, 'val': 0, 'test': 0}
    
    def flush(split):
        xs, ys = buffers[split]
        if len(xs) >= 500:
            X = np.stack(xs[:500])
            Y = np.stack(ys[:500])
            np.save(CFG.BATCHED_DIR / split / f'X_batch_{counts[split]:04d}.npy', X)
            np.save(CFG.BATCHED_DIR / split / f'Y_batch_{counts[split]:04d}.npy', Y)
            counts[split] += 1
            buffers[split] = (xs[500:], ys[500:])

    for f in tqdm(files, desc='Processing'):
        ds = load_single_file(f)
        if ds is None: continue
        
        year = pd.to_datetime(ds.time.values[0]).year
        # Extract and Transform
        tp = np.log1p(np.maximum(ds['tp'].values if 'tp' in ds else ds['total_precipitation'].values, 0))
        t2m = ds['t2m'].values if 't2m' in ds else ds['2m_temperature'].values
        data = np.stack([tp, t2m], axis=-1)
        
        # Normalize
        data = (data - mean) / std
        data = np.nan_to_num(data, nan=0.0) # CRITICAL FIX: Sanitize NaNs (0 = Mean)
        
        split = 'train' if year in TRAIN_YEARS else 'val' if year in range(2022, 2024) else 'test'
        
        # Sequence Creation
        for i in range(CFG.T_IN, len(data) - CFG.T_OUT + 1, CFG.STRIDE):
            buffers[split][0].append(data[i-CFG.T_IN:i])
            buffers[split][1].append(data[i:i+CFG.T_OUT])
            flush(split)
        ds.close()
    
    # Final Flush
    for split in ['train', 'val', 'test']:
        if buffers[split][0]:
            X = np.stack(buffers[split][0])
            Y = np.stack(buffers[split][1])
            np.save(CFG.BATCHED_DIR / split / f'X_batch_{counts[split]:04d}.npy', X)
            np.save(CFG.BATCHED_DIR / split / f'Y_batch_{counts[split]:04d}.npy', Y)
    print('✅ Preprocessing Complete.')

# CRITICAL FIX: ALWAYS LOAD STATS
print("Loading normalization statistics...")
if stats_path.exists():
    stats = np.load(stats_path)
    mean = stats['mean']
    std = stats['std']
    print(f"Loaded Stats — Mean: {mean}, Std: {std}")
else:
    # CRITICAL FIX: Do not allow training without valid stats
    raise FileNotFoundError(f"❌ stats.npz not found at {stats_path}. Run preprocessing first!")


In [None]:
# ============================================================================
# 5. DATALOADER (FIXED DATA LEAK)
# ============================================================================
def get_batch_files(split):
    split_dir = CFG.BATCHED_DIR / split
    x_files = sorted(list(split_dir.glob('X_batch_*.npy')))
    y_files = sorted(list(split_dir.glob('Y_batch_*.npy')))
    return x_files, y_files

class BatchedWeatherDataset(IterableDataset):
    def __init__(self, split, shuffle=False):
        self.split = split
        self.shuffle = shuffle
        self.x_files, self.y_files = get_batch_files(split)
        
    def __iter__(self):
        worker_info = torch.utils.data.get_worker_info()
        x_files = list(self.x_files)
        y_files = list(self.y_files)
        
        if worker_info is not None:
            per_worker = int(np.ceil(len(x_files) / float(worker_info.num_workers)))
            iter_start = worker_info.id * per_worker
            iter_end = min(iter_start + per_worker, len(x_files))
            x_files = x_files[iter_start:iter_end]
            y_files = y_files[iter_start:iter_end]
            
        file_indices = list(range(len(x_files)))
        if self.shuffle: np.random.shuffle(file_indices)
        
        for idx in file_indices:
            try:
                X = np.load(x_files[idx], mmap_mode='r')
                Y = np.load(y_files[idx], mmap_mode='r')
                n = len(X)
                indices = np.random.permutation(n) if self.shuffle else np.arange(n)
                
                for start in range(0, n, CFG.BATCH_SIZE):
                    batch_idx = indices[start:start+CFG.BATCH_SIZE]
                    
                    # CRITICAL FIX: DO NOT DROP PARTIAL BATCHES
                    # if len(batch_idx) < CFG.BATCH_SIZE and self.split == 'train': continue 
                    # Removed the above line to prevent data loss
                    
                    x_np = X[batch_idx]
                    y_np = Y[batch_idx]
                    
                    if not x_np.flags['C_CONTIGUOUS']: x_np = np.ascontiguousarray(x_np)
                    if not y_np.flags['C_CONTIGUOUS']: y_np = np.ascontiguousarray(y_np)
                    
                    # (B, T, H, W, C) -> (B, T, C, H, W)
                    x = torch.from_numpy(x_np).float().permute(0, 1, 4, 2, 3)
                    y = torch.from_numpy(y_np).float().permute(0, 1, 4, 2, 3)
                    yield x, y
            except Exception as e:
                 print(f'Error loading file {x_files[idx]}: {e}')

def get_loader(split, shuffle=False):
    ds = BatchedWeatherDataset(split, shuffle=shuffle)
    return DataLoader(ds, batch_size=None, num_workers=CFG.NUM_WORKERS, pin_memory=True)


In [None]:
# ============================================================================
# 6. MODEL ARCHITECTURE
# ============================================================================
class ConvLSTMCell(nn.Module):
    def __init__(self, input_dim, hidden_dim, kernel_size):
        super().__init__()
        self.hidden_dim = hidden_dim
        self.conv = nn.Conv2d(input_dim + hidden_dim, 4 * hidden_dim, kernel_size, padding=kernel_size//2)
    
    def forward(self, x, hidden):
        h, c = hidden
        gates = self.conv(torch.cat([x, h], dim=1))
        i, f, o, g = torch.chunk(gates, 4, dim=1)
        c_next = torch.sigmoid(f) * c + torch.sigmoid(i) * torch.tanh(g)
        h_next = torch.sigmoid(o) * torch.tanh(c_next)
        return h_next, c_next
    
    def init_hidden(self, B, H, W, device):
        return (torch.zeros(B, self.hidden_dim, H, W, device=device),
                torch.zeros(B, self.hidden_dim, H, W, device=device))

class WeatherNowcaster(nn.Module):
    def __init__(self, in_channels, hidden_dim, out_channels, n_layers=4):
        super().__init__()
        self.n_layers = n_layers
        self.encoder = nn.ModuleList([
            ConvLSTMCell(in_channels if i==0 else hidden_dim, hidden_dim, 3)
            for i in range(n_layers)
        ])
        self.decoder = nn.ModuleList([
            ConvLSTMCell(out_channels if i==0 else hidden_dim, hidden_dim, 3)
            for i in range(n_layers)
        ])
        self.out_conv = nn.Conv2d(hidden_dim, out_channels, 3, padding=1)
        
    def forward(self, x, future_steps):
        B, T, C, H, W = x.shape
        hidden = [cell.init_hidden(B, H, W, x.device) for cell in self.encoder]
        
        # Encode
        for t in range(T):
            inp = x[:, t]
            for i, cell in enumerate(self.encoder):
                h, c = cell(inp, hidden[i])
                hidden[i] = (h, c)
                inp = h
        
        # Decode
        dec_hidden = [(h.clone(), c.clone()) for h, c in hidden]
        outputs = []
        dec_inp = self.out_conv(dec_hidden[-1][0]) # Initial projection
        
        for _ in range(future_steps):
            inp = dec_inp
            for i, cell in enumerate(self.decoder):
                h, c = cell(inp, dec_hidden[i])
                dec_hidden[i] = (h, c)
                inp = h
            
            dec_out = self.out_conv(inp)
            outputs.append(dec_out)
            dec_inp = dec_out # Autoregressive
            
        return torch.stack(outputs, dim=1)

# FIX: Move this function OUTSIDE the class
def init_weights(m):
    if isinstance(m, nn.Conv2d):
        # Xavier/Glorot initialization is generally safer for Sigmoid/Tanh gates
        nn.init.xavier_uniform_(m.weight)
        if m.bias is not None: nn.init.zeros_(m.bias)


In [None]:
# ============================================================================
# 8. LOSS & METRICS (PHYSICALLY GROUNDED & STABILIZED)
# ============================================================================
class WeightedMSE(nn.Module):
    def __init__(self, mean, std, mm_threshold=1.0, nonzero_weight=3.0):
        super().__init__()
        self.nonzero_weight = nonzero_weight
        
        # CRITICAL FIX: Convert physical threshold (mm) to Z-score space
        # Formula: z = (log1p(x) - mean) / std
        # We assume mean/std are for the precip channel (index 0)
        log_thresh = np.log1p(mm_threshold)
        # Ensure we don't divide by zero if stats are weird
        safe_std = std[0] if std[0] > 1e-6 else 1.0
        self.z_threshold = (log_thresh - mean[0]) / safe_std
        print(f"WeightedMSE: Threshold {mm_threshold}mm -> Z-score {self.z_threshold:.4f}")

    def forward(self, pred, target):
        # FIX 1: Force float32 for stability
        pred = pred.float()
        target = target.float()
        
        # 1. Handle Precipitation (Channel 0) with Weights
        target_tp = target[:, :, 0]
        pred_tp = pred[:, :, 0]
        
        # FIX 2: Apply weight ONLY to precip and based on Z-score threshold
        # We use signed comparison (target_tp > z_threshold) because precip is one-sided
        weight = torch.where(target_tp > self.z_threshold, 
                             self.nonzero_weight, 
                             1.0) # Base weight is 1.0
        
        mse_tp = (weight * (pred_tp - target_tp)**2).mean()
        
        # 2. Handle Temperature (Channel 1) normally (No weights)
        mse_t2m = (pred[:, :, 1] - target[:, :, 1])**2 .mean()
        
        return mse_tp + mse_t2m

class SSIMLoss(nn.Module):
    def __init__(self, window_size=7, channels=2):
        super().__init__()
        self.window_size = window_size
        self.channels = channels
        sigma = 1.5
        coords = torch.arange(window_size).float() - window_size // 2
        g = torch.exp(-(coords**2) / (2*sigma**2))
        g = g / g.sum()
        self.window = (g.unsqueeze(1) @ g.unsqueeze(0)).unsqueeze(0).unsqueeze(0)
        self.window = self.window.expand(channels, 1, -1, -1).contiguous()
        
    def forward(self, pred, target):
        # FIX: CRITICAL - Force float32 to prevent -inf in float16 AMP
        pred = pred.float()
        target = target.float()
        
        if pred.dim() == 5:
            B, T, C, H, W = pred.shape
            pred = pred.reshape(B * T, C, H, W)
            target = target.reshape(B * T, C, H, W)
            
        window = self.window.to(pred.device)
        
        mu1 = F.conv2d(pred, window, padding=self.window_size//2, groups=self.channels)
        mu2 = F.conv2d(target, window, padding=self.window_size//2, groups=self.channels)
        
        mu1_sq = mu1.pow(2)
        mu2_sq = mu2.pow(2)
        mu1_mu2 = mu1 * mu2
        
        sigma1_sq = F.conv2d(pred*pred, window, padding=self.window_size//2, groups=self.channels) - mu1_sq
        sigma2_sq = F.conv2d(target*target, window, padding=self.window_size//2, groups=self.channels) - mu2_sq
        
        # Clamp variance to 0 (mathematically required)
        sigma1_sq = torch.clamp(sigma1_sq, min=0.0)
        sigma2_sq = torch.clamp(sigma2_sq, min=0.0)
        sigma12 = F.conv2d(pred*target, window, padding=self.window_size//2, groups=self.channels) - mu1_mu2
        
        C1 = (0.01 * 6)**2
        C2 = (0.03 * 6)**2
        
        numerator = (2 * mu1_mu2 + C1) * (2 * sigma12 + C2)
        denominator = (mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2)
        
        ssim_map = numerator / (denominator + 1e-7)
        
        # Clamp final result to [0, 2] range to be safe
        return torch.clamp(1 - ssim_map.mean(), min=0.0, max=2.0)

# Initialize Loss with Stats (Passed from preprocessing)
# We assume 'mean' and 'std' variables are available from Section 4
wmse = WeightedMSE(mean=mean, std=std, mm_threshold=1.0, nonzero_weight=CFG.PRECIP_WEIGHT)
ssim = SSIMLoss(channels=CFG.OUT_CHANNELS)
mae = nn.L1Loss()

def criterion(pred, target):
    return 0.5 * wmse(pred, target) + 0.3 * ssim(pred, target) + 0.2 * mae(pred, target)


In [None]:
# ============================================================================
# 9. TRAINING LOOP (ROBUST)
# ============================================================================
model = WeatherNowcaster(CFG.IN_CHANNELS, CFG.HIDDEN_DIM, CFG.OUT_CHANNELS, CFG.N_LAYERS).to(device)
model.apply(init_weights)
print(f'🧠 Model Parameters: {sum(p.numel() for p in model.parameters())/1e6:.2f} M')

from torch.amp import autocast, GradScaler
scaler = GradScaler('cuda')

optimizer = optim.Adam(model.parameters(), lr=CFG.LEARNING_RATE, weight_decay=CFG.WEIGHT_DECAY)
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=CFG.NUM_EPOCHS, eta_min=1e-6)

# Resume Logic
start_epoch = 0
train_losses, val_losses = [], []
resume_path = SAVE_DIR / 'last.pth'
if resume_path.exists():
    print(f"📂 Resuming from {resume_path}")
    try:
        ckpt = torch.load(resume_path, map_location=device)
        model.load_state_dict(ckpt['model'])
        optimizer.load_state_dict(ckpt['optimizer'])
        start_epoch = ckpt['epoch'] + 1
        train_losses = ckpt.get('train_losses', [])
        val_losses = ckpt.get('val_losses', [])
        print(f"▶️ Resumed at Epoch {start_epoch}")
    except Exception as e:
        print(f"⚠️ Checkpoint load failed: {e}. Starting fresh.")

train_loader = get_loader('train', shuffle=True)
val_loader = get_loader('val', shuffle=False)

best_val_loss = float('inf')
patience_counter = 0

print('🔥 Starting Training...')
for epoch in range(start_epoch, CFG.NUM_EPOCHS):
    model.train()
    ep_loss = 0
    count = 0
    nan_counter = 0  # Track bad batches
    
    n_train_files = len(list((CFG.BATCHED_DIR / 'train').glob('X_batch_*.npy')))
    total_batches = (n_train_files * 500) // CFG.BATCH_SIZE
    
    pbar = tqdm(train_loader, total=total_batches, desc=f'Epoch {epoch+1}/{CFG.NUM_EPOCHS} [Train]')
    
    for x, y in pbar:
        x, y = x.to(device, non_blocking=True), y.to(device, non_blocking=True)
        optimizer.zero_grad(set_to_none=True)
        
        # ENABLED AMP FOR SPEED
        with autocast('cuda'):
            pred = model(x, CFG.T_OUT)
            loss = criterion(pred, y)
            
        # Robust Safety Check (Catches both NaN and Inf)
        if not torch.isfinite(loss):
            nan_counter += 1
            if nan_counter <= 3: # Only print first 3 warnings
                print(f"⚠️ Non-finite loss detected! Skipping batch.")
            continue
            
        scaler.scale(loss).backward()
        scaler.unscale_(optimizer)
        torch.nn.utils.clip_grad_norm_(model.parameters(), CFG.MAX_GRAD_NORM)
        scaler.step(optimizer)
        scaler.update()
        
        ep_loss += loss.item()
        count += 1
        pbar.set_postfix({'loss': f'{loss.item():.4f}'})
    
    if nan_counter > 0:
        print(f"⚠️ Total skipped batches this epoch: {nan_counter}")

    train_loss = ep_loss / count if count > 0 else 0
    train_losses.append(train_loss)
    
    # Validation
    model.eval()
    val_loss_sum = 0
    val_count = 0
    with torch.no_grad():
        for x, y in tqdm(val_loader, desc=f'Epoch {epoch+1} [Val]', leave=False):
            x, y = x.to(device), y.to(device)
            with autocast('cuda'):
                pred = model(x, CFG.T_OUT)
                val_loss_sum += criterion(pred, y).item()
            val_count += 1
    
    val_loss = val_loss_sum / val_count if val_count > 0 else 0
    val_losses.append(val_loss)
    scheduler.step()
    
    print(f'Epoch {epoch+1} | Train: {train_loss:.4f} | Val: {val_loss:.4f}')
    
    # SAVE LOCALLY (Fast)
    state = {
        'epoch': epoch,
        'model': model.state_dict(),
        'optimizer': optimizer.state_dict(),
        'train_losses': train_losses,
        'val_losses': val_losses
    }
    torch.save(state, SAVE_DIR / 'last.pth')
    
    # BACKUP TO DRIVE
    if DRIVE_AVAILABLE:
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            patience_counter = 0
            torch.save(state, CFG.DRIVE_CKPT_DIR / 'best_model.pth')
            torch.save(state, SAVE_DIR / 'best_model.pth')
            print('  ★ New Best Model Synced to Drive!')
        else:
            patience_counter += 1
            
        if (epoch + 1) % 5 == 0:
             shutil.copy(SAVE_DIR / 'last.pth', CFG.DRIVE_CKPT_DIR / 'last_backup.pth')
        
        if patience_counter >= CFG.PATIENCE:
            print(f'\n⏹️ Early stopping triggered.')
            break

print('✅ Training Complete.')


In [None]:
# ============================================================================
# TRAINING PLOTS
# ============================================================================
fig, axes = plt.subplots(1, 2, figsize=(12, 5))

# Loss
axes[0].plot(train_losses, label='Train Loss')
axes[0].plot(val_losses, label='Val Loss')
axes[0].set_title('Loss Curves')
axes[0].legend()
axes[0].grid(True)
axes[0].set_xlabel('Epoch')
axes[0].set_ylabel('Loss')

plt.show()


In [None]:
# ============================================================================
# EVALUATION & VISUALIZATION (PAPER-READY)
# ============================================================================
print('Loading Best Model for Evaluation...')
# Logic: Try Drive Best -> Local Best -> Local Last
best_path = None
if DRIVE_AVAILABLE and (CFG.DRIVE_CKPT_DIR / 'best_model.pth').exists():
    best_path = CFG.DRIVE_CKPT_DIR / 'best_model.pth'
elif (SAVE_DIR / 'best_model.pth').exists():
    best_path = SAVE_DIR / 'best_model.pth'
else:
    best_path = SAVE_DIR / 'last.pth'

if best_path and best_path.exists():
    ckpt = torch.load(best_path, map_location=device)
    model.load_state_dict(ckpt['model'])
    print(f'Loaded model from {best_path} (Epoch {ckpt["epoch"]+1})')
else:
    print('⚠️ No model found for evaluation.')

model.eval()
test_loader = get_loader('test', shuffle=False)

# ----------------------------------------------------------------------------
# 1. QUANTITATIVE EVALUATION (METRICS vs BASELINE)
# ----------------------------------------------------------------------------
def compute_metrics(pred, target, threshold=0.5):
    pred_b = (pred > threshold).float()
    target_b = (target > threshold).float()
    
    hits = (pred_b * target_b).sum()
    misses = ((1-pred_b)*target_b).sum()
    false_alarms = (pred_b*(1-target_b)).sum()
    
    csi = hits / (hits + misses + false_alarms + 1e-8)
    pod = hits / (hits + misses + 1e-8)
    far = false_alarms / (hits + false_alarms + 1e-8)
    bias = (hits + false_alarms) / (hits + misses + 1e-8)
    
    return csi.item(), pod.item(), far.item(), bias.item()

print('Computing Metrics on Test Set...')
# Define thresholds in MM (Physical Space)
THRESHOLDS_MM = [0.5, 2.0, 5.0] 
results = {thresh: {'csi': [], 'pod': [], 'far': [], 'bias': []} for thresh in THRESHOLDS_MM}
baseline_results = {thresh: {'csi': [], 'pod': [], 'far': [], 'bias': []} for thresh in THRESHOLDS_MM}
mse_scores = []
baseline_mse_scores = []

# Stats for denormalization
tp_mean = mean[0]
tp_std = std[0]

with torch.no_grad():
    for x, y in tqdm(test_loader, desc='Evaluating'):
        x, y = x.to(device), y.to(device)
        
        # Model Prediction
        with torch.amp.autocast('cuda'):
             pred = model(x, CFG.T_OUT)
        
        pred = pred.float()
        
        # ---------------------------------------------------------
        # CRITICAL FIX: DENORMALIZE TO PHYSICAL SPACE (MM)
        # ---------------------------------------------------------
        # Z-score inverse: val = z * std + mean
        # Log1p inverse: mm = exp(val) - 1
        
        # 1. Denormalize Prediction
        pred_z = pred[:, :, 0] # Channel 0 is TP
        pred_denorm = pred_z * tp_std + tp_mean
        pred_mm = torch.expm1(pred_denorm) # Inverse of log1p
        pred_mm = torch.clamp(pred_mm, min=0.0) # Rain cannot be negative
        
        # 2. Denormalize Target
        target_z = y[:, :, 0]
        target_denorm = target_z * tp_std + tp_mean
        target_mm = torch.expm1(target_denorm)
        
        # 3. Denormalize Baseline (Persistence)
        # Baseline is just the last frame of input sequence x
        # x is still normalized. We need to denormalize it.
        last_frame_z = x[:, -1, 0] # Shape (B, H, W)
        last_frame_denorm = last_frame_z * tp_std + tp_mean
        persistence_mm = torch.expm1(last_frame_denorm)
        persistence_mm = torch.clamp(persistence_mm, min=0.0)
        # Repeat for T_OUT steps
        persistence_mm = persistence_mm.unsqueeze(1).repeat(1, CFG.T_OUT, 1, 1)
        
        mse_scores.append(F.mse_loss(pred_mm, target_mm).item())
        baseline_mse_scores.append(F.mse_loss(persistence_mm, target_mm).item())
        
        # Categorical Metrics (Now using MM thresholds)
        for thresh in THRESHOLDS_MM:
            # Model
            c, p, f, b = compute_metrics(pred_mm, target_mm, threshold=thresh)
            results[thresh]['csi'].append(c)
            results[thresh]['pod'].append(p)
            results[thresh]['far'].append(f)
            results[thresh]['bias'].append(b)
            
            # Baseline
            bc, bp, bf, bb = compute_metrics(persistence_mm, target_mm, threshold=thresh)
            baseline_results[thresh]['csi'].append(bc)
            baseline_results[thresh]['pod'].append(bp)
            baseline_results[thresh]['far'].append(bf)
            baseline_results[thresh]['bias'].append(bb)

print('\n' + '='*60)
print('TEST SET RESULTS (Physical Space - mm/hr)')
print('='*60)
print(f'MSE (mm/hr^2) | Model: {np.mean(mse_scores):.4f} | Persistence: {np.mean(baseline_mse_scores):.4f}')
print('-'*60)
print(f'{ "Threshold":<10} | { "Metric":<5} | { "Model":<10} | { "Persistence":<10} | { "Improvement":<10}')
print('-'*60)
for t in THRESHOLDS_MM:
    for m in ['csi', 'pod', 'far']:
        score = np.mean(results[t][m])
        base = np.mean(baseline_results[t][m])
        imp = (score - base) if m != 'far' else (base - score)
        print(f'{t:<10} | {m.upper():<5} | {score:.4f}     | {base:.4f}           | {imp:+.4f}')
    print('-'*60)

# ----------------------------------------------------------------------------
# 2. QUALITATIVE VISUALIZATION (ERROR MAPS)
# ----------------------------------------------------------------------------
def visualize_prediction(x, y, pred, idx=0):
    """Show Input, Target, Pred, Error"""
    fig, axes = plt.subplots(4, 7, figsize=(20, 10))
    
    # Input (Stats Denorm)
    for i, t in enumerate([21, 22, 23]):
        ax = axes[0, i]
        val = x[idx, t, 0].cpu().float() * tp_std + tp_mean
        val = torch.expm1(val).clamp(0)
        ax.imshow(val, cmap='Blues', origin='lower')
        ax.set_title(f'Input t={t-23}')
        ax.axis('off')
    axes[0,3].text(0.5, 0.5, 'INPUT SEQUENCE', ha='center', fontsize=12); axes[0,3].axis('off')
    
    # Target
    for i in range(6):
        ax = axes[1, i]
        val = y[idx, i, 0].cpu().float() * tp_std + tp_mean
        val = torch.expm1(val).clamp(0)
        ax.imshow(val, cmap='Blues', origin='lower')
        ax.set_title(f'Target t+{i+1}')
        ax.axis('off')
    axes[1,6].text(0, 0.5, 'Ground Truth', rotation=90, va='center', fontsize=12); axes[1,6].axis('off')
        
    # Pred
    for i in range(6):
        ax = axes[2, i]
        val = pred[idx, i, 0].detach().cpu().float() * tp_std + tp_mean
        val = torch.expm1(val).clamp(0)
        ax.imshow(val, cmap='Blues', origin='lower')
        ax.set_title(f'Pred t+{i+1}')
        ax.axis('off')
    axes[2,6].text(0, 0.5, 'Prediction', rotation=90, va='center', fontsize=12); axes[2,6].axis('off')

    # Error
    for i in range(6):
        ax = axes[3, i]
        t_val = torch.expm1(y[idx, i, 0].cpu().float()*tp_std+tp_mean).clamp(0)
        p_val = torch.expm1(pred[idx, i, 0].detach().cpu().float()*tp_std+tp_mean).clamp(0)
        diff = torch.abs(t_val - p_val)
        ax.imshow(diff, cmap='hot', origin='lower', vmin=0, vmax=5)
        ax.set_title(f'Error t+{i+1}')
        ax.axis('off')
    axes[3,6].text(0, 0.5, '|Target - Pred|', rotation=90, va='center', fontsize=12); axes[3,6].axis('off')
    
    plt.suptitle('Model Evaluation (Physical Space - mm)', fontsize=16)
    plt.tight_layout()
    plt.show()

print('Running Inference for Visualization...')
with torch.no_grad():
    for x, y in test_loader:
        x, y = x.to(device), y.to(device)
        with torch.amp.autocast('cuda'): pred = model(x, CFG.T_OUT)
        visualize_prediction(x, y, pred.float(), idx=0)
        break
