# 03: Test Data Loader

**Purpose:** Test the dataset and data loading pipeline

**What this does:**
- Load event IDs
- Build dataset index
- Test loading a few samples
- Verify data shapes and ranges
- Test augmentation
- Test batching

**What this does NOT do:**
- Create models
- Run training
- Load all data (only a few samples for testing)

**Expected time:** 2-3 minutes

---

**Prerequisites:**
1. Run `01_Setup_and_Environment.ipynb`
2. Run `02_Data_Verification.ipynb` and ensure data exists

## Step 1: Setup Paths

In [None]:
from google.colab import drive
import sys
import os
import torch

# Mount Drive
print("Mounting Google Drive...")
drive.mount('/content/drive', force_remount=False)
print("✅ Drive mounted\n")

# Add repo to path
sys.path.insert(0, '/content/stormfusion-sevir')

# Paths
DRIVE_ROOT = "/content/drive/MyDrive/SEVIR_Data"
SEVIR_ROOT = f"{DRIVE_ROOT}/data/sevir"
CATALOG_PATH = f"{DRIVE_ROOT}/data/SEVIR_CATALOG.csv"
TRAIN_IDS = f"{DRIVE_ROOT}/data/samples/all_train_ids.txt"
VAL_IDS = f"{DRIVE_ROOT}/data/samples/all_val_ids.txt"

print("Paths configured:")
print(f"  SEVIR_ROOT: {SEVIR_ROOT}")
print(f"  Catalog exists: {os.path.exists(CATALOG_PATH)}")

## Step 2: Import Dataset

In [None]:
from stormfusion.data.sevir_multimodal import (
    SEVIRMultiModalDataset,
    build_multimodal_index,
    multimodal_collate_fn
)

print("✅ Dataset class imported")

## Step 3: Build Index (Test with Small Subset)

In [None]:
print("Building dataset index...\n")

# Build full index first to see how many events we have
train_index_full = build_multimodal_index(CATALOG_PATH, TRAIN_IDS, SEVIR_ROOT)
val_index_full = build_multimodal_index(CATALOG_PATH, VAL_IDS, SEVIR_ROOT)

print(f"\nFull dataset:")
print(f"  Train: {len(train_index_full)} events")
print(f"  Val: {len(val_index_full)} events")

# For testing, use only first 10 events
train_index = train_index_full[:10]
val_index = val_index_full[:5]

print(f"\nTest subset:")
print(f"  Train: {len(train_index)} events")
print(f"  Val: {len(val_index)} events")
print(f"\n✅ Index built")

## Step 4: Create Dataset (Small Subset)

In [None]:
print("Creating test dataset...")

test_dataset = SEVIRMultiModalDataset(
    train_index,  # Only 10 events
    sevir_root=SEVIR_ROOT,
    catalog_path=CATALOG_PATH,
    input_steps=12,
    output_steps=6,
    normalize=True,
    augment=False  # No augmentation for testing
)

print(f"✅ Dataset created with {len(test_dataset)} samples")

## Step 5: Load and Inspect Single Sample

In [None]:
print("Loading sample 0...\n")

inputs, outputs = test_dataset[0]

print("INPUT SHAPES:")
for modality, data in inputs.items():
    print(f"  {modality:8s}: {tuple(data.shape):20s} (T, H, W)")
    print(f"             Range: [{data.min():.3f}, {data.max():.3f}]")
    print(f"             Mean: {data.mean():.3f}, Std: {data.std():.3f}")
    
    # Check for zeros (indicates missing data)
    if data.abs().sum() < 0.01:
        print(f"             ⚠️  WARNING: All zeros (missing modality)")
    print()

print("OUTPUT SHAPES:")
for modality, data in outputs.items():
    print(f"  {modality:8s}: {tuple(data.shape):20s} (T, H, W)")
    print(f"             Range: [{data.min():.3f}, {data.max():.3f}]")
    print()

print("✅ Sample loaded successfully")

## Step 6: Visualize Sample

In [None]:
import matplotlib.pyplot as plt
import numpy as np

fig, axes = plt.subplots(2, 4, figsize=(16, 8))

# Show last input frame for each modality
for i, modality in enumerate(['vil', 'ir069', 'ir107', 'lght']):
    data = inputs[modality][-1].numpy()  # Last timestep
    im = axes[0, i].imshow(data, cmap='viridis', vmin=data.min(), vmax=data.max())
    axes[0, i].set_title(f'{modality.upper()}\n(input t=12)')
    axes[0, i].axis('off')
    plt.colorbar(im, ax=axes[0, i], fraction=0.046)

# Show VIL predictions (first 4 output timesteps)
for i in range(4):
    if i < outputs['vil'].shape[0]:
        data = outputs['vil'][i].numpy()
        im = axes[1, i].imshow(data, cmap='viridis', vmin=0, vmax=1)
        axes[1, i].set_title(f'VIL Target\n(t+{(i+1)*5}min)')
        axes[1, i].axis('off')
        plt.colorbar(im, ax=axes[1, i], fraction=0.046)

plt.tight_layout()
plt.savefig('/content/sample_visualization.png', dpi=150, bbox_inches='tight')
plt.show()

print("✓ Saved to /content/sample_visualization.png")

## Step 7: Test Multiple Samples

In [None]:
print("Loading multiple samples...\n")

num_samples = min(5, len(test_dataset))

for idx in range(num_samples):
    try:
        inputs, outputs = test_dataset[idx]
        print(f"Sample {idx}: ✅ shapes correct")
    except Exception as e:
        print(f"Sample {idx}: ❌ Error: {e}")

print(f"\n✅ Successfully loaded {num_samples} samples")

## Step 8: Test DataLoader with Batching

In [None]:
from torch.utils.data import DataLoader

print("Creating DataLoader...\n")

test_loader = DataLoader(
    test_dataset,
    batch_size=2,
    shuffle=False,
    collate_fn=multimodal_collate_fn
)

print(f"DataLoader created: {len(test_loader)} batches\n")

# Test loading one batch
print("Loading first batch...")
inputs_batch, outputs_batch = next(iter(test_loader))

print("\nBatch shapes:")
print("  Inputs:")
for modality, data in inputs_batch.items():
    print(f"    {modality:8s}: {tuple(data.shape):25s} (B, T, H, W)")

print("  Outputs:")
for modality, data in outputs_batch.items():
    print(f"    {modality:8s}: {tuple(data.shape):25s} (B, T, H, W)")

print("\n✅ Batching works correctly")

## Step 9: Test Augmentation

In [None]:
print("Testing augmentation...\n")

# Create dataset with augmentation
aug_dataset = SEVIRMultiModalDataset(
    train_index[:5],
    sevir_root=SEVIR_ROOT,
    catalog_path=CATALOG_PATH,
    input_steps=12,
    output_steps=6,
    normalize=True,
    augment=True  # Enable augmentation
)

# Load same sample multiple times to see augmentation
print("Loading same sample 3 times with augmentation:\n")
for i in range(3):
    inputs, outputs = aug_dataset[0]
    vil_sum = inputs['vil'].sum().item()
    print(f"  Iteration {i+1}: VIL sum = {vil_sum:.2f} (should vary due to flips/rotations)")

print("\n✅ Augmentation working (values change between iterations)")

---

## Summary

**What we tested:**
- ✅ Dataset index building
- ✅ Single sample loading
- ✅ Data shapes are correct
- ✅ Normalization applied
- ✅ Batching works
- ✅ Augmentation works

**Warnings to check:**
- If you see "All zeros" warnings, those modalities are missing from your data
- Model will still run but use zeros for those channels
- Download complete data for best performance

**Next notebook:** `04_Test_Model_Components.ipynb`  
This will test each model module individually before full integration.