# 🚀 Stage 4: Perceptual Loss Lambda Sweep - READY TO RUN

**Just click Runtime → Run All and go get coffee! ☕**

This notebook contains ALL code inline - no file uploads needed!

**What it does:**
1. Mounts Drive and checks GPU
2. Loads data from your existing SEVIR_Data folder
3. Trains 3 models with λ = {0.0001, 0.0005, 0.001}
4. Compares results and picks winner
5. Saves everything to Drive

**Expected time:** 25-35 minutes on L4/A100 GPU

**Success criteria:**
- ✅ CSI@74 ≥ 0.65 (maintains forecast skill)
- ✅ LPIPS < 0.35 (improves sharpness)

---

## 1. Setup: Mount Drive & Check GPU

In [None]:
from google.colab import drive
import os

# Mount Google Drive
drive.mount('/content/drive')

# Set data root
DRIVE_DATA_ROOT = "/content/drive/MyDrive/SEVIR_Data"
os.makedirs(DRIVE_DATA_ROOT, exist_ok=True)

print(f"✓ Google Drive mounted")
print(f"✓ Data directory: {DRIVE_DATA_ROOT}")

# Check storage
import shutil
total, used, free = shutil.disk_usage("/content/drive/MyDrive")
print(f"\nGoogle Drive Storage:")
print(f"  Total: {total / 1e12:.2f} TB")
print(f"  Free:  {free / 1e12:.2f} TB")

In [None]:
!nvidia-smi

import torch
print(f"\n{'='*70}")
print("GPU CHECK")
print(f"{'='*70}")
print(f"PyTorch: {torch.__version__}")
print(f"CUDA Available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"VRAM: {torch.cuda.get_device_properties(0).total_memory/1e9:.1f} GB")
else:
    print("⚠️  WARNING: Select GPU runtime!")
    print("   Runtime → Change runtime type → GPU")
print(f"{'='*70}")

## 2. Install Dependencies

In [None]:
!pip install -q h5py lpips tqdm matplotlib scikit-image
print("✓ Dependencies installed")

## 3. Data Setup

In [None]:
from pathlib import Path

# Data paths
DATA_ROOT = "/content/drive/MyDrive/SEVIR_Data"
SEVIR_ROOT = f"{DATA_ROOT}/data/sevir"
CATALOG_PATH = f"{DATA_ROOT}/data/SEVIR_CATALOG.csv"

# Check if data exists
catalog_exists = Path(CATALOG_PATH).exists()
vil_exists = Path(f"{SEVIR_ROOT}/vil/2019/SEVIR_VIL_STORMEVENTS_2019_0701_1231.h5").exists()

print(f"Data Check:")
print(f"  Catalog: {'✓' if catalog_exists else '✗'} {CATALOG_PATH}")
print(f"  VIL data: {'✓' if vil_exists else '✗'} {SEVIR_ROOT}/vil/2019/")

if not (catalog_exists and vil_exists):
    print("\n⚠ Data missing! Please ensure SEVIR data is in Drive.")
else:
    print("\n✓ Data ready!")

## 4. Create Event ID Files

In [None]:
# Create sample directories
os.makedirs(f"{DATA_ROOT}/data/samples", exist_ok=True)

# Tiny train IDs (from Stage 1)
train_ids = [
    "S851839", "S856840", "S853914", "S858016",
    "S847132", "S828550", "S844358", "S833039"
]

# Tiny val IDs (from Stage 1)
val_ids = [
    "S848711", "S851205", "S849773", "S849364"
]

TRAIN_IDS = f"{DATA_ROOT}/data/samples/tiny_train_ids.txt"
VAL_IDS = f"{DATA_ROOT}/data/samples/tiny_val_ids.txt"

with open(TRAIN_IDS, 'w') as f:
    f.write('\n'.join(train_ids))

with open(VAL_IDS, 'w') as f:
    f.write('\n'.join(val_ids))

print(f"✓ Created event ID files")
print(f"  Train: {len(train_ids)} events")
print(f"  Val: {len(val_ids)} events")

## 5. Dataset Implementation (All code inline!)

In [None]:
import h5py
import numpy as np
import pandas as pd
import torch
from torch.utils.data import Dataset

class SevirNowcastDataset(Dataset):
    """SEVIR VIL nowcasting dataset."""
    def __init__(self, index, input_steps=12, output_steps=1, target_size=(384, 384)):
        self.index = index
        self.in_steps = input_steps
        self.out_steps = output_steps
        self.target_size = target_size

    def __len__(self):
        return len(self.index)

    def __getitem__(self, idx):
        file_path, file_index, event_id = self.index[idx]

        # Load data from HDF5
        with h5py.File(file_path, "r") as h5:
            data = h5["vil"][file_index].astype(np.float32) / 255.0

        # Random temporal crop
        total_frames = data.shape[2]
        max_start = total_frames - (self.in_steps + self.out_steps)
        t_start = np.random.randint(0, max(1, max_start + 1))

        # Extract sequences: (H, W, T) → (T, H, W)
        x = data[:, :, t_start:t_start + self.in_steps]
        y = data[:, :, t_start + self.in_steps:t_start + self.in_steps + self.out_steps]

        x = np.transpose(x, (2, 0, 1))
        y = np.transpose(y, (2, 0, 1))

        return torch.from_numpy(x).float(), torch.from_numpy(y).float()


def build_tiny_index(catalog_path, ids_txt, sevir_root, modality="vil"):
    """Build index for tiny dataset."""
    with open(ids_txt, 'r') as f:
        event_ids = [line.strip() for line in f if line.strip()]

    catalog = pd.read_csv(catalog_path, low_memory=False)
    modality_cat = catalog[catalog["img_type"] == modality].copy()

    index = []
    for event_id in event_ids:
        event_rows = modality_cat[modality_cat["id"] == event_id]
        if event_rows.empty:
            continue

        row = event_rows.iloc[0]
        file_path = os.path.join(sevir_root, row["file_name"])
        if os.path.exists(file_path):
            index.append((file_path, int(row["file_index"]), event_id))

    print(f"✓ Built index: {len(index)} events")
    return index

print("✓ Dataset classes defined")

## 6. Model Architecture (U-Net2D)

In [None]:
import torch.nn as nn
import torch.nn.functional as F

def conv_block(in_ch, out_ch, use_bn=True):
    layers = [nn.Conv2d(in_ch, out_ch, 3, padding=1), nn.ReLU(inplace=True)]
    if use_bn: layers.append(nn.BatchNorm2d(out_ch))
    layers += [nn.Conv2d(out_ch, out_ch, 3, padding=1), nn.ReLU(inplace=True)]
    if use_bn: layers.append(nn.BatchNorm2d(out_ch))
    return nn.Sequential(*layers)

class Down(nn.Module):
    def __init__(self, in_ch, out_ch, use_bn=True):
        super().__init__()
        self.pool = nn.MaxPool2d(2)
        self.block = conv_block(in_ch, out_ch, use_bn)
    def forward(self, x): return self.block(self.pool(x))

class Up(nn.Module):
    def __init__(self, in_ch, out_ch, use_bn=True):
        super().__init__()
        self.up = nn.ConvTranspose2d(in_ch, out_ch, kernel_size=2, stride=2)
        self.block = conv_block(in_ch, out_ch, use_bn)
    def forward(self, x, skip):
        x = self.up(x)
        dy, dx = skip.size(-2) - x.size(-2), skip.size(-1) - x.size(-1)
        if dy or dx:
            x = F.pad(x, [dx//2, dx - dx//2, dy//2, dy - dy//2])
        x = torch.cat([skip, x], dim=1)
        return self.block(x)

class UNet2D(nn.Module):
    def __init__(self, in_channels=12, out_channels=1, base_ch=32, use_bn=True):
        super().__init__()
        self.inc = conv_block(in_channels, base_ch, use_bn)
        self.d1 = Down(base_ch, base_ch*2, use_bn)
        self.d2 = Down(base_ch*2, base_ch*4, use_bn)
        self.d3 = Down(base_ch*4, base_ch*8, use_bn)
        self.bottleneck = conv_block(base_ch*8, base_ch*16, use_bn)
        self.u3 = Up(base_ch*16, base_ch*8, use_bn)
        self.u2 = Up(base_ch*8, base_ch*4, use_bn)
        self.u1 = Up(base_ch*4, base_ch*2, use_bn)
        self.u0 = Up(base_ch*2, base_ch, use_bn)
        self.outc = nn.Conv2d(base_ch, out_channels, kernel_size=1)

    def forward(self, x):
        c1 = self.inc(x)
        c2 = self.d1(c1)
        c3 = self.d2(c2)
        c4 = self.d3(c3)
        b  = self.bottleneck(c4)
        x = self.u3(b, c4)
        x = self.u2(x, c3)
        x = self.u1(x, c2)
        x = self.u0(x, c1)
        return self.outc(x)

print("✓ U-Net2D model defined")

## 7. VGG Perceptual Loss

In [None]:
import torchvision

class VGGPerceptualLoss(nn.Module):
    """Perceptual loss using pre-trained VGG16."""
    def __init__(self):
        super().__init__()
        vgg = torchvision.models.vgg16(weights='IMAGENET1K_V1').features
        self.slice1 = nn.Sequential(*list(vgg[:4]))
        self.slice2 = nn.Sequential(*list(vgg[4:9]))
        self.slice3 = nn.Sequential(*list(vgg[9:16]))
        self.slice4 = nn.Sequential(*list(vgg[16:23]))
        for param in self.parameters():
            param.requires_grad = False
        self.register_buffer('mean', torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1))
        self.register_buffer('std', torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1))

    def normalize(self, x):
        return (x - self.mean) / self.std

    def forward(self, pred, target):
        if pred.shape[1] == 1:
            pred = pred.repeat(1, 3, 1, 1)
            target = target.repeat(1, 3, 1, 1)
        pred = self.normalize(pred)
        target = self.normalize(target)
        
        loss = 0
        pred_1 = self.slice1(pred)
        target_1 = self.slice1(target)
        loss += F.mse_loss(pred_1, target_1)
        
        pred_2 = self.slice2(pred_1)
        target_2 = self.slice2(target_1)
        loss += F.mse_loss(pred_2, target_2)
        
        pred_3 = self.slice3(pred_2)
        target_3 = self.slice3(target_2)
        loss += F.mse_loss(pred_3, target_3)
        
        pred_4 = self.slice4(pred_3)
        target_4 = self.slice4(target_3)
        loss += F.mse_loss(pred_4, target_4)
        
        return loss / 4

print("✓ VGG Perceptual Loss defined")

## 8. Forecast Metrics & LPIPS

In [None]:
VIP_THRESHOLDS = [16, 74, 133, 160, 181, 219]

def binarize(x, thr):
    return (x >= thr/255.0).to(torch.int32)

def scores(pred, truth, thresholds=VIP_THRESHOLDS):
    """Return dict of POD, SUCR, CSI, BIAS for each threshold."""
    out = {}
    for t in thresholds:
        p = binarize(pred, t)
        y = binarize(truth, t)
        hits = ((p==1)&(y==1)).sum().item()
        miss = ((p==0)&(y==1)).sum().item()
        fa   = ((p==1)&(y==0)).sum().item()
        pod = hits / (hits + miss + 1e-9)
        sucr = hits / (hits + fa + 1e-9)
        csi = hits / (hits + miss + fa + 1e-9)
        bias = (hits + fa) / (hits + miss + 1e-9)
        out[t] = dict(POD=pod, SUCR=sucr, CSI=csi, BIAS=bias)
    return out

# LPIPS
import lpips
lpips_fn = lpips.LPIPS(net='alex').cuda() if torch.cuda.is_available() else lpips.LPIPS(net='alex')

def compute_lpips(pred, target):
    if pred.shape[1] == 1:
        pred = pred.repeat(1, 3, 1, 1)
        target = target.repeat(1, 3, 1, 1)
    return lpips_fn(pred, target).mean().item()

print("✓ Metrics defined")

## 9. Training Function

In [None]:
from tqdm.auto import tqdm
import json
import time

def train_model(lambda_perc, epochs=10, batch_size=4, lr=1e-4, perceptual_scale=6000.0):
    """Train one model with given lambda."""
    
    print(f"\n{'='*70}")
    print(f"TRAINING WITH LAMBDA = {lambda_perc}")
    print(f"{'='*70}\n")
    
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    
    # Load datasets
    train_index = build_tiny_index(CATALOG_PATH, TRAIN_IDS, SEVIR_ROOT, "vil")
    val_index = build_tiny_index(CATALOG_PATH, VAL_IDS, SEVIR_ROOT, "vil")
    
    train_dataset = SevirNowcastDataset(train_index, 12, 1)
    val_dataset = SevirNowcastDataset(val_index, 12, 1)
    
    from torch.utils.data import DataLoader
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=2, pin_memory=True)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=2, pin_memory=True)
    
    # Create model
    model = UNet2D(12, 1, 32, True).to(device)
    mse_criterion = nn.MSELoss()
    perceptual_criterion = VGGPerceptualLoss().to(device)
    optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=0.01)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, epochs)
    scaler = torch.amp.GradScaler('cuda') if device.type == 'cuda' else None
    
    history = {'train_mse_loss': [], 'train_perc_loss': [], 'val_mse': [], 'val_lpips': [], 'val_csi_74': []}
    
    for epoch in range(epochs):
        # Train
        model.train()
        train_mse = 0
        train_perc = 0
        
        for x, y in tqdm(train_loader, desc=f"Epoch {epoch+1}/{epochs}", leave=False):
            x, y = x.to(device), y.to(device)
            optimizer.zero_grad()
            
            if scaler is not None:
                with torch.amp.autocast('cuda'):
                    pred = model(x)
                    mse_loss = mse_criterion(pred, y)
                    perc_loss = perceptual_criterion(pred, y)
                    perc_scaled = perc_loss / perceptual_scale
                    total = mse_loss + lambda_perc * perc_scaled
                scaler.scale(total).backward()
                scaler.unscale_(optimizer)
                nn.utils.clip_grad_norm_(model.parameters(), 1.0)
                scaler.step(optimizer)
                scaler.update()
            else:
                pred = model(x)
                mse_loss = mse_criterion(pred, y)
                perc_loss = perceptual_criterion(pred, y)
                perc_scaled = perc_loss / perceptual_scale
                total = mse_loss + lambda_perc * perc_scaled
                total.backward()
                nn.utils.clip_grad_norm_(model.parameters(), 1.0)
                optimizer.step()
            
            train_mse += mse_loss.item()
            train_perc += perc_loss.item()
        
        train_mse /= len(train_loader)
        train_perc /= len(train_loader)
        
        # Validate
        model.eval()
        val_mse = 0
        val_lpips_total = 0
        agg_scores = None
        
        with torch.no_grad():
            for x, y in val_loader:
                x, y = x.to(device), y.to(device)
                pred = model(x)
                val_mse += mse_criterion(pred, y).item()
                val_lpips_total += compute_lpips(pred, y)
                
                batch_scores = scores(pred, y)
                if agg_scores is None:
                    agg_scores = {k: {m: 0.0 for m in batch_scores[k]} for k in batch_scores}
                for threshold in batch_scores:
                    for metric, value in batch_scores[threshold].items():
                        agg_scores[threshold][metric] += value
        
        val_mse /= len(val_loader)
        val_lpips_avg = val_lpips_total / len(val_loader)
        for threshold in agg_scores:
            for metric in agg_scores[threshold]:
                agg_scores[threshold][metric] /= len(val_loader)
        
        csi_74 = agg_scores[74]['CSI']
        
        history['train_mse_loss'].append(train_mse)
        history['train_perc_loss'].append(train_perc)
        history['val_mse'].append(val_mse)
        history['val_lpips'].append(val_lpips_avg)
        history['val_csi_74'].append(csi_74)
        
        print(f"Epoch {epoch+1}/{epochs}: MSE={val_mse:.4f}, LPIPS={val_lpips_avg:.4f}, CSI@74={csi_74:.3f}")
        
        scheduler.step()
    
    # Save
    os.makedirs('/content/outputs/checkpoints', exist_ok=True)
    checkpoint = {
        'model_state_dict': model.state_dict(),
        'history': history,
        'lambda': lambda_perc
    }
    torch.save(checkpoint, f'/content/outputs/checkpoints/lambda{lambda_perc}.pt')
    with open(f'/content/outputs/checkpoints/unet_perceptual_lambda{lambda_perc}_history.json', 'w') as f:
        json.dump(history, f)
    
    return history

print("✓ Training function defined")

## 10. RUN LAMBDA SWEEP! 🚀

In [None]:
lambdas = [0.0001, 0.0005, 0.001]
results = {}

start_time = time.time()

for lambda_val in lambdas:
    history = train_model(lambda_val, epochs=10, batch_size=4)
    results[lambda_val] = {
        'csi': max(history['val_csi_74']),
        'lpips': min(history['val_lpips']),
        'mse': min(history['val_mse'])
    }

total_time = time.time() - start_time
print(f"\n✅ SWEEP COMPLETE in {total_time/60:.1f} minutes")

## 11. Compare Results & Pick Winner!

In [None]:
print("\n" + "="*70)
print("LAMBDA SWEEP RESULTS")
print("="*70 + "\n")

baseline_csi = 0.68
baseline_lpips = 0.40

print(f"Baseline (MSE only): CSI@74={baseline_csi:.3f}, LPIPS={baseline_lpips:.3f}\n")

best_lambda = None
best_score = -1

for lambda_val in lambdas:
    res = results[lambda_val]
    csi = res['csi']
    lpips_val = res['lpips']
    
    csi_pass = csi >= 0.65
    lpips_pass = lpips_val < 0.35
    success = csi_pass and lpips_pass
    
    print(f"Lambda = {lambda_val}:")
    print(f"  CSI@74:  {csi:.3f} {'✅' if csi_pass else '❌'}")
    print(f"  LPIPS:   {lpips_val:.3f} {'✅' if lpips_pass else '❌'}")
    print(f"  SUCCESS: {'YES ✅✅✅' if success else 'NO ❌'}\n")
    
    if success and csi > best_score:
        best_score = csi
        best_lambda = lambda_val

if best_lambda is not None:
    print(f"🎉 BEST LAMBDA: {best_lambda}")
    print(f"   CSI@74: {results[best_lambda]['csi']:.3f}")
    print(f"   LPIPS:  {results[best_lambda]['lpips']:.3f}")
    print(f"\n✅ Stage 4 SUCCESS! Use lambda={best_lambda} for Stage 5.")
else:
    print(f"⚠️  No lambda met both criteria.")

## 12. Save Results to Drive

In [None]:
!mkdir -p /content/drive/MyDrive/stormfusion_results/stage4
!cp -r /content/outputs/checkpoints/* /content/drive/MyDrive/stormfusion_results/stage4/

# Save summary
summary = f"""Stage 4 Results
===============
Best Lambda: {best_lambda}
CSI@74: {results[best_lambda]['csi']:.3f}
LPIPS: {results[best_lambda]['lpips']:.3f}
Total Time: {total_time/60:.1f} min
"""

with open('/content/drive/MyDrive/stormfusion_results/stage4/summary.txt', 'w') as f:
    f.write(summary)

print("✅ Results saved to Drive!")
print(f"   Location: /content/drive/MyDrive/stormfusion_results/stage4/")

## ✅ DONE! 

Check the results above. If you got a winner (both ✅), Stage 4 is complete!

Next: Download the checkpoint and proceed to Stage 5 (multi-step forecasting).