# 02_preprocessing.ipynb

**Purpose**: Load raw NetCDF files, normalize data, create sliding window sequences, and save Train/Val/Test splits.

**Fixed (v2)**: 
- Handles both `valid_time` and `time` coordinates
- Supports both `data_0*.nc` and `era5land_*.nc` file patterns
- Robust error handling for CDS API changes

In [None]:
# Cell 1: Mount Drive & Install Dependencies
from google.colab import drive
drive.mount('/content/drive')

PROJECT_ROOT = '/content/drive/MyDrive/WeatherPaper'

!pip install xarray netCDF4 pyyaml numpy

In [None]:
# Cell 2: Imports and Configuration
import os
import glob
import yaml
import numpy as np
import xarray as xr
import pandas as pd

# Load Config
config_path = os.path.join(PROJECT_ROOT, 'config/project_scope.yaml')
if os.path.exists(config_path):
    with open(config_path, 'r') as f:
        scope_config = yaml.safe_load(f)
else:
    # Default config if not found
    scope_config = {
        'time_split': {
            'train_years': [2015, 2016, 2017, 2018, 2019, 2020, 2021],
            'val_years': [2022, 2023],
            'test_years': [2024, 2025]
        }
    }
    print("Using default config.")

print("Config Loaded.")
print(f"Train Years: {scope_config['time_split']['train_years']}")
print(f"Val Years: {scope_config['time_split']['val_years']}")
print(f"Test Years: {scope_config['time_split']['test_years']}")

## 1. Load Dataset (Robust Approach)

We load ALL files directly. The `preprocess` function handles:
- `valid_time` â†’ `time` coordinate renaming (CDS API v2 change)
- `expver` dimension flattening
- Extra coordinate cleanup

In [None]:
# Cell 3: Define Preprocessing Function (FIXED FOR CDS API v2)
def preprocess_era5(ds):
    """Standardize each file before combining.
    
    Handles:
    - valid_time -> time renaming (CDS API v2)
    - expver dimension flattening
    - Cleanup of extra coordinates
    """
    
    # 1. Rename valid_time -> time (CDS API v2 uses valid_time)
    if 'valid_time' in ds.coords and 'time' not in ds.coords:
        ds = ds.rename({'valid_time': 'time'})
        print(f"  Renamed valid_time -> time")
    
    # 2. Handle expver dimension (ERA5 vs ERA5T mixing)
    if 'expver' in ds.dims:
        # Flatten expver: prefer ERA5 (1), fill gaps with ERA5T (5)
        try:
            # Try string keys first (common in newer CDS downloads)
            ds_1 = ds.sel(expver='0001')
            ds_5 = ds.sel(expver='0005')
            ds = ds_1.combine_first(ds_5)
        except (KeyError, ValueError):
            try:
                # Try integer keys
                ds_1 = ds.sel(expver=1)
                ds_5 = ds.sel(expver=5)
                ds = ds_1.combine_first(ds_5)
            except (KeyError, ValueError):
                # Fallback: just take the first index
                ds = ds.isel(expver=0, drop=True)
    elif 'expver' in ds.coords:
        # It's a coordinate but not a dimension - drop it
        ds = ds.drop_vars('expver', errors='ignore')
    
    # 3. Drop 'number' coordinate if present (ensemble member indicator)
    if 'number' in ds.coords:
        ds = ds.drop_vars('number', errors='ignore')
    
    # 4. Ensure float32 for memory efficiency
    for var in ds.data_vars:
        if ds[var].dtype == 'float64':
            ds[var] = ds[var].astype('float32')
        
    return ds

In [None]:
# Cell 4: Find and Load All Files
raw_dir = os.path.join(PROJECT_ROOT, 'data/raw')

# Support both file naming patterns
all_files = sorted(glob.glob(os.path.join(raw_dir, "era5land_*.nc")))
if not all_files:
    # Try alternative pattern (data_0*.nc from direct downloads)
    all_files = sorted(glob.glob(os.path.join(raw_dir, "data_0*.nc")))
if not all_files:
    # Try any NetCDF file
    all_files = sorted(glob.glob(os.path.join(raw_dir, "*.nc")))

print(f"Found {len(all_files)} NetCDF files.")
if all_files:
    print(f"Sample: {os.path.basename(all_files[0])}")

if not all_files:
    raise FileNotFoundError(f"No NetCDF files found in {raw_dir}")

# Quick validation of first file
print("\nValidating first file structure...")
test_ds = xr.open_dataset(all_files[0])
print(f"  Coordinates: {list(test_ds.coords)}")
print(f"  Variables: {list(test_ds.data_vars)}")
print(f"  Dimensions: {dict(test_ds.dims)}")
test_ds.close()

In [None]:
# Cell 5: Load All Files with Preprocessing
print("Loading dataset (this may take a few minutes)...")

# Load with netcdf4 engine
# parallel=False to prevent HDF5 locking issues on Drive
ds = xr.open_mfdataset(
    all_files,
    combine='by_coords',
    engine='netcdf4',
    preprocess=preprocess_era5,
    parallel=False,  # Critical: prevents HDF5 race conditions
    lock=False       # Also helps with Drive access
)

# Sort by time to ensure chronological order
ds = ds.sortby('time')

print(f"\n=== DATASET LOADED SUCCESSFULLY ===")
print(f"Time range: {str(ds.time.values[0])[:19]} to {str(ds.time.values[-1])[:19]}")
print(f"Variables: {list(ds.data_vars)}")
print(f"Dimensions: {dict(ds.dims)}")

In [None]:
# Cell 6: Enforce Hourly Continuity
print("Resampling to hourly frequency (filling gaps with NaN)...")
ds = ds.resample(time='1h').asfreq()
print(f"Total hourly timesteps: {len(ds.time)}")

## 2. Normalization (Train Stats Only)

In [None]:
# Cell 7: Load into Memory and Normalize
train_years = scope_config['time_split']['train_years']
val_years = scope_config['time_split']['val_years']
test_years = scope_config['time_split']['test_years']

# Use variables in consistent order
available_vars = list(ds.data_vars)
ordered_vars = [v for v in ['tp', 't2m', 'msl'] if v in available_vars]
if not ordered_vars:
    ordered_vars = available_vars

print(f"Processing variables: {ordered_vars}")

# Convert to numpy array: (time, lat, lon, channel)
print("Loading into memory...")
data_xr = ds[ordered_vars].to_array(dim='channel').transpose('time', 'latitude', 'longitude', 'channel')
data_np = data_xr.values.astype(np.float32)
times = ds.time.values

print(f"Data shape: {data_np.shape}")

# Create year-based masks
years = pd.to_datetime(times).year
train_mask = np.isin(years, train_years)
val_mask = np.isin(years, val_years)
test_mask = np.isin(years, test_years)

print(f"Train samples: {train_mask.sum()}, Val samples: {val_mask.sum()}, Test samples: {test_mask.sum()}")

# Compute normalization stats from TRAINING data only (prevent leakage)
train_data = data_np[train_mask]
mean = np.nanmean(train_data, axis=(0, 1, 2), keepdims=True)
std = np.nanstd(train_data, axis=(0, 1, 2), keepdims=True)
std = np.where(std < 1e-6, 1.0, std)  # Prevent division by zero

print(f"Mean: {mean.flatten()}")
print(f"Std: {std.flatten()}")

# Normalize entire dataset
data_norm = (data_np - mean) / std
print("Normalization complete.")

## 3. Create Sequences

In [None]:
# Cell 8: Sequence Generation with NaN Filtering
T_IN = 24   # 24 hours of input (past)
T_OUT = 6   # 6 hours of output (future)

def create_sequences(data, timestamps, year_mask, t_in, t_out):
    """Create input/output sequences, filtering out any with NaN values."""
    X_list, Y_list, T_list = [], [], []
    
    # Get indices where year mask is true
    valid_idx = np.where(year_mask)[0]
    # Ensure we have enough context before and after
    valid_idx = valid_idx[(valid_idx >= t_in) & (valid_idx < len(data) - t_out)]
    
    nan_skip_count = 0
    
    for i in valid_idx:
        x_seq = data[i - t_in:i]       # Past 24 hours
        y_seq = data[i:i + t_out]      # Next 6 hours
        
        # Skip if any NaN values (gap in data)
        if np.isnan(x_seq).any() or np.isnan(y_seq).any():
            nan_skip_count += 1
            continue
        
        X_list.append(x_seq)
        Y_list.append(y_seq)
        T_list.append(timestamps[i])
    
    print(f"  Created {len(X_list)} sequences, skipped {nan_skip_count} due to NaN.")
    
    if not X_list:
        return np.array([]), np.array([]), np.array([])
    
    return np.array(X_list, dtype='float32'), np.array(Y_list, dtype='float32'), np.array(T_list)

print("Generating sequences...")
print("Train:")
X_train, Y_train, T_train = create_sequences(data_norm, times, train_mask, T_IN, T_OUT)
print("Validation:")
X_val, Y_val, T_val = create_sequences(data_norm, times, val_mask, T_IN, T_OUT)
print("Test:")
X_test, Y_test, T_test = create_sequences(data_norm, times, test_mask, T_IN, T_OUT)

print(f"\nFinal shapes:")
print(f"  Train: X={X_train.shape}, Y={Y_train.shape}")
print(f"  Val:   X={X_val.shape}, Y={Y_val.shape}")
print(f"  Test:  X={X_test.shape}, Y={Y_test.shape}")

In [None]:
# Cell 9: Save Processed Data
processed_dir = os.path.join(PROJECT_ROOT, 'data/processed')
os.makedirs(processed_dir, exist_ok=True)

# Save splits
np.savez_compressed(os.path.join(processed_dir, 'train.npz'), x=X_train, y=Y_train, time=T_train)
np.savez_compressed(os.path.join(processed_dir, 'val.npz'), x=X_val, y=Y_val, time=T_val)
np.savez_compressed(os.path.join(processed_dir, 'test.npz'), x=X_test, y=Y_test, time=T_test)

# Save normalization stats for inference
np.savez_compressed(os.path.join(processed_dir, 'stats.npz'), mean=mean, std=std, variables=ordered_vars)

print(f"\n=== PREPROCESSING COMPLETE ===")
print(f"Saved to: {processed_dir}")
print(f"  - train.npz: {X_train.shape[0]} samples")
print(f"  - val.npz:   {X_val.shape[0]} samples")
print(f"  - test.npz:  {X_test.shape[0]} samples")
print(f"  - stats.npz: normalization parameters")