In [1]:
import xarray as xr
import numpy as np
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import mean_absolute_error, mean_squared_error, r2_score
from tensorflow.keras.models import Sequential, Model
from tensorflow.keras.layers import (
    ConvLSTM2D, Conv3D, BatchNormalization, Input, 
    TimeDistributed, Conv2D, Dropout
)
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.callbacks import EarlyStopping, ReduceLROnPlateau, ModelCheckpoint
import tensorflow as tf

# CONFIGURATION
DATA_PATH = "../surf_data_2024.nc"
SPOT_LAT, SPOT_LON = 6.8399, 81.8396
LOOKBACK_HOURS = 19  # Match reference paper's temporal window
LOOKAHEAD_HOURS = 1   # Predict next frame
VALIDATION_SPLIT = 0.2
EPOCHS = 10
BATCH_SIZE = 8  # Smaller due to memory constraints
SPATIAL_SIZE = 21  # Your data size (can't increase without more data)

INPUT_FEATURES = ["u10", "v10", "msl", "shts", "mpts", "mdts"]
TARGET_FEATURES = ["shts", "mpts", "mdts"]  # Predict wave fields

In [2]:
# DATA PREPROCESSING (From previous artifact)
# ======================================================================
def create_ocean_mask(ds, threshold=0.5):
    """Create mask to exclude land points."""
    shts_data = ds["shts"].values
    nan_ratio = np.isnan(shts_data).sum(axis=0) / shts_data.shape[0]
    ocean_mask = nan_ratio < threshold
    
    valid_points = ocean_mask.sum()
    print(f"üåä Ocean Mask: {valid_points} / {ocean_mask.size} points ({100*valid_points/ocean_mask.size:.1f}%)")
    return ocean_mask


def find_nearest_ocean_point(ds, lat, lon):
    """Find nearest valid ocean point."""
    nearest = ds.sel(latitude=lat, longitude=lon, method="nearest")
    
    if not np.isnan(nearest["shts"].isel(valid_time=0)):
        return float(nearest.latitude), float(nearest.longitude)
    
    shts_data = ds["shts"].isel(valid_time=0).stack(
        point=("latitude", "longitude")
    ).dropna("point")
    
    R = 6371
    lat1, lon1 = np.radians(lat), np.radians(lon)
    lat2 = np.radians(shts_data.latitude)
    lon2 = np.radians(shts_data.longitude)
    
    dlat = (lat2 - lat1) / 2
    dlon = (lon2 - lon1) / 2
    
    a = np.sin(dlat)**2 + np.cos(lat1) * np.cos(lat2) * np.sin(dlon)**2
    distance = 2 * R * np.arcsin(np.sqrt(a))
    
    closest = shts_data.isel(point=distance.argmin())
    return float(closest.latitude), float(closest.longitude)

In [3]:
# FRAME-BASED FEATURE ENGINEERING
# ======================================================================
def engineer_frames(ds, ocean_mask):
    """
    Prepare spatiotemporal frames for prediction.
    Unlike point prediction, this preserves spatial structure.
    """
    # Input frames: All 6 features across space
    X = ds[INPUT_FEATURES].to_array(dim="channel").transpose(
        "valid_time", "latitude", "longitude", "channel"
    ).values
    
    # Output frames: Only wave parameters (3 features)
    y = ds[TARGET_FEATURES].to_array(dim="channel").transpose(
        "valid_time", "latitude", "longitude", "channel"
    ).values
    
    # Apply ocean mask to both
    mask_3d_X = np.broadcast_to(ocean_mask[..., np.newaxis], X.shape[1:])
    mask_3d_y = np.broadcast_to(ocean_mask[..., np.newaxis], y.shape[1:])
    
    X[:, ~mask_3d_X] = 0.0
    y[:, ~mask_3d_y] = 0.0
    
    # Fill remaining NaNs with channel means
    for i in range(X.shape[-1]):
        channel_mean = np.nanmean(X[..., i])
        X[..., i] = np.nan_to_num(X[..., i], nan=channel_mean)
    
    for i in range(y.shape[-1]):
        channel_mean = np.nanmean(y[..., i])
        y[..., i] = np.nan_to_num(y[..., i], nan=channel_mean)
    
    print(f"üìä Frames prepared:")
    print(f"   Input: {X.shape} (time, lat, lon, channels)")
    print(f"   Output: {y.shape}")
    
    return X, y


def normalize_frames(X, y):
    """Normalize each channel independently."""
    X_scaled = np.zeros_like(X)
    y_scaled = np.zeros_like(y)
    
    scalers_X = []
    scalers_y = []
    
    # Normalize input channels
    for i in range(X.shape[-1]):
        scaler = StandardScaler()
        channel = X[..., i].reshape(-1, 1)
        X_scaled[..., i] = scaler.fit_transform(channel).reshape(X[..., i].shape)
        scalers_X.append(scaler)
    
    # Normalize output channels
    for i in range(y.shape[-1]):
        scaler = StandardScaler()
        channel = y[..., i].reshape(-1, 1)
        y_scaled[..., i] = scaler.fit_transform(channel).reshape(y[..., i].shape)
        scalers_y.append(scaler)
    
    return X_scaled, y_scaled, scalers_X, scalers_y


def create_frame_sequences(X, y, lookback, lookahead):
    """
    Create sequences where:
    - Input: [lookback] frames of multi-channel data
    - Output: [lookahead] frame(s) of wave parameters
    """
    X_seq, y_seq = [], []
    
    for i in range(len(X) - lookback - lookahead + 1):
        X_seq.append(X[i:i + lookback])
        y_seq.append(y[i + lookback + lookahead - 1])  # Single frame prediction
    
    return np.array(X_seq), np.array(y_seq)

In [4]:
# IMPROVED ARCHITECTURES
# ======================================================================

def build_model_v3_encoder_decoder(input_shape, output_channels):
    """
    Architecture V3: Encoder-Decoder ConvLSTM
    Based on spatiotemporal forecasting best practices.
    
    Similar to the reference paper but adapted for wave forecasting.
    """
    inputs = Input(shape=input_shape)
    
    # Encoder: Extract spatiotemporal features
    x = ConvLSTM2D(
        filters=64,
        kernel_size=(3, 3),
        padding='same',
        return_sequences=True,
        name='encoder_1'
    )(inputs)
    x = BatchNormalization()(x)
    
    x = ConvLSTM2D(
        filters=64,
        kernel_size=(3, 3),
        padding='same',
        return_sequences=True,
        name='encoder_2'
    )(x)
    x = BatchNormalization()(x)
    
    x = ConvLSTM2D(
        filters=32,
        kernel_size=(3, 3),
        padding='same',
        return_sequences=True,
        name='encoder_3'
    )(x)
    x = BatchNormalization()(x)
    
    # Conv3D for temporal aggregation (as in reference paper)
    x = Conv3D(
        filters=output_channels,
        kernel_size=(3, 3, 3),
        padding='same',
        activation='linear',
        name='output_conv3d'
    )(x)
    
    # Take the last timestep as prediction
    outputs = x[:, -1, :, :, :]
    
    model = Model(inputs=inputs, outputs=outputs, name='ConvLSTM_EncoderDecoder')
    
    model.compile(
        optimizer=Adam(learning_rate=1e-3),
        loss='mse',
        metrics=['mae']
    )
    
    return model


def build_model_v4_deep_stack(input_shape, output_channels):
    """
    Architecture V4: Deep Stacked ConvLSTM
    More layers, better feature extraction.
    """
    model = Sequential([
        ConvLSTM2D(
            128, (3, 3),
            padding='same',
            return_sequences=True,
            input_shape=input_shape,
            name='convlstm_1'
        ),
        BatchNormalization(),
        Dropout(0.2),
        
        ConvLSTM2D(
            64, (3, 3),
            padding='same',
            return_sequences=True,
            name='convlstm_2'
        ),
        BatchNormalization(),
        Dropout(0.2),
        
        ConvLSTM2D(
            32, (3, 3),
            padding='same',
            return_sequences=True,
            name='convlstm_3'
        ),
        BatchNormalization(),
        
        # Final Conv3D for prediction
        Conv3D(
            output_channels,
            kernel_size=(3, 3, 3),
            padding='same',
            activation='linear',
            name='prediction_layer'
        )
    ], name='DeepStack_ConvLSTM')
    
    # Extract last frame
    model.add(tf.keras.layers.Lambda(lambda x: x[:, -1, :, :, :]))
    
    model.compile(
        optimizer=Adam(learning_rate=1e-3),
        loss='mse',
        metrics=['mae']
    )
    
    return model


def build_model_v5_hybrid(input_shape, output_channels, buoy_location):
    """
    Architecture V5: Hybrid Spatial + Point Prediction
    Combines frame prediction with point-specific forecasting.
    """
    inputs = Input(shape=input_shape)
    
    # Shared spatiotemporal encoder
    x = ConvLSTM2D(64, (3, 3), padding='same', return_sequences=True)(inputs)
    x = BatchNormalization()(x)
    
    x = ConvLSTM2D(32, (3, 3), padding='same', return_sequences=True)(x)
    x = BatchNormalization()(x)
    
    # Branch 1: Full frame prediction
    frame_out = Conv3D(
        output_channels,
        kernel_size=(3, 3, 3),
        padding='same',
        activation='linear'
    )(x)
    frame_out = frame_out[:, -1, :, :, :]  # Last timestep
    
    model = Model(inputs=inputs, outputs=frame_out, name='Hybrid_ConvLSTM')
    
    model.compile(
        optimizer=Adam(learning_rate=1e-3),
        loss='mse',
        metrics=['mae']
    )
    
    return model

In [5]:
# EVALUATION FOR FRAME PREDICTIONS
# ======================================================================
def evaluate_frame_model(model, X_val, y_val, scalers_y, buoy_lat, buoy_lon, ds):
    """
    Evaluate model on:
    1. Overall spatial prediction
    2. Point prediction at buoy location
    """
    y_pred_scaled = model.predict(X_val, verbose=0)
    
    # Inverse transform
    y_true = np.zeros_like(y_val)
    y_pred = np.zeros_like(y_pred_scaled)
    
    for i in range(y_val.shape[-1]):
        scaler = scalers_y[i]
        y_true[..., i] = scaler.inverse_transform(
            y_val[..., i].reshape(-1, 1)
        ).reshape(y_val[..., i].shape)
        y_pred[..., i] = scaler.inverse_transform(
            y_pred_scaled[..., i].reshape(-1, 1)
        ).reshape(y_pred_scaled[..., i].shape)
    
    print("\n" + "="*70)
    print("SPATIAL PREDICTION METRICS")
    print("="*70)
    
    # Overall spatial metrics
    mae_overall = mean_absolute_error(y_true.flatten(), y_pred.flatten())
    rmse_overall = np.sqrt(mean_squared_error(y_true.flatten(), y_pred.flatten()))
    
    print(f"Overall Spatial | MAE={mae_overall:.4f}, RMSE={rmse_overall:.4f}")
    print("-"*70)
    
    # Per-channel metrics
    for i, var in enumerate(TARGET_FEATURES):
        mae = mean_absolute_error(y_true[..., i].flatten(), y_pred[..., i].flatten())
        rmse = np.sqrt(mean_squared_error(y_true[..., i].flatten(), y_pred[..., i].flatten()))
        print(f"{var:>6s} | MAE={mae:.3f}, RMSE={rmse:.3f}")
    
    # Point-specific metrics at buoy location
    print("\n" + "="*70)
    print(f"BUOY LOCATION METRICS ({buoy_lat:.2f}¬∞N, {buoy_lon:.2f}¬∞E)")
    print("="*70)
    
    # Find buoy indices in grid
    lats = ds.latitude.values
    lons = ds.longitude.values
    lat_idx = np.argmin(np.abs(lats - buoy_lat))
    lon_idx = np.argmin(np.abs(lons - buoy_lon))
    
    y_true_buoy = y_true[:, lat_idx, lon_idx, :]
    y_pred_buoy = y_pred[:, lat_idx, lon_idx, :]
    
    for i, var in enumerate(TARGET_FEATURES):
        mae = mean_absolute_error(y_true_buoy[:, i], y_pred_buoy[:, i])
        rmse = np.sqrt(mean_squared_error(y_true_buoy[:, i], y_pred_buoy[:, i]))
        r2 = r2_score(y_true_buoy[:, i], y_pred_buoy[:, i])
        print(f"{var:>6s} | MAE={mae:.3f}, RMSE={rmse:.3f}, R¬≤={r2:.3f}")
    
    print("="*70)
    
    return y_true, y_pred, y_true_buoy, y_pred_buoy

In [6]:
# MAIN EXECUTION
# ======================================================================

print("üåä ENHANCED CONVLSTM WITH FRAME PREDICTION")
print("="*70)

# Load data
ds = xr.open_dataset(DATA_PATH)
buoy_lat, buoy_lon = find_nearest_ocean_point(ds, SPOT_LAT, SPOT_LON)
print(f"üìç Buoy: ({buoy_lat:.2f}¬∞N, {buoy_lon:.2f}¬∞E)")

# Create ocean mask
ocean_mask = create_ocean_mask(ds)

# Prepare frames
X, y = engineer_frames(ds, ocean_mask)

# Normalize
X_scaled, y_scaled, scalers_X, scalers_y = normalize_frames(X, y)

# Create sequences
print(f"\nüîÑ Creating sequences (lookback={LOOKBACK_HOURS})...")
X_seq, y_seq = create_frame_sequences(X_scaled, y_scaled, LOOKBACK_HOURS, LOOKAHEAD_HOURS)
print(f"   Sequences: X={X_seq.shape}, y={y_seq.shape}")

# Build model - Choose architecture
print("\nüèóÔ∏è Building model...")
ARCHITECTURE = "v3"  # Change to "v4" or "v5" to try different architectures

if ARCHITECTURE == "v3":
    model = build_model_v3_encoder_decoder(X_seq.shape[1:], y_seq.shape[-1])
    print("Using Architecture V3: Encoder-Decoder ConvLSTM")
elif ARCHITECTURE == "v4":
    model = build_model_v4_deep_stack(X_seq.shape[1:], y_seq.shape[-1])
    print("Using Architecture V4: Deep Stacked ConvLSTM")
else:
    model = build_model_v5_hybrid(X_seq.shape[1:], y_seq.shape[-1], (buoy_lat, buoy_lon))
    print("Using Architecture V5: Hybrid ConvLSTM")

model.summary()

# Callbacks
callbacks = [
    EarlyStopping(
        monitor='val_loss',
        patience=3,
        restore_best_weights=True,
        verbose=1
    ),
    ReduceLROnPlateau(
        monitor='val_loss',
        factor=0.1,  # Match reference paper
        patience=2,
        min_lr=1e-6,
        verbose=1
    ),
    ModelCheckpoint(
        'best_wave_model.keras',
        monitor='val_loss',
        save_best_only=True,
        verbose=1
    )
]

# Train
print(f"\nüöÄ Training for up to {EPOCHS} epochs...")
history = model.fit(
    X_seq, y_seq,
    epochs=EPOCHS,
    batch_size=BATCH_SIZE,
    validation_split=VALIDATION_SPLIT,
    callbacks=callbacks,
    verbose=1
)

# Evaluate
val_size = int(VALIDATION_SPLIT * len(X_seq))
y_true, y_pred, y_true_buoy, y_pred_buoy = evaluate_frame_model(
    model, X_seq[-val_size:], y_seq[-val_size:], 
    scalers_y, buoy_lat, buoy_lon, ds
)

# Forecast next frame
print("\n" + "="*70)
print("üîÆ FORECAST (Next Frame)")
print("="*70)

last_seq = np.expand_dims(X_scaled[-LOOKBACK_HOURS:], axis=0)
forecast_scaled = model.predict(last_seq, verbose=0)[0]

# Inverse transform forecast
forecast = np.zeros_like(forecast_scaled)
for i in range(forecast_scaled.shape[-1]):
    forecast[..., i] = scalers_y[i].inverse_transform(
        forecast_scaled[..., i].reshape(-1, 1)
    ).reshape(forecast_scaled[..., i].shape)

# Extract buoy location forecast
lats = ds.latitude.values
lons = ds.longitude.values
lat_idx = np.argmin(np.abs(lats - buoy_lat))
lon_idx = np.argmin(np.abs(lons - buoy_lon))

buoy_forecast = forecast[lat_idx, lon_idx, :]

print(f"At Buoy Location ({buoy_lat:.2f}¬∞N, {buoy_lon:.2f}¬∞E):")
print(f"  Swell Height:    {buoy_forecast[0]:.2f} m")
print(f"  Swell Period:    {buoy_forecast[1]:.2f} s")
print(f"  Swell Direction: {buoy_forecast[2]:.1f}¬∞")
print("="*70)

üåä ENHANCED CONVLSTM WITH FRAME PREDICTION
üìç Buoy: (7.00¬∞N, 82.00¬∞E)
üåä Ocean Mask: 93 / 441 points (21.1%)
üìä Frames prepared:
   Input: (1464, 21, 21, 6) (time, lat, lon, channels)
   Output: (1464, 21, 21, 3)

üîÑ Creating sequences (lookback=19)...
   Sequences: X=(1445, 19, 21, 21, 6), y=(1445, 21, 21, 3)

üèóÔ∏è Building model...
Using Architecture V3: Encoder-Decoder ConvLSTM
Model: "ConvLSTM_EncoderDecoder"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 input_1 (InputLayer)        [(None, 19, 21, 21, 6)]   0         
                                                                 
 encoder_1 (ConvLSTM2D)      (None, 19, 21, 21, 64)    161536    
                                                                 
 batch_normalization (Batch  (None, 19, 21, 21, 64)    256       
 Normalization)                                                  
                                        

ValueError: The following argument(s) are not supported with the native Keras format: ['options']