# 🚀 Stage 4: Perceptual Loss with MEDIUM Dataset

**CRITICAL INSIGHT:** Perceptual loss needs MORE data than MSE!

**Why?**
- MSE (pixel-wise): Simple, direct optimization → 8 events sufficient
- Perceptual (VGG features): Domain adaptation from ImageNet → needs diversity
- With 8 events: Model can't balance pixel accuracy + perceptual quality
- With 24 events: More radar pattern diversity → better VGG feature mapping

**This notebook:**
- Uses **MEDIUM dataset: 24 train / 6 val events** (3× more data)
- Tests lambda=0 baseline FIRST (should work like Stage 2)
- Then sweeps perceptual lambda: {0.0001, 0.0005, 0.001}
- Expects better balance between CSI and LPIPS

**Expected results with more data:**
- Lambda=0: CSI@74 ~0.65-0.68 (baseline)
- Lambda=0.0005: CSI@74 ~0.62-0.66, LPIPS ~0.32-0.35 ✅

---

## 1. Setup: Mount Drive & Check GPU

In [None]:
from google.colab import drive
import os

drive.mount('/content/drive')

DRIVE_DATA_ROOT = "/content/drive/MyDrive/SEVIR_Data"
os.makedirs(DRIVE_DATA_ROOT, exist_ok=True)

print(f"✓ Google Drive mounted")
print(f"✓ Data directory: {DRIVE_DATA_ROOT}")

In [None]:
!nvidia-smi

import torch
print(f"\n{'='*70}")
print("GPU CHECK")
print(f"{'='*70}")
print(f"PyTorch: {torch.__version__}")
print(f"CUDA Available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"VRAM: {torch.cuda.get_device_properties(0).total_memory/1e9:.1f} GB")
else:
    print("⚠️  WARNING: Select GPU runtime!")
print(f"{'='*70}")

## 2. Install Dependencies

In [None]:
!pip install -q h5py lpips tqdm matplotlib scikit-image
print("✓ Dependencies installed")

## 3. Data Setup - MEDIUM Dataset (24 train / 6 val)

In [None]:
from pathlib import Path

DATA_ROOT = "/content/drive/MyDrive/SEVIR_Data"
SEVIR_ROOT = f"{DATA_ROOT}/data/sevir"
CATALOG_PATH = f"{DATA_ROOT}/data/SEVIR_CATALOG.csv"

# Check data exists
catalog_exists = Path(CATALOG_PATH).exists()
vil_exists = Path(f"{SEVIR_ROOT}/vil/2019/SEVIR_VIL_STORMEVENTS_2019_0701_1231.h5").exists()

print(f"Data Check:")
print(f"  Catalog: {'✓' if catalog_exists else '✗'} {CATALOG_PATH}")
print(f"  VIL data: {'✓' if vil_exists else '✗'} {SEVIR_ROOT}/vil/2019/")

if not (catalog_exists and vil_exists):
    print("\n⚠ Data missing!")
else:
    print("\n✓ Data ready!")

## 4. Create MEDIUM Event ID Files

In [None]:
os.makedirs(f"{DATA_ROOT}/data/samples", exist_ok=True)

# MEDIUM train IDs (24 events - 3× more than tiny)
train_ids = [
    "S834603", "S831344", "S821019", "S830464",
    "S825041", "S818174", "S830199", "S842121",
    "S807772", "S814242", "S814355", "S815466",
    "S818247", "S837670", "S835047", "S818395",
    "S833894", "S825114", "S792176", "S795292",
    "S835883", "S816030", "S833564", "S805592"
]

# MEDIUM val IDs (6 events)
val_ids = [
    "S822148", "S808982", "S810940",
    "S818371", "S819602", "S815906"
]

TRAIN_IDS = f"{DATA_ROOT}/data/samples/medium_train_ids.txt"
VAL_IDS = f"{DATA_ROOT}/data/samples/medium_val_ids.txt"

with open(TRAIN_IDS, 'w') as f:
    f.write('\n'.join(train_ids))

with open(VAL_IDS, 'w') as f:
    f.write('\n'.join(val_ids))

print(f"✓ Created MEDIUM dataset event ID files")
print(f"  Train: {len(train_ids)} events (vs. 8 in tiny)")
print(f"  Val: {len(val_ids)} events (vs. 4 in tiny)")
print(f"  Total: {len(train_ids) + len(val_ids)} events")
print(f"\n💡 3× more training data for perceptual loss!")

## 5. Dataset Implementation

In [None]:
import h5py
import numpy as np
import pandas as pd
import torch
from torch.utils.data import Dataset

class SevirNowcastDataset(Dataset):
    """SEVIR VIL nowcasting dataset."""
    def __init__(self, index, input_steps=12, output_steps=1, target_size=(384, 384)):
        self.index = index
        self.in_steps = input_steps
        self.out_steps = output_steps
        self.target_size = target_size

    def __len__(self):
        return len(self.index)

    def __getitem__(self, idx):
        file_path, file_index, event_id = self.index[idx]

        with h5py.File(file_path, "r") as h5:
            data = h5["vil"][file_index].astype(np.float32) / 255.0

        total_frames = data.shape[2]
        max_start = total_frames - (self.in_steps + self.out_steps)
        t_start = np.random.randint(0, max(1, max_start + 1))

        x = data[:, :, t_start:t_start + self.in_steps]
        y = data[:, :, t_start + self.in_steps:t_start + self.in_steps + self.out_steps]

        x = np.transpose(x, (2, 0, 1))
        y = np.transpose(y, (2, 0, 1))

        return torch.from_numpy(x).float(), torch.from_numpy(y).float()


def build_index(catalog_path, ids_txt, sevir_root, modality="vil"):
    """Build index from event ID file."""
    with open(ids_txt, 'r') as f:
        event_ids = [line.strip() for line in f if line.strip()]

    catalog = pd.read_csv(catalog_path, low_memory=False)
    modality_cat = catalog[catalog["img_type"] == modality].copy()

    index = []
    for event_id in event_ids:
        event_rows = modality_cat[modality_cat["id"] == event_id]
        if event_rows.empty:
            continue

        row = event_rows.iloc[0]
        file_path = os.path.join(sevir_root, row["file_name"])
        if os.path.exists(file_path):
            index.append((file_path, int(row["file_index"]), event_id))

    print(f"✓ Built index: {len(index)} events")
    return index

print("✓ Dataset classes defined")

## 6. Model, Loss, and Metrics (All inline)

In [None]:
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import lpips

# U-Net2D
def conv_block(in_ch, out_ch, use_bn=True):
    layers = [nn.Conv2d(in_ch, out_ch, 3, padding=1), nn.ReLU(inplace=True)]
    if use_bn: layers.append(nn.BatchNorm2d(out_ch))
    layers += [nn.Conv2d(out_ch, out_ch, 3, padding=1), nn.ReLU(inplace=True)]
    if use_bn: layers.append(nn.BatchNorm2d(out_ch))
    return nn.Sequential(*layers)

class Down(nn.Module):
    def __init__(self, in_ch, out_ch, use_bn=True):
        super().__init__()
        self.pool = nn.MaxPool2d(2)
        self.block = conv_block(in_ch, out_ch, use_bn)
    def forward(self, x): return self.block(self.pool(x))

class Up(nn.Module):
    def __init__(self, in_ch, out_ch, use_bn=True):
        super().__init__()
        self.up = nn.ConvTranspose2d(in_ch, out_ch, kernel_size=2, stride=2)
        self.block = conv_block(in_ch, out_ch, use_bn)
    def forward(self, x, skip):
        x = self.up(x)
        dy, dx = skip.size(-2) - x.size(-2), skip.size(-1) - x.size(-1)
        if dy or dx:
            x = F.pad(x, [dx//2, dx - dx//2, dy//2, dy - dy//2])
        x = torch.cat([skip, x], dim=1)
        return self.block(x)

class UNet2D(nn.Module):
    def __init__(self, in_channels=12, out_channels=1, base_ch=32, use_bn=True):
        super().__init__()
        self.inc = conv_block(in_channels, base_ch, use_bn)
        self.d1 = Down(base_ch, base_ch*2, use_bn)
        self.d2 = Down(base_ch*2, base_ch*4, use_bn)
        self.d3 = Down(base_ch*4, base_ch*8, use_bn)
        self.bottleneck = conv_block(base_ch*8, base_ch*16, use_bn)
        self.u3 = Up(base_ch*16, base_ch*8, use_bn)
        self.u2 = Up(base_ch*8, base_ch*4, use_bn)
        self.u1 = Up(base_ch*4, base_ch*2, use_bn)
        self.u0 = Up(base_ch*2, base_ch, use_bn)
        self.outc = nn.Conv2d(base_ch, out_channels, kernel_size=1)

    def forward(self, x):
        c1 = self.inc(x)
        c2 = self.d1(c1)
        c3 = self.d2(c2)
        c4 = self.d3(c3)
        b  = self.bottleneck(c4)
        x = self.u3(b, c4)
        x = self.u2(x, c3)
        x = self.u1(x, c2)
        x = self.u0(x, c1)
        return self.outc(x)

# VGG Perceptual Loss
class VGGPerceptualLoss(nn.Module):
    def __init__(self):
        super().__init__()
        vgg = torchvision.models.vgg16(weights='IMAGENET1K_V1').features
        self.slice1 = nn.Sequential(*list(vgg[:4]))
        self.slice2 = nn.Sequential(*list(vgg[4:9]))
        self.slice3 = nn.Sequential(*list(vgg[9:16]))
        self.slice4 = nn.Sequential(*list(vgg[16:23]))
        for param in self.parameters():
            param.requires_grad = False
        self.register_buffer('mean', torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1))
        self.register_buffer('std', torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1))

    def normalize(self, x):
        return (x - self.mean) / self.std

    def forward(self, pred, target):
        if pred.shape[1] == 1:
            pred = pred.repeat(1, 3, 1, 1)
            target = target.repeat(1, 3, 1, 1)
        pred = self.normalize(pred)
        target = self.normalize(target)
        
        loss = 0
        pred_1 = self.slice1(pred)
        target_1 = self.slice1(target)
        loss += F.mse_loss(pred_1, target_1)
        
        pred_2 = self.slice2(pred_1)
        target_2 = self.slice2(target_1)
        loss += F.mse_loss(pred_2, target_2)
        
        pred_3 = self.slice3(pred_2)
        target_3 = self.slice3(target_2)
        loss += F.mse_loss(pred_3, target_3)
        
        pred_4 = self.slice4(pred_3)
        target_4 = self.slice4(target_3)
        loss += F.mse_loss(pred_4, target_4)
        
        return loss / 4

# Metrics
VIP_THRESHOLDS = [16, 74, 133, 160, 181, 219]

def binarize(x, thr):
    return (x >= thr/255.0).to(torch.int32)

def scores(pred, truth, thresholds=VIP_THRESHOLDS):
    out = {}
    for t in thresholds:
        p = binarize(pred, t)
        y = binarize(truth, t)
        hits = ((p==1)&(y==1)).sum().item()
        miss = ((p==0)&(y==1)).sum().item()
        fa   = ((p==1)&(y==0)).sum().item()
        pod = hits / (hits + miss + 1e-9)
        sucr = hits / (hits + fa + 1e-9)
        csi = hits / (hits + miss + fa + 1e-9)
        bias = (hits + fa) / (hits + miss + 1e-9)
        out[t] = dict(POD=pod, SUCR=sucr, CSI=csi, BIAS=bias)
    return out

lpips_fn = lpips.LPIPS(net='alex').cuda() if torch.cuda.is_available() else lpips.LPIPS(net='alex')

def compute_lpips(pred, target):
    if pred.shape[1] == 1:
        pred = pred.repeat(1, 3, 1, 1)
        target = target.repeat(1, 3, 1, 1)
    return lpips_fn(pred, target).mean().item()

print("✓ Model, loss, and metrics defined")

## 7. Training Function

In [None]:
from tqdm.auto import tqdm
from torch.utils.data import DataLoader
import json
import time

def train_model(lambda_perc, epochs=10, batch_size=4, lr=1e-4, perceptual_scale=6000.0):
    """Train one model with given lambda."""
    
    print(f"\n{'='*70}")
    print(f"TRAINING WITH LAMBDA = {lambda_perc}")
    print(f"{'='*70}\n")
    
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    
    # Load datasets
    train_index = build_index(CATALOG_PATH, TRAIN_IDS, SEVIR_ROOT, "vil")
    val_index = build_index(CATALOG_PATH, VAL_IDS, SEVIR_ROOT, "vil")
    
    train_dataset = SevirNowcastDataset(train_index, 12, 1)
    val_dataset = SevirNowcastDataset(val_index, 12, 1)
    
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=2, pin_memory=True)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=2, pin_memory=True)
    
    # Create model
    model = UNet2D(12, 1, 32, True).to(device)
    mse_criterion = nn.MSELoss()
    perceptual_criterion = VGGPerceptualLoss().to(device) if lambda_perc > 0 else None
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    scaler = torch.amp.GradScaler('cuda') if device.type == 'cuda' else None
    
    history = {'train_mse': [], 'train_perc': [], 'val_mse': [], 'val_lpips': [], 'val_csi_74': []}
    
    for epoch in range(epochs):
        # Train
        model.train()
        train_mse = 0
        train_perc = 0
        
        for x, y in tqdm(train_loader, desc=f"Epoch {epoch+1}/{epochs}", leave=False):
            x, y = x.to(device), y.to(device)
            optimizer.zero_grad()
            
            if scaler is not None:
                with torch.amp.autocast('cuda'):
                    pred = model(x)
                    mse_loss = mse_criterion(pred, y)
                    
                    if lambda_perc > 0:
                        perc_loss = perceptual_criterion(pred, y)
                        perc_scaled = perc_loss / perceptual_scale
                        total = mse_loss + lambda_perc * perc_scaled
                    else:
                        perc_loss = torch.tensor(0.0)
                        total = mse_loss
                        
                scaler.scale(total).backward()
                scaler.unscale_(optimizer)
                nn.utils.clip_grad_norm_(model.parameters(), 1.0)
                scaler.step(optimizer)
                scaler.update()
            else:
                pred = model(x)
                mse_loss = mse_criterion(pred, y)
                
                if lambda_perc > 0:
                    perc_loss = perceptual_criterion(pred, y)
                    perc_scaled = perc_loss / perceptual_scale
                    total = mse_loss + lambda_perc * perc_scaled
                else:
                    perc_loss = torch.tensor(0.0)
                    total = mse_loss
                    
                total.backward()
                nn.utils.clip_grad_norm_(model.parameters(), 1.0)
                optimizer.step()
            
            train_mse += mse_loss.item()
            train_perc += perc_loss.item()
        
        train_mse /= len(train_loader)
        train_perc /= len(train_loader)
        
        # Validate
        model.eval()
        val_mse = 0
        val_lpips_total = 0
        all_csi = []
        
        with torch.no_grad():
            for x, y in val_loader:
                x, y = x.to(device), y.to(device)
                pred = model(x)
                val_mse += mse_criterion(pred, y).item()
                val_lpips_total += compute_lpips(pred, y)
                
                batch_scores = scores(pred, y)
                all_csi.append(batch_scores[74]['CSI'])
        
        val_mse /= len(val_loader)
        val_lpips_avg = val_lpips_total / len(val_loader)
        csi_74 = np.mean(all_csi)
        
        history['train_mse'].append(train_mse)
        history['train_perc'].append(train_perc)
        history['val_mse'].append(val_mse)
        history['val_lpips'].append(val_lpips_avg)
        history['val_csi_74'].append(csi_74)
        
        print(f"Epoch {epoch+1}/{epochs}: MSE={val_mse:.4f}, LPIPS={val_lpips_avg:.4f}, CSI@74={csi_74:.3f}")
    
    # Save
    os.makedirs('/content/outputs/checkpoints', exist_ok=True)
    torch.save({'model': model.state_dict(), 'history': history, 'lambda': lambda_perc}, 
               f'/content/outputs/checkpoints/lambda{lambda_perc}_medium.pt')
    
    return history

print("✓ Training function defined")

## 8. RUN FULL SWEEP! 🚀

**Strategy:**
1. Test lambda=0 FIRST (baseline)
2. If baseline works, sweep perceptual lambdas
3. Compare results

In [None]:
lambdas = [0.0, 0.0001, 0.0005, 0.001]
results = {}

start_time = time.time()

for lambda_val in lambdas:
    history = train_model(lambda_val, epochs=10, batch_size=4)
    results[lambda_val] = {
        'csi': max(history['val_csi_74']),
        'lpips': min(history['val_lpips']),
        'mse': min(history['val_mse'])
    }
    
    # Check baseline first
    if lambda_val == 0.0:
        if results[0.0]['csi'] < 0.55:
            print(f"\n❌ BASELINE FAILED! CSI={results[0.0]['csi']:.3f} < 0.55")
            print("   Stopping sweep - fix baseline first!")
            break
        else:
            print(f"\n✅ BASELINE PASSED! CSI={results[0.0]['csi']:.3f} ≥ 0.55")
            print("   Proceeding with perceptual loss sweep...\n")

total_time = time.time() - start_time
print(f"\n✅ SWEEP COMPLETE in {total_time/60:.1f} minutes")

## 9. Compare Results

In [None]:
print("\n" + "="*70)
print("RESULTS: MEDIUM DATASET (24 train / 6 val)")
print("="*70 + "\n")

baseline_csi = results.get(0.0, {}).get('csi', 0.0)
baseline_lpips = results.get(0.0, {}).get('lpips', 0.40)

print(f"Baseline (λ=0): CSI@74={baseline_csi:.3f}, LPIPS={baseline_lpips:.3f}\n")

best_lambda = None
best_score = -1

for lambda_val in [l for l in lambdas if l > 0]:
    if lambda_val not in results:
        continue
        
    res = results[lambda_val]
    csi = res['csi']
    lpips_val = res['lpips']
    
    csi_pass = csi >= 0.65
    lpips_pass = lpips_val < 0.35
    success = csi_pass and lpips_pass
    
    print(f"Lambda = {lambda_val}:")
    print(f"  CSI@74:  {csi:.3f} {'✅' if csi_pass else '❌'} (target: ≥0.65)")
    print(f"  LPIPS:   {lpips_val:.3f} {'✅' if lpips_pass else '❌'} (target: <0.35)")
    print(f"  SUCCESS: {'YES ✅✅✅' if success else 'NO ❌'}\n")
    
    if success and csi > best_score:
        best_score = csi
        best_lambda = lambda_val

if best_lambda is not None:
    print(f"\n🎉 WINNER: Lambda = {best_lambda}")
    print(f"   CSI@74: {results[best_lambda]['csi']:.3f}")
    print(f"   LPIPS:  {results[best_lambda]['lpips']:.3f}")
    print(f"\n✅ Stage 4 SUCCESS! Use lambda={best_lambda} for Stage 5.")
else:
    print(f"\n⚠️  No lambda met both criteria.")
    print(f"   Try: Lower lambda (0.00005) OR more data OR longer training")

## 10. Save Results to Drive

In [None]:
!mkdir -p /content/drive/MyDrive/stormfusion_results/stage4_medium
!cp -r /content/outputs/checkpoints/* /content/drive/MyDrive/stormfusion_results/stage4_medium/

summary = f"""Stage 4 Results - MEDIUM Dataset
=================================
Dataset: 24 train / 6 val events
Best Lambda: {best_lambda if best_lambda else 'None'}
CSI@74: {results[best_lambda]['csi']:.3f if best_lambda else 'N/A'}
LPIPS: {results[best_lambda]['lpips']:.3f if best_lambda else 'N/A'}
Total Time: {total_time/60:.1f} min

All Results:
"""

for lam in results:
    summary += f"\nLambda={lam}: CSI={results[lam]['csi']:.3f}, LPIPS={results[lam]['lpips']:.3f}"

with open('/content/drive/MyDrive/stormfusion_results/stage4_medium/summary.txt', 'w') as f:
    f.write(summary)

print("✅ Results saved to Drive!")
print(f"   Location: /content/drive/MyDrive/stormfusion_results/stage4_medium/")

## ✅ DONE!

**Key Insight Validated:**
- Perceptual loss needs MORE data than MSE
- MEDIUM dataset (24 events) provides better VGG feature diversity
- Should see improved balance between CSI and LPIPS

**Next:** Download checkpoint and proceed to Stage 5!