# 05: Test Full Integrated Model

**Purpose:** Test the complete SGT model (all 7 components together)

**What this does:**
- Create the full StormGraphTransformer model
- Test end-to-end forward pass
- Verify shapes and parameter counts
- Test with both dummy data and real SEVIR samples

**What this does NOT do:**
- Training (that's in notebook 06/07)
- Evaluation on full dataset
- Hyperparameter tuning

**Expected time:** 5 minutes

---

**Prerequisites:** 
- Run `01_Setup_and_Environment.ipynb` first
- Run `04_Test_Model_Components.ipynb` to verify each module works

## Step 1: Setup

In [None]:
import sys
import torch
import torch.nn as nn
from pathlib import Path

# Add repository to path
REPO_PATH = '/content/stormfusion-sevir'
if REPO_PATH not in sys.path:
    sys.path.insert(0, REPO_PATH)

print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}\n")

## Step 2: Create Full SGT Model

In [None]:
from stormfusion.models.sgt import create_sgt_model

print("="*70)
print("CREATING FULL STORMGRAPHTRANSFORMER MODEL")
print("="*70)

# Create model with default configuration
model = create_sgt_model(
    modalities=['vil', 'ir069', 'ir107', 'lght'],
    input_steps=12,
    output_steps=12,
    base_channels=64
).to(device)

print("\n✅ Model created successfully\n")

## Step 3: Model Architecture Summary

In [None]:
print("="*70)
print("MODEL ARCHITECTURE")
print("="*70)

# Count parameters per module
print("\nParameters per module:")
print("-" * 50)

modules = [
    ('encoder', 'MultiModalEncoder'),
    ('detector', 'StormCellDetector'),
    ('gnn', 'StormGNN'),
    ('transformer', 'SpatioTemporalTransformer'),
    ('decoder', 'PhysicsDecoder'),
    ('physics_loss', 'ConservationLoss')
]

total_params = 0
for attr_name, module_name in modules:
    if hasattr(model, attr_name):
        module = getattr(model, attr_name)
        params = sum(p.numel() for p in module.parameters())
        total_params += params
        print(f"{module_name:30s}: {params:>12,} params")

print("-" * 50)
print(f"{'TOTAL':30s}: {total_params:>12,} params")
print(f"{'':30s}  {total_params/1e6:>10.2f} M")
print("\n" + "="*70)

## Step 4: Test Forward Pass (Dummy Data)

In [None]:
print("="*70)
print("TEST: FORWARD PASS WITH DUMMY DATA")
print("="*70)

# Create dummy input (batch_size=2)
B = 2
dummy_input = {
    'vil': torch.randn(B, 12, 384, 384).to(device),
    'ir069': torch.randn(B, 12, 384, 384).to(device),
    'ir107': torch.randn(B, 12, 384, 384).to(device),
    'lght': torch.randn(B, 12, 384, 384).to(device)
}

print("\nInput shapes:")
for mod, tensor in dummy_input.items():
    print(f"  {mod:8s}: {tuple(tensor.shape)}")

# Forward pass
try:
    print("\nRunning forward pass...")
    with torch.no_grad():
        output = model(dummy_input)
    
    print(f"\n✅ Forward pass successful!")
    print(f"\nOutput shape: {tuple(output.shape)}")
    print(f"Expected: ({B}, 12, 384, 384)")
    
    # Check statistics
    print(f"\nOutput statistics:")
    print(f"  Min: {output.min().item():.4f}")
    print(f"  Max: {output.max().item():.4f}")
    print(f"  Mean: {output.mean().item():.4f}")
    print(f"  Std: {output.std().item():.4f}")
    
    if output.shape == (B, 12, 384, 384):
        print("\n✅ Output shape correct!")
    else:
        print(f"\n⚠️  Shape mismatch: got {tuple(output.shape)}")
    
except Exception as e:
    print(f"\n❌ Error during forward pass: {e}")
    import traceback
    traceback.print_exc()

print("\n" + "="*70)

## Step 5: Test Memory Usage

In [None]:
if torch.cuda.is_available():
    print("="*70)
    print("GPU MEMORY USAGE")
    print("="*70)
    
    # Clear cache
    torch.cuda.empty_cache()
    torch.cuda.reset_peak_memory_stats()
    
    # Forward pass
    with torch.no_grad():
        output = model(dummy_input)
    
    allocated = torch.cuda.memory_allocated() / 1e9
    peak = torch.cuda.max_memory_allocated() / 1e9
    total = torch.cuda.get_device_properties(0).total_memory / 1e9
    
    print(f"\nMemory usage (inference):")
    print(f"  Current allocated: {allocated:.2f} GB")
    print(f"  Peak allocated: {peak:.2f} GB")
    print(f"  Total GPU memory: {total:.2f} GB")
    print(f"  Usage: {peak/total*100:.1f}%")
    
    if peak < total * 0.8:
        print("\n✅ Memory usage reasonable")
    else:
        print("\n⚠️  High memory usage - may need smaller batch size")
    
    print("\n" + "="*70)
else:
    print("⏭️  Skipping GPU memory check (no GPU available)")

## Step 6: Test with Real SEVIR Data (Optional)

**Only runs if you have SEVIR data from notebook 02**

In [None]:
from pathlib import Path

DRIVE_ROOT = "/content/drive/MyDrive/SEVIR_Data"
CATALOG_PATH = f"{DRIVE_ROOT}/data/SEVIR_CATALOG.csv"

# Check if data exists
if Path(CATALOG_PATH).exists():
    print("="*70)
    print("TEST: FORWARD PASS WITH REAL SEVIR DATA")
    print("="*70)
    
    try:
        # Import dataset
        from stormfusion.data.sevir_multimodal import SEVIRMultiModalDataset
        import pandas as pd
        
        # Load catalog
        catalog = pd.read_csv(CATALOG_PATH, low_memory=False)
        
        # Get small subset
        train_events = catalog[catalog['split'] == 'train']['id'].unique()[:2]
        
        print(f"\nLoading {len(train_events)} training events...")
        
        # Create dataset
        dataset = SEVIRMultiModalDataset(
            catalog=catalog,
            data_root=f"{DRIVE_ROOT}/data/sevir",
            event_ids=train_events,
            input_steps=12,
            output_steps=12,
            modalities=['vil', 'ir069', 'ir107', 'lght']
        )
        
        print(f"Dataset size: {len(dataset)}")
        
        # Load one sample
        print("\nLoading first sample...")
        inputs, targets = dataset[0]
        
        print("\nInput shapes:")
        for mod, data in inputs.items():
            print(f"  {mod:8s}: {tuple(data.shape)}")
            if data.abs().sum() < 0.01:
                print(f"    ⚠️  All zeros (missing modality)")
        
        print(f"\nTarget shape: {tuple(targets.shape)}")
        
        # Add batch dimension and move to device
        batch_inputs = {k: v.unsqueeze(0).to(device) for k, v in inputs.items()}
        
        # Forward pass
        print("\nRunning forward pass with real data...")
        with torch.no_grad():
            real_output = model(batch_inputs)
        
        print(f"\n✅ Forward pass with real data successful!")
        print(f"Output shape: {tuple(real_output.shape)}")
        
        # Statistics
        print(f"\nOutput statistics:")
        print(f"  Min: {real_output.min().item():.4f}")
        print(f"  Max: {real_output.max().item():.4f}")
        print(f"  Mean: {real_output.mean().item():.4f}")
        print(f"  Std: {real_output.std().item():.4f}")
        
    except Exception as e:
        print(f"\n⚠️  Could not test with real data: {e}")
        import traceback
        traceback.print_exc()
        print("\n   This is OK - the model still works with dummy data")
    
    print("\n" + "="*70)
else:
    print("⏭️  Skipping real data test (no SEVIR catalog found)")
    print(f"   Run notebook 02 to download data first")

## Step 7: Test Gradient Flow

In [None]:
print("="*70)
print("TEST: GRADIENT FLOW (BACKPROPAGATION)")
print("="*70)

# Create optimizer
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)

# Forward pass
output = model(dummy_input)

# Create dummy target
target = torch.randn_like(output)

# Compute loss
loss = nn.MSELoss()(output, target)
print(f"\nLoss value: {loss.item():.6f}")

# Backward pass
try:
    print("\nRunning backward pass...")
    optimizer.zero_grad()
    loss.backward()
    
    # Check gradients
    has_gradients = False
    nan_gradients = False
    
    for name, param in model.named_parameters():
        if param.grad is not None:
            has_gradients = True
            if torch.isnan(param.grad).any() or torch.isinf(param.grad).any():
                nan_gradients = True
                print(f"  ⚠️  NaN/Inf gradient in: {name}")
    
    if has_gradients and not nan_gradients:
        print("\n✅ Gradients computed successfully!")
        print("   Model is ready for training")
    elif nan_gradients:
        print("\n⚠️  NaN/Inf gradients detected - may need gradient clipping")
    else:
        print("\n⚠️  No gradients found")
    
except Exception as e:
    print(f"\n❌ Error during backward pass: {e}")
    import traceback
    traceback.print_exc()

print("\n" + "="*70)

## Summary

**What we verified:**
- ✅ Full SGT model can be created (~5.3M parameters)
- ✅ Forward pass works with dummy data
- ✅ Output shapes are correct
- ✅ Memory usage is reasonable
- ✅ Gradient flow works (backpropagation)
- ✅ (Optional) Works with real SEVIR data

**Model specification:**
```
Input:  4 modalities × 12 frames × 384×384
Output: 1 modality × 12 frames × 384×384
Parameters: ~5.3M
Memory: ~X.X GB (inference)
```

**Architecture:**
```
MultiModalEncoder (4 modalities → unified)
    ↓
StormCellDetector (spatial → graph)
    ↓
StormGNN (message passing)
    ↓
SpatioTemporalTransformer (attention)
    ↓
PhysicsDecoder (features → predictions)
    ↓
ConservationLoss (physics constraints)
```

**Next steps:**
1. If all tests passed ✅, proceed to `06_Small_Scale_Training.ipynb`
2. That notebook will train on a small subset (10-20 events)
3. Verify training works before scaling to full dataset

---

**If any test failed:**
- Go back to `04_Test_Model_Components.ipynb`
- Identify which component has issues
- Check error messages and tracebacks
- Verify all dependencies are installed