# 2 — Data Preprocessing Pipeline
## Next Day Wildfire Spread — TFRecord → Normalised Tensors

This notebook documents the data preparation pipeline:

1. **Download** — Kaggle API fetches TFRecord shards from `fantineh/next-day-wildfire-spread`
2. **Parse** — TFRecords are decoded into 64×64 numpy arrays (12 input channels + 1 target)
3. **Filter** — Samples with no fire pixels in the target are discarded
4. **NaN handling** — Missing values are replaced with 0
5. **Normalise** — Channel-wise z-score normalisation (skip `prev_fire_mask`)
6. **Augment** — Random flips and 90° rotations (training only)
7. **Save** — Split files `train.npz`, `val.npz`, `test.npz` + `metadata.json`

The command-line script `download_data.py` automates steps 1–4 and 7.

In [None]:
import json, numpy as np, pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from pathlib import Path

# ── Charger la config générée par 00_Setup.ipynb ──────────────────────────────
_cfg_path = Path().resolve() / "setup_config.json"
if not _cfg_path.exists():
    raise FileNotFoundError(
        "setup_config.json introuvable.\n"
        "→ Lance d'abord le notebook 00_Setup.ipynb"
    )
cfg = json.load(open(_cfg_path))

PROCESSED_DIR    = Path(cfg["PROCESSED_DIR"])
FIGURES_DIR      = Path(cfg["FIGURES_DIR"])
MODELS_DIR       = Path(cfg["MODELS_DIR"])
FEATURE_CHANNELS = cfg["FEATURE_CHANNELS"]
N_INPUT_CHANNELS = cfg["N_INPUT_CHANNELS"]
CH               = cfg["CH"]
GRID_SIZE        = cfg["GRID_SIZE"]
norm_stats       = cfg["norm_stats"]

sns.set_theme(style='whitegrid', font_scale=1.1)
%matplotlib inline

print(f"Config chargée depuis : {_cfg_path}")

## 2.1 Download & Convert (if needed)

If the processed `.npz` files do not exist, run:
```bash
python download_data.py
```
This downloads ~2 GB of TFRecord shards from Kaggle and converts them to numpy.

In [None]:
# Check if data exists
for split in ['train', 'val', 'test']:
    p = Path(PROCESSED_DIR) / f'{split}.npz'
    if p.exists():
        d = np.load(p)
        print(f'{split}: X={d["X"].shape}, Y={d["Y"].shape}, size={p.stat().st_size / 1e6:.1f} MB')
    else:
        print(f'{split}: NOT FOUND — run `python download_data.py`')

## 2.2 Normalisation Statistics

In [None]:
print('Default normalisation statistics (channel-wise):\n')
print(f'{"Channel":>18s}  {"Mean":>10s}  {"Std":>10s}')
print('-' * 42)
for name in FEATURE_CHANNELS:
    if name in DEFAULT_STATS:
        m, s = DEFAULT_STATS[name]
        print(f'{name:>18s}  {m:10.4f}  {s:10.4f}')
    else:
        print(f'{name:>18s}  (not normalised)')

## 2.3 Normalisation Visualisation

In [None]:
# Load a few training samples
train_path = Path(PROCESSED_DIR) / 'train.npz'
if train_path.exists():
    data = np.load(train_path)
    X_raw = data['X'][:20]  # first 20 samples
    Y_raw = data['Y'][:20]
    
    # Normalise
    X_norm = np.stack([normalise(x) for x in X_raw])
    
    # Compare raw vs normalised for a sample
    idx = 0
    fig, axes = plt.subplots(2, 4, figsize=(18, 8))
    show_channels = ['elevation', 'wind_speed', 'ndvi', 'humidity']
    
    for j, ch_name in enumerate(show_channels):
        ci = CH[ch_name]
        axes[0, j].imshow(X_raw[idx, ci], cmap='viridis')
        axes[0, j].set_title(f'{ch_name} (raw)', fontweight='bold')
        axes[0, j].axis('off')
        
        axes[1, j].imshow(X_norm[idx, ci], cmap='viridis')
        axes[1, j].set_title(f'{ch_name} (normalised)', fontweight='bold')
        axes[1, j].axis('off')
    
    plt.suptitle('Raw vs Normalised Features', fontsize=14, fontweight='bold')
    plt.tight_layout()
    plt.savefig('../results/figures/preprocessing_normalisation.png', dpi=150, bbox_inches='tight')
    plt.show()
else:
    print('No training data. Run download_data.py first.')

## 2.4 Data Augmentation

In [None]:
if train_path.exists():
    x_sample = X_norm[0]  # (C, H, W)
    y_sample = Y_raw[0]   # (1, H, W)
    
    fig, axes = plt.subplots(2, 4, figsize=(18, 8))
    
    for j in range(4):
        xa, ya = augment_sample(x_sample, y_sample, seed=j)
        axes[0, j].imshow(xa[CH['elevation']], cmap='terrain')
        axes[0, j].set_title(f'Elevation (aug {j})', fontweight='bold')
        axes[0, j].axis('off')
        
        axes[1, j].imshow(ya.squeeze(), cmap='hot', vmin=0, vmax=1)
        axes[1, j].set_title(f'FireMask (aug {j})', fontweight='bold')
        axes[1, j].axis('off')
    
    plt.suptitle('Augmented Views of the Same Sample', fontsize=14, fontweight='bold')
    plt.tight_layout()
    plt.savefig('../results/figures/preprocessing_augmentations.png', dpi=150, bbox_inches='tight')
    plt.show()

## 2.5 NaN Handling Verification

In [None]:
if train_path.exists():
    full_train = np.load(train_path)
    X_all = full_train['X']
    
    print('NaN fraction per channel (should be 0 after processing):\n')
    for i, name in enumerate(FEATURE_CHANNELS):
        nan_frac = np.isnan(X_all[:, i]).mean()
        status = 'OK' if nan_frac == 0 else f'WARNING: {nan_frac:.4%}'
        print(f'  {name:>18s}: {status}')
    
    print(f'\nTarget NaN: {np.isnan(full_train["Y"]).mean():.4%}')

## Summary

The preprocessing pipeline converts raw TFRecord satellite observations into normalised 64×64 tensors:

- **12 input channels**, z-score normalised (except binary `prev_fire_mask`)
- **1 binary target** (next-day fire mask)
- **Geographic pre-split**: train / val / test — no spatial leakage
- **Augmentation**: random flips & rotations applied on-the-fly during training
- **NaN → 0**: missing values replaced to ensure numerical stability