# Stage 4: Perceptual Loss Lambda Sweep

**READY TO RUN**: Just click Runtime → Run All

This notebook:
1. Sets up environment (GPU, packages)
2. Mounts Google Drive for data access
3. Installs StormFusion code
4. Runs λ sweep: {0.0001, 0.0005, 0.001}
5. Compares results
6. Saves checkpoints to Drive

**Expected time:** 20-30 minutes on L4/A100 GPU

**Success criteria:**
- ✅ CSI@74 ≥ 0.65 (maintains forecast skill)
- ✅ LPIPS < 0.35 (improves sharpness)

**See:** `docs/WHY_PERCEPTUAL_LOSS_MATTERS.md` for context

---
## 1. Setup Environment

In [None]:
# Check GPU
import torch
print(f"GPU Available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU Name: {torch.cuda.get_device_name(0)}")
    print(f"GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")
else:
    print("⚠️  WARNING: No GPU detected. Training will be VERY slow.")
    print("   Go to Runtime → Change runtime type → Select GPU")

In [None]:
# Install dependencies
!pip install -q torch torchvision h5py pandas tqdm pyyaml lpips scikit-image
print("✅ Dependencies installed")

---
## 2. Mount Google Drive

In [None]:
from google.colab import drive
drive.mount('/content/drive')

# Verify data exists
import os
data_root = '/content/drive/MyDrive/SEVIR_Data'
if os.path.exists(data_root):
    print(f"✅ Data found at {data_root}")
    !ls -lh {data_root}
else:
    print(f"❌ Data not found at {data_root}")
    print("   Please ensure SEVIR data is uploaded to Google Drive")

---
## 3. Install StormFusion Code

**Option A:** Upload from local machine  
**Option B:** Clone from git (if you pushed to GitHub)  
**Option C:** Install files manually (implemented below)

In [None]:
# Create directory structure
!mkdir -p stormfusion/data
!mkdir -p stormfusion/models/losses
!mkdir -p stormfusion/training
!mkdir -p stormfusion/models/layers
!mkdir -p outputs/checkpoints
!mkdir -p outputs/logs
!mkdir -p data/samples

# Create __init__.py files
!touch stormfusion/__init__.py
!touch stormfusion/data/__init__.py
!touch stormfusion/models/__init__.py
!touch stormfusion/models/layers/__init__.py
!touch stormfusion/training/__init__.py

print("✅ Directory structure created")

### 📤 UPLOAD REQUIRED FILES

**Please upload these files from your local machine:**

1. `stormfusion/data/sevir_dataset.py`
2. `stormfusion/models/unet2d.py`
3. `stormfusion/models/layers/conv_blocks.py`
4. `stormfusion/models/losses/vgg_perceptual.py`
5. `stormfusion/models/losses/__init__.py`
6. `stormfusion/training/metrics.py`
7. `stormfusion/training/forecast_metrics.py`
8. `scripts/train_unet_with_perceptual.py`
9. `data/samples/tiny_train_ids.txt`
10. `data/samples/tiny_val_ids.txt`

**Run the cell below to upload files:**

In [None]:
from google.colab import files
import shutil

print("Upload files one by one (or use Drive if already uploaded)")
print("\nAfter uploading, manually move them to correct locations.")
print("\nAlternatively, upload entire stormfusion-sevir folder to Drive and:")
print("  !cp -r /content/drive/MyDrive/stormfusion-sevir/* .")

# Uncomment if you uploaded to Drive:
# !cp -r /content/drive/MyDrive/stormfusion-sevir/stormfusion ./
# !cp -r /content/drive/MyDrive/stormfusion-sevir/scripts ./
# !cp -r /content/drive/MyDrive/stormfusion-sevir/data/samples ./data/
# print("✅ Files copied from Drive")

### Alternative: Quick Copy from Drive

**If you uploaded the entire project to Drive:**

In [None]:
# Uncomment and modify path if needed:
# !cp -r /content/drive/MyDrive/stormfusion-sevir/stormfusion ./
# !cp -r /content/drive/MyDrive/stormfusion-sevir/scripts ./
# !cp -r /content/drive/MyDrive/stormfusion-sevir/data ./
# print("✅ Project copied from Drive")

### Verify Installation

In [None]:
# Test imports
try:
    from stormfusion.data.sevir_dataset import build_tiny_index, SevirNowcastDataset
    from stormfusion.models.unet2d import UNet2D
    from stormfusion.models.losses import VGGPerceptualLoss
    from stormfusion.training.metrics import mse, lpips_metric
    from stormfusion.training.forecast_metrics import scores
    print("✅ All imports successful")
except ImportError as e:
    print(f"❌ Import failed: {e}")
    print("   Please upload missing files")

### Setup Data Paths

In [None]:
# Link SEVIR data from Drive
!ln -sf /content/drive/MyDrive/SEVIR_Data/SEVIR_CATALOG.csv data/SEVIR_CATALOG.csv
!ln -sf /content/drive/MyDrive/SEVIR_Data/sevir data/sevir

# Verify
!ls -lh data/
print("\n✅ Data paths configured")

---
## 4. Lambda Sweep Training

**We'll train 3 models with different λ values:**
- λ = 0.0001 (conservative: minimal perceptual)
- λ = 0.0005 (balanced: recommended)
- λ = 0.001 (aggressive: maximum sharpness)

**Each training takes ~5-10 minutes on GPU**

In [None]:
import time
import json

# Lambda values to test
lambdas = [0.0001, 0.0005, 0.001]
results_summary = {}

start_time = time.time()

for lambda_val in lambdas:
    print("\n" + "="*80)
    print(f"TRAINING WITH LAMBDA = {lambda_val}")
    print("="*80 + "\n")
    
    run_start = time.time()
    
    # Run training
    !python scripts/train_unet_with_perceptual.py \
        --lambda_perc {lambda_val} \
        --epochs 10 \
        --run_name lambda{lambda_val}
    
    run_time = time.time() - run_start
    
    # Load results
    history_file = f'outputs/checkpoints/unet_perceptual_lambda{lambda_val}_history.json'
    try:
        with open(history_file) as f:
            hist = json.load(f)
        
        best_csi = max(hist['val_csi_74'])
        best_lpips = min(hist['val_lpips'])
        best_mse = min(hist['val_mse'])
        
        results_summary[lambda_val] = {
            'csi': best_csi,
            'lpips': best_lpips,
            'mse': best_mse,
            'time': run_time
        }
        
        print(f"\n✅ Lambda {lambda_val} completed in {run_time/60:.1f} minutes")
        print(f"   CSI@74: {best_csi:.3f}")
        print(f"   LPIPS:  {best_lpips:.3f}")
    except FileNotFoundError:
        print(f"\n❌ Lambda {lambda_val} failed - history file not found")
        results_summary[lambda_val] = {'error': 'Training failed'}

total_time = time.time() - start_time
print(f"\n{'='*80}")
print(f"TOTAL TIME: {total_time/60:.1f} minutes")
print(f"{'='*80}")

---
## 5. Compare Results

In [None]:
import pandas as pd

print("\n" + "="*80)
print("LAMBDA SWEEP RESULTS")
print("="*80 + "\n")

# Baseline for comparison
baseline_csi = 0.68
baseline_lpips = 0.40

print(f"Baseline (MSE only):")
print(f"  CSI@74: {baseline_csi:.3f}")
print(f"  LPIPS:  {baseline_lpips:.3f}")
print("\n" + "-"*80 + "\n")

best_lambda = None
best_score = -1

for lambda_val in lambdas:
    if 'error' in results_summary[lambda_val]:
        print(f"Lambda = {lambda_val}: FAILED")
        continue
    
    res = results_summary[lambda_val]
    csi = res['csi']
    lpips = res['lpips']
    mse = res['mse']
    
    # Check success criteria
    csi_pass = csi >= 0.65
    lpips_pass = lpips < 0.35
    success = csi_pass and lpips_pass
    
    # Calculate improvement
    csi_change = ((csi - baseline_csi) / baseline_csi) * 100
    lpips_change = ((lpips - baseline_lpips) / baseline_lpips) * 100
    
    print(f"Lambda = {lambda_val}:")
    print(f"  CSI@74:  {csi:.3f} {'✅' if csi_pass else '❌'} ({csi_change:+.1f}% vs baseline)")
    print(f"  LPIPS:   {lpips:.3f} {'✅' if lpips_pass else '❌'} ({lpips_change:+.1f}% vs baseline)")
    print(f"  Val MSE: {mse:.4f}")
    print(f"  SUCCESS: {'YES ✅✅✅' if success else 'NO ❌'}")
    print(f"  Time:    {res['time']/60:.1f} min")
    print()
    
    # Track best (prioritize CSI, then LPIPS)
    if success and csi > best_score:
        best_score = csi
        best_lambda = lambda_val

print("-"*80)
if best_lambda is not None:
    print(f"\n🎉 BEST LAMBDA: {best_lambda}")
    print(f"   CSI@74: {results_summary[best_lambda]['csi']:.3f}")
    print(f"   LPIPS:  {results_summary[best_lambda]['lpips']:.3f}")
    print(f"\n✅ Stage 4 SUCCESS! Use lambda={best_lambda} for Stage 5.")
else:
    print(f"\n⚠️  No lambda met both criteria.")
    print(f"   Check individual results above.")
    print(f"   May need to adjust lambda range or accept partial success.")

---
## 6. Visualize Training Curves

In [None]:
import matplotlib.pyplot as plt
import numpy as np

fig, axes = plt.subplots(2, 2, figsize=(14, 10))

colors = {'0.0001': 'blue', '0.0005': 'green', '0.001': 'red'}
baseline_csi = 0.68
baseline_lpips = 0.40

for lambda_val in lambdas:
    history_file = f'outputs/checkpoints/unet_perceptual_lambda{lambda_val}_history.json'
    try:
        with open(history_file) as f:
            hist = json.load(f)
        
        color = colors[str(lambda_val)]
        epochs = range(1, len(hist['val_csi_74']) + 1)
        
        # Plot 1: Training loss
        axes[0, 0].plot(epochs, hist['train_mse_loss'], label=f'λ={lambda_val}', color=color)
        axes[0, 0].set_xlabel('Epoch')
        axes[0, 0].set_ylabel('Train MSE')
        axes[0, 0].set_title('Training MSE Loss')
        axes[0, 0].legend()
        axes[0, 0].grid(True, alpha=0.3)
        
        # Plot 2: Val MSE
        axes[0, 1].plot(epochs, hist['val_mse'], label=f'λ={lambda_val}', color=color)
        axes[0, 1].set_xlabel('Epoch')
        axes[0, 1].set_ylabel('Val MSE')
        axes[0, 1].set_title('Validation MSE')
        axes[0, 1].legend()
        axes[0, 1].grid(True, alpha=0.3)
        
        # Plot 3: CSI@74
        axes[1, 0].plot(epochs, hist['val_csi_74'], label=f'λ={lambda_val}', color=color, linewidth=2)
        
        # Plot 4: LPIPS
        axes[1, 1].plot(epochs, hist['val_lpips'], label=f'λ={lambda_val}', color=color, linewidth=2)
        
    except FileNotFoundError:
        print(f"Could not plot lambda={lambda_val} - file not found")

# Add baselines and thresholds
axes[1, 0].axhline(y=baseline_csi, color='black', linestyle='--', label='Baseline (MSE only)', linewidth=1)
axes[1, 0].axhline(y=0.65, color='gray', linestyle=':', label='Target (≥0.65)', linewidth=1)
axes[1, 0].set_xlabel('Epoch')
axes[1, 0].set_ylabel('CSI@74')
axes[1, 0].set_title('Forecast Skill (CSI@74)')
axes[1, 0].legend()
axes[1, 0].grid(True, alpha=0.3)

axes[1, 1].axhline(y=baseline_lpips, color='black', linestyle='--', label='Baseline (MSE only)', linewidth=1)
axes[1, 1].axhline(y=0.35, color='gray', linestyle=':', label='Target (<0.35)', linewidth=1)
axes[1, 1].set_xlabel('Epoch')
axes[1, 1].set_ylabel('LPIPS')
axes[1, 1].set_title('Perceptual Quality (LPIPS, lower is better)')
axes[1, 1].legend()
axes[1, 1].grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig('outputs/stage4_lambda_sweep_curves.png', dpi=150, bbox_inches='tight')
plt.show()

print("\n✅ Plots saved to outputs/stage4_lambda_sweep_curves.png")

---
## 7. Save Results to Drive

In [None]:
# Copy checkpoints and logs to Drive for safekeeping
!mkdir -p /content/drive/MyDrive/stormfusion_results/stage4
!cp -r outputs/checkpoints/*.pt /content/drive/MyDrive/stormfusion_results/stage4/
!cp -r outputs/checkpoints/*.json /content/drive/MyDrive/stormfusion_results/stage4/
!cp -r outputs/logs/*.log /content/drive/MyDrive/stormfusion_results/stage4/
!cp outputs/stage4_lambda_sweep_curves.png /content/drive/MyDrive/stormfusion_results/stage4/

print("✅ Results saved to Drive: /content/drive/MyDrive/stormfusion_results/stage4/")
print("\nYou can now download:")
print("  - Best checkpoint: unet_perceptual_lambda{best_lambda}_best.pt")
print("  - Training curves: stage4_lambda_sweep_curves.png")
print("  - Full logs: outputs/logs/")

---
## 8. Generate Summary Report

In [None]:
# Create summary report
report = []
report.append("="*80)
report.append("STAGE 4: PERCEPTUAL LOSS LAMBDA SWEEP - FINAL REPORT")
report.append("="*80)
report.append("")
report.append(f"Date: {time.strftime('%Y-%m-%d %H:%M:%S')}")
report.append(f"GPU: {torch.cuda.get_device_name(0) if torch.cuda.is_available() else 'CPU'}")
report.append(f"Total Training Time: {total_time/60:.1f} minutes")
report.append("")
report.append("Baseline (MSE only):")
report.append(f"  CSI@74: 0.68")
report.append(f"  LPIPS:  0.40")
report.append("")
report.append("-"*80)
report.append("")

for lambda_val in lambdas:
    if 'error' in results_summary[lambda_val]:
        report.append(f"Lambda = {lambda_val}: FAILED")
        report.append("")
        continue
    
    res = results_summary[lambda_val]
    csi = res['csi']
    lpips = res['lpips']
    mse = res['mse']
    
    csi_pass = "PASS" if csi >= 0.65 else "FAIL"
    lpips_pass = "PASS" if lpips < 0.35 else "FAIL"
    success = csi >= 0.65 and lpips < 0.35
    
    report.append(f"Lambda = {lambda_val}:")
    report.append(f"  Val CSI@74:  {csi:.3f} ({csi_pass})")
    report.append(f"  Val LPIPS:   {lpips:.3f} ({lpips_pass})")
    report.append(f"  Val MSE:     {mse:.4f}")
    report.append(f"  Training Time: {res['time']/60:.1f} min")
    report.append(f"  SUCCESS: {'YES' if success else 'NO'}")
    report.append("")

report.append("-"*80)
report.append("")

if best_lambda is not None:
    report.append(f"BEST LAMBDA: {best_lambda}")
    report.append(f"  CSI@74: {results_summary[best_lambda]['csi']:.3f}")
    report.append(f"  LPIPS:  {results_summary[best_lambda]['lpips']:.3f}")
    report.append("")
    report.append("RECOMMENDATION: Use lambda={} for Stage 5".format(best_lambda))
else:
    report.append("NO LAMBDA MET BOTH CRITERIA")
    report.append("")
    report.append("RECOMMENDATION: Review results and either:")
    report.append("  1. Accept partial success (best CSI or best LPIPS)")
    report.append("  2. Try intermediate lambda values")
    report.append("  3. Adjust scaling factor")

report.append("")
report.append("="*80)
report.append("See docs/WHY_PERCEPTUAL_LOSS_MATTERS.md for context")
report.append("="*80)

report_text = "\n".join(report)
print(report_text)

# Save report
with open('outputs/logs/stage4_final_report.txt', 'w') as f:
    f.write(report_text)

!cp outputs/logs/stage4_final_report.txt /content/drive/MyDrive/stormfusion_results/stage4/

print("\n✅ Report saved to outputs/logs/stage4_final_report.txt")
print("   and backed up to Drive")

---
## ✅ Stage 4 Complete!

### Next Steps:

1. **Review results above** - Which λ won?
2. **Download checkpoint** from Drive if needed
3. **Proceed to Stage 5** - Multi-step forecasting (6 frames)

### Files Saved:
- Checkpoints: `/content/drive/MyDrive/stormfusion_results/stage4/*.pt`
- Histories: `/content/drive/MyDrive/stormfusion_results/stage4/*.json`
- Plots: `stage4_lambda_sweep_curves.png`
- Report: `stage4_final_report.txt`

### Success Criteria:
- ✅ CSI@74 ≥ 0.65 (maintains forecast skill)
- ✅ LPIPS < 0.35 (improves sharpness)

If both met: **Stage 4 SUCCESS!** 🎉

### Remember:
This enables spatial granularity in probabilistic forecasting (Stages 6-7).  
See `docs/WHY_PERCEPTUAL_LOSS_MATTERS.md` for full context.