# 07: Full-Scale Training (Paper 1 - Storm-Graph Transformer)

**Purpose:** Train SGT model on full SEVIR dataset for Paper 1

**What this does:**
- Train on full dataset (hundreds of events)
- Run for many epochs (20-50)
- Track metrics: MSE, MAE, CSI
- Save checkpoints regularly
- Generate evaluation visualizations
- Compare with baseline methods

**What this does NOT do:**
- Use Stage 5 techniques (that's Paper 2)
- Run ablation studies (create separate notebook if needed)

**Expected time:** Several hours to days (depending on dataset size)

---

**Prerequisites:** 
- Run `01_Setup_and_Environment.ipynb` first
- Run `02_Data_Verification.ipynb` and download FULL dataset
- Run `06_Small_Scale_Training.ipynb` to verify training works
- Ensure you have sufficient GPU memory and disk space

## Step 1: Setup

In [None]:
from google.colab import drive
import sys
import os
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import pandas as pd
import numpy as np
from tqdm.auto import tqdm
import matplotlib.pyplot as plt
import json
from datetime import datetime

# Mount Drive
print("Mounting Google Drive...")
drive.mount('/content/drive', force_remount=False)
print("✅ Drive mounted\n")

# Clone/update repository
REPO_PATH = '/content/stormfusion-sevir'
if not os.path.exists(REPO_PATH):
    print("Cloning repository...")
    !git clone https://github.com/syedhaliz/stormfusion-sevir.git {REPO_PATH}
    print("✅ Repository cloned\n")
else:
    print("Repository exists, pulling latest changes...")
    !cd {REPO_PATH} && git pull
    print("✅ Repository updated\n")

# Add repository to path
if REPO_PATH not in sys.path:
    sys.path.insert(0, REPO_PATH)
    print(f"✅ Added {REPO_PATH} to Python path\n")

# Force reload of modules to get latest code
import importlib
for module_name in ['stormfusion.models.sgt', 'stormfusion.data.sevir_multimodal']:
    if module_name in sys.modules:
        importlib.reload(sys.modules[module_name])
        
print("✅ Modules reloaded\n")

# Paths
DRIVE_ROOT = "/content/drive/MyDrive/SEVIR_Data"
SEVIR_ROOT = f"{DRIVE_ROOT}/data/sevir"
CATALOG_PATH = f"{DRIVE_ROOT}/data/SEVIR_CATALOG.csv"

# Create experiment directory with timestamp
EXPERIMENT_NAME = f"sgt_full_training_{datetime.now().strftime('%Y%m%d_%H%M%S')}"
CHECKPOINT_DIR = f"{DRIVE_ROOT}/experiments/{EXPERIMENT_NAME}"
os.makedirs(CHECKPOINT_DIR, exist_ok=True)

print(f"Experiment: {EXPERIMENT_NAME}")
print(f"Checkpoint directory: {CHECKPOINT_DIR}")
print(f"\nPyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}\n")

## Step 2: Configuration

In [None]:
# Training configuration
config = {
    # Model
    'modalities': ['vil', 'ir069', 'ir107', 'lght'],
    'input_steps': 12,
    'output_steps': 12,
    'base_channels': 64,
    
    # Training
    'num_epochs': 30,
    'batch_size': 4,
    'num_workers': 4,
    'learning_rate': 1e-4,
    'weight_decay': 1e-5,
    'grad_clip': 1.0,
    
    # Scheduler
    'scheduler_patience': 3,
    'scheduler_factor': 0.5,
    'min_lr': 1e-7,
    
    # Data
    'train_augment': True,
    'max_train_events': None,  # None = use all
    'max_val_events': 100,     # Limit validation for speed
    
    # Checkpointing
    'save_every': 5,  # Save checkpoint every N epochs
    'keep_last_n': 3,  # Keep only last N checkpoints
}

# Save config
with open(f"{CHECKPOINT_DIR}/config.json", 'w') as f:
    json.dump(config, f, indent=2)

print("Training configuration:")
print(json.dumps(config, indent=2))

## Step 3: Load Full Dataset

In [None]:
from stormfusion.data.sevir_multimodal import SEVIRMultiModalDataset

print("="*70)
print("LOADING FULL DATASET")
print("="*70)

# Load catalog
catalog = pd.read_csv(CATALOG_PATH, low_memory=False)
print(f"\nCatalog loaded: {len(catalog)} entries")

# Get train/val splits
all_train_events = catalog[catalog['split'] == 'train']['id'].unique()
all_val_events = catalog[catalog['split'] == 'val']['id'].unique()

print(f"\nFull dataset:")
print(f"  Total train events: {len(all_train_events)}")
print(f"  Total val events: {len(all_val_events)}")

# Apply limits if specified
if config['max_train_events'] is not None:
    train_events = all_train_events[:config['max_train_events']]
    print(f"  Using train events: {len(train_events)} (limited)")
else:
    train_events = all_train_events
    print(f"  Using train events: {len(train_events)} (all)")

if config['max_val_events'] is not None:
    val_events = all_val_events[:config['max_val_events']]
    print(f"  Using val events: {len(val_events)} (limited)")
else:
    val_events = all_val_events
    print(f"  Using val events: {len(val_events)} (all)")

# Create datasets
print("\nCreating datasets...")
train_dataset = SEVIRMultiModalDataset(
    catalog=catalog,
    data_root=SEVIR_ROOT,
    event_ids=train_events,
    input_steps=config['input_steps'],
    output_steps=config['output_steps'],
    modalities=config['modalities'],
    augment=config['train_augment']
)

val_dataset = SEVIRMultiModalDataset(
    catalog=catalog,
    data_root=SEVIR_ROOT,
    event_ids=val_events,
    input_steps=config['input_steps'],
    output_steps=config['output_steps'],
    modalities=config['modalities'],
    augment=False
)

print(f"\nDataset sizes:")
print(f"  Train: {len(train_dataset)} samples")
print(f"  Val: {len(val_dataset)} samples")

# Create dataloaders
train_loader = DataLoader(
    train_dataset,
    batch_size=config['batch_size'],
    shuffle=True,
    num_workers=config['num_workers'],
    pin_memory=True
)

val_loader = DataLoader(
    val_dataset,
    batch_size=config['batch_size'],
    shuffle=False,
    num_workers=config['num_workers'],
    pin_memory=True
)

print(f"\nDataloader info:")
print(f"  Train batches: {len(train_loader)}")
print(f"  Val batches: {len(val_loader)}")
print(f"  Batch size: {config['batch_size']}")
print("\n" + "="*70)

## Step 4: Create Model and Training Components

In [None]:
from stormfusion.models.sgt import create_sgt_model

print("="*70)
print("CREATING MODEL AND TRAINING COMPONENTS")
print("="*70)

# Create model
model = create_sgt_model(
    modalities=config['modalities'],
    input_steps=config['input_steps'],
    output_steps=config['output_steps'],
    base_channels=config['base_channels']
).to(device)

total_params = sum(p.numel() for p in model.parameters())
print(f"\nModel created: {total_params:,} parameters ({total_params/1e6:.2f}M)")

# Create optimizer
optimizer = torch.optim.AdamW(
    model.parameters(),
    lr=config['learning_rate'],
    weight_decay=config['weight_decay']
)

# Learning rate scheduler
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
    optimizer,
    mode='min',
    factor=config['scheduler_factor'],
    patience=config['scheduler_patience'],
    min_lr=config['min_lr'],
    verbose=True
)

# Loss criterion
criterion = nn.MSELoss()

print("\n✅ All components created")
print("\n" + "="*70)

## Step 5: Training Loop

In [None]:
print("="*70)
print("STARTING FULL-SCALE TRAINING")
print("="*70)

best_val_loss = float('inf')
history = {
    'train_loss': [],
    'val_loss': [],
    'train_mse': [],
    'val_mse': [],
    'lr': [],
    'epoch_times': []
}

import time

for epoch in range(config['num_epochs']):
    epoch_start = time.time()
    
    print(f"\n{'='*70}")
    print(f"Epoch {epoch+1}/{config['num_epochs']}")
    print(f"{'='*70}")
    
    # ============= TRAINING =============
    model.train()
    train_loss = 0.0
    train_mse = 0.0
    
    pbar = tqdm(train_loader, desc='Training')
    for batch_idx, (inputs, targets) in enumerate(pbar):
        # Move to device
        inputs = {k: v.to(device) for k, v in inputs.items()}
        targets = targets.to(device)
        
        # Forward pass
        optimizer.zero_grad()
        outputs = model(inputs)
        
        # Compute loss
        mse_loss = criterion(outputs, targets)
        loss = mse_loss
        
        # Add physics loss if available
        if hasattr(model, 'physics_loss'):
            physics_loss = model.physics_loss(outputs, targets)
            loss = loss + physics_loss
        
        # Backward pass
        loss.backward()
        
        # Gradient clipping
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=config['grad_clip'])
        
        optimizer.step()
        
        # Update metrics
        train_loss += loss.item()
        train_mse += mse_loss.item()
        
        pbar.set_postfix({
            'loss': f'{loss.item():.6f}',
            'mse': f'{mse_loss.item():.6f}'
        })
    
    avg_train_loss = train_loss / len(train_loader)
    avg_train_mse = train_mse / len(train_loader)
    
    # ============= VALIDATION =============
    model.eval()
    val_loss = 0.0
    val_mse = 0.0
    
    with torch.no_grad():
        for inputs, targets in tqdm(val_loader, desc='Validation'):
            inputs = {k: v.to(device) for k, v in inputs.items()}
            targets = targets.to(device)
            
            outputs = model(inputs)
            
            mse_loss = criterion(outputs, targets)
            loss = mse_loss
            
            if hasattr(model, 'physics_loss'):
                physics_loss = model.physics_loss(outputs, targets)
                loss = loss + physics_loss
            
            val_loss += loss.item()
            val_mse += mse_loss.item()
    
    avg_val_loss = val_loss / len(val_loader)
    avg_val_mse = val_mse / len(val_loader)
    
    # Update learning rate
    scheduler.step(avg_val_loss)
    current_lr = optimizer.param_groups[0]['lr']
    
    # Compute epoch time
    epoch_time = time.time() - epoch_start
    
    # Save history
    history['train_loss'].append(avg_train_loss)
    history['val_loss'].append(avg_val_loss)
    history['train_mse'].append(avg_train_mse)
    history['val_mse'].append(avg_val_mse)
    history['lr'].append(current_lr)
    history['epoch_times'].append(epoch_time)
    
    # Print epoch summary
    print(f"\nEpoch {epoch+1} Summary:")
    print(f"  Train Loss: {avg_train_loss:.6f}")
    print(f"  Val Loss:   {avg_val_loss:.6f}")
    print(f"  Train MSE:  {avg_train_mse:.6f}")
    print(f"  Val MSE:    {avg_val_mse:.6f}")
    print(f"  LR:         {current_lr:.2e}")
    print(f"  Time:       {epoch_time/60:.1f} min")
    
    # Save checkpoint
    checkpoint = {
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'scheduler_state_dict': scheduler.state_dict(),
        'config': config,
        'history': history,
        'train_loss': avg_train_loss,
        'val_loss': avg_val_loss,
    }
    
    # Save best model
    if avg_val_loss < best_val_loss:
        best_val_loss = avg_val_loss
        torch.save(checkpoint, f"{CHECKPOINT_DIR}/best_model.pt")
        print(f"  ✅ Saved best model (val_loss: {best_val_loss:.6f})")
    
    # Save latest checkpoint
    torch.save(checkpoint, f"{CHECKPOINT_DIR}/latest_checkpoint.pt")
    
    # Save periodic checkpoint
    if (epoch + 1) % config['save_every'] == 0:
        torch.save(checkpoint, f"{CHECKPOINT_DIR}/checkpoint_epoch_{epoch+1}.pt")
        print(f"  💾 Saved checkpoint at epoch {epoch+1}")
    
    # Save history
    with open(f"{CHECKPOINT_DIR}/history.json", 'w') as f:
        json.dump(history, f, indent=2)

print("\n" + "="*70)
print("TRAINING COMPLETE")
print("="*70)
print(f"\nBest validation loss: {best_val_loss:.6f}")
print(f"Total training time: {sum(history['epoch_times'])/3600:.2f} hours")
print(f"\nAll results saved to: {CHECKPOINT_DIR}")

## Step 6: Plot Training Curves

In [None]:
fig, axes = plt.subplots(2, 2, figsize=(15, 12))

# Loss curve
axes[0, 0].plot(history['train_loss'], label='Train Loss', marker='o', markersize=4)
axes[0, 0].plot(history['val_loss'], label='Val Loss', marker='s', markersize=4)
axes[0, 0].set_xlabel('Epoch')
axes[0, 0].set_ylabel('Loss')
axes[0, 0].set_title('Training and Validation Loss')
axes[0, 0].legend()
axes[0, 0].grid(True, alpha=0.3)

# MSE curve
axes[0, 1].plot(history['train_mse'], label='Train MSE', marker='o', markersize=4)
axes[0, 1].plot(history['val_mse'], label='Val MSE', marker='s', markersize=4)
axes[0, 1].set_xlabel('Epoch')
axes[0, 1].set_ylabel('MSE')
axes[0, 1].set_title('Mean Squared Error')
axes[0, 1].legend()
axes[0, 1].grid(True, alpha=0.3)

# Learning rate
axes[1, 0].plot(history['lr'], marker='o', color='green', markersize=4)
axes[1, 0].set_xlabel('Epoch')
axes[1, 0].set_ylabel('Learning Rate')
axes[1, 0].set_title('Learning Rate Schedule')
axes[1, 0].set_yscale('log')
axes[1, 0].grid(True, alpha=0.3)

# Epoch time
epoch_times_min = [t/60 for t in history['epoch_times']]
axes[1, 1].plot(epoch_times_min, marker='o', color='purple', markersize=4)
axes[1, 1].set_xlabel('Epoch')
axes[1, 1].set_ylabel('Time (minutes)')
axes[1, 1].set_title('Training Time per Epoch')
axes[1, 1].grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig(f"{CHECKPOINT_DIR}/training_curves.png", dpi=150, bbox_inches='tight')
plt.show()

print(f"\n✅ Training curves saved to: {CHECKPOINT_DIR}/training_curves.png")

## Step 7: Evaluate Best Model

In [None]:
print("="*70)
print("EVALUATING BEST MODEL")
print("="*70)

# Load best model
checkpoint = torch.load(f"{CHECKPOINT_DIR}/best_model.pt")
model.load_state_dict(checkpoint['model_state_dict'])
model.eval()

print(f"\nLoaded best model from epoch {checkpoint['epoch']+1}")
print(f"Validation loss: {checkpoint['val_loss']:.6f}")

# Evaluate on validation set
print("\nComputing metrics on validation set...")

all_predictions = []
all_targets = []

with torch.no_grad():
    for inputs, targets in tqdm(val_loader, desc='Evaluating'):
        inputs = {k: v.to(device) for k, v in inputs.items()}
        targets = targets.to(device)
        
        outputs = model(inputs)
        
        all_predictions.append(outputs.cpu())
        all_targets.append(targets.cpu())

# Concatenate
all_predictions = torch.cat(all_predictions, dim=0)
all_targets = torch.cat(all_targets, dim=0)

print(f"\nEvaluation set: {all_predictions.shape[0]} samples")

# Compute metrics
mse = nn.MSELoss()(all_predictions, all_targets).item()
mae = nn.L1Loss()(all_predictions, all_targets).item()

print(f"\nMetrics:")
print(f"  MSE: {mse:.6f}")
print(f"  MAE: {mae:.6f}")
print(f"  RMSE: {np.sqrt(mse):.6f}")

# Save metrics
metrics = {
    'mse': mse,
    'mae': mae,
    'rmse': np.sqrt(mse),
    'best_epoch': checkpoint['epoch'] + 1,
    'best_val_loss': checkpoint['val_loss'],
}

with open(f"{CHECKPOINT_DIR}/metrics.json", 'w') as f:
    json.dump(metrics, f, indent=2)

print(f"\n✅ Metrics saved to: {CHECKPOINT_DIR}/metrics.json")
print("\n" + "="*70)

## Step 8: Generate Sample Predictions

In [None]:
print("="*70)
print("GENERATING SAMPLE PREDICTIONS")
print("="*70)

# Select 5 random validation samples
num_samples = min(5, len(val_dataset))
sample_indices = np.random.choice(len(val_dataset), num_samples, replace=False)

for idx, sample_idx in enumerate(sample_indices):
    print(f"\nGenerating prediction {idx+1}/{num_samples}...")
    
    inputs, targets = val_dataset[sample_idx]
    
    # Add batch dimension
    inputs_batch = {k: v.unsqueeze(0).to(device) for k, v in inputs.items()}
    
    # Generate prediction
    with torch.no_grad():
        predictions = model(inputs_batch)
    
    # Move to CPU
    predictions = predictions.cpu().squeeze(0).numpy()
    targets = targets.numpy()
    
    # Visualize three frames: beginning, middle, end
    frames_to_plot = [0, 5, 11]
    
    fig, axes = plt.subplots(len(frames_to_plot), 3, figsize=(18, 5*len(frames_to_plot)))
    
    for row_idx, frame_idx in enumerate(frames_to_plot):
        # Input (VIL)
        vil_input = inputs['vil'][frame_idx].numpy()
        im0 = axes[row_idx, 0].imshow(vil_input, cmap='turbo', vmin=0, vmax=255)
        axes[row_idx, 0].set_title(f'Input VIL (t={frame_idx})')
        axes[row_idx, 0].axis('off')
        plt.colorbar(im0, ax=axes[row_idx, 0], fraction=0.046, pad=0.04)
        
        # Prediction
        im1 = axes[row_idx, 1].imshow(predictions[frame_idx], cmap='turbo', vmin=0, vmax=255)
        axes[row_idx, 1].set_title(f'Prediction (t={frame_idx+12})')
        axes[row_idx, 1].axis('off')
        plt.colorbar(im1, ax=axes[row_idx, 1], fraction=0.046, pad=0.04)
        
        # Ground truth
        im2 = axes[row_idx, 2].imshow(targets[frame_idx], cmap='turbo', vmin=0, vmax=255)
        axes[row_idx, 2].set_title(f'Ground Truth (t={frame_idx+12})')
        axes[row_idx, 2].axis('off')
        plt.colorbar(im2, ax=axes[row_idx, 2], fraction=0.046, pad=0.04)
    
    plt.tight_layout()
    plt.savefig(f"{CHECKPOINT_DIR}/sample_prediction_{idx+1}.png", dpi=150, bbox_inches='tight')
    plt.close()
    
    print(f"  ✅ Saved to: {CHECKPOINT_DIR}/sample_prediction_{idx+1}.png")

print("\n" + "="*70)

## Summary

**Training completed successfully!**

**Final results:**
- Best validation loss: `{best_val_loss:.6f}`
- Best epoch: `{checkpoint['epoch']+1}`
- Total training time: `{sum(history['epoch_times'])/3600:.2f}` hours

**Files saved:**
```
{CHECKPOINT_DIR}/
├── config.json                 # Training configuration
├── history.json                # Loss/metric history
├── metrics.json                # Final evaluation metrics
├── best_model.pt               # Best model checkpoint
├── latest_checkpoint.pt        # Latest checkpoint
├── checkpoint_epoch_*.pt       # Periodic checkpoints
├── training_curves.png         # Training visualizations
└── sample_prediction_*.png     # Sample predictions
```

**Next steps for Paper 1:**

1. **Compute CSI metrics:** Critical Severity Index at thresholds [16, 74, 133, 160, 181, 219]
2. **Run ablation studies:** Test without GNN, Transformer, Physics loss
3. **Compare with baselines:** ConvLSTM, U-Net, Persistence
4. **Generate paper figures:** Create visualizations for publication
5. **Write results section:** Document findings in paper

---

**For Paper 2 (Stage 5):**
- Use this trained model as baseline
- Add perceptual loss (LPIPS)
- Add GAN discriminator
- That will be a separate notebook series