# 🔧 Stage 4 DEBUG: Test Baseline First

**CRITICAL FIX:** Previous run had CSI@74=0 for ALL lambdas.

**This notebook:**
1. Tests lambda=0 (pure MSE) FIRST
2. Should match Stage 2 baseline (CSI@74 ~0.65-0.68)
3. Adds detailed loss monitoring
4. Only proceeds to perceptual if baseline works

**Expected baseline results (lambda=0):**
- Val MSE: ~0.009-0.012
- Val CSI@74: ~0.65-0.68
- Val LPIPS: ~0.40 (no improvement expected)

**If baseline fails:** Data loading or model bug
**If baseline works:** Can safely add perceptual loss

---

## 1. Setup: Mount Drive & Check GPU

In [None]:
from google.colab import drive
import os

# Mount Google Drive
drive.mount('/content/drive')

# Set data root
DRIVE_DATA_ROOT = "/content/drive/MyDrive/SEVIR_Data"
os.makedirs(DRIVE_DATA_ROOT, exist_ok=True)

print(f"✓ Google Drive mounted")
print(f"✓ Data directory: {DRIVE_DATA_ROOT}")

In [None]:
!nvidia-smi

import torch
print(f"\n{'='*70}")
print("GPU CHECK")
print(f"{'='*70}")
print(f"PyTorch: {torch.__version__}")
print(f"CUDA Available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"VRAM: {torch.cuda.get_device_properties(0).total_memory/1e9:.1f} GB")
else:
    print("⚠️  WARNING: Select GPU runtime!")
    print("   Runtime → Change runtime type → GPU")
print(f"{'='*70}")

## 2. Install Dependencies

In [None]:
!pip install -q h5py lpips tqdm matplotlib scikit-image
print("✓ Dependencies installed")

## 3. Data Setup

In [None]:
from pathlib import Path

# Data paths
DATA_ROOT = "/content/drive/MyDrive/SEVIR_Data"
SEVIR_ROOT = f"{DATA_ROOT}/data/sevir"
CATALOG_PATH = f"{DATA_ROOT}/data/SEVIR_CATALOG.csv"

# Check if data exists
catalog_exists = Path(CATALOG_PATH).exists()
vil_exists = Path(f"{SEVIR_ROOT}/vil/2019/SEVIR_VIL_STORMEVENTS_2019_0701_1231.h5").exists()

print(f"Data Check:")
print(f"  Catalog: {'✓' if catalog_exists else '✗'} {CATALOG_PATH}")
print(f"  VIL data: {'✓' if vil_exists else '✗'} {SEVIR_ROOT}/vil/2019/")

if not (catalog_exists and vil_exists):
    print("\n⚠ Data missing! Please ensure SEVIR data is in Drive.")
else:
    print("\n✓ Data ready!")

## 4. Create Event ID Files

In [None]:
# Create sample directories
os.makedirs(f"{DATA_ROOT}/data/samples", exist_ok=True)

# Tiny train IDs (from Stage 1)
train_ids = [
    "S851839", "S856840", "S853914", "S858016",
    "S847132", "S828550", "S844358", "S833039"
]

# Tiny val IDs (from Stage 1)
val_ids = [
    "S848711", "S851205", "S849773", "S849364"
]

TRAIN_IDS = f"{DATA_ROOT}/data/samples/tiny_train_ids.txt"
VAL_IDS = f"{DATA_ROOT}/data/samples/tiny_val_ids.txt"

with open(TRAIN_IDS, 'w') as f:
    f.write('\n'.join(train_ids))

with open(VAL_IDS, 'w') as f:
    f.write('\n'.join(val_ids))

print(f"✓ Created event ID files")
print(f"  Train: {len(train_ids)} events")
print(f"  Val: {len(val_ids)} events")

## 5. Dataset Implementation

In [None]:
import h5py
import numpy as np
import pandas as pd
import torch
from torch.utils.data import Dataset

class SevirNowcastDataset(Dataset):
    """SEVIR VIL nowcasting dataset."""
    def __init__(self, index, input_steps=12, output_steps=1, target_size=(384, 384)):
        self.index = index
        self.in_steps = input_steps
        self.out_steps = output_steps
        self.target_size = target_size

    def __len__(self):
        return len(self.index)

    def __getitem__(self, idx):
        file_path, file_index, event_id = self.index[idx]

        # Load data from HDF5
        with h5py.File(file_path, "r") as h5:
            data = h5["vil"][file_index].astype(np.float32) / 255.0

        # Random temporal crop
        total_frames = data.shape[2]
        max_start = total_frames - (self.in_steps + self.out_steps)
        t_start = np.random.randint(0, max(1, max_start + 1))

        # Extract sequences: (H, W, T) → (T, H, W)
        x = data[:, :, t_start:t_start + self.in_steps]
        y = data[:, :, t_start + self.in_steps:t_start + self.in_steps + self.out_steps]

        x = np.transpose(x, (2, 0, 1))
        y = np.transpose(y, (2, 0, 1))

        return torch.from_numpy(x).float(), torch.from_numpy(y).float()


def build_tiny_index(catalog_path, ids_txt, sevir_root, modality="vil"):
    """Build index for tiny dataset."""
    with open(ids_txt, 'r') as f:
        event_ids = [line.strip() for line in f if line.strip()]

    catalog = pd.read_csv(catalog_path, low_memory=False)
    modality_cat = catalog[catalog["img_type"] == modality].copy()

    index = []
    for event_id in event_ids:
        event_rows = modality_cat[modality_cat["id"] == event_id]
        if event_rows.empty:
            continue

        row = event_rows.iloc[0]
        file_path = os.path.join(sevir_root, row["file_name"])
        if os.path.exists(file_path):
            index.append((file_path, int(row["file_index"]), event_id))

    print(f"✓ Built index: {len(index)} events")
    return index

print("✓ Dataset classes defined")

## 6. Model Architecture (U-Net2D)

In [None]:
import torch.nn as nn
import torch.nn.functional as F

def conv_block(in_ch, out_ch, use_bn=True):
    layers = [nn.Conv2d(in_ch, out_ch, 3, padding=1), nn.ReLU(inplace=True)]
    if use_bn: layers.append(nn.BatchNorm2d(out_ch))
    layers += [nn.Conv2d(out_ch, out_ch, 3, padding=1), nn.ReLU(inplace=True)]
    if use_bn: layers.append(nn.BatchNorm2d(out_ch))
    return nn.Sequential(*layers)

class Down(nn.Module):
    def __init__(self, in_ch, out_ch, use_bn=True):
        super().__init__()
        self.pool = nn.MaxPool2d(2)
        self.block = conv_block(in_ch, out_ch, use_bn)
    def forward(self, x): return self.block(self.pool(x))

class Up(nn.Module):
    def __init__(self, in_ch, out_ch, use_bn=True):
        super().__init__()
        self.up = nn.ConvTranspose2d(in_ch, out_ch, kernel_size=2, stride=2)
        self.block = conv_block(in_ch, out_ch, use_bn)
    def forward(self, x, skip):
        x = self.up(x)
        dy, dx = skip.size(-2) - x.size(-2), skip.size(-1) - x.size(-1)
        if dy or dx:
            x = F.pad(x, [dx//2, dx - dx//2, dy//2, dy - dy//2])
        x = torch.cat([skip, x], dim=1)
        return self.block(x)

class UNet2D(nn.Module):
    def __init__(self, in_channels=12, out_channels=1, base_ch=32, use_bn=True):
        super().__init__()
        self.inc = conv_block(in_channels, base_ch, use_bn)
        self.d1 = Down(base_ch, base_ch*2, use_bn)
        self.d2 = Down(base_ch*2, base_ch*4, use_bn)
        self.d3 = Down(base_ch*4, base_ch*8, use_bn)
        self.bottleneck = conv_block(base_ch*8, base_ch*16, use_bn)
        self.u3 = Up(base_ch*16, base_ch*8, use_bn)
        self.u2 = Up(base_ch*8, base_ch*4, use_bn)
        self.u1 = Up(base_ch*4, base_ch*2, use_bn)
        self.u0 = Up(base_ch*2, base_ch, use_bn)
        self.outc = nn.Conv2d(base_ch, out_channels, kernel_size=1)

    def forward(self, x):
        c1 = self.inc(x)
        c2 = self.d1(c1)
        c3 = self.d2(c2)
        c4 = self.d3(c3)
        b  = self.bottleneck(c4)
        x = self.u3(b, c4)
        x = self.u2(x, c3)
        x = self.u1(x, c2)
        x = self.u0(x, c1)
        return self.outc(x)

print("✓ U-Net2D model defined")

## 7. Forecast Metrics & LPIPS

In [None]:
VIP_THRESHOLDS = [16, 74, 133, 160, 181, 219]

def binarize(x, thr):
    return (x >= thr/255.0).to(torch.int32)

def scores(pred, truth, thresholds=VIP_THRESHOLDS):
    """Return dict of POD, SUCR, CSI, BIAS for each threshold."""
    out = {}
    for t in thresholds:
        p = binarize(pred, t)
        y = binarize(truth, t)
        hits = ((p==1)&(y==1)).sum().item()
        miss = ((p==0)&(y==1)).sum().item()
        fa   = ((p==1)&(y==0)).sum().item()
        pod = hits / (hits + miss + 1e-9)
        sucr = hits / (hits + fa + 1e-9)
        csi = hits / (hits + miss + fa + 1e-9)
        bias = (hits + fa) / (hits + miss + 1e-9)
        out[t] = dict(POD=pod, SUCR=sucr, CSI=csi, BIAS=bias)
    return out

# LPIPS
import lpips
lpips_fn = lpips.LPIPS(net='alex').cuda() if torch.cuda.is_available() else lpips.LPIPS(net='alex')

def compute_lpips(pred, target):
    if pred.shape[1] == 1:
        pred = pred.repeat(1, 3, 1, 1)
        target = target.repeat(1, 3, 1, 1)
    return lpips_fn(pred, target).mean().item()

print("✓ Metrics defined")

## 8. Training Function (WITH DETAILED LOGGING)

In [None]:
from tqdm.auto import tqdm
import json
import time

def train_baseline(epochs=10, batch_size=4, lr=1e-4):
    """Train PURE MSE baseline (lambda=0) to verify pipeline works."""
    
    print(f"\n{'='*70}")
    print(f"BASELINE TEST (LAMBDA = 0.0)")
    print(f"{'='*70}\n")
    
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    
    # Load datasets
    train_index = build_tiny_index(CATALOG_PATH, TRAIN_IDS, SEVIR_ROOT, "vil")
    val_index = build_tiny_index(CATALOG_PATH, VAL_IDS, SEVIR_ROOT, "vil")
    
    train_dataset = SevirNowcastDataset(train_index, 12, 1)
    val_dataset = SevirNowcastDataset(val_index, 12, 1)
    
    from torch.utils.data import DataLoader
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=2, pin_memory=True)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=2, pin_memory=True)
    
    # Create model
    model = UNet2D(12, 1, 32, True).to(device)
    mse_criterion = nn.MSELoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    scaler = torch.amp.GradScaler('cuda') if device.type == 'cuda' else None
    
    history = {'train_loss': [], 'val_mse': [], 'val_lpips': [], 'val_csi_74': []}
    
    for epoch in range(epochs):
        # Train
        model.train()
        train_loss = 0
        
        for x, y in tqdm(train_loader, desc=f"Epoch {epoch+1}/{epochs}", leave=False):
            x, y = x.to(device), y.to(device)
            optimizer.zero_grad()
            
            if scaler is not None:
                with torch.amp.autocast('cuda'):
                    pred = model(x)
                    loss = mse_criterion(pred, y)
                scaler.scale(loss).backward()
                scaler.unscale_(optimizer)
                nn.utils.clip_grad_norm_(model.parameters(), 1.0)
                scaler.step(optimizer)
                scaler.update()
            else:
                pred = model(x)
                loss = mse_criterion(pred, y)
                loss.backward()
                nn.utils.clip_grad_norm_(model.parameters(), 1.0)
                optimizer.step()
            
            train_loss += loss.item()
        
        train_loss /= len(train_loader)
        
        # Validate
        model.eval()
        val_mse = 0
        val_lpips_total = 0
        all_csi = []
        
        with torch.no_grad():
            for x, y in val_loader:
                x, y = x.to(device), y.to(device)
                pred = model(x)
                val_mse += mse_criterion(pred, y).item()
                val_lpips_total += compute_lpips(pred, y)
                
                batch_scores = scores(pred, y)
                all_csi.append(batch_scores[74]['CSI'])
        
        val_mse /= len(val_loader)
        val_lpips_avg = val_lpips_total / len(val_loader)
        csi_74 = np.mean(all_csi)
        
        history['train_loss'].append(train_loss)
        history['val_mse'].append(val_mse)
        history['val_lpips'].append(val_lpips_avg)
        history['val_csi_74'].append(csi_74)
        
        print(f"Epoch {epoch+1}/{epochs}: Train={train_loss:.4f}, Val MSE={val_mse:.4f}, LPIPS={val_lpips_avg:.4f}, CSI@74={csi_74:.3f}")
    
    # Save
    os.makedirs('/content/outputs/checkpoints', exist_ok=True)
    torch.save({'model': model.state_dict(), 'history': history}, '/content/outputs/checkpoints/baseline_lambda0.pt')
    
    return history

print("✓ Baseline training function defined")

## 9. RUN BASELINE TEST FIRST! 🔧

In [None]:
baseline_history = train_baseline(epochs=10, batch_size=4)

final_csi = baseline_history['val_csi_74'][-1]
final_mse = baseline_history['val_mse'][-1]
final_lpips = baseline_history['val_lpips'][-1]

print(f"\n{'='*70}")
print("BASELINE RESULTS (Lambda=0)")
print(f"{'='*70}")
print(f"Final Val MSE:    {final_mse:.4f} (expected: ~0.009-0.012)")
print(f"Final Val CSI@74: {final_csi:.3f} (expected: ~0.65-0.68)")
print(f"Final Val LPIPS:  {final_lpips:.3f} (expected: ~0.40)")
print(f"{'='*70}\n")

if final_csi >= 0.60:
    print("✅ BASELINE WORKS! Pipeline is correct.")
    print("   → Safe to proceed with perceptual loss sweep")
else:
    print("❌ BASELINE FAILED! CSI too low.")
    print("   → Bug in data loading, model, or training loop")
    print("   → DO NOT proceed to perceptual loss yet")

## 10. Save Baseline Results to Drive

In [None]:
!mkdir -p /content/drive/MyDrive/stormfusion_results/stage4_debug
!cp -r /content/outputs/checkpoints/* /content/drive/MyDrive/stormfusion_results/stage4_debug/

# Save summary
summary = f"""Stage 4 DEBUG: Baseline Test
============================
Lambda: 0.0 (pure MSE)
Final Val MSE: {final_mse:.4f}
Final Val CSI@74: {final_csi:.3f}
Final Val LPIPS: {final_lpips:.3f}

Expected: CSI ~0.65-0.68
Status: {'PASS ✅' if final_csi >= 0.60 else 'FAIL ❌'}
"""

with open('/content/drive/MyDrive/stormfusion_results/stage4_debug/baseline_summary.txt', 'w') as f:
    f.write(summary)

print("✅ Baseline results saved to Drive!")
print(f"   Location: /content/drive/MyDrive/stormfusion_results/stage4_debug/")

## ✅ NEXT STEPS

**If baseline PASSED (CSI ≥ 0.60):**
- Pipeline is working correctly
- Can now safely add perceptual loss
- The previous failure was due to perceptual loss being too strong
- Try MUCH LOWER lambda values: {0.00001, 0.00005, 0.0001}

**If baseline FAILED (CSI < 0.60):**
- Bug in data loading, model, or training
- Compare to Stage 2 baseline script
- Check data shapes, loss values, gradient flow
- DO NOT add perceptual loss until baseline works