In [None]:
# Multiyear ConvLSTM Wave Forecasting
# Jupyter-notebook friendly (cell-separated) refactor
# - Minimal, readable, modular
# - Keeps full functionality from original script
# - Designed to run in notebook cells (or as a linear script)

# %%
"""
CONFIG / IMPORTS
"""
import xarray as xr
import numpy as np
from pathlib import Path
import pickle
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import mean_absolute_error, mean_squared_error, r2_score
import tensorflow as tf
from tensorflow.keras import Model
from tensorflow.keras.layers import Input, ConvLSTM2D, Conv3D, BatchNormalization, Dropout
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.callbacks import EarlyStopping, ReduceLROnPlateau

# config
DATA_YEARS = [2020, 2021, 2022, 2023, 2024]
DATA_DIR = "../"  # where surf_data_YYYY.nc live
TRAIN_YEARS = [2020, 2021, 2022, 2023]
VAL_YEAR = 2024
LOOKBACK_HOURS = 19
LOOKAHEAD_HOURS = 1
BATCH_SIZE = 16
EPOCHS = 10
LEARNING_RATE = 1e-3
INPUT_FEATURES = ["u10", "v10", "msl", "shts", "mpts", "mdts"]
TARGET_FEATURES = ["shts", "mpts", "mdts"]
PREPROCESSED_FILE = "preprocessed_multiyear.pkl"
MODEL_SAVE_PATH = "best_multiyear_model.keras"
SPOT_LAT, SPOT_LON = 6.8399, 81.8396

# Toggle to load preprocessed data (useful in notebooks)
LOAD_PREPROCESSED = True

In [None]:
"""
UTILITY: Data loading and splitting
"""
def load_multiyear_dataset(years, data_dir="."):
    files = [Path(data_dir) / f"surf_data_{y}.nc" for y in years]
    datasets = []
    for f in files:
        if not f.exists():
            print(f"⚠️  {f} not found, skipping")
            continue
        ds = xr.open_dataset(f)
        missing = set(INPUT_FEATURES + TARGET_FEATURES) - set(ds.data_vars)
        if missing:
            print(f"⚠️  {f.name}: missing {missing}, skipping")
            continue
        ds.attrs["year"] = int(f.name.split("_")[-1].split(".")[0])
        datasets.append(ds)
        print(f"Loaded {f.name}: {len(ds.valid_time)} timesteps")
    if not datasets:
        raise FileNotFoundError("No valid datasets found for requested years")

    combined = xr.concat(datasets, dim="valid_time").sortby("valid_time")
    return combined


def split_by_year(ds, train_years, val_year):
    train_mask = ds.valid_time.dt.year.isin(train_years)
    val_mask = ds.valid_time.dt.year == val_year
    return ds.isel(valid_time=train_mask), ds.isel(valid_time=val_mask)

In [None]:
"""
PREPROCESSING helpers
"""

def create_ocean_mask(ds, threshold=0.5, ref_var="shts"):
    arr = ds[ref_var].values
    nan_ratio = np.isnan(arr).sum(axis=0) / arr.shape[0]
    mask = nan_ratio < threshold
    print(f"Ocean points: {mask.sum()}/{mask.size} ({100*mask.sum()/mask.size:.1f}%)")
    return mask


def engineer_frames(ds, ocean_mask, input_feats=INPUT_FEATURES, target_feats=TARGET_FEATURES):
    # shape: (time, lat, lon, channel)
    X = ds[input_feats].to_array(dim="channel").transpose("valid_time", "latitude", "longitude", "channel").values
    y = ds[target_feats].to_array(dim="channel").transpose("valid_time", "latitude", "longitude", "channel").values

    # broadcast mask and zero-out land
    mask_3d = np.broadcast_to(ocean_mask[..., np.newaxis], X.shape[1:])
    X[:, ~mask_3d] = 0.0
    y[:, ~mask_3d] = 0.0

    # fill NaNs with channel means
    for i in range(X.shape[-1]):
        ch = X[..., i]
        mean = np.nanmean(ch)
        X[..., i] = np.nan_to_num(ch, nan=mean)
    for i in range(y.shape[-1]):
        ch = y[..., i]
        mean = np.nanmean(ch)
        y[..., i] = np.nan_to_num(ch, nan=mean)

    return X.astype(np.float32), y.astype(np.float32)


def fit_scalers(X_train, y_train):
    scalers_X, scalers_y = [], []
    # Flatten each channel across time+space
    for i in range(X_train.shape[-1]):
        s = StandardScaler()
        s.fit(X_train[..., i].reshape(-1, 1))
        scalers_X.append(s)
    for i in range(y_train.shape[-1]):
        s = StandardScaler()
        s.fit(y_train[..., i].reshape(-1, 1))
        scalers_y.append(s)
    return scalers_X, scalers_y


def apply_scalers(X, y, scalers_X, scalers_y):
    Xs = np.zeros_like(X)
    ys = np.zeros_like(y)
    for i, s in enumerate(scalers_X):
        Xs[..., i] = s.transform(X[..., i].reshape(-1, 1)).reshape(X[..., i].shape)
    for i, s in enumerate(scalers_y):
        ys[..., i] = s.transform(y[..., i].reshape(-1, 1)).reshape(y[..., i].shape)
    return Xs, ys


def create_sequences(X, y, lookback, lookahead):
    n = len(X) - lookback - lookahead + 1
    X_seq = np.zeros((n, lookback, *X.shape[1:]), dtype=np.float32)
    y_seq = np.zeros((n, *y.shape[1:]), dtype=np.float32)
    for i in range(n):
        X_seq[i] = X[i : i + lookback]
        y_seq[i] = y[i + lookback + lookahead - 1]
    return X_seq, y_seq

In [None]:
"""
SAVE / LOAD preprocessed
"""
def save_preprocessed(path, **kwargs):
    with open(path, "wb") as f:
        pickle.dump(kwargs, f, protocol=4)
    print(f"Saved preprocessed -> {path}")


def load_preprocessed(path):
    with open(path, "rb") as f:
        return pickle.load(f)

In [None]:
# %%
"""
MODEL definition
"""
def build_enhanced_model(input_shape, output_channels, lr=LEARNING_RATE):
    i = Input(shape=input_shape)
    x = ConvLSTM2D(128, (3, 3), padding="same", return_sequences=True, kernel_regularizer=tf.keras.regularizers.l2(1e-5))(i)
    x = BatchNormalization()(x)
    x = Dropout(0.2)(x)
    x = ConvLSTM2D(64, (3, 3), padding="same", return_sequences=True, kernel_regularizer=tf.keras.regularizers.l2(1e-5))(x)
    x = BatchNormalization()(x)
    x = Dropout(0.2)(x)
    x = ConvLSTM2D(32, (3, 3), padding="same", return_sequences=True)(x)
    x = BatchNormalization()(x)
    x = Conv3D(output_channels, kernel_size=(3, 3, 3), padding="same", activation="linear")(x)
    out = x[:, -1, :, :, :]
    m = Model(i, out)
    m.compile(optimizer=Adam(learning_rate=lr), loss="mse", metrics=["mae"]) 
    return m

In [None]:
"""
EVALUATION helpers
"""
def evaluate_model(model, X_val, y_val, scalers_y, target_names=TARGET_FEATURES):
    yp_scaled = model.predict(X_val, batch_size=32, verbose=1)
    y_true = np.zeros_like(y_val)
    y_pred = np.zeros_like(yp_scaled)
    for i, s in enumerate(scalers_y):
        y_true[..., i] = s.inverse_transform(y_val[..., i].reshape(-1, 1)).reshape(y_val[..., i].shape)
        y_pred[..., i] = s.inverse_transform(yp_scaled[..., i].reshape(-1, 1)).reshape(yp_scaled[..., i].shape)

    mae = mean_absolute_error(y_true.flatten(), y_pred.flatten())
    rmse = np.sqrt(mean_squared_error(y_true.flatten(), y_pred.flatten()))
    print(f"Overall MAE={mae:.4f}, RMSE={rmse:.4f}")

    for i, name in enumerate(target_names):
        mae_v = mean_absolute_error(y_true[..., i].flatten(), y_pred[..., i].flatten())
        rmse_v = np.sqrt(mean_squared_error(y_true[..., i].flatten(), y_pred[..., i].flatten()))
        r2_v = r2_score(y_true[..., i].flatten(), y_pred[..., i].flatten())
        print(f"{name:>6s}: MAE={mae_v:.3f}, RMSE={rmse_v:.3f}, R2={r2_v:.3f}")
    return y_true, y_pred

In [None]:
# %%
"""
NEAREST OCEAN POINT
"""
def find_nearest_ocean_point(ds, lat, lon, ref_var="shts"):
    pt = ds.sel(latitude=lat, longitude=lon, method="nearest")
    if not np.isnan(pt[ref_var].isel(valid_time=0)):
        return float(pt.latitude), float(pt.longitude)
    shts = ds[ref_var].isel(valid_time=0).stack(point=("latitude", "longitude")).dropna("point")
    lat1, lon1 = np.radians(lat), np.radians(lon)
    lat2 = np.radians(shts.latitude)
    lon2 = np.radians(shts.longitude)
    dlat = (lat2 - lat1) / 2
    dlon = (lon2 - lon1) / 2
    a = np.sin(dlat) ** 2 + np.cos(lat1) * np.cos(lat2) * np.sin(dlon) ** 2
    R = 6371
    distance = 2 * R * np.arcsin(np.sqrt(a))
    idx = distance.argmin()
    closest = shts.isel(point=idx)
    return float(closest.latitude), float(closest.longitude)


In [None]:
"""
MAIN pipeline (call cells step-by-step in notebook)

Recommended notebook workflow:
 1) Run data loading cell
 2) Run preprocessing cell
 3) Run model build + train cell
 4) Run evaluation + forecast cell

This keeps each step inspectable in the notebook.
"""

# Cell: LOAD data
if __name__ == "__main__":
    # This block is safe to run as a linear script. In a notebook, run the cells individually.
    # -- load or preprocess
    if LOAD_PREPROCESSED and Path(PREPROCESSED_FILE).exists():
        data = load_preprocessed(PREPROCESSED_FILE)
        X_train, y_train = data["X_train"], data["y_train"]
        X_val, y_val = data["X_val"], data["y_val"]
        scalers_X, scalers_y = data["scalers_X"], data["scalers_y"]
        ds = load_multiyear_dataset(DATA_YEARS, DATA_DIR)
        print("Loaded preprocessed data")
    else:
        ds = load_multiyear_dataset(DATA_YEARS, DATA_DIR)
        ds_train, ds_val = split_by_year(ds, TRAIN_YEARS, VAL_YEAR)
        ocean_mask = create_ocean_mask(ds)
        X_train, y_train = engineer_frames(ds_train, ocean_mask)
        X_val, y_val = engineer_frames(ds_val, ocean_mask)
        scalers_X, scalers_y = fit_scalers(X_train, y_train)
        X_train, y_train = apply_scalers(X_train, y_train, scalers_X, scalers_y)
        X_val, y_val = apply_scalers(X_val, y_val, scalers_X, scalers_y)
        X_train, y_train = create_sequences(X_train, y_train, LOOKBACK_HOURS, LOOKAHEAD_HOURS)
        X_val, y_val = create_sequences(X_val, y_val, LOOKBACK_HOURS, LOOKAHEAD_HOURS)
        save_preprocessed(PREPROCESSED_FILE, X_train=X_train, y_train=y_train, X_val=X_val, y_val=y_val, scalers_X=scalers_X, scalers_y=scalers_y)

    print(f"Train shapes: {X_train.shape}, {y_train.shape} | Val shapes: {X_val.shape}, {y_val.shape}")

    # Build
    model = build_enhanced_model(input_shape=X_train.shape[1:], output_channels=y_train.shape[-1])
    model.summary()

    # Callbacks
    callbacks = [
        EarlyStopping(monitor="val_loss", patience=3, restore_best_weights=True, verbose=1),
        ReduceLROnPlateau(monitor="val_loss", factor=0.1, patience=2, min_lr=1e-7, verbose=1)
    ]

    # Train
    history = model.fit(X_train, y_train, validation_data=(X_val, y_val), epochs=EPOCHS, batch_size=BATCH_SIZE, callbacks=callbacks, verbose=1)
    model.save(MODEL_SAVE_PATH)

    # Evaluate
    y_true, y_pred = evaluate_model(model, X_val, y_val, scalers_y)

    # Forecast example
    buoy_lat, buoy_lon = find_nearest_ocean_point(ds, SPOT_LAT, SPOT_LON)
    last_seq = X_val[-1:]
    fc_scaled = model.predict(last_seq, verbose=0)[0]
    fc = np.zeros_like(fc_scaled)
    for i, s in enumerate(scalers_y):
        fc[..., i] = s.inverse_transform(fc_scaled[..., i].reshape(-1, 1)).reshape(fc_scaled[..., i].shape)
    lats = ds.latitude.values
    lons = ds.longitude.values
    lat_idx = np.argmin(np.abs(lats - buoy_lat))
    lon_idx = np.argmin(np.abs(lons - buoy_lon))
    bf = fc[lat_idx, lon_idx, :]
    print("Forecast (example):")
    print(f"  Swell Height: {bf[0]:.2f} m | Period: {bf[1]:.2f} s | Direction: {bf[2]:.1f}°")

# End of notebook-like script
