In [1]:
import xarray as xr
import numpy as np
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import mean_absolute_error, mean_squared_error, r2_score
from tensorflow.keras.models import Model
from tensorflow.keras.layers import (
    ConvLSTM2D, Conv3D, BatchNormalization, Input, Dropout
)
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.callbacks import EarlyStopping, ReduceLROnPlateau, ModelCheckpoint
import tensorflow as tf
import pickle
from pathlib import Path

DATA_YEARS = [2020, 2021, 2022, 2023, 2024, 2025]
DATA_DIR = "../"
SPOT_LAT, SPOT_LON = 6.8399, 81.8396

LOOKBACK_HOURS = 19
LOOKAHEAD_HOURS = 1
BATCH_SIZE = 16
EPOCHS = 10
LEARNING_RATE = 1e-3

TRAIN_YEARS = [2020, 2021, 2022, 2023]
VAL_YEAR = 2024

INPUT_FEATURES = ["u10", "v10", "msl", "shts", "mpts", "mdts"]
TARGET_FEATURES = ["shts", "mpts", "mdts"]

In [2]:
def load_multiyear_dataset(years, data_dir="."):   
    datasets = []
    for year in years:
        filepath = Path(data_dir) / f"surf_data_{year}.nc"        
        try:
            ds = xr.open_dataset(filepath)
            ds.attrs['year'] = year
            datasets.append(ds)
        except Exception as e:
            print(f"Error loading {year}: {e}")
            continue
    if not datasets:
        raise ValueError("No valid datasets loaded!")
    
    combined_ds = xr.concat(datasets, dim='valid_time')
    combined_ds = combined_ds.sortby('valid_time')
    print(f"\n Combined dataset:")
    print(f"   Total timesteps: {len(combined_ds.valid_time)}")
    print(f"   Time range: {combined_ds.valid_time.values[0]} to {combined_ds.valid_time.values[-1]}")
    print(f"   Spatial shape: {len(combined_ds.latitude)} Ã— {len(combined_ds.longitude)}")
    
    return combined_ds

def split_by_year(ds, train_years, val_years):
    train_mask = ds.valid_time.dt.year.isin(train_years)
    val_mask = ds.valid_time.dt.year.isin(val_years)
    ds_train = ds.isel(valid_time=train_mask)
    ds_val = ds.isel(valid_time=val_mask)
    
    print(f"\n Data Split:")
    print(f"   Training years: {train_years} â†’ {len(ds_train.valid_time)} samples")
    print(f"   Validation years: {val_years} â†’ {len(ds_val.valid_time)} samples")
    print(f"   Split ratio: {len(ds_train.valid_time)/(len(ds_train.valid_time)+len(ds_val.valid_time)):.1%} train")
    
    return ds_train, ds_val

In [None]:
def create_ocean_mask(ds, threshold=0.5):
    shts_data = ds["shts"].values
    nan_ratio = np.isnan(shts_data).sum(axis=0) / shts_data.shape[0]
    ocean_mask = nan_ratio < threshold
    
    valid_points = ocean_mask.sum()
    print(f"\n Ocean Mask: {valid_points}/{ocean_mask.size} points ({100*valid_points/ocean_mask.size:.1f}%)")
    return ocean_mask


def engineer_frames(ds, ocean_mask):
    X = ds[INPUT_FEATURES].to_array(dim="channel").transpose(
        "valid_time", "latitude", "longitude", "channel"
    ).values
    
    y = ds[TARGET_FEATURES].to_array(dim="channel").transpose(
        "valid_time", "latitude", "longitude", "channel"
    ).values

    mask_3d_X = np.broadcast_to(ocean_mask[..., np.newaxis], X.shape[1:])
    mask_3d_y = np.broadcast_to(ocean_mask[..., np.newaxis], y.shape[1:])
    
    X[:, ~mask_3d_X] = 0.0
    y[:, ~mask_3d_y] = 0.0
    
    for i in range(X.shape[-1]):
        channel_mean = np.nanmean(X[..., i])
        X[..., i] = np.nan_to_num(X[..., i], nan=channel_mean)
    
    for i in range(y.shape[-1]):
        channel_mean = np.nanmean(y[..., i])
        y[..., i] = np.nan_to_num(y[..., i], nan=channel_mean)
    
    return X, y

In [3]:
def fit_scalers(X_train, y_train):    
    scalers_X = []
    for i in range(X_train.shape[-1]):
        scaler = StandardScaler()
        channel = X_train[..., i].reshape(-1, 1)
        scaler.fit(channel)
        scalers_X.append(scaler)
    
    scalers_y = []
    for i in range(y_train.shape[-1]):
        scaler = StandardScaler()
        channel = y_train[..., i].reshape(-1, 1)
        scaler.fit(channel)
        scalers_y.append(scaler)
    
    return scalers_X, scalers_y


def apply_scalers(X, y, scalers_X, scalers_y):
    X_scaled = np.zeros_like(X)
    y_scaled = np.zeros_like(y)
    
    for i in range(X.shape[-1]):
        channel = X[..., i].reshape(-1, 1)
        X_scaled[..., i] = scalers_X[i].transform(channel).reshape(X[..., i].shape)
    
    for i in range(y.shape[-1]):
        channel = y[..., i].reshape(-1, 1)
        y_scaled[..., i] = scalers_y[i].transform(channel).reshape(y[..., i].shape)
    
    return X_scaled, y_scaled

In [None]:
def create_sequences(X, y, lookback, lookahead):
    total_sequences = len(X) - lookback - lookahead + 1
    print(f"\n Creating {total_sequences} sequences...")
    
    X_seq = np.zeros((total_sequences, lookback, *X.shape[1:]), dtype=np.float32)
    y_seq = np.zeros((total_sequences, *y.shape[1:]), dtype=np.float32)
    
    for i in range(total_sequences):
        X_seq[i] = X[i:i + lookback]
        y_seq[i] = y[i + lookback + lookahead - 1]
    
    return X_seq, y_seq

In [4]:
def save_preprocessed_data(X_train, y_train, X_val, y_val, scalers_X, scalers_y, filename="preprocessed_data.pkl"):
    data = {
        'X_train': X_train,
        'y_train': y_train,
        'X_val': X_val,
        'y_val': y_val,
        'scalers_X': scalers_X,
        'scalers_y': scalers_y
    }
    with open(filename, 'wb') as f:
        pickle.dump(data, f, protocol=4)

def load_preprocessed_data(filename="preprocessed_data.pkl"):
    print(f"\n Loading preprocessed data from {filename}...")
    with open(filename, 'rb') as f:
        data = pickle.load(f)
    print(" Loaded preprocessed data")
    return data

### Model Architecture

In [5]:
def build_enhanced_model(input_shape, output_channels):

    inputs = Input(shape=input_shape)
    x = ConvLSTM2D(
        128, (3, 3),
        padding='same',
        return_sequences=True,
        kernel_regularizer=tf.keras.regularizers.l2(1e-5)
    )(inputs)
    x = BatchNormalization()(x)
    x = Dropout(0.2)(x)
    
    x = ConvLSTM2D(
        64, (3, 3),
        padding='same',
        return_sequences=True,
        kernel_regularizer=tf.keras.regularizers.l2(1e-5)
    )(x)
    x = BatchNormalization()(x)
    x = Dropout(0.2)(x)
    
    x = ConvLSTM2D(
        32, (3, 3),
        padding='same',
        return_sequences=True
    )(x)
    x = BatchNormalization()(x)
    
    x = Conv3D(
        output_channels,
        kernel_size=(3, 3, 3),
        padding='same',
        activation='linear'
    )(x)
    
    outputs = x[:, -1, :, :, :]
    model = Model(inputs=inputs, outputs=outputs)
    model.compile(
        optimizer=Adam(learning_rate=LEARNING_RATE),
        loss='mse',
        metrics=['mae']
    )
    
    return model

In [6]:
def evaluate_model(model, X_val, y_val, scalers_y, target_names):
    print("\n" + "="*70)
    print("MODEL EVALUATION")
    print("="*70)
    
    y_pred_scaled = model.predict(X_val, batch_size=32, verbose=1)
    y_true = np.zeros_like(y_val)
    y_pred = np.zeros_like(y_pred_scaled)
    
    for i in range(y_val.shape[-1]):
        y_true[..., i] = scalers_y[i].inverse_transform(
            y_val[..., i].reshape(-1, 1)
        ).reshape(y_val[..., i].shape)
        y_pred[..., i] = scalers_y[i].inverse_transform(
            y_pred_scaled[..., i].reshape(-1, 1)
        ).reshape(y_pred_scaled[..., i].shape)
    
    mae = mean_absolute_error(y_true.flatten(), y_pred.flatten())
    rmse = np.sqrt(mean_squared_error(y_true.flatten(), y_pred.flatten()))
    
    print(f"\nOverall Spatial Metrics:")
    print(f"   MAE:  {mae:.4f}")
    print(f"   RMSE: {rmse:.4f}")

    for i, var in enumerate(target_names):
        mae_var = mean_absolute_error(y_true[..., i].flatten(), y_pred[..., i].flatten())
        rmse_var = np.sqrt(mean_squared_error(y_true[..., i].flatten(), y_pred[..., i].flatten()))
        r2_var = r2_score(y_true[..., i].flatten(), y_pred[..., i].flatten())
        
        print(f"   {var:>6s}: MAE={mae_var:.3f}, RMSE={rmse_var:.3f}, RÂ²={r2_var:.3f}")
    
    return y_true, y_pred

In [None]:
def find_nearest_ocean_point(ds, lat, lon):
    nearest = ds.sel(latitude=lat, longitude=lon, method="nearest")
    
    if not np.isnan(nearest["shts"].isel(valid_time=0)):
        return float(nearest.latitude), float(nearest.longitude)
    
    shts_data = ds["shts"].isel(valid_time=0).stack(
        point=("latitude", "longitude")
    ).dropna("point")
    
    R = 6371
    lat1, lon1 = np.radians(lat), np.radians(lon)
    lat2 = np.radians(shts_data.latitude)
    lon2 = np.radians(shts_data.longitude)
    
    dlat = (lat2 - lat1) / 2
    dlon = (lon2 - lon1) / 2
    
    a = np.sin(dlat)**2 + np.cos(lat1) * np.cos(lat2) * np.sin(dlon)**2
    distance = 2 * R * np.arcsin(np.sqrt(a))
    
    closest = shts_data.isel(point=distance.argmin())
    return float(closest.latitude), float(closest.longitude)

In [7]:
def main():
    preprocessed_file = "preprocessed_multiyear.pkl"
    
    if Path(preprocessed_file).exists():
        print(f"\nFound {preprocessed_file}")
        user_input = input("Load preprocessed data? (y/n): ")
        
        if user_input.lower() == 'y':
            data = load_preprocessed_data(preprocessed_file)
            X_train, y_train = data['X_train'], data['y_train']
            X_val, y_val = data['X_val'], data['y_val']
            scalers_X, scalers_y = data['scalers_X'], data['scalers_y']
            print(f"Loaded: Train={X_train.shape}, Val={X_val.shape}")
            ds = load_multiyear_dataset(DATA_YEARS, DATA_DIR)
    
            skip_preprocessing = True
        else:
            skip_preprocessing = False
    else:
        skip_preprocessing = False
    
    if not skip_preprocessing:
        ds = load_multiyear_dataset(DATA_YEARS, DATA_DIR)
        ds_train, ds_val = split_by_year(ds, TRAIN_YEARS, [VAL_YEAR])
        ocean_mask = create_ocean_mask(ds)
        
        print("\nEngineering features...")
        X_train, y_train = engineer_frames(ds_train, ocean_mask)
        X_val, y_val = engineer_frames(ds_val, ocean_mask)
        
        print(f"   Train: X={X_train.shape}, y={y_train.shape}")
        print(f"   Val:   X={X_val.shape}, y={y_val.shape}")
        
        scalers_X, scalers_y = fit_scalers(X_train, y_train)
        X_train, y_train = apply_scalers(X_train, y_train, scalers_X, scalers_y)
        X_val, y_val = apply_scalers(X_val, y_val, scalers_X, scalers_y)

        X_train, y_train = create_sequences(X_train, y_train, LOOKBACK_HOURS, LOOKAHEAD_HOURS)
        X_val, y_val = create_sequences(X_val, y_val, LOOKBACK_HOURS, LOOKAHEAD_HOURS)
        save_preprocessed_data(X_train, y_train, X_val, y_val, scalers_X, scalers_y, preprocessed_file)
    
    model = build_enhanced_model(X_train.shape[1:], y_train.shape[-1])
    model.summary()
    callbacks = [
        EarlyStopping(
            monitor='val_loss',
            patience=3,
            restore_best_weights=True,
            verbose=1
        ),
        ReduceLROnPlateau(
            monitor='val_loss',
            factor=0.1,
            patience=2,
            min_lr=1e-7,
            verbose=1
        )
    ]
    
    print(f"\n Training on {len(X_train)} samples...")
    print(f"   Validation on {len(X_val)} samples from {VAL_YEAR}")
    
    history = model.fit(
        X_train, y_train,
        validation_data=(X_val, y_val),
        epochs=EPOCHS,
        batch_size=BATCH_SIZE,
        callbacks=callbacks,
        verbose=1
    )
    
    evaluate_model(model, X_val, y_val, scalers_y, TARGET_FEATURES)
    buoy_lat, buoy_lon = find_nearest_ocean_point(ds, SPOT_LAT, SPOT_LON)
    last_seq = X_val[-1:] 
    forecast_scaled = model.predict(last_seq, verbose=0)[0]
    forecast = np.zeros_like(forecast_scaled)
    for i in range(forecast_scaled.shape[-1]):
        forecast[..., i] = scalers_y[i].inverse_transform(
            forecast_scaled[..., i].reshape(-1, 1)
        ).reshape(forecast_scaled[..., i].shape)
    lats = ds.latitude.values
    lons = ds.longitude.values
    lat_idx = np.argmin(np.abs(lats - buoy_lat))
    lon_idx = np.argmin(np.abs(lons - buoy_lon))
    
    buoy_forecast = forecast[lat_idx, lon_idx, :]
    
    print("\n" + "="*70)
    print(f"FORECAST at ({buoy_lat:.2f}Â°N, {buoy_lon:.2f}Â°E)")
    print("="*70)
    print(f"Swell Height:    {buoy_forecast[0]:.2f} m")
    print(f"Swell Period:    {buoy_forecast[1]:.2f} s")
    print(f"Swell Direction: {buoy_forecast[2]:.1f}Â°")
    print("="*70)

    print("\n Saving model and metadata...")
    model.save("convlstm_final.keras")

    metadata = {
        "scalers_X": scalers_X,
        "scalers_y": scalers_y,
        "input_features": INPUT_FEATURES,
        "target_features": TARGET_FEATURES,
        "lookback_hours": LOOKBACK_HOURS,
        "lookahead_hours": LOOKAHEAD_HOURS,
        "latitude": ds.latitude.values,
        "longitude": ds.longitude.values,
        "training_years": TRAIN_YEARS,
        "validation_year": VAL_YEAR
    }
    with open("model_metadata.pkl", "wb") as f:
        pickle.dump(metadata, f)

    print(" Saved: convlstm_final.keras + model_metadata.pkl")


if __name__ == "__main__":
    main()

MULTI-YEAR CONVLSTM WAVE FORECASTING

Found preprocessed_multiyear.pkl


Load preprocessed data? (y/n):  y



 Loading preprocessed data from preprocessed_multiyear.pkl...
 Loaded preprocessed data
Loaded: Train=(5825, 19, 21, 21, 6), Val=(1445, 19, 21, 21, 6)
Loaded 2020: 1464 timesteps, shape=(1464, 21, 21)
Loaded 2021: 1460 timesteps, shape=(1460, 21, 21)
Loaded 2022: 1460 timesteps, shape=(1460, 21, 21)
Loaded 2023: 1460 timesteps, shape=(1460, 21, 21)
Loaded 2024: 1464 timesteps, shape=(1464, 21, 21)

 Concatenating datasets...

 Combined dataset:
   Total timesteps: 7308
   Time range: 2020-01-01T00:00:00.000000000 to 2024-12-31T18:00:00.000000000
   Spatial shape: 21 Ã— 21

 Building enhanced model...
Model: "model"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 input_1 (InputLayer)        [(None, 19, 21, 21, 6)]   0         
                                                                 
 conv_lstm2d (ConvLSTM2D)    (None, 19, 21, 21, 128)   617984    
                                              