In [1]:
# Multi-Year ConvLSTM Wave Forecasting Pipeline
# Using 2020-2024 ERA5 Data
# ======================================================================

import xarray as xr
import numpy as np
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import mean_absolute_error, mean_squared_error, r2_score
from tensorflow.keras.models import Model
from tensorflow.keras.layers import (
    ConvLSTM2D, Conv3D, BatchNormalization, Input, Dropout
)
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.callbacks import EarlyStopping, ReduceLROnPlateau, ModelCheckpoint
import tensorflow as tf
import pickle
from pathlib import Path

# CONFIGURATION
DATA_YEARS = [2020, 2021, 2022, 2023, 2024]
DATA_DIR = "../"  # Directory containing surf_data_YYYY.nc files
SPOT_LAT, SPOT_LON = 6.8399, 81.8396

# Training config
LOOKBACK_HOURS = 19
LOOKAHEAD_HOURS = 1
BATCH_SIZE = 16  # Can increase with more data
EPOCHS = 10
LEARNING_RATE = 1e-3

# Data split strategy
TRAIN_YEARS = [2020, 2021, 2022, 2023]  # 4 years training
VAL_YEAR = 2024  # Hold out 2024 for validation (most recent)

INPUT_FEATURES = ["u10", "v10", "msl", "shts", "mpts", "mdts"]
TARGET_FEATURES = ["shts", "mpts", "mdts"]

In [2]:
# MULTI-YEAR DATA LOADING
# ======================================================================
def load_multiyear_dataset(years, data_dir="."):
    """
    Load and concatenate multiple years of ERA5 data.
    """
    print("="*70)
    print("LOADING MULTI-YEAR DATASET")
    print("="*70)
    
    datasets = []
    for year in years:
        filepath = Path(data_dir) / f"surf_data_{year}.nc"
        
        if not filepath.exists():
            print(f"‚ö†Ô∏è  Warning: {filepath} not found, skipping...")
            continue
        
        try:
            ds = xr.open_dataset(filepath)
            
            # Verify required variables exist
            missing_vars = set(INPUT_FEATURES + TARGET_FEATURES) - set(ds.data_vars)
            if missing_vars:
                print(f"‚ö†Ô∏è  {year}: Missing variables {missing_vars}, skipping...")
                continue
            
            # Add year attribute for tracking
            ds.attrs['year'] = year
            datasets.append(ds)
            
            print(f"‚úì Loaded {year}: {len(ds.valid_time)} timesteps, "
                  f"shape={ds[INPUT_FEATURES[0]].shape}")
            
        except Exception as e:
            print(f"‚ùå Error loading {year}: {e}")
            continue
    
    if not datasets:
        raise ValueError("No valid datasets loaded!")
    
    # Concatenate along time dimension
    print("\nüîó Concatenating datasets...")
    combined_ds = xr.concat(datasets, dim='valid_time')
    
    # Sort by time (important!)
    combined_ds = combined_ds.sortby('valid_time')
    
    print(f"\n‚úì Combined dataset:")
    print(f"   Total timesteps: {len(combined_ds.valid_time)}")
    print(f"   Time range: {combined_ds.valid_time.values[0]} to {combined_ds.valid_time.values[-1]}")
    print(f"   Spatial shape: {len(combined_ds.latitude)} √ó {len(combined_ds.longitude)}")
    print("="*70)
    
    return combined_ds


def split_by_year(ds, train_years, val_years):
    """
    Split dataset by year (better than random split for time series).
    """
    train_mask = ds.valid_time.dt.year.isin(train_years)
    val_mask = ds.valid_time.dt.year.isin(val_years)
    
    ds_train = ds.isel(valid_time=train_mask)
    ds_val = ds.isel(valid_time=val_mask)
    
    print(f"\nüìä Data Split:")
    print(f"   Training years: {train_years} ‚Üí {len(ds_train.valid_time)} samples")
    print(f"   Validation years: {val_years} ‚Üí {len(ds_val.valid_time)} samples")
    print(f"   Split ratio: {len(ds_train.valid_time)/(len(ds_train.valid_time)+len(ds_val.valid_time)):.1%} train")
    
    return ds_train, ds_val

In [3]:
# ENHANCED PREPROCESSING FOR MULTI-YEAR DATA
# ======================================================================
def create_ocean_mask(ds, threshold=0.5):
    """Create persistent ocean mask across all years."""
    shts_data = ds["shts"].values
    nan_ratio = np.isnan(shts_data).sum(axis=0) / shts_data.shape[0]
    ocean_mask = nan_ratio < threshold
    
    valid_points = ocean_mask.sum()
    print(f"\nüåä Ocean Mask: {valid_points}/{ocean_mask.size} points ({100*valid_points/ocean_mask.size:.1f}%)")
    return ocean_mask


def engineer_frames(ds, ocean_mask):
    """Extract and prepare spatial frames."""
    X = ds[INPUT_FEATURES].to_array(dim="channel").transpose(
        "valid_time", "latitude", "longitude", "channel"
    ).values
    
    y = ds[TARGET_FEATURES].to_array(dim="channel").transpose(
        "valid_time", "latitude", "longitude", "channel"
    ).values
    
    # Apply ocean mask
    mask_3d_X = np.broadcast_to(ocean_mask[..., np.newaxis], X.shape[1:])
    mask_3d_y = np.broadcast_to(ocean_mask[..., np.newaxis], y.shape[1:])
    
    X[:, ~mask_3d_X] = 0.0
    y[:, ~mask_3d_y] = 0.0
    
    # Fill NaNs with channel-wise means
    for i in range(X.shape[-1]):
        channel_mean = np.nanmean(X[..., i])
        X[..., i] = np.nan_to_num(X[..., i], nan=channel_mean)
    
    for i in range(y.shape[-1]):
        channel_mean = np.nanmean(y[..., i])
        y[..., i] = np.nan_to_num(y[..., i], nan=channel_mean)
    
    return X, y


def fit_scalers(X_train, y_train):
    """
    Fit scalers ONLY on training data to prevent data leakage.
    """
    print("\nüìè Fitting scalers on training data...")
    
    scalers_X = []
    for i in range(X_train.shape[-1]):
        scaler = StandardScaler()
        channel = X_train[..., i].reshape(-1, 1)
        scaler.fit(channel)
        scalers_X.append(scaler)
    
    scalers_y = []
    for i in range(y_train.shape[-1]):
        scaler = StandardScaler()
        channel = y_train[..., i].reshape(-1, 1)
        scaler.fit(channel)
        scalers_y.append(scaler)
    
    return scalers_X, scalers_y


def apply_scalers(X, y, scalers_X, scalers_y):
    """Apply pre-fitted scalers."""
    X_scaled = np.zeros_like(X)
    y_scaled = np.zeros_like(y)
    
    for i in range(X.shape[-1]):
        channel = X[..., i].reshape(-1, 1)
        X_scaled[..., i] = scalers_X[i].transform(channel).reshape(X[..., i].shape)
    
    for i in range(y.shape[-1]):
        channel = y[..., i].reshape(-1, 1)
        y_scaled[..., i] = scalers_y[i].transform(channel).reshape(y[..., i].shape)
    
    return X_scaled, y_scaled


def create_sequences(X, y, lookback, lookahead):
    """Generate sequences with progress tracking."""
    total_sequences = len(X) - lookback - lookahead + 1
    print(f"\nüîÑ Creating {total_sequences} sequences...")
    
    X_seq = np.zeros((total_sequences, lookback, *X.shape[1:]), dtype=np.float32)
    y_seq = np.zeros((total_sequences, *y.shape[1:]), dtype=np.float32)
    
    for i in range(total_sequences):
        X_seq[i] = X[i:i + lookback]
        y_seq[i] = y[i + lookback + lookahead - 1]
        
        if (i + 1) % 500 == 0:
            print(f"   Progress: {i+1}/{total_sequences} ({100*(i+1)/total_sequences:.1f}%)")
    
    return X_seq, y_seq

In [4]:
# SAVE/LOAD PREPROCESSED DATA
# ======================================================================
def save_preprocessed_data(X_train, y_train, X_val, y_val, scalers_X, scalers_y, filename="preprocessed_data.pkl"):
    """Save preprocessed data to avoid reprocessing."""
    print(f"\nüíæ Saving preprocessed data to {filename}...")
    
    data = {
        'X_train': X_train,
        'y_train': y_train,
        'X_val': X_val,
        'y_val': y_val,
        'scalers_X': scalers_X,
        'scalers_y': scalers_y
    }
    
    with open(filename, 'wb') as f:
        pickle.dump(data, f, protocol=4)
    
    print(f"‚úì Saved {filename} ({Path(filename).stat().st_size / 1e6:.1f} MB)")


def load_preprocessed_data(filename="preprocessed_data.pkl"):
    """Load preprocessed data."""
    print(f"\nüìÇ Loading preprocessed data from {filename}...")
    
    with open(filename, 'rb') as f:
        data = pickle.load(f)
    
    print("‚úì Loaded preprocessed data")
    return data

In [5]:
# IMPROVED MODEL ARCHITECTURE
# ======================================================================
def build_enhanced_model(input_shape, output_channels):

    inputs = Input(shape=input_shape)
    
    # Encoder with increasing capacity
    x = ConvLSTM2D(
        128, (3, 3),
        padding='same',
        return_sequences=True,
        kernel_regularizer=tf.keras.regularizers.l2(1e-5)
    )(inputs)
    x = BatchNormalization()(x)
    x = Dropout(0.2)(x)
    
    x = ConvLSTM2D(
        64, (3, 3),
        padding='same',
        return_sequences=True,
        kernel_regularizer=tf.keras.regularizers.l2(1e-5)
    )(x)
    x = BatchNormalization()(x)
    x = Dropout(0.2)(x)
    
    x = ConvLSTM2D(
        32, (3, 3),
        padding='same',
        return_sequences=True
    )(x)
    x = BatchNormalization()(x)
    
    # Conv3D prediction layer
    x = Conv3D(
        output_channels,
        kernel_size=(3, 3, 3),
        padding='same',
        activation='linear'
    )(x)
    
    # Extract last timestep
    outputs = x[:, -1, :, :, :]
    
    model = Model(inputs=inputs, outputs=outputs)
    
    model.compile(
        optimizer=Adam(learning_rate=LEARNING_RATE),
        loss='mse',
        metrics=['mae']
    )
    
    return model

In [6]:
# EVALUATION
# ======================================================================
def evaluate_model(model, X_val, y_val, scalers_y, target_names):
    """Comprehensive evaluation."""
    print("\n" + "="*70)
    print("MODEL EVALUATION")
    print("="*70)
    
    y_pred_scaled = model.predict(X_val, batch_size=32, verbose=1)
    
    # Inverse transform
    y_true = np.zeros_like(y_val)
    y_pred = np.zeros_like(y_pred_scaled)
    
    for i in range(y_val.shape[-1]):
        y_true[..., i] = scalers_y[i].inverse_transform(
            y_val[..., i].reshape(-1, 1)
        ).reshape(y_val[..., i].shape)
        y_pred[..., i] = scalers_y[i].inverse_transform(
            y_pred_scaled[..., i].reshape(-1, 1)
        ).reshape(y_pred_scaled[..., i].shape)
    
    # Overall metrics
    mae = mean_absolute_error(y_true.flatten(), y_pred.flatten())
    rmse = np.sqrt(mean_squared_error(y_true.flatten(), y_pred.flatten()))
    
    print(f"\nOverall Spatial Metrics:")
    print(f"   MAE:  {mae:.4f}")
    print(f"   RMSE: {rmse:.4f}")
    
    # Per-variable metrics
    print(f"\nPer-Variable Metrics:")
    for i, var in enumerate(target_names):
        mae_var = mean_absolute_error(y_true[..., i].flatten(), y_pred[..., i].flatten())
        rmse_var = np.sqrt(mean_squared_error(y_true[..., i].flatten(), y_pred[..., i].flatten()))
        r2_var = r2_score(y_true[..., i].flatten(), y_pred[..., i].flatten())
        
        print(f"   {var:>6s}: MAE={mae_var:.3f}, RMSE={rmse_var:.3f}, R¬≤={r2_var:.3f}")
    
    print("="*70)
    
    return y_true, y_pred


def find_nearest_ocean_point(ds, lat, lon):
    """Find nearest valid ocean point."""
    nearest = ds.sel(latitude=lat, longitude=lon, method="nearest")
    
    if not np.isnan(nearest["shts"].isel(valid_time=0)):
        return float(nearest.latitude), float(nearest.longitude)
    
    shts_data = ds["shts"].isel(valid_time=0).stack(
        point=("latitude", "longitude")
    ).dropna("point")
    
    R = 6371
    lat1, lon1 = np.radians(lat), np.radians(lon)
    lat2 = np.radians(shts_data.latitude)
    lon2 = np.radians(shts_data.longitude)
    
    dlat = (lat2 - lat1) / 2
    dlon = (lon2 - lon1) / 2
    
    a = np.sin(dlat)**2 + np.cos(lat1) * np.cos(lat2) * np.sin(dlon)**2
    distance = 2 * R * np.arcsin(np.sqrt(a))
    
    closest = shts_data.isel(point=distance.argmin())
    return float(closest.latitude), float(closest.longitude)

In [7]:
# MAIN PIPELINE
def main():
    print("üåä MULTI-YEAR CONVLSTM WAVE FORECASTING")
    print("="*70)
    
    # Check if preprocessed data exists
    preprocessed_file = "preprocessed_multiyear.pkl"
    
    if Path(preprocessed_file).exists():
        print(f"\n‚úì Found {preprocessed_file}")
        user_input = input("Load preprocessed data? (y/n): ")
        
        if user_input.lower() == 'y':
            data = load_preprocessed_data(preprocessed_file)
            X_train, y_train = data['X_train'], data['y_train']
            X_val, y_val = data['X_val'], data['y_val']
            scalers_X, scalers_y = data['scalers_X'], data['scalers_y']
            
            print(f"‚úì Loaded: Train={X_train.shape}, Val={X_val.shape}")
            
            # Still need to load ds for metadata
            ds = load_multiyear_dataset(DATA_YEARS, DATA_DIR)
            
            skip_preprocessing = True
        else:
            skip_preprocessing = False
    else:
        skip_preprocessing = False
    
    if not skip_preprocessing:
        # Step 1: Load all years
        ds = load_multiyear_dataset(DATA_YEARS, DATA_DIR)
        
        # Step 2: Split by year
        ds_train, ds_val = split_by_year(ds, TRAIN_YEARS, [VAL_YEAR])
        
        # Step 3: Create ocean mask (from combined data)
        ocean_mask = create_ocean_mask(ds)
        
        # Step 4: Engineer features
        print("\nüìä Engineering features...")
        X_train, y_train = engineer_frames(ds_train, ocean_mask)
        X_val, y_val = engineer_frames(ds_val, ocean_mask)
        
        print(f"   Train: X={X_train.shape}, y={y_train.shape}")
        print(f"   Val:   X={X_val.shape}, y={y_val.shape}")
        
        # Step 5: Fit scalers on training data only
        scalers_X, scalers_y = fit_scalers(X_train, y_train)
        
        # Step 6: Apply scalers
        print("\nüîß Applying scalers...")
        X_train, y_train = apply_scalers(X_train, y_train, scalers_X, scalers_y)
        X_val, y_val = apply_scalers(X_val, y_val, scalers_X, scalers_y)
        
        # Step 7: Create sequences
        X_train, y_train = create_sequences(X_train, y_train, LOOKBACK_HOURS, LOOKAHEAD_HOURS)
        X_val, y_val = create_sequences(X_val, y_val, LOOKBACK_HOURS, LOOKAHEAD_HOURS)
        
        print(f"\n‚úì Final shapes:")
        print(f"   Train: X={X_train.shape}, y={y_train.shape}")
        print(f"   Val:   X={X_val.shape}, y={y_val.shape}")
        
        # Save for future runs
        save_preprocessed_data(X_train, y_train, X_val, y_val, scalers_X, scalers_y, preprocessed_file)
    
    # Step 8: Build model
    print("\nüèóÔ∏è Building enhanced model...")
    model = build_enhanced_model(X_train.shape[1:], y_train.shape[-1])
    model.summary()
    
    # Step 9: Callbacks
    callbacks = [
        EarlyStopping(
            monitor='val_loss',
            patience=3,
            restore_best_weights=True,
            verbose=1
        ),
        ReduceLROnPlateau(
            monitor='val_loss',
            factor=0.1,
            patience=2,
            min_lr=1e-7,
            verbose=1
        )
    ]
    
    # Step 10: Train
    print(f"\nüöÄ Training on {len(X_train)} samples...")
    print(f"   Validation on {len(X_val)} samples from {VAL_YEAR}")
    
    history = model.fit(
        X_train, y_train,
        validation_data=(X_val, y_val),
        epochs=EPOCHS,
        batch_size=BATCH_SIZE,
        callbacks=callbacks,
        verbose=1
    )
    
    # Step 11: Evaluate
    evaluate_model(model, X_val, y_val, scalers_y, TARGET_FEATURES)
    
    # Step 12: Forecast
    print("\nüîÆ Making forecast...")
    buoy_lat, buoy_lon = find_nearest_ocean_point(ds, SPOT_LAT, SPOT_LON)
    
    # Get most recent sequence from validation set
    last_seq = X_val[-1:] 
    forecast_scaled = model.predict(last_seq, verbose=0)[0]
    
    # Inverse transform
    forecast = np.zeros_like(forecast_scaled)
    for i in range(forecast_scaled.shape[-1]):
        forecast[..., i] = scalers_y[i].inverse_transform(
            forecast_scaled[..., i].reshape(-1, 1)
        ).reshape(forecast_scaled[..., i].shape)
    
    # Extract buoy location
    lats = ds.latitude.values
    lons = ds.longitude.values
    lat_idx = np.argmin(np.abs(lats - buoy_lat))
    lon_idx = np.argmin(np.abs(lons - buoy_lon))
    
    buoy_forecast = forecast[lat_idx, lon_idx, :]
    
    print("\n" + "="*70)
    print(f"FORECAST at ({buoy_lat:.2f}¬∞N, {buoy_lon:.2f}¬∞E)")
    print("="*70)
    print(f"Swell Height:    {buoy_forecast[0]:.2f} m")
    print(f"Swell Period:    {buoy_forecast[1]:.2f} s")
    print(f"Swell Direction: {buoy_forecast[2]:.1f}¬∞")
    print("="*70)
    
    print("\n‚úì Training complete! Model saved as 'best_multiyear_model.keras'")


if __name__ == "__main__":
    main()

üåä MULTI-YEAR CONVLSTM WAVE FORECASTING

‚úì Found preprocessed_multiyear.pkl


Load preprocessed data? (y/n):  y



üìÇ Loading preprocessed data from preprocessed_multiyear.pkl...
‚úì Loaded preprocessed data
‚úì Loaded: Train=(5825, 19, 21, 21, 6), Val=(1445, 19, 21, 21, 6)
LOADING MULTI-YEAR DATASET
‚úì Loaded 2020: 1464 timesteps, shape=(1464, 21, 21)
‚úì Loaded 2021: 1460 timesteps, shape=(1460, 21, 21)
‚úì Loaded 2022: 1460 timesteps, shape=(1460, 21, 21)
‚úì Loaded 2023: 1460 timesteps, shape=(1460, 21, 21)
‚úì Loaded 2024: 1464 timesteps, shape=(1464, 21, 21)

üîó Concatenating datasets...

‚úì Combined dataset:
   Total timesteps: 7308
   Time range: 2020-01-01T00:00:00.000000000 to 2024-12-31T18:00:00.000000000
   Spatial shape: 21 √ó 21

üèóÔ∏è Building enhanced model...
Model: "model"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 input_1 (InputLayer)        [(None, 19, 21, 21, 6)]   0         
                                                                 
 conv_lstm2d (ConvLSTM2D)    (None, 19, 

In [9]:
# After evaluate_model(...)
model.save("best_multiyear_model_final.keras")
print("\n‚úì Final model (with best weights) saved as 'best_multiyear_model_final.keras'")

NameError: name 'model' is not defined

In [11]:
import matplotlib.pyplot as plt

def plot_accuracy_metrics(history):
    plt.figure(figsize=(10, 5))
    plt.plot(history.history['mae'], label='Train MAE')
    plt.plot(history.history['val_mae'], label='Val MAE')
    plt.title('Model Accuracy (MAE)')
    plt.xlabel('Epoch')
    plt.ylabel('Mean Absolute Error')
    plt.legend()
    plt.grid(True)
    plt.show()

# Example use after model.fit():
plot_accuracy_metrics(history)


NameError: name 'history' is not defined

In [None]:
def plot_performance(history):
    plt.figure(figsize=(10, 5))
    plt.plot(history.history['loss'], label='Train Loss')
    plt.plot(history.history['val_loss'], label='Validation Loss')
    plt.title('Model Performance (Loss)')
    plt.xlabel('Epoch')
    plt.ylabel('MSE Loss')
    plt.legend()
    plt.grid(True)
    plt.show()

# Example use:
# plot_performance(history)


In [None]:
import time
import psutil

start_time = time.time()
process = psutil.Process()

# Load or preprocess (simulate)
_ = load_multiyear_dataset(DATA_YEARS, DATA_DIR)

elapsed = time.time() - start_time
memory_used = process.memory_info().rss / (1024**3)
print(f"Elapsed Time: {elapsed:.2f} s, Memory Used: {memory_used:.2f} GB")


In [None]:
plt.scatter(y_true[...,0].flatten(), y_pred[...,0].flatten(), alpha=0.3)
plt.xlabel("True Swell Height (m)")
plt.ylabel("Predicted Swell Height (m)")
plt.title("True vs Predicted Swell Height")
plt.grid(True)
plt.show()


In [12]:
whos

Variable                   Type        Data/Info
------------------------------------------------
Adam                       type        <class 'keras.src.optimizers.adam.Adam'>
BATCH_SIZE                 int         16
BatchNormalization         type        <class 'keras.src.layers.<...>tion.BatchNormalization'>
Conv3D                     type        <class 'keras.src.layers.<...>olutional.conv3d.Conv3D'>
ConvLSTM2D                 type        <class 'keras.src.layers.<...>.conv_lstm2d.ConvLSTM2D'>
DATA_DIR                   str         ../
DATA_YEARS                 list        n=5
Dropout                    type        <class 'keras.src.layers.<...>ization.dropout.Dropout'>
EPOCHS                     int         10
EarlyStopping              type        <class 'keras.src.callbacks.EarlyStopping'>
INPUT_FEATURES             list        n=6
Input                      function    <function Input at 0x0000023AFB431AF0>
LEARNING_RATE              float       0.001
LOOKAHEAD_HOURS        