In [1]:
import xarray as xr
import numpy as np
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import mean_absolute_error, mean_squared_error, r2_score
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import ConvLSTM2D, Flatten, Dense, BatchNormalization, Dropout, PReLU
from tensorflow.keras.optimizers import AdamW
import tensorflow_addons as tfa


TensorFlow Addons (TFA) has ended development and introduction of new features.
TFA has entered a minimal maintenance and release mode until a planned end of life in May 2024.
Please modify downstream libraries to take dependencies from other repositories in our TensorFlow community (e.g. Keras, Keras-CV, and Keras-NLP). 

For more information see: https://github.com/tensorflow/addons/issues/2807 



In [2]:
# CONFIGURATION
DATA_PATH = "surf_data_2020.nc"
SPOT_LAT, SPOT_LON = 6.8399, 81.8396  # Arugam Bay
LOOKBACK_HOURS = 16
LOOKAHEAD_HOURS = 1
VALIDATION_SPLIT = 0.2
EPOCHS = 20
BATCH_SIZE = 32

# Feature and target configuration
INPUT_FEATURES = ["u10", "v10", "msl", "tp", "shts", "mpts", "mdts"]
TARGET_VARS = ["shts", "mpts", "mdts_sin", "mdts_cos", "wind_speed"]

In [3]:
# DATA LOADING & BUOY SELECTION
def find_nearest_ocean_point(ds, lat, lon):
    """Find nearest valid ocean point using Haversine distance."""
    nearest = ds.sel(latitude=lat, longitude=lon, method="nearest")
    
    if not np.isnan(nearest["shts"].isel(valid_time=0)):
        print("Found valid offshore point")
        return float(nearest.latitude), float(nearest.longitude)
    
    print("Nearest point on land. Searching for closest ocean point...")
    shts_data = ds["shts"].isel(valid_time=0).stack(
        point=("latitude", "longitude")
    ).dropna("point")
    
    # Haversine distance
    R = 6371  # Earth radius (km)
    lat1, lon1 = np.radians(lat), np.radians(lon)
    lat2 = np.radians(shts_data.latitude)
    lon2 = np.radians(shts_data.longitude)
    
    dlat = (lat2 - lat1) / 2
    dlon = (lon2 - lon1) / 2
    
    a = np.sin(dlat)**2 + np.cos(lat1) * np.cos(lat2) * np.sin(dlon)**2
    distance = 2 * R * np.arcsin(np.sqrt(a))
    
    closest = shts_data.isel(point=distance.argmin())
    return float(closest.latitude), float(closest.longitude)

In [4]:
# FEATURE ENGINEERING
def engineer_features(ds, buoy_lat, buoy_lon):
    """Extract and engineer features for model input."""
    # Spatial features (full grid)
    X = ds[INPUT_FEATURES].to_array(dim="channel").transpose(
        "valid_time", "latitude", "longitude", "channel"
    ).values
    
    # Target point features
    buoy_data = ds.sel(latitude=buoy_lat, longitude=buoy_lon)
    
    # Wind speed from components
    wind_speed = np.sqrt(buoy_data["u10"]**2 + buoy_data["v10"]**2).values
    
    # Direction encoding (sine/cosine)
    mdts_rad = np.deg2rad(buoy_data["mdts"].values)
    mdts_sin = np.sin(mdts_rad)
    mdts_cos = np.cos(mdts_rad)
    
    # Target array
    y = np.column_stack([
        buoy_data["shts"].values,
        buoy_data["mpts"].values,
        mdts_sin,
        mdts_cos,
        wind_speed
    ])
    
    return X, y

In [5]:
# PREPROCESSING
def preprocess_data(X, y):
    """Clean, normalize, and scale data."""
    # Remove NaN/Inf
    X = np.nan_to_num(X, nan=0.0, posinf=0.0, neginf=0.0)
    y = np.nan_to_num(y, nan=0.0, posinf=0.0, neginf=0.0)
    
    # Normalize X per channel
    X_scaled = np.zeros_like(X)
    scalers_X = []
    
    for i in range(X.shape[-1]):
        scaler = StandardScaler()
        channel = X[..., i].reshape(-1, 1)
        X_scaled[..., i] = scaler.fit_transform(channel).reshape(X[..., i].shape)
        scalers_X.append(scaler)
    
    # Normalize y
    scaler_y = StandardScaler()
    y_scaled = scaler_y.fit_transform(y)
    
    return X_scaled, y_scaled, scalers_X, scaler_y

def create_sequences(X, y, lookback, lookahead):
    """Generate sliding window sequences."""
    X_seq, y_seq = [], []
    
    for i in range(len(X) - lookback - lookahead + 1):
        X_seq.append(X[i:i + lookback])
        y_seq.append(y[i + lookback + lookahead - 1])
    
    return np.array(X_seq), np.array(y_seq)

In [6]:
# MODEL BUILDING
def build_model(input_shape, output_dim):
    """Build ConvLSTM2D model for spatiotemporal forecasting."""
    model = Sequential([
        ConvLSTM2D(
            filters=32,
            kernel_size=(3, 3),
            padding='same',
            return_sequences=False,
            input_shape=input_shape
        ),
        BatchNormalization(),
        Flatten(),
        Dense(64),
        PReLU(),
        Dropout(0.2),
        Dense(output_dim, activation='linear')
    ])
    
    model.compile(
        optimizer=tfa.optimizers.AdamW(learning_rate=1e-3, weight_decay=1e-5),
        loss='mse',
        metrics=['mae']
    )
    
    return model

In [7]:
# EVALUATION
def evaluate_model(model, X_val, y_val, scaler_y, target_names):
    """Comprehensive model evaluation."""
    y_pred_scaled = model.predict(X_val, verbose=0)
    
    # Inverse transform
    y_true = scaler_y.inverse_transform(y_val)
    y_pred = scaler_y.inverse_transform(y_pred_scaled)
    
    # Overall metrics
    mae_overall = mean_absolute_error(y_true, y_pred)
    rmse_overall = np.sqrt(mean_squared_error(y_true, y_pred))
    r2_overall = r2_score(y_true, y_pred)
    
    print("\n" + "="*60)
    print("VALIDATION METRICS")
    print("="*60)
    print(f"Overall | MAE={mae_overall:.4f}, RMSE={rmse_overall:.4f}, R²={r2_overall:.4f}")
    print("-"*60)
    
    # Per-variable metrics
    for i, var in enumerate(target_names):
        mae_i = mean_absolute_error(y_true[:, i], y_pred[:, i])
        rmse_i = np.sqrt(mean_squared_error(y_true[:, i], y_pred[:, i]))
        r2_i = r2_score(y_true[:, i], y_pred[:, i])
        print(f"{var:>12s} | MAE={mae_i:.3f}, RMSE={rmse_i:.3f}, R²={r2_i:.3f}")
    
    print("="*60)
    return y_true, y_pred


In [8]:
# MAIN PIPELINE
def main():
    print("Loading dataset...")
    ds = xr.open_dataset(DATA_PATH)
    
    # Find valid buoy location
    buoy_lat, buoy_lon = find_nearest_ocean_point(ds, SPOT_LAT, SPOT_LON)
    print(f"Virtual buoy: ({buoy_lat:.2f}°N, {buoy_lon:.2f}°E)")
    
    # Feature engineering
    print("\nEngineering features...")
    X, y = engineer_features(ds, buoy_lat, buoy_lon)
    print(f"Spatial input shape: {X.shape}")
    print(f"Target shape: {y.shape}")
    
    # Preprocessing
    print("Preprocessing data...")
    X_scaled, y_scaled, scalers_X, scaler_y = preprocess_data(X, y)
    
    # Create sequences
    print(f"Creating sequences (lookback={LOOKBACK_HOURS}, lookahead={LOOKAHEAD_HOURS})...")
    X_seq, y_seq = create_sequences(X_scaled, y_scaled, LOOKBACK_HOURS, LOOKAHEAD_HOURS)
    print(f"Sequence shape: X={X_seq.shape}, y={y_seq.shape}")
    
    # Build model
    print("\nBuilding model...")
    model = build_model(X_seq.shape[1:], y_seq.shape[1])
    model.summary()
    
    # Train
    print(f"\nTraining for {EPOCHS} epochs...")
    history = model.fit(
        X_seq, y_seq,
        epochs=EPOCHS,
        batch_size=BATCH_SIZE,
        validation_split=VALIDATION_SPLIT,
        verbose=1
    )
    
    # Evaluate
    val_size = int(VALIDATION_SPLIT * len(X_seq))
    evaluate_model(
        model, 
        X_seq[-val_size:], 
        y_seq[-val_size:], 
        scaler_y,
        TARGET_VARS
    )
    
    # Forecast next time step
    print("\n" + "="*60)
    print("FORECAST (Next 6 Hours)")
    print("="*60)
    last_seq = np.expand_dims(X_scaled[-LOOKBACK_HOURS:], axis=0)
    forecast_scaled = model.predict(last_seq, verbose=0)
    forecast = scaler_y.inverse_transform(forecast_scaled)[0]
    
    # Reconstruct direction from sin/cos
    direction_deg = np.rad2deg(np.arctan2(forecast[2], forecast[3])) % 360
    
    print(f"Swell Height:    {forecast[0]:.2f} m")
    print(f"Swell Period:    {forecast[1]:.2f} s")
    print(f"Swell Direction: {direction_deg:.1f}°")
    print(f"Wind Speed:      {forecast[4]:.2f} m/s ({forecast[4]*3.6:.1f} km/h)")
    print("="*60)
    
    return model, history, scaler_y

if __name__ == "__main__":
    model, history, scaler = main()

Loading dataset...
Nearest point on land. Searching for closest ocean point...
Virtual buoy: (7.00°N, 82.00°E)

Engineering features...
Spatial input shape: (1464, 21, 21, 7)
Target shape: (1464, 5)
Preprocessing data...
Creating sequences (lookback=16, lookahead=1)...
Sequence shape: X=(1448, 16, 21, 21, 7), y=(1448, 5)

Building model...
Model: "sequential"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 conv_lstm2d (ConvLSTM2D)    (None, 21, 21, 32)        45056     
                                                                 
 batch_normalization (Batch  (None, 21, 21, 32)        128       
 Normalization)                                                  
                                                                 
 flatten (Flatten)           (None, 14112)             0         
                                                                 
 dense (Dense)               (None, 64)     