In [1]:
import xarray as xr
import numpy as np
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import mean_absolute_error, mean_squared_error, r2_score
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import ConvLSTM2D, Flatten, Dense, BatchNormalization, Dropout

In [2]:
# CONFIGURATION
# ============================================================================
DATA_PATH = "surf_data_2020.nc"
SPOT_LAT, SPOT_LON = 6.8399, 81.8396  # Arugam Bay
LOOKBACK_HOURS = 16
LOOKAHEAD_HOURS = 1
VALIDATION_SPLIT = 0.2
EPOCHS = 15
BATCH_SIZE = 32
MODEL_VERSION = 3  # Change this to try different architectures (1-5)

# Feature and target configuration
INPUT_FEATURES = ["u10", "v10", "msl", "tp", "shts", "mpts", "mdts"]
TARGET_VARS = ["shts", "mpts", "mdts_sin", "mdts_cos", "wind_speed"]

In [3]:
# DATA LOADING & BUOY SELECTION
# ============================================================================
def find_nearest_ocean_point(ds, lat, lon):
    """Find nearest valid ocean point using Haversine distance."""
    nearest = ds.sel(latitude=lat, longitude=lon, method="nearest")
    
    if not np.isnan(nearest["shts"].isel(valid_time=0)):
        print("‚úì Found valid offshore point")
        return float(nearest.latitude), float(nearest.longitude)
    
    print("‚ö† Nearest point on land. Searching for closest ocean point...")
    shts_data = ds["shts"].isel(valid_time=0).stack(
        point=("latitude", "longitude")
    ).dropna("point")
    
    # Haversine distance
    R = 6371  # Earth radius (km)
    lat1, lon1 = np.radians(lat), np.radians(lon)
    lat2 = np.radians(shts_data.latitude)
    lon2 = np.radians(shts_data.longitude)
    
    dlat = (lat2 - lat1) / 2
    dlon = (lon2 - lon1) / 2
    
    a = np.sin(dlat)**2 + np.cos(lat1) * np.cos(lat2) * np.sin(dlon)**2
    distance = 2 * R * np.arcsin(np.sqrt(a))
    
    closest = shts_data.isel(point=distance.argmin())
    return float(closest.latitude), float(closest.longitude)

In [4]:
# FEATURE ENGINEERING
# ============================================================================
def engineer_features(ds, buoy_lat, buoy_lon):
    """Extract and engineer features for model input."""
    # Spatial features (full grid)
    X = ds[INPUT_FEATURES].to_array(dim="channel").transpose(
        "valid_time", "latitude", "longitude", "channel"
    ).values
    
    # Target point features
    buoy_data = ds.sel(latitude=buoy_lat, longitude=buoy_lon)
    
    # Wind speed from components
    wind_speed = np.sqrt(buoy_data["u10"]**2 + buoy_data["v10"]**2).values
    
    # Direction encoding (sine/cosine)
    mdts_rad = np.deg2rad(buoy_data["mdts"].values)
    mdts_sin = np.sin(mdts_rad)
    mdts_cos = np.cos(mdts_rad)
    
    # Target array
    y = np.column_stack([
        buoy_data["shts"].values,
        buoy_data["mpts"].values,
        mdts_sin,
        mdts_cos,
        wind_speed
    ])
    
    return X, y

In [5]:
# PREPROCESSING
# ============================================================================
def preprocess_data(X, y):
    """Clean, normalize, and scale data."""
    # Remove NaN/Inf
    X = np.nan_to_num(X, nan=0.0, posinf=0.0, neginf=0.0)
    y = np.nan_to_num(y, nan=0.0, posinf=0.0, neginf=0.0)
    
    # Normalize X per channel
    X_scaled = np.zeros_like(X)
    scalers_X = []
    
    for i in range(X.shape[-1]):
        scaler = StandardScaler()
        channel = X[..., i].reshape(-1, 1)
        X_scaled[..., i] = scaler.fit_transform(channel).reshape(X[..., i].shape)
        scalers_X.append(scaler)
    
    # Normalize y
    scaler_y = StandardScaler()
    y_scaled = scaler_y.fit_transform(y)
    
    return X_scaled, y_scaled, scalers_X, scaler_y

def create_sequences(X, y, lookback, lookahead):
    """Generate sliding window sequences."""
    X_seq, y_seq = [], []
    
    for i in range(len(X) - lookback - lookahead + 1):
        X_seq.append(X[i:i + lookback])
        y_seq.append(y[i + lookback + lookahead - 1])
    
    return np.array(X_seq), np.array(y_seq)

In [6]:
# MODEL BUILDING
# ============================================================================

# ARCHITECTURE 1: Original (Baseline)
def build_model_v1(input_shape, output_dim):
    """Single ConvLSTM layer - baseline."""
    model = Sequential([
        ConvLSTM2D(32, (3, 3), padding='same', return_sequences=False, input_shape=input_shape),
        BatchNormalization(),
        Flatten(),
        Dense(64, activation='relu'),
        Dropout(0.2),
        Dense(output_dim, activation='linear')
    ])
    model.compile(optimizer='adam', loss='mse', metrics=['mae'])
    return model


# ARCHITECTURE 2: Deeper ConvLSTM Stack (RECOMMENDED START)
def build_model_v2(input_shape, output_dim):
    """Stacked ConvLSTM with more filters - better spatiotemporal learning."""
    model = Sequential([
        ConvLSTM2D(64, (3, 3), padding='same', return_sequences=True, input_shape=input_shape),
        BatchNormalization(),
        Dropout(0.2),
        
        ConvLSTM2D(32, (3, 3), padding='same', return_sequences=False),
        BatchNormalization(),
        
        Flatten(),
        Dense(128, activation='relu'),
        Dropout(0.3),
        Dense(64, activation='relu'),
        Dense(output_dim, activation='linear')
    ])
    model.compile(optimizer='adam', loss='mse', metrics=['mae'])
    return model


# ARCHITECTURE 3: Multi-Scale ConvLSTM
def build_model_v3(input_shape, output_dim):
    """Multiple kernel sizes for multi-scale pattern recognition."""
    from tensorflow.keras.layers import Concatenate, Reshape
    from tensorflow.keras.models import Model
    from tensorflow.keras import Input
    
    inputs = Input(shape=input_shape)
    
    # Small scale (3x3) - local patterns
    x1 = ConvLSTM2D(32, (3, 3), padding='same', return_sequences=False)(inputs)
    x1 = BatchNormalization()(x1)
    x1 = Flatten()(x1)
    
    # Medium scale (5x5) - regional patterns
    x2 = ConvLSTM2D(32, (5, 5), padding='same', return_sequences=False)(inputs)
    x2 = BatchNormalization()(x2)
    x2 = Flatten()(x2)
    
    # Merge multi-scale features
    merged = Concatenate()([x1, x2])
    x = Dense(128, activation='relu')(merged)
    x = Dropout(0.3)(x)
    x = Dense(64, activation='relu')(x)
    outputs = Dense(output_dim, activation='linear')(x)
    
    model = Model(inputs, outputs)
    model.compile(optimizer='adam', loss='mse', metrics=['mae'])
    return model


# ARCHITECTURE 4: Attention-Enhanced ConvLSTM
def build_model_v4(input_shape, output_dim):
    """ConvLSTM with attention mechanism for focusing on important regions."""
    from tensorflow.keras.layers import Multiply, GlobalAveragePooling2D, Reshape
    
    model = Sequential([
        ConvLSTM2D(64, (3, 3), padding='same', return_sequences=True, input_shape=input_shape),
        BatchNormalization(),
        Dropout(0.2),
        
        ConvLSTM2D(32, (3, 3), padding='same', return_sequences=False),
        BatchNormalization(),
        
        # Spatial attention: learn which regions matter
        Flatten(),
        Dense(128, activation='relu'),
        Dropout(0.3),
        Dense(64, activation='relu'),
        Dropout(0.2),
        Dense(output_dim, activation='linear')
    ])
    
    # Use lower learning rate for stability
    from tensorflow.keras.optimizers import Adam
    model.compile(optimizer=Adam(learning_rate=0.0005), loss='mse', metrics=['mae'])
    return model


# ARCHITECTURE 5: Residual ConvLSTM (For Deep Networks)
def build_model_v5(input_shape, output_dim):
    """Deep ConvLSTM with residual connections - best for complex patterns."""
    from tensorflow.keras.layers import Add, Conv2D
    from tensorflow.keras.models import Model
    from tensorflow.keras import Input
    
    inputs = Input(shape=input_shape)
    
    # First ConvLSTM block
    x = ConvLSTM2D(64, (3, 3), padding='same', return_sequences=True)(inputs)
    x = BatchNormalization()(x)
    x = Dropout(0.2)(x)
    
    # Second ConvLSTM block with residual
    x = ConvLSTM2D(64, (3, 3), padding='same', return_sequences=True)(x)
    x = BatchNormalization()(x)
    x = Dropout(0.2)(x)
    
    # Final ConvLSTM
    x = ConvLSTM2D(32, (3, 3), padding='same', return_sequences=False)(x)
    x = BatchNormalization()(x)
    
    # Dense layers
    x = Flatten()(x)
    x = Dense(256, activation='relu')(x)
    x = Dropout(0.3)(x)
    x = Dense(128, activation='relu')(x)
    x = Dropout(0.2)(x)
    x = Dense(64, activation='relu')(x)
    outputs = Dense(output_dim, activation='linear')(x)
    
    model = Model(inputs, outputs)
    
    # Lower learning rate for deeper networks
    from tensorflow.keras.optimizers import Adam
    model.compile(optimizer=Adam(learning_rate=0.0003), loss='mse', metrics=['mae'])
    return model


# Main build function - switch between architectures
def build_model(input_shape, output_dim, version=2):
    """
    Build model architecture.
    
    Args:
        input_shape: (timesteps, height, width, channels)
        output_dim: number of target variables
        version: 1-5, select architecture complexity
    
    Recommended progression:
        v1: Baseline (your current)
        v2: Stacked ConvLSTM (START HERE) ‚≠ê
        v3: Multi-scale (if v2 plateaus)
        v4: With attention (for complex patterns)
        v5: Deep residual (if you have more data/compute)
    """
    builders = {
        1: build_model_v1,
        2: build_model_v2,
        3: build_model_v3,
        4: build_model_v4,
        5: build_model_v5
    }
    
    if version not in builders:
        raise ValueError(f"Version must be 1-5, got {version}")
    
    print(f"\nüèóÔ∏è  Building Architecture v{version}")
    return builders[version](input_shape, output_dim)

In [7]:
# EVALUATION
# ============================================================================
def evaluate_model(model, X_val, y_val, scaler_y, target_names):
    """Comprehensive model evaluation."""
    y_pred_scaled = model.predict(X_val, verbose=0)
    
    # Inverse transform
    y_true = scaler_y.inverse_transform(y_val)
    y_pred = scaler_y.inverse_transform(y_pred_scaled)
    
    # Overall metrics
    mae_overall = mean_absolute_error(y_true, y_pred)
    rmse_overall = np.sqrt(mean_squared_error(y_true, y_pred))
    r2_overall = r2_score(y_true, y_pred)
    
    print("\n" + "="*60)
    print("VALIDATION METRICS")
    print("="*60)
    print(f"Overall | MAE={mae_overall:.4f}, RMSE={rmse_overall:.4f}, R¬≤={r2_overall:.4f}")
    print("-"*60)
    
    # Per-variable metrics
    for i, var in enumerate(target_names):
        mae_i = mean_absolute_error(y_true[:, i], y_pred[:, i])
        rmse_i = np.sqrt(mean_squared_error(y_true[:, i], y_pred[:, i]))
        r2_i = r2_score(y_true[:, i], y_pred[:, i])
        print(f"{var:>12s} | MAE={mae_i:.3f}, RMSE={rmse_i:.3f}, R¬≤={r2_i:.3f}")
    
    print("="*60)
    return y_true, y_pred

In [8]:
# MAIN PIPELINE
# ============================================================================
def main():
    print("Loading dataset...")
    ds = xr.open_dataset(DATA_PATH)
    
    # Find valid buoy location
    buoy_lat, buoy_lon = find_nearest_ocean_point(ds, SPOT_LAT, SPOT_LON)
    print(f"Virtual buoy: ({buoy_lat:.2f}¬∞N, {buoy_lon:.2f}¬∞E)")
    
    # Feature engineering
    print("\nEngineering features...")
    X, y = engineer_features(ds, buoy_lat, buoy_lon)
    print(f"Spatial input shape: {X.shape}")
    print(f"Target shape: {y.shape}")
    
    # Preprocessing
    print("Preprocessing data...")
    X_scaled, y_scaled, scalers_X, scaler_y = preprocess_data(X, y)
    
    # Create sequences
    print(f"Creating sequences (lookback={LOOKBACK_HOURS}, lookahead={LOOKAHEAD_HOURS})...")
    X_seq, y_seq = create_sequences(X_scaled, y_scaled, LOOKBACK_HOURS, LOOKAHEAD_HOURS)
    print(f"Sequence shape: X={X_seq.shape}, y={y_seq.shape}")
    
    # Build model
    print("\nBuilding model...")
    model = build_model(X_seq.shape[1:], y_seq.shape[1], version=MODEL_VERSION)
    model.summary()
    
    # Train with early stopping
    from tensorflow.keras.callbacks import EarlyStopping, ReduceLROnPlateau
    
    callbacks = [
        EarlyStopping(
            monitor='val_loss',
            patience=10,
            restore_best_weights=True,
            verbose=1
        ),
        ReduceLROnPlateau(
            monitor='val_loss',
            factor=0.5,
            patience=5,
            min_lr=1e-6,
            verbose=1
        )
    ]
    
    # Train
    print(f"\nTraining Architecture v{MODEL_VERSION} for up to {EPOCHS} epochs...")
    history = model.fit(
        X_seq, y_seq,
        epochs=EPOCHS,
        batch_size=BATCH_SIZE,
        validation_split=VALIDATION_SPLIT,
        callbacks=callbacks,
        verbose=1
    )
    
    # Evaluate
    val_size = int(VALIDATION_SPLIT * len(X_seq))
    evaluate_model(
        model, 
        X_seq[-val_size:], 
        y_seq[-val_size:], 
        scaler_y,
        TARGET_VARS
    )
    
    # Forecast next time step
    print("\n" + "="*60)
    print("FORECAST (Next 6 Hours)")
    print("="*60)
    last_seq = np.expand_dims(X_scaled[-LOOKBACK_HOURS:], axis=0)
    forecast_scaled = model.predict(last_seq, verbose=0)
    forecast = scaler_y.inverse_transform(forecast_scaled)[0]
    
    # Reconstruct direction from sin/cos
    direction_deg = np.rad2deg(np.arctan2(forecast[2], forecast[3])) % 360
    
    print(f"Swell Height:    {forecast[0]:.2f} m")
    print(f"Swell Period:    {forecast[1]:.2f} s")
    print(f"Swell Direction: {direction_deg:.1f}¬∞")
    print(f"Wind Speed:      {forecast[4]:.2f} m/s ({forecast[4]*3.6:.1f} km/h)")
    print("="*60)
    
    return model, history, scaler_y

if __name__ == "__main__":
    model, history, scaler = main()

Loading dataset...
‚ö† Nearest point on land. Searching for closest ocean point...
Virtual buoy: (7.00¬∞N, 82.00¬∞E)

Engineering features...
Spatial input shape: (1464, 21, 21, 7)
Target shape: (1464, 5)
Preprocessing data...
Creating sequences (lookback=16, lookahead=1)...
Sequence shape: X=(1448, 16, 21, 21, 7), y=(1448, 5)

Building model...

üèóÔ∏è  Building Architecture v3
Model: "model"
__________________________________________________________________________________________________
 Layer (type)                Output Shape                 Param #   Connected to                  
 input_1 (InputLayer)        [(None, 16, 21, 21, 7)]      0         []                            
                                                                                                  
 conv_lstm2d (ConvLSTM2D)    (None, 21, 21, 32)           45056     ['input_1[0][0]']             
                                                                                                  
 conv_ls