In [1]:
import xarray as xr
import numpy as np
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import mean_absolute_error, mean_squared_error, r2_score
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import ConvLSTM2D, Flatten, Dense, BatchNormalization, Dropout, PReLU
import tensorflow_addons as tfa
from tensorflow.keras.callbacks import EarlyStopping, ReduceLROnPlateau
import sys

# Python 3.9 compatibility fix
if sys.version_info >= (3, 10):
    from typing import TypeAlias
else:
    from typing_extensions import TypeAlias


TensorFlow Addons (TFA) has ended development and introduction of new features.
TFA has entered a minimal maintenance and release mode until a planned end of life in May 2024.
Please modify downstream libraries to take dependencies from other repositories in our TensorFlow community (e.g. Keras, Keras-CV, and Keras-NLP). 

For more information see: https://github.com/tensorflow/addons/issues/2807 



In [2]:
# CONFIGURATION
DATA_PATH = "../surf_data_2024.nc"
SPOT_LAT, SPOT_LON = 6.8399, 81.8396  # Arugam Bay
LOOKBACK_HOURS = 28
LOOKAHEAD_HOURS = 12
VALIDATION_SPLIT = 0.2
EPOCHS = 15
BATCH_SIZE = 32

INPUT_FEATURES = ["u10", "v10", "msl", "shts", "mpts", "mdts"]
TARGET_VARS = ["shts", "mpts", "mdts_sin", "mdts_cos", "wind_speed"]

In [3]:
# IMPROVED NaN DIAGNOSTICS
def diagnose_nans(ds, features):
    """Detailed NaN analysis before preprocessing."""
    print("\n" + "="*70)
    print("NaN DIAGNOSTICS")
    print("="*70)
    
    for var in features:
        data = ds[var].values
        nan_count = np.isnan(data).sum()
        nan_pct = 100 * nan_count / data.size
        
        # Check if NaNs are in specific locations
        nan_mask = np.isnan(data)
        if nan_count > 0:
            # Spatial distribution
            spatial_nans = nan_mask.any(axis=0)  # Any NaN across time
            temporal_nans = nan_mask.any(axis=(1, 2))  # Any NaN across space
            
            print(f"\n{var}:")
            print(f"  Total NaNs: {nan_count:,} ({nan_pct:.2f}%)")
            print(f"  Affected spatial points: {spatial_nans.sum()} / {spatial_nans.size}")
            print(f"  Affected timesteps: {temporal_nans.sum()} / {len(temporal_nans)}")
        else:
            print(f"\n{var}: ‚úì No NaNs")
    
    print("="*70)

In [4]:
# ======================================================================
# SOLUTION 1: OCEAN MASK (Recommended for Spatial Data)
# ======================================================================
def create_ocean_mask(ds, threshold=0.5):
    """
    Create a mask to exclude land points.
    Points with >50% NaN values across time are considered land.
    """
    # Use wave height as indicator (always NaN over land)
    shts_data = ds["shts"].values
    nan_ratio = np.isnan(shts_data).sum(axis=0) / shts_data.shape[0]
    ocean_mask = nan_ratio < threshold
    
    valid_points = ocean_mask.sum()
    total_points = ocean_mask.size
    
    print(f"\nüåä Ocean Mask Created:")
    print(f"   Valid ocean points: {valid_points} / {total_points} ({100*valid_points/total_points:.1f}%)")
    
    return ocean_mask


def apply_ocean_mask(X, ocean_mask):
    """Apply ocean mask to spatial data, setting land points to 0."""
    X_masked = X.copy()
    # Broadcast mask to all timesteps and channels
    mask_3d = np.broadcast_to(ocean_mask[..., np.newaxis], X.shape)
    X_masked[~mask_3d] = 0.0
    return X_masked

In [5]:
# ======================================================================
# SOLUTION 2: TEMPORAL INTERPOLATION (COMMENTED OUT)
# ======================================================================
# def interpolate_temporal_gaps(ds, features, method='linear', limit=2):
#     """
#     Fill small temporal gaps using interpolation.
#     Only fills gaps up to 'limit' consecutive NaNs.
#     """
#     print(f"\n‚è±Ô∏è  Interpolating temporal gaps (method={method}, max_gap={limit})...")
#     
#     ds_interp = ds.copy()
#     for var in features:
#         original_nans = np.isnan(ds[var].values).sum()
#         
#         # Interpolate along time dimension
#         ds_interp[var] = ds[var].interpolate_na(
#             dim='valid_time', 
#             method=method,
#             limit=limit,
#             fill_value='extrapolate'
#         )
#         
#         remaining_nans = np.isnan(ds_interp[var].values).sum()
#         filled = original_nans - remaining_nans
#         
#         if filled > 0:
#             print(f"   {var}: Filled {filled:,} NaNs ({original_nans:,} ‚Üí {remaining_nans:,})")
#     
#     return ds_interp
# Simple function that just returns the original dataset
def interpolate_temporal_gaps(ds, features, method='linear', limit=2):
    print("‚è≠Ô∏è  Skipping temporal interpolation...")
    return ds

In [6]:
# ======================================================================
# SOLUTION 3: SMART NaN REMOVAL
# ======================================================================
def remove_corrupted_timesteps(X, y, max_nan_ratio=0.1):
    """
    Remove timesteps where NaN ratio exceeds threshold.
    More conservative than filling with zeros.
    """
    nan_ratio_per_time = np.isnan(X).sum(axis=(1, 2, 3)) / (X.shape[1] * X.shape[2] * X.shape[3])
    valid_mask = nan_ratio_per_time < max_nan_ratio
    
    removed = len(X) - valid_mask.sum()
    if removed > 0:
        print(f"\nüóëÔ∏è  Removed {removed} corrupted timesteps (>{max_nan_ratio*100:.0f}% NaNs)")
    
    return X[valid_mask], y[valid_mask]

In [7]:
# ======================================================================
# IMPROVED PREPROCESSING
# ======================================================================
def preprocess_data_v2(X, y, ocean_mask=None):
    """
    Enhanced preprocessing with proper NaN handling.
    """
    print("\n" + "="*70)
    print("PREPROCESSING")
    print("="*70)
    
    # Step 1: Apply ocean mask if provided
    if ocean_mask is not None:
        X = apply_ocean_mask(X, ocean_mask)
        print("‚úì Ocean mask applied (land points ‚Üí 0)")
    
    # Step 2: Check remaining NaNs
    nan_count_X = np.isnan(X).sum()
    nan_count_y = np.isnan(y).sum()
    
    print(f"\nRemaining NaNs:")
    print(f"   Input (X): {nan_count_X:,} ({100*nan_count_X/X.size:.3f}%)")
    print(f"   Target (y): {nan_count_y:,} ({100*nan_count_y/y.size:.3f}%)")
    
    # Step 3: Remove rows with NaN targets (critical!)
    if nan_count_y > 0:
        valid_mask = ~np.isnan(y).any(axis=1)
        X = X[valid_mask]
        y = y[valid_mask]
        print(f"‚úì Removed {(~valid_mask).sum()} samples with NaN targets")
    
    # Step 4: Handle remaining spatial NaNs (e.g., edge effects)
    # Fill with spatial mean (better than 0 for model learning)
    if nan_count_X > 0:
        for i in range(X.shape[-1]):
            channel = X[..., i]
            channel_mean = np.nanmean(channel)
            X[..., i] = np.nan_to_num(channel, nan=channel_mean)
        print(f"‚úì Filled remaining input NaNs with channel means")
    
    # Step 5: Scale data
    X_scaled = np.zeros_like(X)
    scalers_X = []
    
    for i in range(X.shape[-1]):
        scaler = StandardScaler()
        channel = X[..., i].reshape(-1, 1)
        X_scaled[..., i] = scaler.fit_transform(channel).reshape(X[..., i].shape)
        scalers_X.append(scaler)
    
    scaler_y = StandardScaler()
    y_scaled = scaler_y.fit_transform(y)
    
    print(f"\n‚úì Scaled data | X: {X_scaled.shape}, y: {y_scaled.shape}")
    print("="*70)
    
    return X_scaled, y_scaled, scalers_X, scaler_y

In [8]:
# ======================================================================
# IMPROVED FEATURE ENGINEERING
# ======================================================================
def engineer_features_v2(ds, buoy_lat, buoy_lon):
    """Feature engineering with NaN tracking."""
    # Spatial input
    X = ds[INPUT_FEATURES].to_array(dim="channel").transpose(
        "valid_time", "latitude", "longitude", "channel"
    ).values
    
    # Target point
    buoy_data = ds.sel(latitude=buoy_lat, longitude=buoy_lon)
    
    wind_speed = np.sqrt(buoy_data["u10"]**2 + buoy_data["v10"]**2).values
    mdts_rad = np.deg2rad(buoy_data["mdts"].values)
    
    y = np.column_stack([
        buoy_data["shts"].values,
        buoy_data["mpts"].values,
        np.sin(mdts_rad),
        np.cos(mdts_rad),
        wind_speed
    ])
    
    # Diagnostics
    print(f"\nüìä Feature Extraction:")
    print(f"   Input shape: {X.shape}")
    print(f"   Target shape: {y.shape}")
    print(f"   Input NaNs: {np.isnan(X).sum():,}")
    print(f"   Target NaNs: {np.isnan(y).sum():,}")
    
    return X, y

In [9]:
# ======================================================================
# FIND NEAREST OCEAN POINT (Same as before)
# ======================================================================
def find_nearest_ocean_point(ds, lat, lon):
    """Find nearest valid ocean point using Haversine distance."""
    nearest = ds.sel(latitude=lat, longitude=lon, method="nearest")
    
    if not np.isnan(nearest["shts"].isel(valid_time=0)):
        print("‚úì Found valid offshore point")
        return float(nearest.latitude), float(nearest.longitude)
    
    print("‚ö† Nearest point on land. Searching for closest ocean point...")
    shts_data = ds["shts"].isel(valid_time=0).stack(
        point=("latitude", "longitude")
    ).dropna("point")
    
    R = 6371
    lat1, lon1 = np.radians(lat), np.radians(lon)
    lat2 = np.radians(shts_data.latitude)
    lon2 = np.radians(shts_data.longitude)
    
    dlat = (lat2 - lat1) / 2
    dlon = (lon2 - lon1) / 2
    
    a = np.sin(dlat)**2 + np.cos(lat1) * np.cos(lat2) * np.sin(dlon)**2
    distance = 2 * R * np.arcsin(np.sqrt(a))
    
    closest = shts_data.isel(point=distance.argmin())
    return float(closest.latitude), float(closest.longitude)


def create_sequences(X, y, lookback, lookahead):
    """Generate sliding window sequences."""
    X_seq, y_seq = [], []
    
    for i in range(len(X) - lookback - lookahead + 1):
        X_seq.append(X[i:i + lookback])
        y_seq.append(y[i + lookback + lookahead - 1])
    
    return np.array(X_seq), np.array(y_seq)


def build_model_v2(input_shape, output_dim):
    """Stacked ConvLSTM architecture."""
    model = Sequential([
        ConvLSTM2D(64, (3, 3), padding='same', return_sequences=True, input_shape=input_shape),
        BatchNormalization(),
        
        ConvLSTM2D(32, (3, 3), padding='same', return_sequences=False),
        BatchNormalization(),
        
        Flatten(),
        Dense(128),
        PReLU(),
        Dropout(0.2),
        Dense(64),
        PReLU(),
        Dropout(0.3),
        Dense(output_dim, activation='linear')
    ])
    
    model.compile(
        optimizer=tfa.optimizers.AdamW(learning_rate=5e-4, weight_decay=1e-5),
        loss='mse',
        metrics=['mae']
    )
    return model


def evaluate_model(model, X_val, y_val, scaler_y, target_names):
    """Evaluation metrics."""
    y_pred_scaled = model.predict(X_val, verbose=0)
    
    y_true = scaler_y.inverse_transform(y_val)
    y_pred = scaler_y.inverse_transform(y_pred_scaled)
    
    mae_overall = mean_absolute_error(y_true, y_pred)
    rmse_overall = np.sqrt(mean_squared_error(y_true, y_pred))
    r2_overall = r2_score(y_true, y_pred)
    
    print("\n" + "="*70)
    print("VALIDATION METRICS")
    print("="*70)
    print(f"Overall | MAE={mae_overall:.4f}, RMSE={rmse_overall:.4f}, R¬≤={r2_overall:.4f}")
    print("-"*70)
    
    for i, var in enumerate(target_names):
        mae_i = mean_absolute_error(y_true[:, i], y_pred[:, i])
        rmse_i = np.sqrt(mean_squared_error(y_true[:, i], y_pred[:, i]))
        r2_i = r2_score(y_true[:, i], y_pred[:, i])
        print(f"{var:>12s} | MAE={mae_i:.3f}, RMSE={rmse_i:.3f}, R¬≤={r2_i:.3f}")
    
    print("="*70)
    return y_true, y_pred

In [10]:
# ======================================================================
# MAIN EXECUTION
# ======================================================================

print("üåä IMPROVED CONVLSTM PIPELINE WITH NaN HANDLING")
print("="*70)

# Load data
print("\nüìÇ Loading dataset...")
ds = xr.open_dataset(DATA_PATH)

# Diagnose NaNs BEFORE processing
diagnose_nans(ds, INPUT_FEATURES)

# STRATEGY 1: Interpolate small gaps
ds = interpolate_temporal_gaps(ds, INPUT_FEATURES, method='linear', limit=2)

# Find buoy location
buoy_lat, buoy_lon = find_nearest_ocean_point(ds, SPOT_LAT, SPOT_LON)
print(f"üìç Virtual buoy: ({buoy_lat:.2f}¬∞N, {buoy_lon:.2f}¬∞E)")

# STRATEGY 2: Create ocean mask
ocean_mask = create_ocean_mask(ds)

# Feature engineering
X, y = engineer_features_v2(ds, buoy_lat, buoy_lon)

# STRATEGY 3: Improved preprocessing
X_scaled, y_scaled, scalers_X, scaler_y = preprocess_data_v2(X, y, ocean_mask)

# Create sequences
print(f"\nüîÑ Creating sequences (lookback={LOOKBACK_HOURS}, lookahead={LOOKAHEAD_HOURS})...")
X_seq, y_seq = create_sequences(X_scaled, y_scaled, LOOKBACK_HOURS, LOOKAHEAD_HOURS)
print(f"   Sequence shape: X={X_seq.shape}, y={y_seq.shape}")

# Build and train
print("\nüèóÔ∏è Building model...")
model = build_model_v2(X_seq.shape[1:], y_seq.shape[1])

callbacks = [
    EarlyStopping(monitor='val_loss', patience=10, restore_best_weights=True, verbose=1),
    ReduceLROnPlateau(monitor='val_loss', factor=0.5, patience=5, min_lr=1e-6, verbose=1)
]

print(f"\nüöÄ Training for up to {EPOCHS} epochs...")
history = model.fit(
    X_seq, y_seq,
    epochs=EPOCHS,
    batch_size=BATCH_SIZE,
    validation_split=VALIDATION_SPLIT,
    callbacks=callbacks,
    verbose=1
)

# Evaluate
val_size = int(VALIDATION_SPLIT * len(X_seq))
evaluate_model(model, X_seq[-val_size:], y_seq[-val_size:], scaler_y, TARGET_VARS)

# Forecast
print("\n" + "="*70)
print("üîÆ FORECAST (Next 12 Hours)")
print("="*70)
last_seq = np.expand_dims(X_scaled[-LOOKBACK_HOURS:], axis=0)
forecast_scaled = model.predict(last_seq, verbose=0)
forecast = scaler_y.inverse_transform(forecast_scaled)[0]

direction_deg = np.rad2deg(np.arctan2(forecast[2], forecast[3])) % 360

print(f"Swell Height:    {forecast[0]:.2f} m")
print(f"Swell Period:    {forecast[1]:.2f} s")
print(f"Swell Direction: {direction_deg:.1f}¬∞")
print(f"Wind Speed:      {forecast[4]:.2f} m/s ({forecast[4]*3.6:.1f} km/h)")
print("="*70)

üåä IMPROVED CONVLSTM PIPELINE WITH NaN HANDLING

üìÇ Loading dataset...

NaN DIAGNOSTICS

u10: ‚úì No NaNs

v10: ‚úì No NaNs

msl: ‚úì No NaNs

shts:
  Total NaNs: 509,472 (78.91%)
  Affected spatial points: 348 / 441
  Affected timesteps: 1464 / 1464

mpts:
  Total NaNs: 509,472 (78.91%)
  Affected spatial points: 348 / 441
  Affected timesteps: 1464 / 1464

mdts:
  Total NaNs: 509,472 (78.91%)
  Affected spatial points: 348 / 441
  Affected timesteps: 1464 / 1464
‚è≠Ô∏è  Skipping temporal interpolation...
‚ö† Nearest point on land. Searching for closest ocean point...
üìç Virtual buoy: (7.00¬∞N, 82.00¬∞E)

üåä Ocean Mask Created:
   Valid ocean points: 93 / 441 (21.1%)

üìä Feature Extraction:
   Input shape: (1464, 21, 21, 6)
   Target shape: (1464, 5)
   Input NaNs: 1,528,416
   Target NaNs: 0

PREPROCESSING
‚úì Ocean mask applied (land points ‚Üí 0)

Remaining NaNs:
   Input (X): 0 (0.000%)
   Target (y): 0 (0.000%)

‚úì Scaled data | X: (1464, 21, 21, 6), y: (1464, 5)

üîÑ

KeyboardInterrupt: 