# 06: Small-Scale Training Test

**Purpose:** Test the training loop with a tiny dataset (10-20 events)

**What this does:**
- Train SGT model on small subset
- Verify training loop works
- Check for overfitting (model should memorize small dataset)
- Test validation and checkpointing
- Measure training speed

**What this does NOT do:**
- Train on full dataset (that's notebook 07)
- Run for many epochs
- Compute final metrics

**Expected time:** 10-20 minutes

---

**Prerequisites:** 
- Run `01_Setup_and_Environment.ipynb` first
- Run `02_Data_Verification.ipynb` to check/download data
- Run `05_Test_Full_Model.ipynb` to verify model works

## Step 1: Setup

In [None]:
from google.colab import drive
import sys
import os
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import pandas as pd
from tqdm.auto import tqdm
import matplotlib.pyplot as plt

# Mount Drive
print("Mounting Google Drive...")
drive.mount('/content/drive', force_remount=False)
print("✅ Drive mounted\n")

# Clone/update repository
REPO_PATH = '/content/stormfusion-sevir'
if not os.path.exists(REPO_PATH):
    print("Cloning repository...")
    !git clone https://github.com/syedhaliz/stormfusion-sevir.git {REPO_PATH}
    print("✅ Repository cloned\n")
else:
    print("Repository exists, pulling latest changes...")
    !cd {REPO_PATH} && git pull
    print("✅ Repository updated\n")

# Add repository to path
if REPO_PATH not in sys.path:
    sys.path.insert(0, REPO_PATH)
    print(f"✅ Added {REPO_PATH} to Python path\n")

# Force reload of modules to get latest code
import importlib
for module_name in ['stormfusion.models.sgt', 'stormfusion.data.sevir_multimodal']:
    if module_name in sys.modules:
        importlib.reload(sys.modules[module_name])
        
print("✅ Modules reloaded\n")

# Paths
DRIVE_ROOT = "/content/drive/MyDrive/SEVIR_Data"
SEVIR_ROOT = f"{DRIVE_ROOT}/data/sevir"
CATALOG_PATH = f"{DRIVE_ROOT}/data/SEVIR_CATALOG.csv"
CHECKPOINT_DIR = f"{DRIVE_ROOT}/checkpoints/small_scale_test"

os.makedirs(CHECKPOINT_DIR, exist_ok=True)

print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}\n")

## Step 2: Load Small Dataset

In [None]:
from stormfusion.data.sevir_multimodal import SEVIRMultiModalDataset

print("="*70)
print("LOADING SMALL DATASET")
print("="*70)

# Load catalog
catalog = pd.read_csv(CATALOG_PATH, low_memory=False)
print(f"\nCatalog loaded: {len(catalog)} entries")

# Get small subset: 16 train, 4 val
train_events = catalog[catalog['split'] == 'train']['id'].unique()[:16]
val_events = catalog[catalog['split'] == 'val']['id'].unique()[:4]

print(f"\nSmall dataset:")
print(f"  Train events: {len(train_events)}")
print(f"  Val events: {len(val_events)}")

# Create datasets
train_dataset = SEVIRMultiModalDataset(
    catalog=catalog,
    data_root=SEVIR_ROOT,
    event_ids=train_events,
    input_steps=12,
    output_steps=12,
    modalities=['vil', 'ir069', 'ir107', 'lght'],
    augment=True
)

val_dataset = SEVIRMultiModalDataset(
    catalog=catalog,
    data_root=SEVIR_ROOT,
    event_ids=val_events,
    input_steps=12,
    output_steps=12,
    modalities=['vil', 'ir069', 'ir107', 'lght'],
    augment=False
)

print(f"\nDataset sizes:")
print(f"  Train: {len(train_dataset)} samples")
print(f"  Val: {len(val_dataset)} samples")

# Create dataloaders
train_loader = DataLoader(
    train_dataset,
    batch_size=2,
    shuffle=True,
    num_workers=2,
    pin_memory=True
)

val_loader = DataLoader(
    val_dataset,
    batch_size=2,
    shuffle=False,
    num_workers=2,
    pin_memory=True
)

print(f"\nDataloader batches:")
print(f"  Train: {len(train_loader)} batches")
print(f"  Val: {len(val_loader)} batches")
print("\n" + "="*70)

## Step 3: Create Model and Optimizer

In [None]:
from stormfusion.models.sgt import create_sgt_model

print("="*70)
print("CREATING MODEL")
print("="*70)

# Create model
model = create_sgt_model(
    modalities=['vil', 'ir069', 'ir107', 'lght'],
    input_steps=12,
    output_steps=12,
    base_channels=64
).to(device)

total_params = sum(p.numel() for p in model.parameters())
print(f"\nTotal parameters: {total_params:,} ({total_params/1e6:.2f}M)")

# Create optimizer
optimizer = torch.optim.AdamW(
    model.parameters(),
    lr=1e-4,
    weight_decay=1e-5
)

# Learning rate scheduler
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
    optimizer,
    mode='min',
    factor=0.5,
    patience=2,
    verbose=True
)

# Loss criterion
criterion = nn.MSELoss()

print("\n✅ Model, optimizer, and scheduler created")
print("\n" + "="*70)

## Step 4: Training Loop

In [None]:
print("="*70)
print("STARTING SMALL-SCALE TRAINING")
print("="*70)

num_epochs = 5
best_val_loss = float('inf')
history = {'train_loss': [], 'val_loss': [], 'lr': []}

for epoch in range(num_epochs):
    print(f"\nEpoch {epoch+1}/{num_epochs}")
    print("-" * 50)
    
    # ============= TRAINING =============
    model.train()
    train_loss = 0.0
    
    pbar = tqdm(train_loader, desc='Training')
    for batch_idx, (inputs, targets) in enumerate(pbar):
        # Move to device
        inputs = {k: v.to(device) for k, v in inputs.items()}
        targets = targets.to(device)
        
        # Forward pass
        optimizer.zero_grad()
        outputs = model(inputs)
        
        # Compute loss
        loss = criterion(outputs, targets)
        
        # Add physics loss if available
        if hasattr(model, 'physics_loss'):
            physics_loss = model.physics_loss(outputs, targets)
            loss = loss + physics_loss
        
        # Backward pass
        loss.backward()
        
        # Gradient clipping
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        
        optimizer.step()
        
        # Update metrics
        train_loss += loss.item()
        pbar.set_postfix({'loss': f'{loss.item():.6f}'})
    
    avg_train_loss = train_loss / len(train_loader)
    
    # ============= VALIDATION =============
    model.eval()
    val_loss = 0.0
    
    with torch.no_grad():
        for inputs, targets in tqdm(val_loader, desc='Validation'):
            inputs = {k: v.to(device) for k, v in inputs.items()}
            targets = targets.to(device)
            
            outputs = model(inputs)
            loss = criterion(outputs, targets)
            
            if hasattr(model, 'physics_loss'):
                physics_loss = model.physics_loss(outputs, targets)
                loss = loss + physics_loss
            
            val_loss += loss.item()
    
    avg_val_loss = val_loss / len(val_loader)
    
    # Update learning rate
    scheduler.step(avg_val_loss)
    current_lr = optimizer.param_groups[0]['lr']
    
    # Save history
    history['train_loss'].append(avg_train_loss)
    history['val_loss'].append(avg_val_loss)
    history['lr'].append(current_lr)
    
    # Print epoch summary
    print(f"\nEpoch {epoch+1} Summary:")
    print(f"  Train Loss: {avg_train_loss:.6f}")
    print(f"  Val Loss:   {avg_val_loss:.6f}")
    print(f"  LR:         {current_lr:.2e}")
    
    # Save best model
    if avg_val_loss < best_val_loss:
        best_val_loss = avg_val_loss
        checkpoint = {
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'scheduler_state_dict': scheduler.state_dict(),
            'train_loss': avg_train_loss,
            'val_loss': avg_val_loss,
        }
        torch.save(checkpoint, f"{CHECKPOINT_DIR}/best_model.pt")
        print(f"  ✅ Saved best model (val_loss: {best_val_loss:.6f})")

print("\n" + "="*70)
print("TRAINING COMPLETE")
print("="*70)
print(f"\nBest validation loss: {best_val_loss:.6f}")

## Step 5: Plot Training Curves

In [None]:
fig, axes = plt.subplots(1, 2, figsize=(15, 5))

# Loss curve
axes[0].plot(history['train_loss'], label='Train Loss', marker='o')
axes[0].plot(history['val_loss'], label='Val Loss', marker='s')
axes[0].set_xlabel('Epoch')
axes[0].set_ylabel('Loss')
axes[0].set_title('Training and Validation Loss')
axes[0].legend()
axes[0].grid(True, alpha=0.3)

# Learning rate
axes[1].plot(history['lr'], marker='o', color='green')
axes[1].set_xlabel('Epoch')
axes[1].set_ylabel('Learning Rate')
axes[1].set_title('Learning Rate Schedule')
axes[1].set_yscale('log')
axes[1].grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig(f"{CHECKPOINT_DIR}/training_curves.png", dpi=150, bbox_inches='tight')
plt.show()

print("\n✅ Training curves saved to:", f"{CHECKPOINT_DIR}/training_curves.png")

## Step 6: Test Model Predictions

In [None]:
print("="*70)
print("TESTING MODEL PREDICTIONS")
print("="*70)

# Load best model
checkpoint = torch.load(f"{CHECKPOINT_DIR}/best_model.pt")
model.load_state_dict(checkpoint['model_state_dict'])
model.eval()

print(f"\nLoaded best model from epoch {checkpoint['epoch']+1}")
print(f"Validation loss: {checkpoint['val_loss']:.6f}")

# Get a validation sample
inputs, targets = val_dataset[0]

# Add batch dimension
inputs_batch = {k: v.unsqueeze(0).to(device) for k, v in inputs.items()}
targets_batch = targets.unsqueeze(0).to(device)

# Generate prediction
with torch.no_grad():
    predictions = model(inputs_batch)

# Move to CPU for visualization
predictions = predictions.cpu().squeeze(0).numpy()
targets = targets.numpy()

print(f"\nPrediction shape: {predictions.shape}")
print(f"Target shape: {targets.shape}")

# Visualize middle frame (frame 6)
fig, axes = plt.subplots(1, 3, figsize=(18, 6))

frame_idx = 6

# Input (VIL)
vil_input = inputs['vil'][frame_idx].numpy()
im0 = axes[0].imshow(vil_input, cmap='turbo', vmin=0, vmax=255)
axes[0].set_title(f'Input VIL (t={frame_idx})')
axes[0].axis('off')
plt.colorbar(im0, ax=axes[0], fraction=0.046, pad=0.04)

# Prediction
im1 = axes[1].imshow(predictions[frame_idx], cmap='turbo', vmin=0, vmax=255)
axes[1].set_title(f'Prediction (t={frame_idx+12})')
axes[1].axis('off')
plt.colorbar(im1, ax=axes[1], fraction=0.046, pad=0.04)

# Ground truth
im2 = axes[2].imshow(targets[frame_idx], cmap='turbo', vmin=0, vmax=255)
axes[2].set_title(f'Ground Truth (t={frame_idx+12})')
axes[2].axis('off')
plt.colorbar(im2, ax=axes[2], fraction=0.046, pad=0.04)

plt.tight_layout()
plt.savefig(f"{CHECKPOINT_DIR}/sample_prediction.png", dpi=150, bbox_inches='tight')
plt.show()

print("\n✅ Sample prediction saved to:", f"{CHECKPOINT_DIR}/sample_prediction.png")
print("\n" + "="*70)

## Summary

**What we verified:**
- ✅ Training loop works end-to-end
- ✅ Model can learn from small dataset
- ✅ Validation and checkpointing work
- ✅ Loss decreases over epochs
- ✅ Model generates reasonable predictions

**Training configuration:**
```
Dataset: 16 train events, 4 val events
Batch size: 2
Epochs: 5
Optimizer: AdamW (lr=1e-4)
Scheduler: ReduceLROnPlateau
```

**Expected behavior:**
- Training loss should decrease consistently
- Validation loss should decrease (model memorizes small dataset)
- If validation loss increases, dataset might have missing modalities

**Next steps:**
1. If training worked ✅, proceed to `07_Full_Training.ipynb`
2. That notebook will train on the full dataset
3. You'll need complete SEVIR data (from notebook 02)
4. Full training takes several hours to days depending on dataset size

---

**Troubleshooting:**
- If loss doesn't decrease: Check for missing modalities (all zeros)
- If OOM error: Reduce batch size or use gradient accumulation
- If NaN loss: Add gradient clipping (already included)
- If slow training: Check GPU utilization with `nvidia-smi`