# 01 Environment Check

**Stage 0: Setup**

This notebook validates the StormFusion-SEVIR environment is correctly set up.

## Checks:
1. Import all core packages
2. Verify PyTorch + GPU availability
3. Test UNet2D forward pass
4. Test ConvLSTM forward pass
5. Test data loading stubs
6. Test metric calculations

## 1. Import Core Packages

In [None]:
import sys
import torch
import numpy as np
import matplotlib.pyplot as plt
from pathlib import Path

print(f"Python version: {sys.version}")
print(f"PyTorch version: {torch.__version__}")
print(f"NumPy version: {np.__version__}")

## 2. Check PyTorch Device

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Device: {device}")

if torch.cuda.is_available():
    print(f"  GPU: {torch.cuda.get_device_name(0)}")
    print(f"  Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")
elif torch.backends.mps.is_available():
    print("  Apple Silicon MPS available")
    device = torch.device("mps")
else:
    print("  Running on CPU")

## 3. Test UNet2D Forward Pass

In [None]:
from stormfusion.models.unet2d import UNet2D

# Create model
model = UNet2D(in_channels=12, out_channels=1).to(device)
print(f"UNet2D created with {sum(p.numel() for p in model.parameters()):,} parameters")

# Test forward pass
batch_size = 2
t_in = 12
h, w = 128, 128  # Small resolution for testing

x = torch.randn(batch_size, t_in, h, w).to(device)
print(f"Input shape: {x.shape}")

with torch.no_grad():
    y = model(x)
    
print(f"Output shape: {y.shape}")
print(f"✓ UNet2D forward pass successful!")

# Visualize one sample
fig, axes = plt.subplots(1, 2, figsize=(10, 4))
axes[0].imshow(x[0, -1].cpu().numpy(), cmap='turbo')
axes[0].set_title('Last Input Frame (t=55 min)')
axes[0].axis('off')

axes[1].imshow(y[0, 0].cpu().numpy(), cmap='turbo')
axes[1].set_title('Prediction (t=60 min)')
axes[1].axis('off')

plt.suptitle('UNet2D Dummy Forward Pass', fontweight='bold')
plt.tight_layout()
plt.show()

## 4. Test ConvLSTM Forward Pass

In [None]:
from stormfusion.models.convlstm import ConvLSTM

# Create model
model_lstm = ConvLSTM(in_channels=1, hidden_dim=64, kernel_size=3, num_layers=2).to(device)
print(f"ConvLSTM created with {sum(p.numel() for p in model_lstm.parameters()):,} parameters")

# Test forward pass
x_seq = torch.randn(batch_size, t_in, 1, h, w).to(device)
print(f"Input shape: {x_seq.shape}")

with torch.no_grad():
    y_seq = model_lstm(x_seq)
    
print(f"Output shape: {y_seq.shape}")
print(f"✓ ConvLSTM forward pass successful!")

## 5. Test Data Loading Stub

In [None]:
from stormfusion.data.sevir_dataset import build_tiny_index

# Check if data directory exists
data_dir = Path("../data")
if not data_dir.exists():
    print(f"⚠ Data directory not found: {data_dir}")
    print("  This is expected for Stage 0 - data download happens in Stage 1")
else:
    print(f"✓ Data directory found: {data_dir}")

# Test stub function
try:
    result = build_tiny_index(None, None, None)
    print(f"build_tiny_index stub returns: {result}")
    print("✓ Data loading stub functional")
except NotImplementedError as e:
    print(f"✓ build_tiny_index raises NotImplementedError as expected")
    print(f"  (Will be implemented in Stage 1)")

## 6. Test Metric Calculations

In [None]:
from stormfusion.training.forecast_metrics import scores

# Create dummy predictions and targets
pred = torch.rand(batch_size, 1, h, w)
target = torch.rand(batch_size, 1, h, w)

# Test metrics
thresholds = [0.1, 0.5]
metrics = scores(pred, target, thresholds)

print("Forecast Metrics (dummy data):")
for thr in thresholds:
    print(f"  Threshold {thr}:")
    for metric_name in ['pod', 'sucr', 'csi', 'bias']:
        key = f"{metric_name}_{thr}"
        if key in metrics:
            print(f"    {metric_name.upper()}: {metrics[key]:.3f}")

print("\n✓ Metric calculation successful!")

## Summary

✅ **Stage 0 Complete!**

All core components are functional:
- Python environment set up
- PyTorch installed and working
- UNet2D model forward pass works
- ConvLSTM model forward pass works
- Data loading stubs in place
- Metrics calculation works

**Next Steps:**
- Stage 1: Download tiny SEVIR dataset (8 train / 4 val events)
- Implement `build_tiny_index` function
- Create `SevirNowcastDataset` class
- Test data loading pipeline

In [None]:
# Clean up
import gc
del model, model_lstm, x, y, x_seq, y_seq
if torch.cuda.is_available():
    torch.cuda.empty_cache()
gc.collect()

print("✓ Environment check complete!")