# 📝 Paper 1: Storm-Graph Transformer (SGT)

**Complete Standalone Training Notebook**

**Title:** "Physics-Informed Graph Neural Networks with Transformers for Severe Weather Nowcasting"

**Core Innovation:**
- Hybrid GNN-Transformer-Physics architecture
- Treats storms as discrete graph nodes (not continuous fields)
- Physics-constrained predictions (conservation laws)
- Interpretable attention mechanisms

**Timeline:** Week 1-3 (Oct 10-31)
**Target:** ArXiv + NeurIPS workshop

---

## 1️⃣ Setup & Installation

In [1]:
# Mount Google Drive
from google.colab import drive
import os

drive.mount('/content/drive')

DRIVE_ROOT = "/content/drive/MyDrive/SEVIR_Data"
print(f"✓ Drive mounted: {DRIVE_ROOT}")

Mounted at /content/drive
✓ Drive mounted: /content/drive/MyDrive/SEVIR_Data


In [2]:
# Check GPU
!nvidia-smi

import torch
print(f"\nPyTorch: {torch.__version__}")
print(f"CUDA: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"VRAM: {torch.cuda.get_device_properties(0).total_memory/1e9:.1f} GB")

Sun Oct 12 19:27:54 2025       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 550.54.15              Driver Version: 550.54.15      CUDA Version: 12.4     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|   0  Tesla T4                       Off |   00000000:00:04.0 Off |                    0 |
| N/A   39C    P8              9W /   70W |       0MiB /  15360MiB |      0%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+
                                                

In [3]:
# Install dependencies
print("Installing dependencies (5-10 min on first run)...")
!pip install -q torch-geometric h5py pandas tqdm matplotlib lpips scikit-image scipy
print("✓ All dependencies installed")

Installing dependencies (5-10 min on first run)...
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m63.1/63.1 kB[0m [31m3.5 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.1/1.1 MB[0m [31m32.1 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m53.8/53.8 kB[0m [31m5.2 MB/s[0m eta [36m0:00:00[0m
[?25h✓ All dependencies installed


In [4]:
# Clone/pull repo
if not os.path.exists('/content/stormfusion-sevir'):
    !git clone https://github.com/syedhaliz/stormfusion-sevir.git
else:
    !cd stormfusion-sevir && git pull

import sys
sys.path.insert(0, '/content/stormfusion-sevir')
print("✓ Repo ready")

Cloning into 'stormfusion-sevir'...
remote: Enumerating objects: 284, done.[K
remote: Counting objects: 100% (284/284), done.[K
remote: Compressing objects: 100% (208/208), done.[K
remote: Total 284 (delta 121), reused 221 (delta 65), pack-reused 0 (from 0)[K
Receiving objects: 100% (284/284), 5.88 MiB | 10.72 MiB/s, done.
Resolving deltas: 100% (121/121), done.
✓ Repo ready


## 2️⃣ Data Download & Verification

**IMPORTANT:** You need ~50 GB of SEVIR data for training.

In [5]:
# Check existing data
from pathlib import Path

SEVIR_ROOT = f"{DRIVE_ROOT}/data/sevir"
CATALOG_PATH = f"{DRIVE_ROOT}/data/SEVIR_CATALOG.csv"

print("="*70)
print("SEVIR DATA CHECK")
print("="*70)

modalities = ['vil', 'ir069', 'ir107', 'lght']
data_complete = True

for mod in modalities:
    mod_path = Path(SEVIR_ROOT) / mod / "2019"
    if mod_path.exists():
        h5_files = list(mod_path.glob("*.h5"))
        total_gb = sum(f.stat().st_size for f in h5_files) / 1e9
        status = "✅" if len(h5_files) >= 100 else "⚠️"
        print(f"{status} {mod:8s}: {len(h5_files):3d} files ({total_gb:.1f} GB)")
        if len(h5_files) < 100:
            data_complete = False
    else:
        print(f"❌ {mod:8s}: MISSING")
        data_complete = False

print("="*70)
if data_complete:
    print("✅ Data looks complete!")
else:
    print("⚠️  Incomplete data - run next cell to download")

SEVIR DATA CHECK
⚠️ vil     :   1 files (3.9 GB)
⚠️ ir069   :   1 files (2.5 GB)
⚠️ ir107   :   1 files (2.5 GB)
⚠️ lght    :  11 files (1.0 GB)
⚠️  Incomplete data - run next cell to download


In [8]:
# Download SEVIR data if incomplete
# Set DOWNLOAD = True to enable download
DOWNLOAD = True  # ⚠️ SET TO True TO DOWNLOAD

if DOWNLOAD and not data_complete:
    print("="*70)
    print("DOWNLOADING SEVIR DATA FROM AWS S3")
    print("="*70)
    print("\nThis will download ~50 GB and take 30-90 minutes")
    print("Modalities: VIL, IR069, IR107, Lightning")
    print("\nStarting download...\n")

    !pip install -q awscli

    for mod in ['vil', 'ir069', 'ir107', 'lght']:
        print(f"\n{'='*70}")
        print(f"Downloading {mod.upper()}")
        print(f"{'='*70}")

        target_dir = f"{SEVIR_ROOT}/{mod}/2019"
        !mkdir -p {target_dir}
        !aws s3 sync s3://sevir/data/{mod}/2019/ {target_dir} --no-sign-request --region us-east-1

    print("\n" + "="*70)
    print("✅ DOWNLOAD COMPLETE!")
    print("="*70)

elif not DOWNLOAD:
    print("⏭️  Download skipped (set DOWNLOAD=True to enable)")
    if not data_complete:
        print("\n⚠️  WARNING: Training with incomplete data will use zeros for missing modalities")
else:
    print("✅ Data already complete")

DOWNLOADING SEVIR DATA FROM AWS S3

This will download ~50 GB and take 30-90 minutes
Modalities: VIL, IR069, IR107, Lightning

Starting download...


Downloading VIL

Downloading IR069

Downloading IR107

Downloading LGHT

✅ DOWNLOAD COMPLETE!


## 3️⃣ Load Data

In [9]:
# Import dataset
from stormfusion.data.sevir_multimodal import (
    SEVIRMultiModalDataset,
    build_multimodal_index,
    multimodal_collate_fn
)

print("✓ Dataset imported")

✓ Dataset imported


In [10]:
# Build dataset index
TRAIN_IDS = f"{DRIVE_ROOT}/data/samples/all_train_ids.txt"
VAL_IDS = f"{DRIVE_ROOT}/data/samples/all_val_ids.txt"

print("Building dataset index...")
train_index = build_multimodal_index(CATALOG_PATH, TRAIN_IDS, SEVIR_ROOT)
val_index = build_multimodal_index(CATALOG_PATH, VAL_IDS, SEVIR_ROOT)

print(f"\n📊 Dataset:")
print(f"  Train: {len(train_index)} events")
print(f"  Val: {len(val_index)} events")

[1;30;43mStreaming output truncated to the last 5000 lines.[0m
✓ Built multimodal index: 6613 events
✓ Built multimodal index: 1547 events

📊 Dataset:
  Train: 6613 events
  Val: 1547 events


In [11]:
# Create datasets
print("Creating datasets...")

train_dataset = SEVIRMultiModalDataset(
    train_index,
    sevir_root=SEVIR_ROOT,
    catalog_path=CATALOG_PATH,
    input_steps=12,
    output_steps=6,
    normalize=True,
    augment=True
)

val_dataset = SEVIRMultiModalDataset(
    val_index,
    sevir_root=SEVIR_ROOT,
    catalog_path=CATALOG_PATH,
    input_steps=12,
    output_steps=6,
    normalize=True,
    augment=False
)

print("✓ Datasets created")

# Test loading
print("\nTesting data loading...")
inputs, outputs = train_dataset[0]
print("\nInput shapes:")
for mod, data in inputs.items():
    print(f"  {mod:8s}: {tuple(data.shape)}")
print("\nOutput shape:")
print(f"  vil     : {tuple(outputs['vil'].shape)}")
print("\n✅ Data loading successful!")

Creating datasets...
✓ Datasets created

Testing data loading...


KeyError: "Unable to synchronously open object (object 'lght' doesn't exist)"

## 4️⃣ Create Model

In [None]:
# Import model
from stormfusion.models.sgt import create_sgt_model

print("Creating Storm-Graph Transformer...")

model_config = {
    'modalities': ['vil', 'ir069', 'ir107', 'lght'],
    'input_steps': 12,
    'output_steps': 6,
    'hidden_dim': 128,
    'gnn_layers': 3,
    'transformer_layers': 4,
    'num_heads': 8,
    'use_physics': True
}

model = create_sgt_model(model_config)

# Model info
total_params = sum(p.numel() for p in model.parameters())
print(f"\n✅ Model Created!")
print(f"   Parameters: {total_params:,}")
print(f"   Size: ~{total_params * 4 / 1e6:.1f} MB")

# Move to GPU
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = model.to(device)
print(f"   Device: {device}")

In [None]:
# Test forward pass
print("Testing forward pass...")

inputs, targets = train_dataset[0]
inputs_batch = {mod: data.unsqueeze(0).to(device) for mod, data in inputs.items()}
targets_batch = targets['vil'].unsqueeze(0).to(device)

with torch.no_grad():
    predictions, attention_info, physics_info = model(inputs_batch)
    loss, loss_dict = model.compute_loss(predictions, targets_batch, physics_info)

print(f"\n✅ Forward pass successful!")
print(f"   Predictions: {tuple(predictions.shape)}")
print(f"   Loss: {loss.item():.4f}")
print("\n🎉 Model ready for training!")

## 5️⃣ Training Setup

In [None]:
# Training configuration
from torch.utils.data import DataLoader
from tqdm import tqdm
import time

# Hyperparameters
BATCH_SIZE = 4
LR = 1e-4
EPOCHS = 20
LAMBDA_MSE = 1.0
LAMBDA_PHYSICS = 0.1
LAMBDA_EXTREME = 2.0

CHECKPOINT_DIR = f"{DRIVE_ROOT}/checkpoints/paper1_sgt"
os.makedirs(CHECKPOINT_DIR, exist_ok=True)

print("Training Configuration:")
print(f"  Batch size: {BATCH_SIZE}")
print(f"  Learning rate: {LR}")
print(f"  Epochs: {EPOCHS}")
print(f"  Loss weights: MSE={LAMBDA_MSE}, Physics={LAMBDA_PHYSICS}, Extreme={LAMBDA_EXTREME}")
print(f"  Checkpoints: {CHECKPOINT_DIR}")

In [None]:
# Create data loaders
train_loader = DataLoader(
    train_dataset,
    batch_size=BATCH_SIZE,
    shuffle=True,
    num_workers=2,
    collate_fn=multimodal_collate_fn,
    pin_memory=True
)

val_loader = DataLoader(
    val_dataset,
    batch_size=BATCH_SIZE,
    shuffle=False,
    num_workers=2,
    collate_fn=multimodal_collate_fn,
    pin_memory=True
)

print(f"✅ DataLoaders created:")
print(f"   Train batches: {len(train_loader)}")
print(f"   Val batches: {len(val_loader)}")

In [None]:
# Optimizer and scheduler
optimizer = torch.optim.AdamW(model.parameters(), lr=LR, weight_decay=1e-5)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
    optimizer, mode='min', factor=0.5, patience=3, verbose=True
)

print("✅ Optimizer: AdamW")
print("✅ Scheduler: ReduceLROnPlateau")

## 6️⃣ Training Loop

In [None]:
# Training and validation functions
def train_epoch(model, loader, optimizer, device):
    model.train()
    total_loss = 0
    loss_components = {'mse': 0, 'extreme': 0, 'physics': 0}

    pbar = tqdm(loader, desc="Training")
    for inputs, targets in pbar:
        inputs = {mod: data.to(device) for mod, data in inputs.items()}
        targets = targets['vil'].to(device)

        optimizer.zero_grad()
        predictions, attention_info, physics_info = model(inputs)

        loss, loss_dict = model.compute_loss(
            predictions, targets, physics_info,
            lambda_mse=LAMBDA_MSE,
            lambda_physics=LAMBDA_PHYSICS,
            lambda_extreme=LAMBDA_EXTREME
        )

        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        optimizer.step()

        total_loss += loss.item()
        for key in loss_components:
            if key in loss_dict:
                loss_components[key] += loss_dict[key]

        pbar.set_postfix({'loss': f"{loss.item():.4f}"})

    n = len(loader)
    return total_loss / n, {k: v / n for k, v in loss_components.items()}


def validate(model, loader, device):
    model.eval()
    total_loss = 0
    loss_components = {'mse': 0, 'extreme': 0, 'physics': 0}

    with torch.no_grad():
        pbar = tqdm(loader, desc="Validation")
        for inputs, targets in pbar:
            inputs = {mod: data.to(device) for mod, data in inputs.items()}
            targets = targets['vil'].to(device)

            predictions, attention_info, physics_info = model(inputs)
            loss, loss_dict = model.compute_loss(
                predictions, targets, physics_info,
                lambda_mse=LAMBDA_MSE,
                lambda_physics=LAMBDA_PHYSICS,
                lambda_extreme=LAMBDA_EXTREME
            )

            total_loss += loss.item()
            for key in loss_components:
                if key in loss_dict:
                    loss_components[key] += loss_dict[key]

    n = len(loader)
    return total_loss / n, {k: v / n for k, v in loss_components.items()}


print("✅ Training functions defined")

In [None]:
# Main training loop
print("="*70)
print("STARTING TRAINING")
print("="*70)

best_val_loss = float('inf')
history = {'train_loss': [], 'val_loss': []}

for epoch in range(EPOCHS):
    print(f"\nEpoch {epoch+1}/{EPOCHS}")
    print("-" * 70)

    start_time = time.time()

    # Train
    train_loss, train_components = train_epoch(model, train_loader, optimizer, device)

    # Validate
    val_loss, val_components = validate(model, val_loader, device)

    # Scheduler
    scheduler.step(val_loss)

    epoch_time = time.time() - start_time

    # Log
    print(f"\nEpoch {epoch+1} Summary:")
    print(f"  Train Loss: {train_loss:.4f} (MSE: {train_components['mse']:.4f}, "
          f"Extreme: {train_components['extreme']:.4f}, Physics: {train_components['physics']:.4f})")
    print(f"  Val Loss:   {val_loss:.4f} (MSE: {val_components['mse']:.4f}, "
          f"Extreme: {val_components['extreme']:.4f}, Physics: {val_components['physics']:.4f})")
    print(f"  Time: {epoch_time:.1f}s")

    # Save history
    history['train_loss'].append(train_loss)
    history['val_loss'].append(val_loss)

    # Save checkpoints
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        torch.save({
            'epoch': epoch + 1,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'scheduler_state_dict': scheduler.state_dict(),
            'val_loss': val_loss,
            'config': model_config
        }, f"{CHECKPOINT_DIR}/best_model.pt")
        print(f"  ✅ Saved best model (val_loss: {val_loss:.4f})")

    # Save latest
    torch.save({
        'epoch': epoch + 1,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'scheduler_state_dict': scheduler.state_dict(),
        'val_loss': val_loss,
        'config': model_config,
        'history': history
    }, f"{CHECKPOINT_DIR}/latest_model.pt")

print("\n" + "="*70)
print("✅ TRAINING COMPLETE!")
print("="*70)
print(f"Best validation loss: {best_val_loss:.4f}")

## 7️⃣ Visualize Results

In [None]:
# Plot training curves
import matplotlib.pyplot as plt

plt.figure(figsize=(10, 5))
plt.plot(history['train_loss'], label='Train Loss')
plt.plot(history['val_loss'], label='Val Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Training Progress')
plt.legend()
plt.grid(True)
plt.savefig(f"{CHECKPOINT_DIR}/training_curves.png", dpi=150, bbox_inches='tight')
plt.show()

print(f"✓ Saved training curves to {CHECKPOINT_DIR}/training_curves.png")

---

## 📚 Next Steps

1. **Evaluate on test set** - compute CSI metrics at VIP thresholds
2. **Visualize predictions** - compare pred vs ground truth
3. **Analyze attention weights** - see which storms matter
4. **Run ablations** - w/o GNN, w/o Transformer, w/o Physics
5. **Baseline comparisons** - vs UNet, ConvLSTM, Persistence

**Checkpoints saved to:** `{CHECKPOINT_DIR}/`

**Architecture docs:** `docs/PAPER1_ARCHITECTURE.md`

**Progress report:** `docs/PROGRESS_REPORT.md`