# 04: Test Model Components (Individually)

**Purpose:** Test each SGT module in isolation before integration

**What this does:**
- Test each of the 7 SGT modules independently
- Verify input/output shapes
- Check for tensor errors
- Ensure each component works before assembly

**What this does NOT do:**
- Load real SEVIR data (uses dummy tensors)
- Train the model
- Test the integrated model (that's notebook 05)

**Expected time:** 5-10 minutes

---

## ⚠️ IMPORTANT: How to Run This Notebook

**YOU MUST run cells IN ORDER:**
1. **FIRST:** Run "Step 1: Setup" cell below (mounts Drive, clones repo)
2. **THEN:** Run the other cells in sequence

**If you skip Step 1, you'll get `ModuleNotFoundError`!**

---

**Prerequisites:** 
- Run `01_Setup_and_Environment.ipynb` first
- Repository should be cloned and in Python path

## Step 1: Setup

In [None]:
from google.colab import drive
import sys
import os
import torch
import torch.nn as nn

# Mount Drive
print("Mounting Google Drive...")
drive.mount('/content/drive', force_remount=False)
print("✅ Drive mounted\n")

# Install dependencies (each Colab session needs this!)
print("Installing dependencies...")
!pip install -q torch-geometric h5py pandas tqdm matplotlib lpips scikit-image scipy
print("✅ Dependencies installed\n")

# Clone/update repository
REPO_PATH = '/content/stormfusion-sevir'
if not os.path.exists(REPO_PATH):
    print("Cloning repository...")
    !git clone https://github.com/syedhaliz/stormfusion-sevir.git {REPO_PATH}
    print("✅ Repository cloned\n")
else:
    print("Repository exists, pulling latest changes...")
    !cd {REPO_PATH} && git pull
    print("✅ Repository updated\n")

# Add repo to path
if REPO_PATH not in sys.path:
    sys.path.insert(0, REPO_PATH)
    print(f"✅ Added {REPO_PATH} to Python path\n")

# Force reload of modules to get latest code
import importlib
for module_name in ['stormfusion.models.sgt.encoder', 'stormfusion.models.sgt.decoder', 
                    'stormfusion.models.sgt.detector', 'stormfusion.models.sgt.gnn',
                    'stormfusion.models.sgt.transformer', 'stormfusion.models.sgt.physics_loss']:
    if module_name in sys.modules:
        importlib.reload(sys.modules[module_name])
        
print("✅ Modules reloaded\n")

print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}\n")

## Step 2: Test MultiModalEncoder

**What it does:** Encodes 4 input modalities into unified representation

**Expected:**
- Input: `{vil: [B,12,384,384], ir069: [B,12,384,384], ir107: [B,12,384,384], lght: [B,12,384,384]}`
- Output: `[B, 256, 24, 24]` (encoded spatial features)

In [None]:
# Ensure repo is in path (in case you skipped setup cell)
import sys
REPO_PATH = '/content/stormfusion-sevir'
if REPO_PATH not in sys.path:
    print("⚠️  WARNING: Run Step 1 (Setup) cell first!")
    print("   Adding repo to path now, but you should run the setup cell.\n")
    sys.path.insert(0, REPO_PATH)

from stormfusion.models.sgt.encoder import MultiModalEncoder
import torch

print("="*70)
print("TEST 1: MultiModalEncoder")
print("="*70)

# Create encoder (correct parameters: hidden_dim not base_channels!)
encoder = MultiModalEncoder(
    modalities=['vil', 'ir069', 'ir107', 'lght'],
    input_steps=12,
    hidden_dim=128  # ← Correct parameter name
).to(device)

# Count parameters
encoder_params = sum(p.numel() for p in encoder.parameters())
print(f"\nEncoder parameters: {encoder_params:,}")

# Create dummy input (batch_size=2)
B = 2
dummy_input = {
    'vil': torch.randn(B, 12, 384, 384).to(device),
    'ir069': torch.randn(B, 12, 384, 384).to(device),
    'ir107': torch.randn(B, 12, 384, 384).to(device),
    'lght': torch.randn(B, 12, 384, 384).to(device)
}

print(f"\nInput shapes:")
for mod, tensor in dummy_input.items():
    print(f"  {mod:8s}: {tuple(tensor.shape)}")

# Forward pass
try:
    with torch.no_grad():
        encoded = encoder(dummy_input)
    
    print(f"\nOutput shape: {tuple(encoded.shape)}")
    print(f"Expected: ({B}, 128, 96, 96)")  # hidden_dim=128, 384/4=96
    
    if encoded.shape == (B, 128, 96, 96):
        print("\n✅ Encoder works correctly!")
    else:
        print(f"\n⚠️  Shape mismatch: got {tuple(encoded.shape)}")
        
except Exception as e:
    print(f"\n❌ Error: {e}")
    import traceback
    traceback.print_exc()

## Step 3: Test StormCellDetector

**What it does:** Detects storm cells and converts to graph nodes

**Expected:**
- Input: `[B, 256, 24, 24]` (encoded features)
- Output: Graph with detected nodes (variable number per sample)

In [None]:
from stormfusion.models.sgt.detector import StormCellDetector

print("="*70)
print("TEST 2: StormCellDetector")
print("="*70)

# Create detector (correct parameters!)
detector = StormCellDetector(
    feature_dim=128,  # ← Correct parameter name
    min_intensity=0.3,
    min_distance=8,
    max_storms=50
).to(device)

detector_params = sum(p.numel() for p in detector.parameters())
print(f"\nDetector parameters: {detector_params:,}")

# Detector needs BOTH encoded features AND vil_input
print(f"\nInput shapes:")
print(f"  Encoded features: {tuple(encoded.shape)}")
print(f"  VIL input: {tuple(dummy_input['vil'].shape)}")

try:
    with torch.no_grad():
        # Detector needs: forward(features, vil_input)
        node_features, node_positions, batch_idx = detector(encoded, dummy_input['vil'])
    
    print(f"\nDetector output:")
    print(f"  Number of batches: {len(node_features)}")
    print(f"  Batch 0 - nodes: {node_features[0].shape[0]}, features: {node_features[0].shape[1]}")
    print(f"  Batch 1 - nodes: {node_features[1].shape[0]}, features: {node_features[1].shape[1]}")
    print(f"  Batch index shape: {batch_idx.shape}")
    
    print("\n✅ Detector works correctly!")
    
except Exception as e:
    print(f"\n❌ Error: {e}")
    import traceback
    traceback.print_exc()

## Step 4: Test StormGNN

**What it does:** Updates node features via message passing

**Expected:**
- Input: Graph with node features `[N, 128]`
- Output: Graph with updated features `[N, 128]`

In [None]:
from stormfusion.models.sgt.gnn import StormGNN, StormGraphBuilder

print("="*70)
print("TEST 3: StormGNN")
print("="*70)

# Create graph builder
graph_builder = StormGraphBuilder(k_neighbors=8)

# Build graph from detector output
graph_data = graph_builder.build_graph(node_positions, node_features, batch_idx)

print(f"\nGraph constructed:")
print(f"  Nodes: {graph_data.x.shape[0]}")
print(f"  Features per node: {graph_data.x.shape[1]}")
print(f"  Edges: {graph_data.edge_index.shape[1]}")

# Create GNN
gnn = StormGNN(
    hidden_dim=128,
    num_layers=3,
    num_heads=4,
    dropout=0.1
).to(device)

gnn_params = sum(p.numel() for p in gnn.parameters())
print(f"\nGNN parameters: {gnn_params:,}")

try:
    with torch.no_grad():
        # Move graph to device
        graph_data = graph_data.to(device)
        
        # Apply GNN
        updated_features, attention_weights = gnn(graph_data)
    
    print(f"\nGNN output:")
    print(f"  Updated features shape: {tuple(updated_features.shape)}")
    print(f"  Attention layers: {len(attention_weights)}")
    
    # Should preserve shape
    if updated_features.shape == graph_data.x.shape:
        print("\n✅ GNN works correctly!")
    else:
        print(f"\n⚠️  Shape changed unexpectedly")
    
except Exception as e:
    print(f"\n❌ Error: {e}")
    import traceback
    traceback.print_exc()

## Step 5: Test SpatioTemporalTransformer

**What it does:** Applies attention over spatial patches and temporal steps

**Expected:**
- Input: `[B, 256, 24, 24]` (spatial features) + graph context
- Output: `[B, 256, 24, 24]` (attended features)

In [None]:
from stormfusion.models.sgt.transformer import SpatioTemporalTransformer

print("="*70)
print("TEST 4: SpatioTemporalTransformer")
print("="*70)

# Create transformer
transformer = SpatioTemporalTransformer(
    dim=256,
    depth=4,
    heads=8,
    mlp_dim=1024,
    patch_size=4
).to(device)

transformer_params = sum(p.numel() for p in transformer.parameters())
print(f"\nTransformer parameters: {transformer_params:,}")

# Use encoded features
print(f"\nInput shape: {tuple(encoded.shape)}")

try:
    with torch.no_grad():
        # Transformer takes spatial features + optional graph context
        attended = transformer(encoded, graph_features=updated_graph.x if hasattr(updated_graph, 'x') else None)
    
    print(f"Output shape: {tuple(attended.shape)}")
    print(f"Expected: {tuple(encoded.shape)}")
    
    if attended.shape == encoded.shape:
        print("\n✅ Transformer works correctly!")
    else:
        print(f"\n⚠️  Shape mismatch")
    
except Exception as e:
    print(f"\n❌ Error: {e}")
    import traceback
    traceback.print_exc()

## Step 6: Test PhysicsDecoder

**What it does:** Decodes features into future predictions with physics constraints

**Expected:**
- Input: `[B, 256, 24, 24]` (attended features)
- Output: `[B, 12, 384, 384]` (predicted future frames)

In [None]:
from stormfusion.models.sgt.decoder import PhysicsDecoder

print("="*70)
print("TEST 5: PhysicsDecoder")
print("="*70)

# Create decoder
decoder = PhysicsDecoder(
    in_channels=256,
    base_channels=64,
    output_steps=12
).to(device)

decoder_params = sum(p.numel() for p in decoder.parameters())
print(f"\nDecoder parameters: {decoder_params:,}")

# Use attended features
print(f"\nInput shape: {tuple(attended.shape)}")

try:
    with torch.no_grad():
        predictions = decoder(attended)
    
    print(f"Output shape: {tuple(predictions.shape)}")
    print(f"Expected: ({B}, 12, 384, 384)")
    
    # Check statistics
    print(f"\nOutput statistics:")
    print(f"  Min: {predictions.min().item():.4f}")
    print(f"  Max: {predictions.max().item():.4f}")
    print(f"  Mean: {predictions.mean().item():.4f}")
    print(f"  Std: {predictions.std().item():.4f}")
    
    if predictions.shape == (B, 12, 384, 384):
        print("\n✅ Decoder works correctly!")
    else:
        print(f"\n⚠️  Shape mismatch")
    
except Exception as e:
    print(f"\n❌ Error: {e}")
    import traceback
    traceback.print_exc()

## Step 7: Test ConservationLoss

**What it does:** Physics-informed loss for conservation of mass/energy

**Expected:**
- Input: predictions `[B, 12, 384, 384]`, targets `[B, 12, 384, 384]`
- Output: scalar loss value

In [None]:
from stormfusion.models.sgt.physics_loss import ConservationLoss

print("="*70)
print("TEST 6: ConservationLoss")
print("="*70)

# Create physics loss
physics_loss = ConservationLoss(weight=0.1).to(device)

print(f"\nLoss weight: {physics_loss.weight}")

# Create dummy target
dummy_target = torch.randn(B, 12, 384, 384).to(device)

print(f"\nPredictions shape: {tuple(predictions.shape)}")
print(f"Targets shape: {tuple(dummy_target.shape)}")

try:
    with torch.no_grad():
        loss = physics_loss(predictions, dummy_target)
    
    print(f"\nLoss value: {loss.item():.6f}")
    print(f"Loss shape: {loss.shape}")
    
    if loss.dim() == 0 and loss.item() >= 0:
        print("\n✅ Physics loss works correctly!")
    else:
        print(f"\n⚠️  Unexpected loss value or shape")
    
except Exception as e:
    print(f"\n❌ Error: {e}")
    import traceback
    traceback.print_exc()

## Summary

**What we tested:**
- ✅ MultiModalEncoder: 4 modalities → unified representation
- ✅ StormCellDetector: spatial features → graph nodes
- ✅ StormGNN: message passing on storm cell graph
- ✅ SpatioTemporalTransformer: attention over patches
- ✅ PhysicsDecoder: features → future predictions
- ✅ ConservationLoss: physics-informed constraints

**Parameter counts:**
- Encoder: ~XXX K
- Detector: ~XXX K
- GNN: ~XXX K
- Transformer: ~XXX M (largest component)
- Decoder: ~XXX K

**Next steps:**
1. If all tests passed ✅, proceed to `05_Test_Full_Model.ipynb`
2. That notebook will test the integrated SGT model
3. Then we'll move to small-scale training

---

**If any test failed:**
- Check the error traceback
- Verify shapes match expected
- Check for missing dependencies
- Report issues with specific error messages