# 03: Test Data Loader

**Purpose:** Test the dataset and data loading pipeline

**What this does:**
- Load event IDs
- Build dataset index
- Test loading a few samples
- Verify data shapes and ranges
- Test augmentation
- Test batching

**What this does NOT do:**
- Create models
- Run training
- Load all data (only a few samples for testing)

**Expected time:** 2-3 minutes

---

**Prerequisites:**
1. Run `01_Setup_and_Environment.ipynb`
2. Run `02_Data_Verification.ipynb` and ensure data exists

## Step 1: Setup Paths

In [None]:
from google.colab import drive
import sys
import os
import torch

# Mount Drive
print("Mounting Google Drive...")
drive.mount('/content/drive', force_remount=False)
print("✅ Drive mounted\n")

# Clone/update repository
REPO_PATH = '/content/stormfusion-sevir'
if not os.path.exists(REPO_PATH):
    print("Cloning repository...")
    !git clone https://github.com/syedhaliz/stormfusion-sevir.git {REPO_PATH}
    print("✅ Repository cloned\n")
else:
    print("Repository exists, pulling latest changes...")
    !cd {REPO_PATH} && git pull
    print("✅ Repository updated\n")

# Add repo to path
if REPO_PATH not in sys.path:
    sys.path.insert(0, REPO_PATH)
    print(f"✅ Added {REPO_PATH} to Python path\n")

# Force reload of modules to get latest code
import importlib
if 'stormfusion.data.sevir_multimodal' in sys.modules:
    print("Reloading sevir_multimodal module to get latest code...")
    importlib.reload(sys.modules['stormfusion.data.sevir_multimodal'])
    print("✅ Module reloaded\n")

# Paths
DRIVE_ROOT = "/content/drive/MyDrive/SEVIR_Data"
SEVIR_ROOT = f"{DRIVE_ROOT}/data/sevir"
CATALOG_PATH = f"{DRIVE_ROOT}/data/SEVIR_CATALOG.csv"
TRAIN_IDS = f"{DRIVE_ROOT}/data/samples/all_train_ids.txt"
VAL_IDS = f"{DRIVE_ROOT}/data/samples/all_val_ids.txt"

print("Paths configured:")
print(f"  SEVIR_ROOT: {SEVIR_ROOT}")
print(f"  Catalog exists: {os.path.exists(CATALOG_PATH)}")

## Step 2: Import Dataset

In [None]:
from stormfusion.data.sevir_multimodal import SEVIRMultiModalDataset

print("✅ Dataset class imported")

## Step 3: Build Index (Test with Small Subset)

In [None]:
import pandas as pd

print("Loading catalog and building event list...\n")

# Load catalog
catalog = pd.read_csv(CATALOG_PATH, low_memory=False)
print(f"Total catalog entries: {len(catalog)}")

# Filter to only VIL events (target modality)
vil_catalog = catalog[catalog['img_type'] == 'vil']
print(f"VIL entries: {len(vil_catalog)}")

# Filter to only 2019 (the data we have)
vil_catalog_2019 = vil_catalog[vil_catalog['file_name'].str.contains('2019')]
print(f"VIL entries from 2019: {len(vil_catalog_2019)}")

# Get unique event IDs
all_event_ids = vil_catalog_2019['id'].unique()
print(f"Unique events: {len(all_event_ids)}")

# Create train/val split (80/20)
import numpy as np
np.random.seed(42)
indices = np.random.permutation(len(all_event_ids))
split_idx = int(0.8 * len(all_event_ids))

train_event_ids = all_event_ids[indices[:split_idx]]
val_event_ids = all_event_ids[indices[split_idx:]]

print(f"\nTrain events: {len(train_event_ids)}")
print(f"Val events: {len(val_event_ids)}")

# For testing, use small subset
test_train_ids = train_event_ids[:10]
test_val_ids = val_event_ids[:5]

print(f"\nTest subset:")
print(f"  Train: {len(test_train_ids)} events")
print(f"  Val: {len(test_val_ids)} events")
print(f"\n✅ Event IDs prepared")

## Step 4: Create Dataset (Small Subset)

In [None]:
print("Building dataset index from event IDs...\n")

# Build index: List of (event_id, file_index) tuples
def build_index_from_ids(catalog, event_ids, modality='vil'):
    """Build index from event ID list."""
    modality_cat = catalog[catalog["img_type"] == modality].copy()
    
    index = []
    for event_id in event_ids:
        event_rows = modality_cat[modality_cat["id"] == event_id]
        if not event_rows.empty:
            row = event_rows.iloc[0]
            index.append((event_id, int(row["file_index"])))
    
    return index

# Build indices for our test subset
train_index = build_index_from_ids(catalog, test_train_ids)
print(f"Train index: {len(train_index)} events")

# Create dataset
from stormfusion.data.sevir_multimodal import SEVIRMultiModalDataset

test_dataset = SEVIRMultiModalDataset(
    index=train_index,  # List of (event_id, file_index) tuples
    sevir_root=SEVIR_ROOT,
    catalog_path=CATALOG_PATH,
    input_steps=12,
    output_steps=12,
    normalize=True,
    augment=False  # No augmentation for testing
)

print(f"✅ Dataset created with {len(test_dataset)} samples")

## Step 5: Load and Inspect Single Sample

In [None]:
print("Loading sample 0...\n")

inputs, outputs = test_dataset[0]

print("INPUT SHAPES:")
for modality, data in inputs.items():
    shape_str = str(tuple(data.shape))
    print(f"  {modality:8s}: {shape_str} (T, H, W)")
    print(f"             Range: [{data.min():.3f}, {data.max():.3f}]")
    print(f"             Mean: {data.mean():.3f}, Std: {data.std():.3f}")
    
    # Check for zeros (indicates missing data)
    if data.abs().sum() < 0.01:
        print(f"             ⚠️  WARNING: All zeros (missing modality)")
    print()

print("OUTPUT SHAPE:")
output_data = outputs['vil']  # Outputs is dict with 'vil' key
shape_str = str(tuple(output_data.shape))
print(f"  vil:      {shape_str} (T, H, W)")
print(f"            Range: [{output_data.min():.3f}, {output_data.max():.3f}]")

print("\n✅ Sample loaded successfully")

## Step 6: Visualize Sample

In [None]:
import matplotlib.pyplot as plt
import numpy as np

fig, axes = plt.subplots(2, 4, figsize=(16, 8))

# Show last input frame for each modality
for i, modality in enumerate(['vil', 'ir069', 'ir107', 'lght']):
    data = inputs[modality][-1].numpy()  # Last timestep
    im = axes[0, i].imshow(data, cmap='viridis', vmin=data.min(), vmax=data.max())
    axes[0, i].set_title(f'{modality.upper()}\n(input t=12)')
    axes[0, i].axis('off')
    plt.colorbar(im, ax=axes[0, i], fraction=0.046)

# Show VIL predictions (first 4 output timesteps)
vil_output = outputs['vil']  # Outputs is dict with 'vil' key
for i in range(4):
    if i < vil_output.shape[0]:
        data = vil_output[i].numpy()
        im = axes[1, i].imshow(data, cmap='viridis', vmin=0, vmax=1)
        axes[1, i].set_title(f'VIL Target\n(t+{(i+1)*5}min)')
        axes[1, i].axis('off')
        plt.colorbar(im, ax=axes[1, i], fraction=0.046)

plt.tight_layout()
plt.savefig('/content/sample_visualization.png', dpi=150, bbox_inches='tight')
plt.show()

print("✓ Saved to /content/sample_visualization.png")

## Step 7: Test Multiple Samples

In [None]:
print("Loading multiple samples...\n")

num_samples = min(5, len(test_dataset))

for idx in range(num_samples):
    try:
        inputs, outputs = test_dataset[idx]
        print(f"Sample {idx}: ✅ shapes correct")
    except Exception as e:
        print(f"Sample {idx}: ❌ Error: {e}")

print(f"\n✅ Successfully loaded {num_samples} samples")

## Step 8: Test DataLoader with Batching

In [None]:
from torch.utils.data import DataLoader
from stormfusion.data.sevir_multimodal import multimodal_collate_fn

print("Creating DataLoader...\n")

test_loader = DataLoader(
    test_dataset,
    batch_size=2,
    shuffle=False,
    num_workers=0,  # Use 0 for Colab
    collate_fn=multimodal_collate_fn  # Use custom collate function
)

print(f"DataLoader created: {len(test_loader)} batches\n")

# Test loading one batch
print("Loading first batch...")
inputs_batch, outputs_batch = next(iter(test_loader))

print("\nBatch shapes:")
print("  Inputs:")
for modality, data in inputs_batch.items():
    shape_str = str(tuple(data.shape))
    print(f"    {modality:8s}: {shape_str} (B, T, H, W)")

print("  Outputs:")
vil_batch = outputs_batch['vil']  # Outputs is dict with 'vil' key
shape_str = str(tuple(vil_batch.shape))
print(f"    vil:      {shape_str} (B, T, H, W)")

print("\n✅ Batching works correctly")

## Step 9: Test Augmentation

In [None]:
print("Testing augmentation...\n")

# Build index for augmentation test
aug_index = build_index_from_ids(catalog, test_train_ids[:5])

# Create dataset with augmentation  
aug_dataset = SEVIRMultiModalDataset(
    index=aug_index,
    sevir_root=SEVIR_ROOT,
    catalog_path=CATALOG_PATH,
    input_steps=12,
    output_steps=12,
    normalize=True,
    augment=True  # Enable augmentation
)

# Load same sample multiple times to see augmentation
print("Loading same sample 3 times with augmentation:\n")
for i in range(3):
    inputs, outputs = aug_dataset[0]
    vil_sum = inputs['vil'].sum().item()
    print(f"  Iteration {i+1}: VIL sum = {vil_sum:.2f} (should vary due to flips/rotations)")

print("\n✅ Augmentation working (values change between iterations)")

---

## Summary

**What we tested:**
- ✅ Dataset index building
- ✅ Single sample loading
- ✅ Data shapes are correct
- ✅ Normalization applied
- ✅ Batching works
- ✅ Augmentation works

**Warnings to check:**
- If you see "All zeros" warnings, those modalities are missing from your data
- Model will still run but use zeros for those channels
- Download complete data for best performance

**Next notebook:** `04_Test_Model_Components.ipynb`  
This will test each model module individually before full integration.