# 🚀 Stage 5: Multi-Step Forecasting (6 time steps)

**Building on Stage 4 Success:**
- Stage 4: Single-step nowcasting (12 input → 1 output frame)
- Results: CSI@74=0.82, CSI@181=0.50, CSI@219=0.33, LPIPS=0.137 ✅

**Stage 5 Goal:** Extend to multi-step forecasting
- Input: 12 frames (0-55 min of history)
- Output: **6 frames** (predict 5, 10, 15, 20, 25, 30 min ahead)
- Challenge: Maintain skill and sharpness across all lead times

**Key Questions:**
1. Does CSI degrade gracefully with lead time?
2. Does blur accumulate (LPIPS increases)?
3. Can we maintain extreme event skill (CSI@181, CSI@219)?
4. Do we need perceptual loss to prevent blur?

**Success Criteria:**
- CSI@74 ≥ 0.70 at t+5min, ≥0.50 at t+30min
- CSI@181 ≥ 0.40 at all lead times
- LPIPS < 0.20 at all lead times (no blur accumulation)
- Temporal consistency maintained

---

## 1. Setup: Mount Drive & Check GPU

In [None]:
from google.colab import drive
import os

drive.mount('/content/drive')

DRIVE_DATA_ROOT = "/content/drive/MyDrive/SEVIR_Data"
os.makedirs(DRIVE_DATA_ROOT, exist_ok=True)

print(f"✓ Google Drive mounted")
print(f"✓ Data directory: {DRIVE_DATA_ROOT}")

In [None]:
!nvidia-smi

import torch
print(f"\n{'='*70}")
print("GPU CHECK")
print(f"{'='*70}")
print(f"PyTorch: {torch.__version__}")
print(f"CUDA Available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"VRAM: {torch.cuda.get_device_properties(0).total_memory/1e9:.1f} GB")
else:
    print("⚠️  WARNING: Select GPU runtime!")
print(f"{'='*70}")

## 2. Install Dependencies

In [None]:
!pip install -q h5py lpips tqdm matplotlib scikit-image pandas
print("✓ Dependencies installed")

## 3. Data Setup - Use ALL 541 Events (Stage 4 Standard)

In [None]:
from pathlib import Path

DATA_ROOT = "/content/drive/MyDrive/SEVIR_Data"
SEVIR_ROOT = f"{DATA_ROOT}/data/sevir"
CATALOG_PATH = f"{DATA_ROOT}/data/SEVIR_CATALOG.csv"

# Check data exists
catalog_exists = Path(CATALOG_PATH).exists()
vil_exists = Path(f"{SEVIR_ROOT}/vil/2019/SEVIR_VIL_STORMEVENTS_2019_0701_1231.h5").exists()

print(f"Data Check:")
print(f"  Catalog: {'✓' if catalog_exists else '✗'} {CATALOG_PATH}")
print(f"  VIL data: {'✓' if vil_exists else '✗'} {SEVIR_ROOT}/vil/2019/")

if not (catalog_exists and vil_exists):
    print("\n⚠ Data missing!")
else:
    print("\n✓ Data ready!")

## 4. Get Event IDs (Same as Stage 4)

In [None]:
import pandas as pd
import numpy as np

# Load catalog and get ALL VIL events
catalog = pd.read_csv(CATALOG_PATH, low_memory=False)
vil_catalog = catalog[catalog['img_type'] == 'vil'].copy()

print(f"Total VIL events in SEVIR: {len(vil_catalog)}")

# Get all unique event IDs
all_event_ids = vil_catalog['id'].unique().tolist()
print(f"Unique events: {len(all_event_ids)}")

# Create 80/20 train/val split (same seed as Stage 4)
np.random.seed(42)
shuffled_ids = np.random.permutation(all_event_ids)

n_train = int(len(all_event_ids) * 0.8)
all_train_ids = shuffled_ids[:n_train].tolist()
all_val_ids = shuffled_ids[n_train:].tolist()

print(f"\n📊 Dataset Split:")
print(f"  Train: {len(all_train_ids)} events")
print(f"  Val: {len(all_val_ids)} events")
print(f"  Total: {len(all_event_ids)} events")

# Save event ID files
os.makedirs(f"{DATA_ROOT}/data/samples", exist_ok=True)

TRAIN_IDS = f"{DATA_ROOT}/data/samples/all_train_ids.txt"
VAL_IDS = f"{DATA_ROOT}/data/samples/all_val_ids.txt"

with open(TRAIN_IDS, 'w') as f:
    f.write('\n'.join(all_train_ids))

with open(VAL_IDS, 'w') as f:
    f.write('\n'.join(all_val_ids))

print(f"\n✓ Event ID files ready")

## 5. Dataset Implementation (Multi-Step Version)

In [None]:
import h5py
import torch
from torch.utils.data import Dataset

class SevirMultiStepDataset(Dataset):
    """SEVIR VIL multi-step nowcasting dataset.
    
    Args:
        index: List of (file_path, file_index, event_id) tuples
        input_steps: Number of input frames (default: 12)
        output_steps: Number of output frames to predict (default: 6)
        target_size: Spatial size of frames
    """
    def __init__(self, index, input_steps=12, output_steps=6, target_size=(384, 384)):
        self.index = index
        self.in_steps = input_steps
        self.out_steps = output_steps
        self.target_size = target_size

    def __len__(self):
        return len(self.index)

    def __getitem__(self, idx):
        file_path, file_index, event_id = self.index[idx]

        with h5py.File(file_path, "r") as h5:
            data = h5["vil"][file_index].astype(np.float32) / 255.0

        total_frames = data.shape[2]
        max_start = total_frames - (self.in_steps + self.out_steps)
        t_start = np.random.randint(0, max(1, max_start + 1))

        # Input: 12 frames
        x = data[:, :, t_start:t_start + self.in_steps]
        
        # Output: 6 frames (next 6 time steps)
        y = data[:, :, t_start + self.in_steps:t_start + self.in_steps + self.out_steps]

        # Transpose to (T, H, W)
        x = np.transpose(x, (2, 0, 1))
        y = np.transpose(y, (2, 0, 1))

        return torch.from_numpy(x).float(), torch.from_numpy(y).float()


def build_index(catalog_path, ids_txt, sevir_root, modality="vil"):
    """Build index from event ID file."""
    with open(ids_txt, 'r') as f:
        event_ids = [line.strip() for line in f if line.strip()]

    catalog = pd.read_csv(catalog_path, low_memory=False)
    modality_cat = catalog[catalog["img_type"] == modality].copy()

    index = []
    for event_id in event_ids:
        event_rows = modality_cat[modality_cat["id"] == event_id]
        if event_rows.empty:
            continue

        row = event_rows.iloc[0]
        file_path = os.path.join(sevir_root, row["file_name"])
        if os.path.exists(file_path):
            index.append((file_path, int(row["file_index"]), event_id))

    print(f"✓ Built index: {len(index)} events")
    return index

print("✓ Multi-step dataset class defined")

## 6. Model: UNet2D Adapted for Multi-Step Output

In [None]:
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import lpips

# U-Net2D building blocks
def conv_block(in_ch, out_ch, use_bn=True):
    layers = [nn.Conv2d(in_ch, out_ch, 3, padding=1), nn.ReLU(inplace=True)]
    if use_bn: layers.append(nn.BatchNorm2d(out_ch))
    layers += [nn.Conv2d(out_ch, out_ch, 3, padding=1), nn.ReLU(inplace=True)]
    if use_bn: layers.append(nn.BatchNorm2d(out_ch))
    return nn.Sequential(*layers)

class Down(nn.Module):
    def __init__(self, in_ch, out_ch, use_bn=True):
        super().__init__()
        self.pool = nn.MaxPool2d(2)
        self.block = conv_block(in_ch, out_ch, use_bn)
    def forward(self, x): return self.block(self.pool(x))

class Up(nn.Module):
    def __init__(self, in_ch, out_ch, use_bn=True):
        super().__init__()
        self.up = nn.ConvTranspose2d(in_ch, out_ch, kernel_size=2, stride=2)
        self.block = conv_block(in_ch, out_ch, use_bn)
    def forward(self, x, skip):
        x = self.up(x)
        dy, dx = skip.size(-2) - x.size(-2), skip.size(-1) - x.size(-1)
        if dy or dx:
            x = F.pad(x, [dx//2, dx - dx//2, dy//2, dy - dy//2])
        x = torch.cat([skip, x], dim=1)
        return self.block(x)

class UNet2DMultiStep(nn.Module):
    """UNet2D adapted for multi-step forecasting.
    
    Args:
        in_channels: Number of input time steps (default: 12)
        out_channels: Number of output time steps (default: 6)
        base_ch: Base number of channels (default: 32)
        use_bn: Use batch normalization (default: True)
    """
    def __init__(self, in_channels=12, out_channels=6, base_ch=32, use_bn=True):
        super().__init__()
        self.inc = conv_block(in_channels, base_ch, use_bn)
        self.d1 = Down(base_ch, base_ch*2, use_bn)
        self.d2 = Down(base_ch*2, base_ch*4, use_bn)
        self.d3 = Down(base_ch*4, base_ch*8, use_bn)
        self.bottleneck = conv_block(base_ch*8, base_ch*16, use_bn)
        self.u3 = Up(base_ch*16, base_ch*8, use_bn)
        self.u2 = Up(base_ch*8, base_ch*4, use_bn)
        self.u1 = Up(base_ch*4, base_ch*2, use_bn)
        self.u0 = Up(base_ch*2, base_ch, use_bn)
        # Output 6 frames instead of 1
        self.outc = nn.Conv2d(base_ch, out_channels, kernel_size=1)

    def forward(self, x):
        c1 = self.inc(x)
        c2 = self.d1(c1)
        c3 = self.d2(c2)
        c4 = self.d3(c3)
        b  = self.bottleneck(c4)
        x = self.u3(b, c4)
        x = self.u2(x, c3)
        x = self.u1(x, c2)
        x = self.u0(x, c1)
        return self.outc(x)  # Shape: (B, 6, H, W)

# VGG Perceptual Loss (same as Stage 4)
class VGGPerceptualLoss(nn.Module):
    def __init__(self):
        super().__init__()
        vgg = torchvision.models.vgg16(weights='IMAGENET1K_V1').features
        self.slice1 = nn.Sequential(*list(vgg[:4]))
        self.slice2 = nn.Sequential(*list(vgg[4:9]))
        self.slice3 = nn.Sequential(*list(vgg[9:16]))
        self.slice4 = nn.Sequential(*list(vgg[16:23]))
        for param in self.parameters():
            param.requires_grad = False
        self.register_buffer('mean', torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1))
        self.register_buffer('std', torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1))

    def normalize(self, x):
        return (x - self.mean) / self.std

    def forward(self, pred, target):
        if pred.shape[1] == 1:
            pred = pred.repeat(1, 3, 1, 1)
            target = target.repeat(1, 3, 1, 1)
        pred = self.normalize(pred)
        target = self.normalize(target)
        
        loss = 0
        pred_1 = self.slice1(pred)
        target_1 = self.slice1(target)
        loss += F.mse_loss(pred_1, target_1)
        
        pred_2 = self.slice2(pred_1)
        target_2 = self.slice2(target_1)
        loss += F.mse_loss(pred_2, target_2)
        
        pred_3 = self.slice3(pred_2)
        target_3 = self.slice3(target_2)
        loss += F.mse_loss(pred_3, target_3)
        
        pred_4 = self.slice4(pred_3)
        target_4 = self.slice4(target_3)
        loss += F.mse_loss(pred_4, target_4)
        
        return loss / 4

print("✓ Multi-step UNet2D model defined")

## 7. Metrics (Per Lead Time)

In [None]:
VIP_THRESHOLDS = [16, 74, 133, 160, 181, 219]

def binarize(x, thr):
    return (x >= thr/255.0).to(torch.int32)

def scores(pred, truth, thresholds=VIP_THRESHOLDS):
    """Compute CSI, POD, SUCR, BIAS for single frame."""
    out = {}
    for t in thresholds:
        p = binarize(pred, t)
        y = binarize(truth, t)
        hits = ((p==1)&(y==1)).sum().item()
        miss = ((p==0)&(y==1)).sum().item()
        fa   = ((p==1)&(y==0)).sum().item()
        pod = hits / (hits + miss + 1e-9)
        sucr = hits / (hits + fa + 1e-9)
        csi = hits / (hits + miss + fa + 1e-9)
        bias = (hits + fa) / (hits + miss + 1e-9)
        out[t] = dict(POD=pod, SUCR=sucr, CSI=csi, BIAS=bias)
    return out

def scores_per_lead_time(pred, truth, thresholds=VIP_THRESHOLDS):
    """Compute metrics for each of 6 lead times.
    
    Args:
        pred: (B, 6, H, W) predictions
        truth: (B, 6, H, W) ground truth
    
    Returns:
        List of 6 dictionaries (one per lead time)
    """
    lead_time_scores = []
    for t in range(6):
        pred_t = pred[:, t:t+1, :, :]  # (B, 1, H, W)
        truth_t = truth[:, t:t+1, :, :]
        lead_time_scores.append(scores(pred_t, truth_t, thresholds))
    return lead_time_scores

lpips_fn = lpips.LPIPS(net='alex').cuda() if torch.cuda.is_available() else lpips.LPIPS(net='alex')

def compute_lpips(pred, target):
    """Compute LPIPS for single frame."""
    if pred.shape[1] == 1:
        pred = pred.repeat(1, 3, 1, 1)
        target = target.repeat(1, 3, 1, 1)
    return lpips_fn(pred, target).mean().item()

def compute_lpips_per_lead_time(pred, truth):
    """Compute LPIPS for each of 6 lead times."""
    lpips_scores = []
    for t in range(6):
        pred_t = pred[:, t:t+1, :, :]
        truth_t = truth[:, t:t+1, :, :]
        lpips_scores.append(compute_lpips(pred_t, truth_t))
    return lpips_scores

print("✓ Per-lead-time metrics defined")

## 8. Training Function (Multi-Step)

In [None]:
from tqdm.auto import tqdm
from torch.utils.data import DataLoader
import json
import time

def train_multistep_model(lambda_perc, epochs=10, batch_size=4, lr=1e-4, perceptual_scale=6000.0):
    """Train multi-step model with given lambda."""
    
    print(f"\n{'='*70}")
    print(f"TRAINING MULTI-STEP MODEL (Lambda = {lambda_perc})")
    print(f"{'='*70}\n")
    
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    
    # Load datasets
    train_index = build_index(CATALOG_PATH, TRAIN_IDS, SEVIR_ROOT, "vil")
    val_index = build_index(CATALOG_PATH, VAL_IDS, SEVIR_ROOT, "vil")
    
    train_dataset = SevirMultiStepDataset(train_index, input_steps=12, output_steps=6)
    val_dataset = SevirMultiStepDataset(val_index, input_steps=12, output_steps=6)
    
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=2, pin_memory=True)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=2, pin_memory=True)
    
    # Create model (12 input frames → 6 output frames)
    model = UNet2DMultiStep(in_channels=12, out_channels=6, base_ch=32, use_bn=True).to(device)
    mse_criterion = nn.MSELoss()
    perceptual_criterion = VGGPerceptualLoss().to(device) if lambda_perc > 0 else None
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    scaler = torch.amp.GradScaler('cuda') if device.type == 'cuda' else None
    
    # Enhanced history: track metrics PER LEAD TIME
    history = {
        'train_mse': [], 'train_perc': [], 
        'val_mse': [],
    }
    # Per lead time (t+5, t+10, t+15, t+20, t+25, t+30 min)
    for t in range(6):
        lead_min = (t+1) * 5
        history[f'val_csi_74_t{lead_min}'] = []
        history[f'val_csi_181_t{lead_min}'] = []
        history[f'val_csi_219_t{lead_min}'] = []
        history[f'val_lpips_t{lead_min}'] = []
    
    for epoch in range(epochs):
        # Train
        model.train()
        train_mse = 0
        train_perc = 0
        
        for x, y in tqdm(train_loader, desc=f"Epoch {epoch+1}/{epochs}", leave=False):
            x, y = x.to(device), y.to(device)  # x: (B, 12, H, W), y: (B, 6, H, W)
            optimizer.zero_grad()
            
            if scaler is not None:
                with torch.amp.autocast('cuda'):
                    pred = model(x)  # (B, 6, H, W)
                    mse_loss = mse_criterion(pred, y)
                    
                    if lambda_perc > 0:
                        # Compute perceptual loss averaged over all 6 frames
                        perc_loss_total = 0
                        for t in range(6):
                            perc_loss_total += perceptual_criterion(pred[:, t:t+1], y[:, t:t+1])
                        perc_loss = perc_loss_total / 6.0
                        perc_scaled = perc_loss / perceptual_scale
                        total = mse_loss + lambda_perc * perc_scaled
                    else:
                        perc_loss = torch.tensor(0.0)
                        total = mse_loss
                        
                scaler.scale(total).backward()
                scaler.unscale_(optimizer)
                nn.utils.clip_grad_norm_(model.parameters(), 1.0)
                scaler.step(optimizer)
                scaler.update()
            else:
                pred = model(x)
                mse_loss = mse_criterion(pred, y)
                
                if lambda_perc > 0:
                    perc_loss_total = 0
                    for t in range(6):
                        perc_loss_total += perceptual_criterion(pred[:, t:t+1], y[:, t:t+1])
                    perc_loss = perc_loss_total / 6.0
                    perc_scaled = perc_loss / perceptual_scale
                    total = mse_loss + lambda_perc * perc_scaled
                else:
                    perc_loss = torch.tensor(0.0)
                    total = mse_loss
                    
                total.backward()
                nn.utils.clip_grad_norm_(model.parameters(), 1.0)
                optimizer.step()
            
            train_mse += mse_loss.item()
            train_perc += perc_loss.item()
        
        train_mse /= len(train_loader)
        train_perc /= len(train_loader)
        
        # Validate - Track metrics PER LEAD TIME
        model.eval()
        val_mse = 0
        
        # Accumulate per lead time
        csi_per_lead = {t: {74: [], 181: [], 219: []} for t in range(6)}
        lpips_per_lead = {t: [] for t in range(6)}
        
        with torch.no_grad():
            for x, y in val_loader:
                x, y = x.to(device), y.to(device)
                pred = model(x)  # (B, 6, H, W)
                val_mse += mse_criterion(pred, y).item()
                
                # Metrics per lead time
                lead_scores = scores_per_lead_time(pred, y)
                lead_lpips = compute_lpips_per_lead_time(pred, y)
                
                for t in range(6):
                    csi_per_lead[t][74].append(lead_scores[t][74]['CSI'])
                    csi_per_lead[t][181].append(lead_scores[t][181]['CSI'])
                    csi_per_lead[t][219].append(lead_scores[t][219]['CSI'])
                    lpips_per_lead[t].append(lead_lpips[t])
        
        val_mse /= len(val_loader)
        
        # Store history
        history['train_mse'].append(train_mse)
        history['train_perc'].append(train_perc)
        history['val_mse'].append(val_mse)
        
        for t in range(6):
            lead_min = (t+1) * 5
            history[f'val_csi_74_t{lead_min}'].append(np.mean(csi_per_lead[t][74]))
            history[f'val_csi_181_t{lead_min}'].append(np.mean(csi_per_lead[t][181]))
            history[f'val_csi_219_t{lead_min}'].append(np.mean(csi_per_lead[t][219]))
            history[f'val_lpips_t{lead_min}'].append(np.mean(lpips_per_lead[t]))
        
        # Print progress (show t+5 and t+30 for comparison)
        print(f"Epoch {epoch+1}/{epochs}:")
        print(f"  MSE={val_mse:.4f}")
        print(f"  t+5min:  CSI@74={history['val_csi_74_t5'][-1]:.3f}, CSI@181={history['val_csi_181_t5'][-1]:.3f}, LPIPS={history['val_lpips_t5'][-1]:.3f}")
        print(f"  t+30min: CSI@74={history['val_csi_74_t30'][-1]:.3f}, CSI@181={history['val_csi_181_t30'][-1]:.3f}, LPIPS={history['val_lpips_t30'][-1]:.3f}")
    
    # Save
    os.makedirs('/content/outputs/checkpoints', exist_ok=True)
    torch.save({'model': model.state_dict(), 'history': history, 'lambda': lambda_perc}, 
               f'/content/outputs/checkpoints/multistep_lambda{lambda_perc}.pt')
    
    return history

print("✓ Multi-step training function defined")

## 9. RUN TRAINING! 🚀

**Strategy:**
1. Test baseline (λ=0) first
2. Check for blur accumulation across lead times
3. If blur emerges, test with perceptual loss

In [None]:
lambdas = [0.0, 0.0005]  # Start with baseline, optionally add perceptual
results = {}

start_time = time.time()

for lambda_val in lambdas:
    history = train_multistep_model(lambda_val, epochs=10, batch_size=4)
    
    # Extract final metrics for each lead time
    results[lambda_val] = {'history': history}
    for t in range(6):
        lead_min = (t+1) * 5
        results[lambda_val][f't{lead_min}'] = {
            'csi_74': max(history[f'val_csi_74_t{lead_min}']),
            'csi_181': max(history[f'val_csi_181_t{lead_min}']),
            'csi_219': max(history[f'val_csi_219_t{lead_min}']),
            'lpips': min(history[f'val_lpips_t{lead_min}'])
        }
    
    # Early analysis for baseline
    if lambda_val == 0.0:
        print(f"\n{'='*70}")
        print("BASELINE (λ=0) ANALYSIS")
        print(f"{'='*70}")
        
        # Check skill degradation
        csi_74_t5 = results[0.0]['t5']['csi_74']
        csi_74_t30 = results[0.0]['t30']['csi_74']
        degradation = ((csi_74_t5 - csi_74_t30) / csi_74_t5) * 100
        
        print(f"\nSkill Degradation (CSI@74):")
        print(f"  t+5min:  {csi_74_t5:.3f}")
        print(f"  t+30min: {csi_74_t30:.3f}")
        print(f"  Degradation: {degradation:.1f}%")
        
        # Check blur accumulation
        lpips_t5 = results[0.0]['t5']['lpips']
        lpips_t30 = results[0.0]['t30']['lpips']
        blur_increase = ((lpips_t30 - lpips_t5) / lpips_t5) * 100
        
        print(f"\nBlur Accumulation (LPIPS):")
        print(f"  t+5min:  {lpips_t5:.3f}")
        print(f"  t+30min: {lpips_t30:.3f}")
        print(f"  Increase: {blur_increase:.1f}%")
        
        # Check extreme event skill
        csi_181_t30 = results[0.0]['t30']['csi_181']
        print(f"\nExtreme Event Skill at t+30min:")
        print(f"  CSI@181: {csi_181_t30:.3f} (target: ≥0.40)")
        
        if lpips_t30 > 0.20:
            print(f"\n⚠️  BLUR ACCUMULATION DETECTED!")
            print(f"   → Perceptual loss recommended for next experiment")
        else:
            print(f"\n✅ SHARPNESS MAINTAINED!")
            print(f"   → Pure MSE sufficient for multi-step")

total_time = time.time() - start_time
print(f"\n✅ TRAINING COMPLETE in {total_time/60:.1f} minutes")

## 10. Detailed Results Analysis

In [None]:
print("\n" + "="*80)
print("MULTI-STEP RESULTS: Per Lead Time Analysis")
print("="*80 + "\n")

for lambda_val in results:
    print(f"\nLambda = {lambda_val}:")
    print(f"{'Lead Time':<12} {'CSI@74':<10} {'CSI@181':<10} {'CSI@219':<10} {'LPIPS':<10}")
    print("-"*80)
    
    for t in range(6):
        lead_min = (t+1) * 5
        res = results[lambda_val][f't{lead_min}']
        print(f"t+{lead_min}min{' '*6} {res['csi_74']:<10.3f} {res['csi_181']:<10.3f} {res['csi_219']:<10.3f} {res['lpips']:<10.3f}")

# Success evaluation
print("\n" + "="*80)
print("SUCCESS CRITERIA CHECK (Lambda=0.0)")
print("="*80)

baseline = results[0.0]
checks = [
    ("CSI@74 ≥ 0.70 at t+5min", baseline['t5']['csi_74'] >= 0.70, baseline['t5']['csi_74']),
    ("CSI@74 ≥ 0.50 at t+30min", baseline['t30']['csi_74'] >= 0.50, baseline['t30']['csi_74']),
    ("CSI@181 ≥ 0.40 at t+30min", baseline['t30']['csi_181'] >= 0.40, baseline['t30']['csi_181']),
    ("LPIPS < 0.20 at all times", all(baseline[f't{(t+1)*5}']['lpips'] < 0.20 for t in range(6)), "See table")
]

all_pass = True
for criterion, passed, value in checks:
    status = "✅ PASS" if passed else "❌ FAIL"
    print(f"{criterion:<40} {status:>10} (value: {value})")
    if not passed:
        all_pass = False

if all_pass:
    print("\n🎉 STAGE 5 SUCCESS! All criteria met.")
else:
    print("\n⚠️  Some criteria not met - consider architectural changes or perceptual loss")

## 11. Visualization: Skill Degradation Over Lead Time

In [None]:
import matplotlib.pyplot as plt

if 0.0 in results:
    fig, axes = plt.subplots(1, 3, figsize=(18, 5))
    
    lead_times = [5, 10, 15, 20, 25, 30]
    
    # Plot 1: CSI degradation
    csi_74 = [results[0.0][f't{t}']['csi_74'] for t in lead_times]
    csi_181 = [results[0.0][f't{t}']['csi_181'] for t in lead_times]
    csi_219 = [results[0.0][f't{t}']['csi_219'] for t in lead_times]
    
    axes[0].plot(lead_times, csi_74, 'o-', label='CSI@74 (Moderate)', linewidth=2)
    axes[0].plot(lead_times, csi_181, 's-', label='CSI@181 (Extreme)', linewidth=2)
    axes[0].plot(lead_times, csi_219, '^-', label='CSI@219 (Hail)', linewidth=2)
    axes[0].axhline(y=0.50, color='green', linestyle='--', alpha=0.5, label='Target (0.50)')
    axes[0].axhline(y=0.40, color='orange', linestyle='--', alpha=0.5, label='Min Extreme (0.40)')
    axes[0].set_xlabel('Lead Time (minutes)')
    axes[0].set_ylabel('CSI')
    axes[0].set_title('Forecast Skill vs Lead Time (λ=0)')
    axes[0].legend()
    axes[0].grid(alpha=0.3)
    
    # Plot 2: LPIPS (sharpness) over lead time
    lpips_vals = [results[0.0][f't{t}']['lpips'] for t in lead_times]
    axes[1].plot(lead_times, lpips_vals, 'o-', linewidth=2, color='purple')
    axes[1].axhline(y=0.20, color='red', linestyle='--', alpha=0.5, label='Blur threshold (0.20)')
    axes[1].axhline(y=0.137, color='green', linestyle='--', alpha=0.5, label='Stage 4 baseline (0.137)')
    axes[1].set_xlabel('Lead Time (minutes)')
    axes[1].set_ylabel('LPIPS (lower = sharper)')
    axes[1].set_title('Sharpness vs Lead Time (λ=0)')
    axes[1].legend()
    axes[1].grid(alpha=0.3)
    
    # Plot 3: Compare Stage 4 vs Stage 5 (t+5min)
    categories = ['CSI@74', 'CSI@181', 'CSI@219']
    stage4 = [0.818, 0.499, 0.334]  # From Stage 4 results
    stage5_t5 = [results[0.0]['t5']['csi_74'], results[0.0]['t5']['csi_181'], results[0.0]['t5']['csi_219']]
    stage5_t30 = [results[0.0]['t30']['csi_74'], results[0.0]['t30']['csi_181'], results[0.0]['t30']['csi_219']]
    
    x = np.arange(len(categories))
    width = 0.25
    
    axes[2].bar(x - width, stage4, width, label='Stage 4 (1-step)', alpha=0.8)
    axes[2].bar(x, stage5_t5, width, label='Stage 5 (t+5min)', alpha=0.8)
    axes[2].bar(x + width, stage5_t30, width, label='Stage 5 (t+30min)', alpha=0.8)
    axes[2].set_ylabel('CSI')
    axes[2].set_title('Stage 4 vs Stage 5 Performance')
    axes[2].set_xticks(x)
    axes[2].set_xticklabels(categories)
    axes[2].legend()
    axes[2].grid(alpha=0.3)
    
    plt.tight_layout()
    plt.savefig('/content/drive/MyDrive/stormfusion_results/stage5_multistep_analysis.png', dpi=150, bbox_inches='tight')
    plt.show()
    
    print("✅ Analysis plots saved!")

## 12. Save Results

In [None]:
!mkdir -p /content/drive/MyDrive/stormfusion_results/stage5_multistep
!cp -r /content/outputs/checkpoints/* /content/drive/MyDrive/stormfusion_results/stage5_multistep/

summary = f"""Stage 5 Results - Multi-Step Forecasting
==========================================
Dataset: {len(all_train_ids)} train / {len(all_val_ids)} val events
Architecture: UNet2D (12 input frames → 6 output frames)
Total Time: {total_time/60:.1f} min

BASELINE (λ=0.0) RESULTS:
-------------------------
"""

for t in range(6):
    lead_min = (t+1) * 5
    res = results[0.0][f't{lead_min}']
    summary += f"\nt+{lead_min}min: CSI@74={res['csi_74']:.3f}, CSI@181={res['csi_181']:.3f}, CSI@219={res['csi_219']:.3f}, LPIPS={res['lpips']:.3f}"

summary += "\n\nKEY INSIGHTS:\n"
baseline = results[0.0]
degradation = ((baseline['t5']['csi_74'] - baseline['t30']['csi_74']) / baseline['t5']['csi_74']) * 100
summary += f"\nSkill degradation (CSI@74): {degradation:.1f}% from t+5 to t+30"

lpips_change = ((baseline['t30']['lpips'] - baseline['t5']['lpips']) / baseline['t5']['lpips']) * 100
summary += f"\nBlur accumulation (LPIPS): {lpips_change:.1f}% from t+5 to t+30"

with open('/content/drive/MyDrive/stormfusion_results/stage5_multistep/summary.txt', 'w') as f:
    f.write(summary)

print("✅ Results saved to Drive!")
print(f"   Location: /content/drive/MyDrive/stormfusion_results/stage5_multistep/")
print(f"\n{summary}")

## ✅ Stage 5 Complete!

**What We Tested:**
- Multi-step forecasting (1→6 frames)
- Skill degradation over lead time
- Blur accumulation (LPIPS trend)
- Extreme event skill at longer horizons

**Expected Outcomes:**

1. **If sharpness maintained (LPIPS < 0.20 at all times):**
   - ✅ Pure MSE sufficient for multi-step
   - Ready for Stage 6 (Generative models)

2. **If blur accumulates (LPIPS > 0.20 at t+30):**
   - Add perceptual loss (λ≈0.0005)
   - Helps maintain sharpness over time

**Next Steps:**
- Review per-lead-time metrics
- Check analysis plots
- Proceed to Stage 6 (GANs/Diffusion for probabilistic forecasting)