# Multi-Horizon River Differential Prediction with Future Rainfall

This notebook implements a multi-horizon LSTM model that predicts river differential for the next 10 days (240 hours) using:
1. **Historical rainfall data** (for soil moisture context)
2. **Historical differential data** (for river state context)
3. **Future rainfall forecasts** (key improvement!)

In [1]:
# Imports
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import os
import sys
import warnings
import requests
from sklearn.model_selection import train_test_split, TimeSeriesSplit
from sklearn.metrics import mean_absolute_error, mean_squared_error
from sklearn.preprocessing import StandardScaler, MinMaxScaler
from datetime import datetime, timedelta
from tqdm import tqdm
import glob

# PyTorch imports for LSTM
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, TensorDataset

# PyTorch Information
print("="*70)
print(f"PyTorch version: {torch.__version__}")
print("="*70)

# Set device for PyTorch - Check for MPS (Apple Silicon), CUDA, then CPU
if torch.cuda.is_available():
    device = torch.device('cuda')
    print(f"Using device: cuda (NVIDIA GPU)")
elif torch.backends.mps.is_available():
    device = torch.device('mps')
    print(f"Using device: mps (Apple Silicon GPU)")
else:
    device = torch.device('cpu')
    print(f"Using device: cpu")


# Set the working directory
os.chdir('/Users/robertds413/Documents/Flag_Predictor')
print(f"Working directory: {os.getcwd()}")


PyTorch version: 2.2.2
Using device: mps (Apple Silicon GPU)
Working directory: /Users/robertds413/Documents/Flag_Predictor


## 1. Data Loading Functions


In [2]:
def process_lock_api_data(data):
    """Process API data from flood monitoring service"""
    if 'items' not in data or not data['items']:
        return pd.DataFrame()
    
    temp_df = pd.DataFrame(data['items'])
    if 'dateTime' not in temp_df.columns or 'value' not in temp_df.columns:
        return pd.DataFrame()
    
    temp_df = temp_df[['dateTime', 'value']]
    temp_df.rename(columns={'dateTime': 'timestamp', 'value': 'level'}, inplace=True)
    temp_df['timestamp'] = pd.to_datetime(temp_df['timestamp'])
    df = temp_df.set_index('timestamp')
    return df

def process_rainfall_api_data(data):
    """Process rainfall API data from flood monitoring service"""
    if 'items' not in data or not data['items']:
        return pd.DataFrame()
    
    temp_df = pd.DataFrame(data['items'])
    if 'dateTime' not in temp_df.columns or 'value' not in temp_df.columns:
        return pd.DataFrame()
    
    temp_df = temp_df[['dateTime', 'value']]
    temp_df.rename(columns={'dateTime': 'timestamp', 'value': 'rainfall'}, inplace=True)
    temp_df['timestamp'] = pd.to_datetime(temp_df['timestamp'])
    df = temp_df.set_index('timestamp')
    return df

def load_flag_model_data(file_path, differential_column):
    """Load flag model training data from CSV file"""
    df = pd.read_csv(file_path)
    df['timestamp'] = pd.to_datetime(df['timestamp'])
    df = df.set_index('timestamp')
    df = df[[differential_column]]
    df = df.rename(columns={differential_column: 'differential'})
    return df

print("âœ“ Data loading functions defined")


âœ“ Data loading functions defined


## 2. Load API Data (River Levels & Current Rainfall)


In [3]:
# API URLs
kings_mill_downstream_url = 'http://environment.data.gov.uk/flood-monitoring/id/measures/1491TH-level-downstage-i-15_min-mASD/readings?_sorted&_limit=90000'
godstow_downstream_url = 'http://environment.data.gov.uk/flood-monitoring/id/measures/1302TH-level-downstage-i-15_min-mASD/readings?_sorted&_limit=90000'
osney_upstream_url = 'http://environment.data.gov.uk/flood-monitoring/id/measures/1303TH-level-stage-i-15_min-mASD/readings?_sorted&_limit=90000'
osney_downstream_url = 'http://environment.data.gov.uk/flood-monitoring/id/measures/1303TH-level-downstage-i-15_min-mASD/readings?_sorted&_limit=90000'
iffley_upstream_url = 'http://environment.data.gov.uk/flood-monitoring/id/measures/1501TH-level-stage-i-15_min-mASD/readings?_sorted&_limit=90000'

# Rainfall API URLs
rainfall_urls = [
    'http://environment.data.gov.uk/flood-monitoring/id/measures/256230TP-rainfall-tipping_bucket_raingauge-t-15_min-mm/readings?_sorted&_limit=90000',
    'http://environment.data.gov.uk/flood-monitoring/id/measures/254336TP-rainfall-tipping_bucket_raingauge-t-15_min-mm/readings?_sorted&_limit=90000',
    'http://environment.data.gov.uk/flood-monitoring/id/measures/251530TP-rainfall-tipping_bucket_raingauge-t-15_min-mm/readings?_sorted&_limit=90000',
    'http://environment.data.gov.uk/flood-monitoring/id/measures/248332TP-rainfall-tipping_bucket_raingauge-t-15_min-mm/readings?_sorted&_limit=90000',
    'http://environment.data.gov.uk/flood-monitoring/id/measures/248965TP-rainfall-tipping_bucket_raingauge-t-15_min-mm/readings?_sorted&_limit=90000',
    'http://environment.data.gov.uk/flood-monitoring/id/measures/251556TP-rainfall-tipping_bucket_raingauge-t-15_min-mm/readings?_sorted&_limit=90000',
    'http://environment.data.gov.uk/flood-monitoring/id/measures/253340TP-rainfall-tipping_bucket_raingauge-t-15_min-mm/readings?_sorted&_limit=90000',
    'http://environment.data.gov.uk/flood-monitoring/id/measures/254829TP-rainfall-tipping_bucket_raingauge-t-15_min-mm/readings?_sorted&_limit=90000',
    'http://environment.data.gov.uk/flood-monitoring/id/measures/257039TP-rainfall-tipping_bucket_raingauge-t-15_min-mm/readings?_sorted&_limit=90000',
    'http://environment.data.gov.uk/flood-monitoring/id/measures/259110TP-rainfall-tipping_bucket_raingauge-t-15_min-mm/readings?_sorted&_limit=90000',
    'http://environment.data.gov.uk/flood-monitoring/id/measures/256345TPrainfall-tipping_bucket_raingauge-t-15_min-mm/readings?_sorted&_limit=90000',
    'http://environment.data.gov.uk/flood-monitoring/id/measures/249744TP-rainfall-tipping_bucket_raingauge-t-15_min-mm/readings?_sorted&_limit=90000',
    'http://environment.data.gov.uk/flood-monitoring/id/measures/253861TP-rainfall-tipping_bucket_raingauge-t-15_min-mm/readings?_sorted&_limit=90000'
]

location_names = ['Osney', 'Eynsham', 'St', 'Shorncote', 'Rapsgate', 'Stowell', 
                  'Bourton', 'Chipping', 'Grimsbury', 'Bicester', 'Byfield', 'Swindon', 'Worsham']

print("Fetching river level data...")
# Get river level data
kings_mill_downstream_df = process_lock_api_data(requests.get(kings_mill_downstream_url).json())
godstow_downstream_df = process_lock_api_data(requests.get(godstow_downstream_url).json())
osney_upstream_df = process_lock_api_data(requests.get(osney_upstream_url).json())
osney_downstream_df = process_lock_api_data(requests.get(osney_downstream_url).json())
iffley_upstream_df = process_lock_api_data(requests.get(iffley_upstream_url).json())

print("Fetching rainfall data...")
# Get rainfall data
rainfall_api_dfs = {}
for url, name in zip(rainfall_urls, location_names):
    response = requests.get(url)
    data = response.json()
    df = process_rainfall_api_data(data)
    if not df.empty:
        df.rename(columns={'rainfall': name}, inplace=True)
        rainfall_api_dfs[name] = df

# Combine all rainfall API datasets
mega_rainfall_api_df = pd.concat(rainfall_api_dfs.values(), axis=1)
print(f"\nâœ“ Combined rainfall API data: {mega_rainfall_api_df.shape}")
print(f"Date range: {mega_rainfall_api_df.index.min()} to {mega_rainfall_api_df.index.max()}")


Fetching river level data...
Fetching rainfall data...

âœ“ Combined rainfall API data: (2828, 12)
Date range: 2025-12-22 00:00:00+00:00 to 2026-01-20 10:45:00+00:00


## 3. Calculate Differentials from API Data


In [4]:
# Calculate isis and godstow differential
isis_diff_isis_contrib = 0.71 * (osney_downstream_df['level'] - iffley_upstream_df['level'] - 2.14)
isis_diff_cherwell_contrib = 0.29 * (kings_mill_downstream_df['level'] - iffley_upstream_df['level'] - 0.73)
isis_differential_api = isis_diff_isis_contrib + isis_diff_cherwell_contrib

godstow_differential_api = godstow_downstream_df['level'] - osney_upstream_df['level'] - 1.63

# Convert to DataFrames
isis_api_diff_df = pd.DataFrame({'differential': isis_differential_api})
godstow_api_diff_df = pd.DataFrame({'differential': godstow_differential_api})

print(f"âœ“ ISIS differential API data: {isis_api_diff_df.shape}")
print(f"âœ“ Godstow differential API data: {godstow_api_diff_df.shape}")


âœ“ ISIS differential API data: (2828, 1)
âœ“ Godstow differential API data: (2828, 1)


## 4. Load Historical Data


In [5]:
# Load historical rainfall data
rainfall_data_path = './data/rainfall_training_data/'
csv_files = glob.glob(os.path.join(rainfall_data_path, '*.csv'))

rainfall_dfs = {}
print("Loading historical rainfall data...")

for csv_file in csv_files:
    file_name = os.path.splitext(os.path.basename(csv_file))[0]
    
    try:
        df = pd.read_csv(csv_file, dtype=str, on_bad_lines='warn')
    except Exception as e:
        print(f"Error reading {file_name}: {e}")
        continue
    
    if 'dateTime' not in df.columns or 'value' not in df.columns:
        continue
    
    df = df[['dateTime', 'value']]
    df = df.rename(columns={'dateTime': 'timestamp', 'value': f'rainfall_mm_{file_name}'})
    df['timestamp'] = pd.to_datetime(df['timestamp'])
    df[f'rainfall_mm_{file_name}'] = pd.to_numeric(df[f'rainfall_mm_{file_name}'], errors='coerce').fillna(0)
    df = df.set_index('timestamp')
    rainfall_dfs[file_name] = df

mega_rainfall_hist_df = pd.concat(rainfall_dfs.values(), axis=1)
print(f"\nâœ“ Combined historical rainfall: {mega_rainfall_hist_df.shape}")
print(f"Date range: {mega_rainfall_hist_df.index.min()} to {mega_rainfall_hist_df.index.max()}")

# Load historical differential data
print("\nLoading historical differential data...")
godstow_hist_diff_df = load_flag_model_data('data/godstow_flag_model_data.csv', 'jameson_godstow_differential')
isis_hist_diff_df = load_flag_model_data('data/isis_flag_model_data.csv', 'jameson_isis_differential')
print(f"âœ“ ISIS historical differential: {isis_hist_diff_df.shape}")
print(f"âœ“ Godstow historical differential: {godstow_hist_diff_df.shape}")


Loading historical rainfall data...

âœ“ Combined historical rainfall: (314283, 13)
Date range: 2017-02-02 00:00:00 to 2026-01-19 18:15:00

Loading historical differential data...
âœ“ ISIS historical differential: (216385, 1)
âœ“ Godstow historical differential: (196557, 1)


## 5. Data Merging and Cleaning


In [6]:
def merge_and_clean_data(hist_diff_df, mega_rainfall_hist_df, api_diff_df, mega_rainfall_api_df, differential_column):
    """
    Merge flag data with rainfall data and API data, then resample to hourly frequency.
    Cleans spurious rainfall values (negatives, extremes >50mm/h, statistical outliers).
    """
    rainfall_hist_df = mega_rainfall_hist_df.copy()
    
    # Rename rainfall columns for consistency
    new_columns = {}
    for col in rainfall_hist_df.columns:
        if 'mm_' in col and '-' in col:
            parts = col.split('mm_')
            if len(parts) > 1:
                after_mm = parts[1]
                if '-' in after_mm:
                    new_name = after_mm.split('-')[0]
                    new_columns[col] = new_name
    if new_columns:
        rainfall_hist_df = rainfall_hist_df.rename(columns=new_columns)
    
    if rainfall_hist_df.index.tz is None:
        rainfall_hist_df = rainfall_hist_df.tz_localize('UTC')
    
    # Join historical data
    df = hist_diff_df.join(rainfall_hist_df, how='inner')
    
    # Merge with API data (API data takes precedence)
    df = df.combine_first(mega_rainfall_api_df)
    df = df.combine_first(api_diff_df)
    
    # Resample to hourly
    aggregation_rules = {}
    for col in df.columns:
        if col != differential_column:
            aggregation_rules[col] = 'sum'
        else:
            aggregation_rules[col] = 'mean'
    
    df_hourly = df.resample('1H').agg(aggregation_rules)
    df = df_hourly.copy()
    
    # Clean data
    df = df.dropna(subset=[differential_column])
    df.fillna(0, inplace=True)
    
    # --- NEW: Clean spurious rainfall values ---
    rainfall_cols = [col for col in df.columns if col != differential_column]
    print(f"Cleaning rainfall data for {len(rainfall_cols)} stations...")
    
    for col in rainfall_cols:
        # Remove negative rainfall (sensor errors)
        negative_count = (df[col] < 0).sum()
        if negative_count > 0:
            print(f"  {col}: Removing {negative_count} negative values")
            df.loc[df[col] < 0, col] = 0
        
        # Remove physically impossible values (>50mm/hour is extremely rare in UK)
        high_count = (df[col] > 50).sum()
        if high_count > 0:
            print(f"  {col}: Capping {high_count} values >50mm/hour")
            df.loc[df[col] > 50, col] = 0
        
    print("âœ“ Rainfall cleaning complete")
    
    # Remove extreme values in differential
    df = df[(df[differential_column] > -0.1) & (df[differential_column] <= 1.5)].copy()
    
    # Remove spikes: where value changes by >0.5 and then returns to similar value
    diff_series = df[differential_column]
    changes = diff_series.diff()
    next_changes = changes.shift(-1)
    
    # Detect spikes: large change followed by opposite large change
    spike_mask = (np.abs(changes) > 0.5) & (np.abs(next_changes) > 0.5) & (np.sign(changes) != np.sign(next_changes))
    
    # Replace spikes with interpolated values
    if spike_mask.sum() > 0:
        print(f"  Removing {spike_mask.sum()} spike values in differential")
        df.loc[spike_mask, differential_column] = np.nan
        df[differential_column] = df[differential_column].interpolate(method='linear')
    
    # Also remove statistical outliers
    window = 6
    rolling_mean = df[differential_column].rolling(window=window, center=True).mean()
    rolling_std = df[differential_column].rolling(window=window, center=True).std()
    outliers = np.abs(df[differential_column] - rolling_mean) > (3 * rolling_std)
    
    if outliers.sum() > 0:
        df.loc[outliers, differential_column] = df[differential_column].rolling(
            window=window, center=True
        ).median()[outliers]
        df[differential_column] = df[differential_column].rolling(
            window=3, center=True
        ).mean().fillna(df[differential_column])
    
    return df

print("âœ“ Data merging and cleaning function defined")


âœ“ Data merging and cleaning function defined


## 6. Feature Engineering with Future Rainfall Support

**This is the key innovation!** We create features that include future rainfall forecasts.

### ðŸ†• IMPROVED FEATURES (to fix underprediction of rapid rises):
- **Velocity features**: Rate of change at multiple timescales (1h, 3h, 6h, 12h, 24h)
- **Acceleration features**: Second derivative to detect when rise is accelerating
- **Momentum indicator**: Trend strength over 12h window
- **Rising signals**: Binary flags when differential is actively rising
- **Rainfall intensity ratios**: Recent vs historical rainfall comparison
- **Rainfall acceleration**: Is rainfall rate increasing?
- **Near-term future rainfall**: 6h and 12h ahead (critical for rapid response)
- **Flood amplification**: Interaction between recent rainfall and current differential


In [7]:
def create_features_with_future_rainfall(df, future_rainfall_df=None, differential_column='differential'):
    """
    Create features including FUTURE rainfall forecasts for better long-term predictions.
    
    Args:
        df: DataFrame with historical differential and rainfall data (hourly frequency)
        future_rainfall_df: DataFrame with future rainfall forecasts (optional)
        differential_column: Name of the differential column
    
    Returns:
        DataFrame: DataFrame with all engineered features including future rainfall
    """
    df = df.copy()
    
    # --- Historical Rainfall Features ---
    rainfall_station_cols = [col for col in df.columns if col not in ['differential']]
    df[rainfall_station_cols] = df[rainfall_station_cols].fillna(0)
    df['catchment_rainfall_total'] = df[rainfall_station_cols].sum(axis=1)
    
    # ðŸ†• FIX: LOW-FLOW REGIME FEATURES - Detect stable low periods to prevent false upticks
    # These help the model recognize when conditions support staying low
    
    # Hours since significant rainfall (>1mm total)
    significant_rain = (df['catchment_rainfall_total'] > 1.0).astype(int)
    df['hours_since_rain'] = significant_rain.groupby((significant_rain != significant_rain.shift()).cumsum()).cumcount()
    df.loc[significant_rain == 1, 'hours_since_rain'] = 0  # Reset when rain occurs
    # Cap at 720 hours (30 days) to avoid extreme values
    df['hours_since_rain'] = df['hours_since_rain'].clip(upper=720)
    
    # Is currently in a "dry spell"? (no significant rain for 48+ hours)
    df['is_dry_spell'] = (df['hours_since_rain'] > 48).astype(float)
    
    # Is currently in a "low flow" state? (differential < 0.2m)
    df['is_low_flow'] = (df[differential_column] < 0.2).astype(float)
    
    # Combined: dry spell AND low flow - strong stability signal
    df['stable_low_regime'] = df['is_dry_spell'] * df['is_low_flow']
    
    # Recent rainfall trend: is rainfall decreasing? (last 24h vs 24-48h ago)
    df['rainfall_trend_48h'] = (
        df['catchment_rainfall_total'].rolling(window=24).sum() - 
        df['catchment_rainfall_total'].rolling(window=24).sum().shift(24)
    )
    
    # Is drainage ongoing? (differential decreasing while no new rain)
    df['is_draining'] = (
        (df[differential_column].diff(6) < -0.01) &  # Differential decreasing
        (df['catchment_rainfall_total'].rolling(window=6).sum() < 1.0)  # No significant recent rain
    ).astype(float)
    
    # Historical rolling features for soil saturation
    df['rainfall_rolling_24h'] = df['catchment_rainfall_total'].rolling(window=24).sum()
    df['rainfall_rolling_72h'] = df['catchment_rainfall_total'].rolling(window=72).sum()
    df['rainfall_rolling_168h'] = df['catchment_rainfall_total'].rolling(window=168).sum()
    df['rainfall_rolling_720h'] = df['catchment_rainfall_total'].rolling(window=720).sum()
    
    # ðŸ†• NEW: Short-term rainfall intensity (recent vs historical)
    df['rainfall_rolling_6h'] = df['catchment_rainfall_total'].rolling(window=6).sum()
    df['rainfall_rolling_12h'] = df['catchment_rainfall_total'].rolling(window=12).sum()
    df['rainfall_intensity_ratio_6_24'] = df['rainfall_rolling_6h'] / (df['rainfall_rolling_24h'] + 0.01)
    df['rainfall_intensity_ratio_24_168'] = df['rainfall_rolling_24h'] / (df['rainfall_rolling_168h'] + 0.01)
    
    # ðŸ†• NEW: Rainfall acceleration (is rainfall increasing?)
    df['rainfall_accel_1h'] = df['catchment_rainfall_total'].rolling(window=1).sum().diff(1)
    df['rainfall_accel_3h'] = df['catchment_rainfall_total'].rolling(window=3).sum().diff(3)
    df['rainfall_accel_6h'] = df['catchment_rainfall_total'].rolling(window=6).sum().diff(6)
    df['rainfall_accel_12h'] = df['catchment_rainfall_total'].rolling(window=12).sum().diff(12)
    df['rainfall_accel_24h'] = df['catchment_rainfall_total'].rolling(window=24).sum().diff(24)
    df['rainfall_accel_48h'] = df['catchment_rainfall_total'].rolling(window=48).sum().diff(48)
    
    # --- NEW: Future Rainfall Features ---
    if future_rainfall_df is not None:
        # Ensure timezone alignment
        if future_rainfall_df.index.tz is None and df.index.tz is not None:
            future_rainfall_df = future_rainfall_df.tz_localize(df.index.tz)
        elif future_rainfall_df.index.tz is not None and df.index.tz is None:
            df.index = df.index.tz_localize(future_rainfall_df.index.tz)
        
        # Align future rainfall with df index
        future_rainfall_aligned = future_rainfall_df.reindex(df.index)
        
        # Create forward-looking rainfall totals for different horizons
        catchment_future = future_rainfall_aligned.sum(axis=1)
        
        # ðŸ†• IMPROVED: Catchment response lag - rain takes time to affect river levels
        # Typical catchment lag is 12-24 hours for this river system
        CATCHMENT_LAG = 18  # hours - rain at time T affects river at T+18h
        
        # Rolling FORWARD windows WITH CATCHMENT LAG
        # These represent rainfall that will have had time to reach the river
        # For horizon H, we look at rainfall from (H - CATCHMENT_LAG) to H
        df['rainfall_future_24h'] = catchment_future.rolling(window=24).sum().shift(-24 - CATCHMENT_LAG)
        df['rainfall_future_48h'] = catchment_future.rolling(window=48).sum().shift(-48 - CATCHMENT_LAG)
        df['rainfall_future_72h'] = catchment_future.rolling(window=72).sum().shift(-72 - CATCHMENT_LAG)
        df['rainfall_future_120h'] = catchment_future.rolling(window=120).sum().shift(-120)
        df['rainfall_future_240h'] = catchment_future.rolling(window=240).sum().shift(-240)
        
        # ðŸ†• IMPROVED: Near-term future rainfall WITH LAG
        # Rain in next 6-12h won't affect river yet - need to account for catchment travel time
        df['rainfall_future_6h'] = catchment_future.rolling(window=6).sum().shift(-6 - CATCHMENT_LAG)
        df['rainfall_future_12h'] = catchment_future.rolling(window=12).sum().shift(-12 - CATCHMENT_LAG)
        
        # ðŸ†• IMPROVED: Horizon-aligned future rainfall WITH LAG
        # For each prediction horizon, show rainfall that will have arrived by then
        # For 0-24h: every 2 hours (with lag, so rain that fell CATCHMENT_LAG hours before)
        for h in [2, 4, 8, 10, 14, 16, 18, 20, 22]:
            # For short horizons < CATCHMENT_LAG, use historical/recent rain instead
            effective_shift = max(h, CATCHMENT_LAG) 
            df[f'rainfall_future_{h}h'] = catchment_future.rolling(window=h).sum().shift(-effective_shift)
        # For 24-48h: every 6 hours
        for h in [30, 36, 42]:
            df[f'rainfall_future_{h}h'] = catchment_future.rolling(window=h).sum().shift(-h)
        
        # ðŸ†• FIX: FUTURE DRY FEATURES - Detect when future has no significant rain
        # These help prevent false upticks when no rain is expected
        
        # Is future dry? (less than 5mm total over entire forecast)
        df['future_is_dry_240h'] = (df['rainfall_future_240h'] < 5.0).astype(float)
        df['future_is_dry_120h'] = (df['rainfall_future_120h'] < 3.0).astype(float)
        df['future_is_dry_72h'] = (df['rainfall_future_72h'] < 2.0).astype(float)
        df['future_is_dry_48h'] = (df['rainfall_future_48h'] < 1.0).astype(float)
        
        # Stability signal: currently low + future dry = should stay low
        df['expect_stable_low'] = df['stable_low_regime'] * df['future_is_dry_240h']
        
        # Future rainfall intensity (mm per hour on average)
        df['future_rainfall_intensity_240h'] = df['rainfall_future_240h'] / 240
        df['future_rainfall_intensity_120h'] = df['rainfall_future_120h'] / 120
        df['future_rainfall_intensity_72h'] = df['rainfall_future_72h'] / 72
    
    # --- River Differential Features (historical lags) ---
    for i in [1, 2, 3, 6, 12, 24, 48]:
        df[f'differential_lag_{i}h'] = df[differential_column].shift(i)
    
    df['differential_rolling_mean_6h'] = df[differential_column].rolling(window=6).mean()
    df['differential_rolling_std_6h'] = df[differential_column].rolling(window=6).std()
    df['differential_rolling_mean_24h'] = df[differential_column].rolling(window=24).mean()
    df['differential_rolling_std_24h'] = df[differential_column].rolling(window=24).std()
    df['differential_roc_6h'] = df[differential_column].diff(periods=6)
    df['differential_ewma_6h'] = df[differential_column].ewm(span=6, adjust=False).mean()
    
    # ðŸ†• NEW: Rate of Change Features (velocity) - CRITICAL for detecting rising limbs
    for i in [1, 3, 6, 12, 24]:
        df[f'differential_velocity_{i}h'] = df[differential_column].diff(periods=i) / i
    
    # ðŸ†• NEW: Acceleration Features (second derivative) - Detect when rise is accelerating
    for i in [3, 6, 12, 24]:
        velocity = df[differential_column].diff(periods=i)
        df[f'differential_acceleration_{i}h'] = velocity.diff(periods=i) / i
    
    # ðŸ†• NEW: Momentum indicator (trend strength)
    df['differential_momentum_12h'] = df[differential_column].rolling(window=12).apply(
        lambda x: (x.iloc[-1] - x.iloc[0]) / (x.std() + 0.001) if len(x) > 0 else 0,
        raw=False
    )
    
    # ðŸ†• NEW: Is differential rising? (binary signal)
    df['is_rising_6h'] = (df['differential_velocity_6h'] > 0.01).astype(float)
    df['is_rising_24h'] = (df['differential_velocity_24h'] > 0.01).astype(float)
    
    # Interaction features
    df['rainfall_last_hour'] = df['catchment_rainfall_total']
    df['rainfall_interaction_720h'] = df['rainfall_last_hour'] * df['rainfall_rolling_720h']
    
    # ðŸ†• NEW: Critical interaction - recent rainfall * current differential (flood amplification)
    df['flood_amplification_24h'] = df['rainfall_rolling_24h'] * df[differential_column]
    df['flood_amplification_72h'] = df['rainfall_rolling_72h'] * df[differential_column]
    df['flood_amplification_168h'] = df['rainfall_rolling_168h'] * df[differential_column]
    
    # ðŸ†• IMPROVED: Future rainfall * current differential (predictive flood risk)
    # These now incorporate catchment lag from the rainfall features
    df['future_flood_risk_24h'] = df['rainfall_future_24h'] * df[differential_column]
    df['future_flood_risk_48h'] = df['rainfall_future_48h'] * df[differential_column]
    df['future_flood_risk_72h'] = df['rainfall_future_72h'] * df[differential_column]
    
    # ðŸ†• NEW: Lagged rainfall impact - emphasize recent rain that's still in transit
    # Rain from 12-24h ago is currently flowing through catchment
    df['rainfall_in_transit_12_24h'] = df['catchment_rainfall_total'].rolling(window=12).sum().shift(12)
    df['rainfall_in_transit_6_18h'] = df['catchment_rainfall_total'].rolling(window=12).sum().shift(6)
    
    # --- ðŸ†• IMPROVED: Multi-day river state context (helps shape and timing) ---
    # 3-day (72h) behaviour
    df['differential_rolling_mean_72h'] = df[differential_column].rolling(window=72).mean()
    df['differential_rolling_max_72h'] = df[differential_column].rolling(window=72).max()
    df['differential_rolling_min_72h'] = df[differential_column].rolling(window=72).min()
    df['differential_range_72h'] = (
        df['differential_rolling_max_72h'] - df['differential_rolling_min_72h']
    )
    
    # 7-day (168h) behaviour
    df['differential_rolling_mean_168h'] = df[differential_column].rolling(window=168).mean()
    df['differential_rolling_max_168h'] = df[differential_column].rolling(window=168).max()
    df['differential_rolling_min_168h'] = df[differential_column].rolling(window=168).min()
    df['differential_range_168h'] = (
        df['differential_rolling_max_168h'] - df['differential_rolling_min_168h']
    )
    
    # Antecedent wetness index: blends short, medium, long rainfall
    df['antecedent_wetness'] = (
        0.5 * df['rainfall_rolling_24h']
        + 0.3 * df['rainfall_rolling_72h']
        + 0.2 * df['rainfall_rolling_168h']
    )
    
    # Fast vs slow runoff proxies: exponential rainfall memory
    for span, name in [(6, 'fast'), (24, 'medium'), (72, 'slow')]:
        df[f'catchment_rainfall_exp_{name}'] = (
            df['catchment_rainfall_total']
            .ewm(span=span, adjust=False)
            .mean()
        )
    
    # --- Cyclical Features ---
    df['day_of_year'] = df.index.dayofyear
    df['day_of_year_sin'] = np.sin(2 * np.pi * df['day_of_year']/365.25)
    df['day_of_year_cos'] = np.cos(2 * np.pi * df['day_of_year']/365.25)
    
    return df

print("âœ“ Feature engineering function defined")


âœ“ Feature engineering function defined


## 7. Multi-Horizon LSTM Model

The model predicts all time horizons simultaneously (24h, 48h, ..., 240h) from a single input sequence.


In [8]:
class MultiHorizonLSTMModel(nn.Module):
    def __init__(self, input_size, hidden_sizes=[128, 64], dropout_rate=0.2, n_horizons=10):
        """
        LSTM that predicts multiple future time horizons simultaneously.
        
        Args:
            input_size: Number of input features
            hidden_sizes: List of hidden layer sizes
            dropout_rate: Dropout rate
            n_horizons: Number of future horizons to predict
        """
        super(MultiHorizonLSTMModel, self).__init__()
        
        self.hidden_sizes = hidden_sizes
        self.num_layers = len(hidden_sizes)
        self.n_horizons = n_horizons
        
        # Create LSTM layers
        self.lstm_layers = nn.ModuleList()
        self.dropout_layers = nn.ModuleList()
        
        for i, hidden_size in enumerate(hidden_sizes):
            input_dim = input_size if i == 0 else hidden_sizes[i-1]
            self.lstm_layers.append(
                nn.LSTM(input_dim, hidden_size, batch_first=True)
            )
            self.dropout_layers.append(nn.Dropout(dropout_rate))
        
        # Multiple output heads for different time horizons
        self.fc_layers = nn.ModuleList([
            nn.Linear(hidden_sizes[-1], 1) for _ in range(n_horizons)
        ])
    
    def forward(self, x, debug=False):
        # x shape: (batch_size, sequence_length, input_size)
        
        if debug:
            print(f"[FORWARD] Input shape: {x.shape}")
            import sys
            sys.stdout.flush()
        
        for i, (lstm, dropout) in enumerate(zip(self.lstm_layers, self.dropout_layers)):
            if debug:
                print(f"[FORWARD] Running LSTM layer {i+1}/{len(self.lstm_layers)}...")
                sys.stdout.flush()
            x, (h_n, c_n) = lstm(x)
            if debug:
                print(f"[FORWARD] âœ“ LSTM layer {i+1} done, shape: {x.shape}")
                sys.stdout.flush()
            x = dropout(x)
        
        if debug:
            print(f"[FORWARD] Taking last timestep...")
            sys.stdout.flush()
        # Take the output from the last time step
        x = x[:, -1, :]
        if debug:
            print(f"[FORWARD] âœ“ Last timestep extracted, shape: {x.shape}")
            sys.stdout.flush()
        
        # Predict multiple horizons
        if debug:
            print(f"[FORWARD] Running {len(self.fc_layers)} output heads...")
            sys.stdout.flush()
        predictions = []
        for fc in self.fc_layers:
            predictions.append(fc(x))
        
        if debug:
            print(f"[FORWARD] âœ“ All output heads done")
            sys.stdout.flush()
        
        # Stack predictions: (batch_size, n_horizons)
        predictions = torch.cat(predictions, dim=1)
        
        if debug:
            print(f"[FORWARD] âœ“ Predictions stacked, final shape: {predictions.shape}")
            sys.stdout.flush()
        
        return predictions

print("âœ“ Multi-Horizon LSTM Model defined")


âœ“ Multi-Horizon LSTM Model defined


## 8. Training Functions for Multi-Horizon Model

### ðŸ†• IMPROVED TRAINING (weighted loss function):
- **3x weight** for high-flow events (>0.3m) - forces model to pay attention to floods
- **1.5x additional weight** for rising events (>5cm increase) - prioritizes rapid rises
- This addresses the systematic underprediction of flood peaks!


In [9]:
def create_sequences_lstm_multihorizon(X, y_multi, sequence_length=24):
    """
    Create sequences for multi-horizon LSTM training - OPTIMIZED VERSION.
    
    Args:
        X: Feature DataFrame
        y_multi: Target DataFrame with multiple horizons (columns)
        sequence_length: Number of time steps to look back
    
    Returns:
        tuple: (X_sequences, y_sequences, valid_indices)
    """
    # Convert to numpy arrays ONCE (much faster than repeated .iloc calls)
    X_values = X.values
    y_values = y_multi.values
    
    # Pre-allocate arrays (much faster than appending to lists)
    n_samples = len(X) - sequence_length
    n_features = X_values.shape[1]
    n_horizons = y_values.shape[1]
    
    X_sequences = np.zeros((n_samples, sequence_length, n_features), dtype=np.float32)
    y_sequences = np.zeros((n_samples, n_horizons), dtype=np.float32)
    valid_indices = []
    
    # Vectorized sequence creation
    for i in range(n_samples):
        X_sequences[i] = X_values[i:i+sequence_length]
        y_sequences[i] = y_values[i+sequence_length]
        valid_indices.append(X.index[i+sequence_length])
    
    return X_sequences, y_sequences, valid_indices

def create_target_and_X_y_multihorizon(df_featureless, future_rainfall_df, differential_column='differential', 
                                       horizons=[24, 48, 72, 96, 120, 144, 168, 192, 216, 240]):
    """
    Create multi-horizon targets and feature set.
    
    Args:
        df_featureless: DataFrame with differential and rainfall (no features yet)
        future_rainfall_df: DataFrame with future rainfall forecasts
        differential_column: Name of the differential column to predict
        horizons: List of hours ahead to predict
    
    Returns:
        tuple: (X, y_multi, mask, horizons)
    """
    print(f"\nCreating features with future rainfall...")
    
    # Create features WITH future rainfall
    df_with_features = create_features_with_future_rainfall(df_featureless, future_rainfall_df, differential_column)
    
    print(f"Creating targets for {len(horizons)} horizons...")
    
    # Create multiple target variables (one for each horizon)
    targets = []
    for horizon in horizons:
        target_col = f'target_{horizon}h'
        df_with_features[target_col] = df_with_features[differential_column].shift(-horizon)
        targets.append(target_col)
    
    # Define features (exclude targets and differential)
    exclude_cols = targets + [differential_column] + [col for col in df_with_features.columns if 'target_' in col]
    features = [col for col in df_with_features.columns if col not in exclude_cols]
    
    X = df_with_features[features]
    y_multi = df_with_features[targets]
    
    # Remove rows where ANY target is NaN
    mask = ~y_multi.isna().any(axis=1)
    X = X[mask]
    y_multi = y_multi[mask]
    
    print(f"\n{'='*70}")
    print(f"Multi-Horizon Model Setup:")
    print(f"{'='*70}")
    print(f"Number of features: {len(features)}")
    print(f"Number of horizons: {len(horizons)}")
    print(f"Horizons: {horizons}")
    print(f"Training samples: {len(X)}")
    print(f"Date range: {X.index.min()} to {X.index.max()}")
    print(f"{'='*70}")
    
    return X, y_multi, mask, horizons

print("âœ“ Sequence creation functions defined")


âœ“ Sequence creation functions defined


In [10]:
def train_multihorizon_model(X, y_multi, horizons, sequence_length=24, epochs=50, batch_size=32, 
                            validation_split=0.2, learning_rate=0.0001, patience=10, max_grad_norm=1.0,
                            hidden_sizes=[192, 128, 64], dropout_rate=0.3):
    """
    Train multi-horizon PyTorch LSTM model.
    
    Args:
        hidden_sizes: List of hidden layer sizes for the LSTM stack (default: [192, 128, 64])
        dropout_rate: Dropout rate between layers (default: 0.3)
    
    Returns:
        tuple: (model, scaler, history, sequence_length, horizons)
    """
    # Scale the features
    scaler = MinMaxScaler()
    X_scaled = pd.DataFrame(
        scaler.fit_transform(X),
        columns=X.columns,
        index=X.index
    )
    
    # Create sequences
    X_seq, y_seq, _ = create_sequences_lstm_multihorizon(X_scaled, y_multi, sequence_length)
    
    print(f"\nSequence shape: {X_seq.shape}")
    print(f"Target shape: {y_seq.shape}")
    
    # Split into train and validation (time series split)
    n_train = int(len(X_seq) * (1 - validation_split))
    X_train, X_val = X_seq[:n_train], X_seq[n_train:]
    y_train, y_val = y_seq[:n_train], y_seq[n_train:]
    
    # Convert to PyTorch tensors - KEEP ON CPU for DataLoader
    # This fixes MPS hanging issue with DataLoader + shuffle
    X_train_tensor = torch.FloatTensor(X_train)
    y_train_tensor = torch.FloatTensor(y_train)
    X_val_tensor = torch.FloatTensor(X_val).to(device)
    y_val_tensor = torch.FloatTensor(y_val).to(device)
    
    # Create data loaders - data stays on CPU during loading
    train_dataset = TensorDataset(X_train_tensor, y_train_tensor)
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, 
                             num_workers=0, pin_memory=False)
    
    # Initialize model with configurable architecture
    n_features = X_seq.shape[2]
    n_horizons = len(horizons)
    model = MultiHorizonLSTMModel(input_size=n_features, hidden_sizes=hidden_sizes, 
                                  dropout_rate=dropout_rate, n_horizons=n_horizons)
    model = model.to(device)
    
    # ðŸ†• Horizon importance: emphasise near-term horizons in the loss
    # 0-24h weighted 2.0x, 24-72h weighted 1.5x, 72-240h weighted 1.0x
    horizon_importance = []
    for h in horizons:
        if h <= 24:
            horizon_importance.append(2.0)   # strongest weight for short term
        elif h <= 72:
            horizon_importance.append(1.5)   # medium weight
        else:
            horizon_importance.append(1.0)   # baseline
    horizon_importance = torch.tensor(horizon_importance, device=device).view(1, -1)
    print(f"Horizon importance weights: {horizon_importance.squeeze().tolist()}")
    
    print("\nModel Architecture:")
    print(model)
    
    print(f"\n{'='*70}")
    print(f"Training Configuration:")
    print(f"{'='*70}")
    print(f"Device: {device}")
    print(f"Training samples: {len(X_train)}")
    print(f"Validation samples: {len(X_val)}")
    print(f"Batch size: {batch_size}")
    print(f"Epochs: {epochs}")
    print(f"{'='*70}")
    
    import sys
    sys.stdout.flush()
    
    print("\n[DEBUG] Creating loss criterion...")
    sys.stdout.flush()
    # Loss and optimizer
    criterion = nn.MSELoss()
    print("[DEBUG] âœ“ Loss criterion created")
    sys.stdout.flush()
    
    print("[DEBUG] Creating optimizer...")
    sys.stdout.flush()
    optimizer = optim.Adam(model.parameters(), lr=learning_rate)
    print("[DEBUG] âœ“ Optimizer created")
    sys.stdout.flush()
    
    print("[DEBUG] Creating learning rate scheduler...")
    sys.stdout.flush()
    # Learning rate scheduler
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(
        optimizer, mode='min', factor=0.5, patience=5, verbose=False
    )
    print("[DEBUG] âœ“ Scheduler created")
    sys.stdout.flush()
    
    print("[DEBUG] Initializing training variables...")
    sys.stdout.flush()
    # Training history
    history = {
        'train_loss': [],
        'val_loss': [],
        'train_mae': [],
        'val_mae': []
    }
    
    best_val_loss = float('inf')
    patience_counter = 0
    best_model_state = None
    print("[DEBUG] âœ“ Variables initialized")
    sys.stdout.flush()
    
    print("\n" + "="*70)
    print("STARTING TRAINING")
    print("="*70)
    sys.stdout.flush()
    
    # Test first batch to catch errors early
    print("\n[DEBUG] Getting first batch from DataLoader...")
    sys.stdout.flush()
    test_batch = next(iter(train_loader))
    print(f"[DEBUG] âœ“ Got first batch")
    sys.stdout.flush()
    
    print(f"[DEBUG] Moving batch to device ({device})...")
    sys.stdout.flush()
    test_X, test_y = test_batch[0].to(device), test_batch[1].to(device)
    print(f"[DEBUG] âœ“ Batch moved to {device}")
    sys.stdout.flush()
    
    print("[DEBUG] Setting model to training mode...")
    sys.stdout.flush()
    model.train()
    print("[DEBUG] âœ“ Model in training mode")
    sys.stdout.flush()
    
    print("[DEBUG] Running forward pass...")
    sys.stdout.flush()
    test_output = model(test_X)
    print(f"[DEBUG] âœ“ Forward pass complete, output shape: {test_output.shape}")
    sys.stdout.flush()
    
    print("[DEBUG] Computing loss...")
    sys.stdout.flush()
    test_loss = criterion(test_output, test_y)
    print(f"[DEBUG] âœ“ Loss computed: {test_loss.item():.6f}")
    sys.stdout.flush()
    
    print("\n" + "="*70)
    print(f"âœ“ FIRST BATCH TEST PASSED! Loss: {test_loss.item():.6f}")
    print("="*70)
    sys.stdout.flush()
    
    import time
    print("\n[DEBUG] Starting epoch loop...")
    sys.stdout.flush()
    
    for epoch in range(epochs):
        print(f"\n{'='*70}")
        print(f"EPOCH {epoch+1}/{epochs}")
        print(f"{'='*70}")
        sys.stdout.flush()
        epoch_start_time = time.time()
        
        # Training phase
        model.train()
        train_losses = []
        train_maes = []
        
        # Progress bar
        pbar = tqdm(train_loader, desc=f'Epoch {epoch+1}/{epochs}', leave=True)
        
        for batch_X, batch_y in pbar:
            # Move batch to device (MPS/CUDA/CPU)
            batch_X = batch_X.to(device)
            batch_y = batch_y.to(device)
            
            # Forward pass
            outputs = model(batch_X)
            
            # ðŸ†• NEW: Weighted loss - penalize errors during high-flow events more
            # Weight high-flow predictions (>0.3m) 3x more than low-flow
            weights = torch.where(batch_y > 0.3, 3.0, 1.0)
            # Also weight rising events (where any horizon shows increase)
            rising_mask = (batch_y > batch_y[:, 0:1] + 0.05)  # If rising >5cm
            weights = torch.where(rising_mask, weights * 1.5, weights)
            
            # ðŸ†• FIX: Reduce weight for low-flow predictions at far horizons
            # This prevents the model from being penalized for staying near current low value
            low_flow_mask = (batch_y[:, 0:1] < 0.2)  # Currently low
            stable_mask = (batch_y[:, -1:] < 0.25)  # Stays low at end
            low_stable_mask = low_flow_mask & stable_mask
            # For low-stable cases, don't up-weight far horizons
            weights = torch.where(
                low_stable_mask.expand_as(weights),
                torch.clamp(weights, max=1.5),  # Cap weights for stable low periods
                weights
            )
            
            # Calculate weighted MSE loss with horizon importance
            mse_loss = (outputs - batch_y) ** 2
            # Combine event/level weights + horizon importance
            weighted_loss = (mse_loss * weights * horizon_importance).mean()
            
            # ðŸ†• FIX: Early rising event detection loss - penalize missing early signs of rises
            # This addresses the lag issue by forcing the model to detect rising events earlier
            early_rise_loss = 0.0
            if outputs.shape[1] >= 2:
                # Check if actual is rising (comparing first horizon to later horizons)
                actual_rising = (batch_y[:, -1] > batch_y[:, 0] + 0.02)  # Rising >2cm over forecast
                # Check if prediction is rising
                pred_rising = (outputs[:, -1] > outputs[:, 0] + 0.02)
                
                # Penalize when actual is rising but prediction isn't (missed early detection)
                missed_rise_mask = actual_rising & ~pred_rising
                if missed_rise_mask.any():
                    # Calculate how much the prediction missed the early rise
                    actual_rise_magnitude = batch_y[:, -1] - batch_y[:, 0]
                    pred_rise_magnitude = outputs[:, -1] - outputs[:, 0]
                    rise_error = torch.clamp(actual_rise_magnitude - pred_rise_magnitude, min=0)
                    early_rise_loss = (rise_error * missed_rise_mask.float()).mean()
            
            # (Old anchor_loss removed - replaced by extended_anchor_loss below)
            
            # ðŸ†• NEW: Continuity loss - first prediction must be close to current value
            # This ensures smooth transition from observed to predicted
            continuity_loss = torch.mean((outputs[:, 0] - batch_y[:, 0]) ** 2)
            
            # ðŸ†• FIX: PERSISTENCE LOSS - Penalize predicting upticks when actual stays low
            # This is the KEY fix for the mean-reversion bias
            persistence_loss = 0.0
            current_val = batch_y[:, 0]  # Current differential value
            
            # Identify stable-low cases: currently low AND ends low
            is_stable_low = (current_val < 0.2) & (batch_y[:, -1] < 0.25)
            
            if is_stable_low.any():
                # For stable-low cases, penalize predictions that diverge upward
                for i in range(outputs.shape[1]):
                    # How much does prediction rise above current?
                    pred_rise = torch.clamp(outputs[:, i] - current_val - 0.02, min=0)  # 2cm tolerance
                    # Only apply to stable-low cases
                    persistence_loss += (pred_rise * is_stable_low.float()).mean()
                persistence_loss = persistence_loss / outputs.shape[1]
            
            # ðŸ†• FIX: EXTENDED ANCHOR LOSS - Apply to ALL horizons, not just first 6
            # For low-flow cases, anchor predictions more strongly to current value
            extended_anchor_loss = 0.0
            for i in range(outputs.shape[1]):
                # Weight decays with horizon, but doesn't disappear entirely
                weight = 0.3 / (1 + i * 0.1)  # Slower decay than before
                # Stronger anchoring for low-flow cases
                anchor_strength = torch.where(current_val < 0.2, 2.0, 1.0)
                extended_anchor_loss += weight * torch.mean(
                    anchor_strength * (outputs[:, i] - current_val) ** 2
                )
            extended_anchor_loss = extended_anchor_loss / outputs.shape[1]
            
            # Combine losses: main loss + early rise detection + anchor + continuity + persistence
            # ðŸ†• IMPROVED: Adjusted weights to fix mean-reversion bias
            # - Moderate early_rise (0.3) to catch peaks
            # - Extended anchor (0.1) applied across all horizons
            # - Strong persistence (0.5) to prevent false upticks
            # - Moderate continuity (1.0) for smooth join at t0
            loss = (weighted_loss + 
                    0.3 * early_rise_loss + 
                    0.1 * extended_anchor_loss + 
                    0.5 * persistence_loss +
                    1.0 * continuity_loss)
            
            # Check for NaN
            if torch.isnan(loss):
                print(f"\nWarning: NaN loss detected at epoch {epoch+1}")
                raise ValueError("NaN loss encountered")
            
            # Backward pass
            optimizer.zero_grad()
            loss.backward()
            
            # Gradient clipping
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm)
            
            optimizer.step()
            
            train_losses.append(loss.item())
            train_maes.append(torch.mean(torch.abs(outputs - batch_y)).item())
            
            # Update progress bar
            pbar.set_postfix({'loss': f'{loss.item():.4f}', 'mae': f'{torch.mean(torch.abs(outputs - batch_y)).item():.4f}'})
        
        # Validation phase
        model.eval()
        with torch.no_grad():
            val_outputs = model(X_val_tensor)
            val_loss = criterion(val_outputs, y_val_tensor)
            val_mae = torch.mean(torch.abs(val_outputs - y_val_tensor))
        
        # Record history
        avg_train_loss = np.mean(train_losses)
        avg_train_mae = np.mean(train_maes)
        history['train_loss'].append(avg_train_loss)
        history['val_loss'].append(val_loss.item())
        history['train_mae'].append(avg_train_mae)
        history['val_mae'].append(val_mae.item())
        
        # Learning rate scheduling
        scheduler.step(val_loss)
        
        # Early stopping
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            best_model_state = model.state_dict().copy()
            patience_counter = 0
        else:
            patience_counter += 1
        
        # Print progress
        epoch_time = time.time() - epoch_start_time
        if (epoch + 1) % 5 == 0 or epoch == 0:
            print(f"Epoch [{epoch+1}/{epochs}] ({epoch_time:.1f}s) - "
                  f"Train Loss: {avg_train_loss:.6f}, Val Loss: {val_loss.item():.6f}, "
                  f"Train MAE: {avg_train_mae:.6f}, Val MAE: {val_mae.item():.6f}")
        
        if patience_counter >= patience:
            print(f"\nEarly stopping triggered at epoch {epoch+1}")
            break
    
    # Restore best model
    if best_model_state is not None:
        model.load_state_dict(best_model_state)
        print("\nâœ“ Restored best model weights")
    
    print("âœ“ Model training complete!")
    
    return model, scaler, history, sequence_length, horizons

print("âœ“ Training function defined")


âœ“ Training function defined


## 10. Prepare Training Data


In [11]:
## 10. Prepare Training Data
stretch = "Isis"
# Merge and clean ISIS data


print("Preparing ISIS training data...")
isis_df_featureless = merge_and_clean_data(
    isis_hist_diff_df, 
    mega_rainfall_hist_df, 
    isis_api_diff_df, 
    mega_rainfall_api_df, 
    'differential'
)

isis_df_featureless = merge_and_clean_data(
    isis_hist_diff_df, 
    mega_rainfall_hist_df, 
    isis_api_diff_df, 
    mega_rainfall_api_df, 
    'differential'
)

print(f"\nâœ“ ISIS data prepared: {isis_df_featureless.shape}")
print(f"Date range: {isis_df_featureless.index.min()} to {isis_df_featureless.index.max()}")

# Define prediction horizons:
# - Every 2 hours for next 24 hours
# - Every 6 hours for 24-48 hours  
# - Every 24 hours (daily) for 48-240 hours
horizons = [2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24,  # 0-24h: every 2h
            30, 36, 42, 48,                               # 24-48h: every 6h
            72, 96, 120, 144, 168, 192, 216, 240]         # 48-240h: every 24h (daily)

# For TRAINING: Use the historical rainfall data itself as "future" rainfall
# Extract just the rainfall columns from the historical data
rainfall_cols = [col for col in isis_df_featureless.columns if col != 'differential']
historical_rainfall_for_training = isis_df_featureless[rainfall_cols].copy()

# Create multi-horizon targets and features
X_isis, y_isis_multi, mask_isis, horizons = create_target_and_X_y_multihorizon(
    isis_df_featureless,
    historical_rainfall_for_training,  # Use historical data, not forecast!
    differential_column='differential',
    horizons=horizons
)

# Check for data issues
print("\nChecking for data issues...")
print(f"X_isis contains NaN: {X_isis.isna().any().any()}")
print(f"X_isis contains Inf: {np.isinf(X_isis.values).any()}")
print(f"y_isis_multi contains NaN: {y_isis_multi.isna().any().any()}")

# Clean if needed
if X_isis.isna().any().any() or np.isinf(X_isis.values).any():
    print("Cleaning X_isis...")
    X_isis = X_isis.replace([np.inf, -np.inf], np.nan)
    X_isis = X_isis.ffill().bfill().fillna(0)
    print("âœ“ Cleaned")

if y_isis_multi.isna().any().any():
    print("Cleaning y_isis_multi...")
    y_isis_multi = y_isis_multi.ffill().bfill().fillna(0)
    print("âœ“ Cleaned")

print("\nâœ“ Data ready for training!")

Preparing ISIS training data...


  df_hourly = df.resample('1H').agg(aggregation_rules)
  df_hourly = df.resample('1H').agg(aggregation_rules)


Cleaning rainfall data for 13 stations...
  Swindon: Capping 1 values >50mm/hour
âœ“ Rainfall cleaning complete
  Removing 45 spike values in differential
Cleaning rainfall data for 13 stations...
  Swindon: Capping 1 values >50mm/hour
âœ“ Rainfall cleaning complete
  Removing 45 spike values in differential

âœ“ ISIS data prepared: (67315, 14)
Date range: 2017-02-04 12:00:00+00:00 to 2026-01-20 10:00:00+00:00

Creating features with future rainfall...


  df['differential_range_72h'] = (
  df['differential_rolling_mean_168h'] = df[differential_column].rolling(window=168).mean()
  df['differential_rolling_max_168h'] = df[differential_column].rolling(window=168).max()
  df['differential_rolling_min_168h'] = df[differential_column].rolling(window=168).min()
  df['differential_range_168h'] = (
  df['antecedent_wetness'] = (
  df[f'catchment_rainfall_exp_{name}'] = (
  df[f'catchment_rainfall_exp_{name}'] = (
  df[f'catchment_rainfall_exp_{name}'] = (
  df['day_of_year'] = df.index.dayofyear
  df['day_of_year_sin'] = np.sin(2 * np.pi * df['day_of_year']/365.25)
  df['day_of_year_cos'] = np.cos(2 * np.pi * df['day_of_year']/365.25)
  df_with_features[target_col] = df_with_features[differential_column].shift(-horizon)
  df_with_features[target_col] = df_with_features[differential_column].shift(-horizon)
  df_with_features[target_col] = df_with_features[differential_column].shift(-horizon)
  df_with_features[target_col] = df_with_features[dif

Creating targets for 24 horizons...

Multi-Horizon Model Setup:
Number of features: 111
Number of horizons: 24
Horizons: [2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24, 30, 36, 42, 48, 72, 96, 120, 144, 168, 192, 216, 240]
Training samples: 67075
Date range: 2017-02-04 12:00:00+00:00 to 2026-01-10 10:00:00+00:00

Checking for data issues...
X_isis contains NaN: True
X_isis contains Inf: False
y_isis_multi contains NaN: False
Cleaning X_isis...
âœ“ Cleaned

âœ“ Data ready for training!


In [14]:
X_isis

Unnamed: 0_level_0,Bicester,Bourton,Byfield,Chipping,Eynsham,Grimsbury,Osney,Rapsgate,Shorncote,St,...,differential_rolling_max_168h,differential_rolling_min_168h,differential_range_168h,antecedent_wetness,catchment_rainfall_exp_fast,catchment_rainfall_exp_medium,catchment_rainfall_exp_slow,day_of_year,day_of_year_sin,day_of_year_cos
timestamp,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
2017-02-04 12:00:00+00:00,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.603580,0.156300,0.447280,7.755,0.000000,0.000000,0.000000,35,0.566362,0.824157
2017-02-04 13:00:00+00:00,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.603580,0.156300,0.447280,7.755,0.000000,0.000000,0.000000,35,0.566362,0.824157
2017-02-04 14:00:00+00:00,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.603580,0.156300,0.447280,7.755,0.000000,0.000000,0.000000,35,0.566362,0.824157
2017-02-04 15:00:00+00:00,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.603580,0.156300,0.447280,7.755,0.000000,0.000000,0.000000,35,0.566362,0.824157
2017-02-04 16:00:00+00:00,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.603580,0.156300,0.447280,7.755,0.000000,0.000000,0.000000,35,0.566362,0.824157
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
2026-01-10 06:00:00+00:00,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.426772,0.142698,0.284075,149.650,0.007688,1.419657,2.956617,10,0.171177,0.985240
2026-01-10 07:00:00+00:00,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.426772,0.142698,0.284075,144.776,0.005491,1.306084,2.875614,10,0.171177,0.985240
2026-01-10 08:00:00+00:00,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.426772,0.142698,0.284075,142.266,0.003922,1.201597,2.796830,10,0.171177,0.985240
2026-01-10 09:00:00+00:00,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.426772,0.142698,0.284075,141.001,0.002802,1.105470,2.720204,10,0.171177,0.985240


## 11. Train Multi-Horizon Model

This will take a few minutes depending on your hardware (faster on GPU/MPS).


## ðŸš€ READY TO RETRAIN WITH IMPROVEMENTS!

The model has been enhanced with:

### ðŸ†• FIX: Low-Flow Stability Features (prevents false upticks)
- **hours_since_rain**: Time since last significant rainfall
- **is_dry_spell**: Binary flag for 48+ hours without rain
- **stable_low_regime**: Combined low-flow + dry-spell indicator
- **future_is_dry_XXh**: Whether future rainfall is minimal
- **expect_stable_low**: Strong signal that river should stay low

### ðŸ†• FIX: Improved Loss Function
1. **Persistence Loss** (NEW): Penalizes predicting upticks when actual stays low
2. **Extended Anchor Loss** (IMPROVED): Applied to ALL horizons, not just first 6
3. **Asymmetric Weighting** (NEW): Doesn't over-penalize staying near current low value
4. **Weighted MSE**: High-flow (3x), Rising events (1.5x)
5. **Early Rise Detection**: Catches flood peaks earlier

### Loss Function Formula:
```
loss = weighted_mse + 0.3*early_rise + 0.1*extended_anchor + 0.5*persistence + 1.0*continuity
```

**Next steps:**
- Run Cell 20 to prepare training data with new features
- Run Cell 23 to train the improved model
- Re-run the what-if forecasts to verify the fix!


In [12]:
# ============================================================================
# SKIP TRAINING FLAG: Set to True to load the latest saved model instead
# ============================================================================
# ðŸ†• FIX APPLIED: Set to False to retrain with improvements that fix the
# mean-reversion bias (false upticks at end of forecast when values are low)
#
# KEY IMPROVEMENTS:
# 1. LOW-FLOW REGIME FEATURES: hours_since_rain, is_dry_spell, stable_low_regime
# 2. FUTURE DRY FEATURES: future_is_dry_XXh, expect_stable_low
# 3. PERSISTENCE LOSS: Penalizes predicting rises when actual stays low
# 4. EXTENDED ANCHOR LOSS: Applies to ALL horizons (not just first 6)
# 5. ASYMMETRIC WEIGHTING: Doesn't over-penalize staying near current low value
# ============================================================================
SKIP_TRAINING = False  # Change to True after retraining to skip in future

import os
import pickle
import glob

if SKIP_TRAINING:
    print("="*70)
    print("SKIPPING TRAINING - Loading latest saved model")
    print("="*70)
    
    # Check if models directory exists
    models_dir = '../models'
    latest_model_path = os.path.join(models_dir, 'multihorizon_model_latest.pth')
    latest_scaler_path = os.path.join(models_dir, 'scaler_latest.pkl')
    latest_config_path = os.path.join(models_dir, 'config_latest.pkl')
    
    # Try to find the latest model (either latest_* or most recent timestamped)
    if not os.path.exists(latest_model_path):
        # Look for most recent timestamped model
        model_files = glob.glob(os.path.join(models_dir, 'multihorizon_model_*.pth'))
        if model_files:
            # Sort by modification time, get most recent
            latest_model_path = max(model_files, key=os.path.getmtime)
            # Find corresponding scaler and config
            base_name = os.path.basename(latest_model_path).replace('multihorizon_model_', '').replace('.pth', '')
            latest_scaler_path = os.path.join(models_dir, f'scaler_{base_name}.pkl')
            latest_config_path = os.path.join(models_dir, f'config_{base_name}.pkl')
            print(f"  Found timestamped model: {os.path.basename(latest_model_path)}")
        else:
            raise FileNotFoundError(f"No saved model found in {models_dir}. Please train a model first or set SKIP_TRAINING=False")
    
    # Load configuration
    if not os.path.exists(latest_config_path):
        raise FileNotFoundError(f"Config file not found: {latest_config_path}")
    
    with open(latest_config_path, 'rb') as f:
        config = pickle.load(f)
    
    # Extract parameters from config
    sequence_length = config.get('sequence_length', 120)
    horizons = config.get('horizons', [2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24,
                                        30, 36, 42, 48, 72, 96, 120, 144, 168, 192, 216, 240])
    
    # Recreate model architecture
    n_features = config['input_size']
    n_horizons = len(horizons)
    best_model = MultiHorizonLSTMModel(
        input_size=n_features,
        hidden_sizes=config['hidden_sizes'],
        n_horizons=n_horizons,
        dropout_rate=config['dropout_rate']
    )
    best_model = best_model.to(device)
    
    # Load weights
    best_model.load_state_dict(torch.load(latest_model_path, map_location=device))
    best_model.eval()  # Set to evaluation mode
    
    # Load scaler
    if not os.path.exists(latest_scaler_path):
        raise FileNotFoundError(f"Scaler file not found: {latest_scaler_path}")
    
    with open(latest_scaler_path, 'rb') as f:
        scaler = pickle.load(f)
    
    print(f"\nâœ“ Model loaded from: {latest_model_path}")
    print(f"âœ“ Scaler loaded from: {latest_scaler_path}")
    print(f"âœ“ Config loaded from: {latest_config_path}")
    print(f"âœ“ Sequence length: {sequence_length}")
    print(f"âœ“ Horizons: {horizons}")
    print(f"âœ“ Model ready for predictions!")
    
else:
    # Train the improved multi-horizon model with enhanced features and weighted loss
    # 
    # HYPERPARAMETER GUIDE (for hourly data with 24 horizons):
    # - sequence_length: 72h (3 days) - captures recent weather patterns & river response
    # - epochs: 30 (with early stopping, anchor loss helps faster convergence)
    # - batch_size: 64 (larger batch = faster training + more stable gradients)
    # - learning_rate: 0.0001 (conservative, won't overshoot with weighted loss)
    # - patience: 7 (early stopping for efficient training)
    # 
    # LOSS FUNCTION: weighted_loss + 0.3*early_rise + 0.1*anchor + 2.0*continuity
    # - Removed smoothness regularization to allow sharper predictions
    # - Added anchor loss to reduce overdispersion
    # - Added continuity loss to ensure smooth transition from observed to predicted

    model_tuple = train_multihorizon_model(
        X_isis, 
        y_isis_multi, 
        horizons,
        sequence_length=120,       # ðŸ†• 5 days of history (was 72) for better pattern recognition
        epochs=30,                 # Increased from 10 for better convergence with new loss
        batch_size=64,             # Larger batch for speed (was 32)
        learning_rate=0.0001,      # Keep conservative
        patience=7,                # Early stopping if no improvement
        hidden_sizes=[192, 128, 64],  # ðŸ†• Deeper 3-layer LSTM with more capacity
        dropout_rate=0.3           # ðŸ†• Slightly stronger dropout
    )

    best_model, scaler, history, sequence_length, horizons = model_tuple

    print("\nâœ“ Training complete!")



Sequence shape: (66955, 120, 111)
Target shape: (66955, 24)
Horizon importance weights: [2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 1.5, 1.5, 1.5, 1.5, 1.5, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0]

Model Architecture:
MultiHorizonLSTMModel(
  (lstm_layers): ModuleList(
    (0): LSTM(111, 192, batch_first=True)
    (1): LSTM(192, 128, batch_first=True)
    (2): LSTM(128, 64, batch_first=True)
  )
  (dropout_layers): ModuleList(
    (0-2): 3 x Dropout(p=0.3, inplace=False)
  )
  (fc_layers): ModuleList(
    (0-23): 24 x Linear(in_features=64, out_features=1, bias=True)
  )
)

Training Configuration:
Device: mps
Training samples: 53564
Validation samples: 13391
Batch size: 64
Epochs: 30

[DEBUG] Creating loss criterion...
[DEBUG] âœ“ Loss criterion created
[DEBUG] Creating optimizer...
[DEBUG] âœ“ Optimizer created
[DEBUG] Creating learning rate scheduler...
[DEBUG] âœ“ Scheduler created
[DEBUG] Initializing training variables...
[DEBUG] âœ“ Variables initialized

STARTING TR



[DEBUG] âœ“ Forward pass complete, output shape: torch.Size([64, 24])
[DEBUG] Computing loss...
[DEBUG] âœ“ Loss computed: 0.125022

âœ“ FIRST BATCH TEST PASSED! Loss: 0.125022

[DEBUG] Starting epoch loop...

EPOCH 1/30


Epoch 1/30:  35%|â–ˆâ–ˆâ–ˆâ–Œ      | 297/837 [00:33<01:01,  8.79it/s, loss=0.0895, mae=0.0877]


KeyboardInterrupt: 

In [None]:
# # Save the trained model weights and training artifacts
# import os
# from datetime import datetime

# # Create models directory if it doesn't exist
# models_dir = '../models'
# os.makedirs(models_dir, exist_ok=True)

# # Generate timestamp for unique filename
# timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
# model_path = os.path.join(models_dir, f'multihorizon_model_{timestamp}.pth')
# scaler_path = os.path.join(models_dir, f'scaler_{timestamp}.pkl')
# config_path = os.path.join(models_dir, f'config_{timestamp}.pkl')

# # Save model weights
# torch.save(best_model.state_dict(), model_path)
# print(f"âœ“ Model weights saved to: {model_path}")

# # Save scaler
# import pickle
# with open(scaler_path, 'wb') as f:
#     pickle.dump(scaler, f)
# print(f"âœ“ Scaler saved to: {scaler_path}")

# # Save configuration (for reproducibility)
# config = {
#     'sequence_length': sequence_length,
#     'horizons': horizons,
#     'hidden_sizes': [192, 128, 64],
#     'dropout_rate': 0.3,
#     'input_size': len(X_isis.columns),
#     'feature_columns': list(X_isis.columns),
#     'training_history': history
# }
# with open(config_path, 'wb') as f:
#     pickle.dump(config, f)
# print(f"âœ“ Configuration saved to: {config_path}")

# # Also save a "latest" version for easy loading
# latest_model_path = os.path.join(models_dir, 'multihorizon_model_latest.pth')
# latest_scaler_path = os.path.join(models_dir, 'scaler_latest.pkl')
# latest_config_path = os.path.join(models_dir, 'config_latest.pkl')

# torch.save(best_model.state_dict(), latest_model_path)
# with open(latest_scaler_path, 'wb') as f:
#     pickle.dump(scaler, f)
# with open(latest_config_path, 'wb') as f:
#     pickle.dump(config, f)
# print(f"\nâœ“ Latest versions saved for easy loading")



In [None]:
# Function to load a saved model
def load_trained_model(model_path=None, scaler_path=None, config_path=None):
    """
    Load a previously trained model, scaler, and configuration.
    
    Args:
        model_path: Path to model weights (default: latest)
        scaler_path: Path to scaler (default: latest)
        config_path: Path to config (default: latest)
    
    Returns:
        tuple: (model, scaler, config)
    """
    import pickle
    from torch import nn
    
    # Use latest versions if paths not specified
    if model_path is None:
        model_path = os.path.join(models_dir, 'multihorizon_model_2025_08.pth')
    if scaler_path is None:
        scaler_path = os.path.join(models_dir, 'scaler_2025_08.pkl')
    if config_path is None:
        config_path = os.path.join(models_dir, 'config_2025_08.pkl')
    
    # Load configuration
    with open(config_path, 'rb') as f:
        config = pickle.load(f)
    
    # Recreate model architecture - MultiHorizonLSTMModel must be defined earlier in the notebook
    model = MultiHorizonLSTMModel(
        input_size=config['input_size'],
        hidden_sizes=config['hidden_sizes'],
        n_horizons=len(config['horizons']),
        dropout_rate=config['dropout_rate']
    )
    model = model.to(device)
    
    # Load weights
    model.load_state_dict(torch.load(model_path, map_location=device))
    model.eval()  # Set to evaluation mode
    
    # Load scaler
    with open(scaler_path, 'rb') as f:
        scaler = pickle.load(f)
    
    print(f"âœ“ Model loaded from: {model_path}")
    print(f"âœ“ Scaler loaded from: {scaler_path}")
    print(f"âœ“ Config loaded from: {config_path}")
    
    return model, scaler, config

print("\nðŸ“¦ Use load_trained_model() to reload this model in future sessions")

# Don't call load_trained_model() here - it will fail if MultiHorizonLSTM isn't defined yet
# Users should call it manually when needed: model, scaler, config = load_trained_model()
load_trained_model()

## 12. Generate Future Predictions


In [None]:

def predict_single_ensemble_member(model, scaler, historical_df, rainfall_forecast_df, 
                                    X_train_columns, sequence_length=72, 
                                    horizons=None, verbose=True):
    """
    Generate a 240-hour river flow prediction for ONE ensemble rainfall member.
    
    This function:
    1. Combines historical data with the FULL future rainfall forecast
    2. Computes features ONCE with future rainfall visible
    3. Makes ONE prediction from the current state
    4. Interpolates the sparse horizon predictions into a continuous timeseries
    
    Args:
        model: Trained PyTorch model
        scaler: Fitted scaler
        historical_df: Historical data with differential and rainfall
        rainfall_forecast_df: 240-hour rainfall forecast for this ensemble member
        X_train_columns: Column names used in training
        sequence_length: LSTM sequence length (default 72 = 3 days)
        horizons: List of prediction horizons [2, 4, 6, ..., 240]
        verbose: Whether to print progress
    
    Returns:
        pd.Series: 240-hour river differential prediction (hourly)
    """
    if horizons is None:
        horizons = [2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24,
                    30, 36, 42, 48,
                    72, 96, 120, 144, 168, 192, 216, 240]
    
    # Ensure timezone alignment
    if rainfall_forecast_df.index.tz is None and historical_df.index.tz is not None:
        rainfall_forecast_df = rainfall_forecast_df.tz_localize(historical_df.index.tz)
    elif rainfall_forecast_df.index.tz is not None and historical_df.index.tz is None:
        rainfall_forecast_df = rainfall_forecast_df.tz_convert('UTC').tz_localize(None)
        historical_df = historical_df.copy()
    
    # Get enough historical data for rolling features (30 days = 720 hours)
    lookback_hours = 720
    recent_history = historical_df.iloc[-lookback_hours:].copy()
    
    # Store the forecast start time and current differential
    forecast_start_time = recent_history.index[-1]
    current_differential = recent_history['differential'].iloc[-1]
    
    if verbose:
        print(f"  Current time: {forecast_start_time}")
        print(f"  Current differential: {current_differential:.3f}m")
    
    # Create combined dataframe: historical + future rainfall
    # This allows the feature engineering to "see" future rainfall
    combined_df = recent_history.copy()
    
    # Add future rainfall data to the combined dataframe
    for col in combined_df.columns:
        if col in rainfall_forecast_df.columns:
            # Append future rainfall
            future_data = rainfall_forecast_df[[col]].copy()
            # Only add future timestamps not already in combined_df
            future_data = future_data[future_data.index > combined_df.index[-1]]
            combined_df = pd.concat([combined_df, future_data])
        elif col == 'differential':
            # Forward fill differential for future timestamps
            future_times = rainfall_forecast_df.index[rainfall_forecast_df.index > combined_df.index[-1]]
            future_diff = pd.DataFrame(
                {'differential': [current_differential] * len(future_times)},
                index=future_times
            )
            combined_df = pd.concat([combined_df, future_diff])
    
    # Remove any duplicate indices
    combined_df = combined_df[~combined_df.index.duplicated(keep='first')]
    combined_df = combined_df.sort_index()
    
    # Now create features with the FULL future rainfall visible
    # This is the key fix - features are computed with future data available
    df_with_features = create_features_with_future_rainfall(
        combined_df,
        rainfall_forecast_df,
        differential_column='differential'
    )
    
    # Get features at the "current" time (end of historical data)
    # This is where we make our prediction FROM
    current_time_idx = df_with_features.index.get_loc(forecast_start_time)
    
    # Get only the features used during training (in same order)
    X_forecast = df_with_features[X_train_columns].copy()
    
    # Fill any remaining NaN (should be minimal now)
    X_forecast = X_forecast.ffill().bfill().fillna(0)
    
    # Get the input sequence (last `sequence_length` hours before current time)
    sequence_start = current_time_idx - sequence_length + 1
    sequence_end = current_time_idx + 1
    sequence_data = X_forecast.iloc[sequence_start:sequence_end]
    
    if len(sequence_data) < sequence_length:
        raise ValueError(f"Not enough historical data. Need {sequence_length} hours, got {len(sequence_data)}")
    
    # Scale and convert to tensor
    sequence_scaled = scaler.transform(sequence_data)
    sequence_tensor = torch.FloatTensor(sequence_scaled).unsqueeze(0).to(device)
    
    # Make prediction
    model.eval()
    with torch.no_grad():
        predictions = model(sequence_tensor).cpu().numpy()[0]  # Shape: (n_horizons,)
    
    # Create sparse prediction series (at model horizons)
    horizon_times = [forecast_start_time + pd.Timedelta(hours=h) for h in horizons]
    sparse_predictions = pd.Series(predictions, index=horizon_times)
    
    # Add current value at t=0 for interpolation
    sparse_predictions[forecast_start_time] = current_differential
    sparse_predictions = sparse_predictions.sort_index()
    
    # Interpolate to hourly resolution for full 240-hour forecast
    full_timeline = pd.date_range(
        start=forecast_start_time,
        periods=241,  # 0 to 240 hours inclusive
        freq='1h'
    )
    
    # Reindex and interpolate
    full_predictions = sparse_predictions.reindex(full_timeline)
    full_predictions = full_predictions.interpolate(method='linear')
    
    # Fill any edge NaNs
    full_predictions = full_predictions.ffill().bfill()
    
    if verbose:
        print(f"  Prediction range: {full_predictions.min():.3f}m to {full_predictions.max():.3f}m")
    
    return full_predictions


def predict_ensemble(model, scaler, historical_df, rainfall_ensemble_df,
                     X_train_columns, sequence_length=72, horizons=None,
                     station_names=None, n_members=20, verbose=True):
    """
    Generate ensemble river flow predictions from multiple rainfall scenarios.
    
    This is the main prediction function that:
    1. Extracts each rainfall ensemble member
    2. Runs prediction for each member (FAST - one prediction per member)
    3. Returns all predictions for statistical analysis
    
    Args:
        model: Trained PyTorch model
        scaler: Fitted scaler
        historical_df: Historical data with differential and rainfall
        rainfall_ensemble_df: DataFrame with all ensemble members
                              Columns like: Osney_member_0, Osney_member_1, ...
        X_train_columns: Column names used in training
        sequence_length: LSTM sequence length
        horizons: List of prediction horizons
        station_names: List of station names
        n_members: Number of ensemble members to use
        verbose: Whether to print progress
    
    Returns:
        pd.DataFrame: Each column is one ensemble member's 240-hour prediction
    """
    if horizons is None:
        horizons = [2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24,
                    30, 36, 42, 48,
                    72, 96, 120, 144, 168, 192, 216, 240]
    
    if station_names is None:
        # Try to infer from column names
        station_names = list(set(col.rsplit('_member_', 1)[0] 
                                  for col in rainfall_ensemble_df.columns 
                                  if '_member_' in col))
    
    print("="*70)
    print("ENSEMBLE PREDICTION")
    print("="*70)
    print(f"Using {n_members} ensemble members")
    print(f"Stations: {len(station_names)}")
    print(f"Forecast horizons: {len(horizons)} points")
    print("="*70)
    
    ensemble_predictions = {}
    
    for member_idx in tqdm(range(n_members), desc="Processing ensemble members"):
        # Extract this member's rainfall across all stations
        member_columns = [f'{station}_member_{member_idx}' for station in station_names]
        
        # Check columns exist
        missing = [col for col in member_columns if col not in rainfall_ensemble_df.columns]
        if missing:
            if verbose:
                print(f"  Skipping member {member_idx}: missing {len(missing)} columns")
            continue
        
        # Create rainfall df for this member (rename columns to station names)
        member_rainfall = rainfall_ensemble_df[member_columns].copy()
        member_rainfall.columns = station_names
        
        try:
            # Generate prediction for this ensemble member
            prediction = predict_single_ensemble_member(
                model=model,
                scaler=scaler,
                historical_df=historical_df,
                rainfall_forecast_df=member_rainfall,
                X_train_columns=X_train_columns,
                sequence_length=sequence_length,
                horizons=horizons,
                verbose=False
            )
            
            ensemble_predictions[f'member_{member_idx}'] = prediction
            
        except Exception as e:
            print(f"  Error on member {member_idx}: {e}")
            continue
    
    # Combine all predictions into a DataFrame
    ensemble_df = pd.DataFrame(ensemble_predictions)
    
    print(f"\nâœ“ Generated {len(ensemble_predictions)} ensemble predictions")
    print(f"  Forecast shape: {ensemble_df.shape}")
    print(f"  Time range: {ensemble_df.index[0]} to {ensemble_df.index[-1]}")
    
    return ensemble_df


print("âœ“ Prediction functions defined (FIXED VERSION)")

In [None]:
# Fetch rainfall forecast from Open-Meteo API
def get_rainfall_forecast(locations):
    """
    Fetches 10-day hourly rainfall forecast from the Open-Meteo API for multiple locations.
    """
    location_names = list(locations.keys())
    latitudes = [loc['latitude'] for loc in locations.values()]
    longitudes = [loc['longitude'] for loc in locations.values()]

    url = "https://api.open-meteo.com/v1/forecast"
    params = {
        "latitude": latitudes,
        "longitude": longitudes,
        "hourly": "precipitation",
        "forecast_days": 10
    }

    try:
        response = requests.get(url, params=params)
        response.raise_for_status()
        data = response.json()

        forecast_dfs = []
        for i, location_data in enumerate(data):
            df = pd.DataFrame(location_data['hourly'])
            df['timestamp'] = pd.to_datetime(df['time'])
            df = df.set_index('timestamp')
            df = df[['precipitation']]
            df = df.rename(columns={'precipitation': location_names[i]})
            forecast_dfs.append(df)
        
        combined_df = pd.concat(forecast_dfs, axis=1)
        print("Successfully fetched and processed rainfall forecast.")
        return combined_df

    except requests.exceptions.RequestException as e:
        print(f"API request failed: {e}")
        return pd.DataFrame()
    except (KeyError, TypeError) as e:
        print(f"Failed to parse API response: {e}")
        return pd.DataFrame()

def get_rainfall_forecast_ensemble(locations, ensemble_method='mean', ensemble_model='ecmwf_ifs025'):
    """
    Fetches ensemble rainfall forecast from the Open-Meteo API for multiple locations.
    
    Parameters:
    -----------
    locations : dict
        Dictionary of location names and their coordinates
    ensemble_method : str
        'mean', 'median', 'percentiles', or 'all'
    ensemble_model : str
        'ecmwf_ifs025', 'icon_seamless', or 'gfs_seamless'
    
    Returns:
    --------
    pd.DataFrame
        Rainfall forecast data with timestamps as index
    """
    location_names = list(locations.keys())
    latitudes = [loc['latitude'] for loc in locations.values()]
    longitudes = [loc['longitude'] for loc in locations.values()]

    url = "https://ensemble-api.open-meteo.com/v1/ensemble"
    params = {
        "latitude": latitudes,
        "longitude": longitudes,
        "hourly": "precipitation",
        "models": ensemble_model
    }

    try:
        response = requests.get(url, params=params)
        response.raise_for_status()
        data = response.json()

        forecast_dfs = []
        for i, location_data in enumerate(data):
            df = pd.DataFrame(location_data['hourly'])
            df['timestamp'] = pd.to_datetime(df['time'])
            df = df.set_index('timestamp')
            
            # Handle ensemble members
            precip_cols = [col for col in df.columns if col.startswith('precipitation')]
            
            if ensemble_method == 'mean':
                df_agg = df[precip_cols].mean(axis=1).to_frame(name=location_names[i])
            elif ensemble_method == 'median':
                df_agg = df[precip_cols].median(axis=1).to_frame(name=location_names[i])
            elif ensemble_method == 'percentiles':
                df_agg = pd.DataFrame({
                    f"{location_names[i]}_p10": df[precip_cols].quantile(0.1, axis=1),
                    f"{location_names[i]}_p50": df[precip_cols].quantile(0.5, axis=1),
                    f"{location_names[i]}_p90": df[precip_cols].quantile(0.9, axis=1)
                })
            elif ensemble_method == 'all':
                df_agg = df[precip_cols].copy()
                df_agg.columns = [f"{location_names[i]}_member{j}" for j in range(len(precip_cols))]
            else:
                raise ValueError(f"Unknown ensemble_method: {ensemble_method}")
            
            forecast_dfs.append(df_agg)
        
        combined_df = pd.concat(forecast_dfs, axis=1)
        print(f"Successfully fetched and processed ensemble rainfall forecast ({ensemble_method}).")
        return combined_df

    except requests.exceptions.RequestException as e:
        print(f"API request failed: {e}")
        return pd.DataFrame()
    except (KeyError, TypeError) as e:
        print(f"Failed to parse API response: {e}")
        return pd.DataFrame()

# Coordinates for rainfall stations
station_coordinates = {
    'Osney': {'latitude': 51.750, 'longitude': -1.272},
    'Eynsham': {'latitude': 51.789, 'longitude': -1.402},
    'St': {'latitude': 51.7, 'longitude': -1.5},
    'Shorncote': {'latitude': 51.666, 'longitude': -1.916},
    'Rapsgate': {'latitude': 51.815, 'longitude': -1.975},
    'Stowell': {'latitude': 51.833, 'longitude': -1.821},
    'Bourton': {'latitude': 51.884, 'longitude': -1.758},
    'Chipping': {'latitude': 51.942, 'longitude': -1.547},
    'Grimsbury': {'latitude': 52.065, 'longitude': -1.326},
    'Bicester': {'latitude': 51.899, 'longitude': -1.155},
    'Byfield': {'latitude': 52.179, 'longitude': -1.274},
    'Swindon': {'latitude': 51.556, 'longitude': -1.779},
    'Worsham': {'latitude': 51.817, 'longitude': -1.498}
}




## ðŸŽ¯ ENSEMBLE OF FORECASTS

**The Ultimate Upgrade:** Instead of averaging rainfall ensemble members BEFORE prediction, we now:

1. **Get ALL rainfall ensemble members** (50 members from ECMWF)
2. **Generate a separate forecast for EACH rainfall member** (50 flag forecasts!)
3. **Analyze the ensemble spread** to quantify prediction uncertainty

### Why This is Better:
- **Preserves full uncertainty chain**: Rainfall uncertainty â†’ Flow uncertainty
- **Probabilistic predictions**: Calculate probability of different flag colors
- **Distribution-aware**: See the full range of possible outcomes, not just mean/median
- **Operational decision support**: "30% chance of red flag" is more useful than a single prediction


In [None]:
def get_rainfall_forecast_ensemble(locations, ensemble_method='mean', 
                                   ensemble_model='icon_seamless'):
    """
    Fetches 10-day hourly ENSEMBLE rainfall forecast from the Open-Meteo Ensemble API.
    
    This function retrieves multiple ensemble members (perturbed forecasts) and combines
    them according to the specified method.

    Args:
        locations (dict): A dictionary where keys are location names and values are dicts
                          with 'latitude' and 'longitude'.
        ensemble_method (str): How to combine ensemble members:
                               - 'mean': Average across all members (default)
                               - 'median': Median across all members
                               - 'all': Return all individual members
                               - 'percentiles': Return 10th, 50th, 90th percentiles
        ensemble_model (str): Which ensemble model to use:
                               - 'icon_seamless': DWD ICON (default, best for Europe)
                               - 'gfs_seamless': NOAA GFS (good for global)
                               - 'ecmwf_ifs025': ECMWF (best accuracy, 51 members)

    Returns:
        pd.DataFrame: A DataFrame with a timestamp index and columns for each location's
                      predicted rainfall in mm. For ensemble_method='all', columns will be
                      named like 'Location_member_0', 'Location_member_1', etc.
    """
    location_names = list(locations.keys())
    latitudes = [loc['latitude'] for loc in locations.values()]
    longitudes = [loc['longitude'] for loc in locations.values()]

    # API endpoint and parameters for ENSEMBLE forecast
    # NOTE: Use ensemble-api.open-meteo.com, not api.open-meteo.com
    url = "https://ensemble-api.open-meteo.com/v1/ensemble"
    params = {
        "latitude": latitudes,
        "longitude": longitudes,
        "hourly": "precipitation",
        "forecast_days": 10,
        "models": ensemble_model
    }

    try:
        # Make the API request
        print(f"Fetching ensemble forecast from {ensemble_model}...")
        response = requests.get(url, params=params)
        response.raise_for_status()
        data = response.json()

        # Process the response - ensemble API returns data differently
        # For multiple locations, data is a list; for single location, it's a dict
        if not isinstance(data, list):
            data = [data]  # Convert to list for consistent handling
        
        forecast_dfs = []
        
        for i, location_data in enumerate(data):
            location_name = location_names[i]
            hourly_data = location_data['hourly']
            
            # Extract timestamps
            timestamps = pd.to_datetime(hourly_data['time'])
            
            # Extract all ensemble members for precipitation
            # Ensemble members are named like 'precipitation_member01', 'precipitation_member02', etc.
            # Note: they use 2-digit format with leading zeros (01, 02, ... 09, 10, 11, ...)
            ensemble_members = []
            member_idx = 1  # Start at 1, not 0
            
            while True:
                # Try with leading zero format for members 1-9
                member_key = f'precipitation_member{member_idx:02d}'
                if member_key in hourly_data:
                    member_data = hourly_data[member_key]
                    ensemble_members.append(member_data)
                    member_idx += 1
                else:
                    break  # No more members found
            
            # Convert to DataFrame (rows=time, columns=members)
            ensemble_df = pd.DataFrame(ensemble_members).T
            ensemble_df.index = timestamps
            
            print(f"  {location_name}: Retrieved {len(ensemble_members)} ensemble members")
            
            # Combine ensemble members according to specified method
            if ensemble_method == 'mean':
                # Take the mean across all ensemble members
                result_df = pd.DataFrame({
                    location_name: ensemble_df.mean(axis=1)
                })
            elif ensemble_method == 'median':
                # Take the median across all ensemble members
                result_df = pd.DataFrame({
                    location_name: ensemble_df.median(axis=1)
                })
            elif ensemble_method == 'percentiles':
                # Return 10th, 50th (median), and 90th percentiles
                result_df = pd.DataFrame({
                    f'{location_name}_p10': ensemble_df.quantile(0.10, axis=1),
                    f'{location_name}_p50': ensemble_df.quantile(0.50, axis=1),
                    f'{location_name}_p90': ensemble_df.quantile(0.90, axis=1)
                })
            elif ensemble_method == 'all':
                # Return all individual ensemble members
                result_df = ensemble_df.copy()
                result_df.columns = [f'{location_name}_member_{j}' for j in range(len(ensemble_members))]
            else:
                raise ValueError(f"Unknown ensemble_method: {ensemble_method}")
            
            forecast_dfs.append(result_df)
        
        # Combine all location forecasts into a single DataFrame
        combined_df = pd.concat(forecast_dfs, axis=1)
        
        print(f"\nâœ“ Successfully fetched and processed ENSEMBLE rainfall forecast")
        print(f"  Model: {ensemble_model}")
        print(f"  Method: {ensemble_method}")
        print(f"  Shape: {combined_df.shape}")
        
        return combined_df

    except requests.exceptions.RequestException as e:
        print(f"API request failed: {e}")
        return pd.DataFrame()
    except (KeyError, TypeError) as e:
        print(f"Failed to parse API response: {e}")
        print(f"Response structure: {data.keys() if isinstance(data, dict) else 'Not a dict'}")
        return pd.DataFrame()

print("âœ“ Ensemble rainfall forecast function defined")


In [None]:
# Step 1: Fetch ALL individual rainfall ensemble members
print("="*80)
print("STEP 1: Fetching ALL Rainfall Ensemble Members")
print("="*80)

rainfall_forecast_all_members = get_rainfall_forecast_ensemble(
    station_coordinates, 
    ensemble_method='all',  # Get ALL individual members!
    ensemble_model='ecmwf_ifs025'  # ECMWF has 50 members
)

print(f"\nâœ“ Retrieved rainfall ensemble with shape: {rainfall_forecast_all_members.shape}")
print(f"  Columns: {len(rainfall_forecast_all_members.columns)}")
print(f"  Hours: {len(rainfall_forecast_all_members)}")

# Count how many members per station
station_names = list(station_coordinates.keys())
members_per_station = len([col for col in rainfall_forecast_all_members.columns if col.startswith(f'{station_names[0]}_member_')])
print(f"  Ensemble members per station: {members_per_station}")
print(f"  Total stations: {len(station_names)}")

print("\nSample column names:")
print(rainfall_forecast_all_members.columns[:5].tolist())
print("...")


In [None]:
# Step 2: Generate ensemble predictions (FAST - one prediction per member)
# This uses the new efficient prediction function

print("\n" + "="*80)
print("STEP 2: Generating Ensemble River Flow Predictions")
print("="*80)
print("Each ensemble member has different rainfall â†’ different river prediction")
print("="*80)

# Generate predictions for ALL 50 ensemble members
# This should take ~30-60 seconds total
n_members_to_use = 50

ensemble_predictions_df = predict_ensemble(
    model=best_model,
    scaler=scaler,
    historical_df=isis_df_featureless,
    rainfall_ensemble_df=rainfall_forecast_all_members,
    X_train_columns=X_isis.columns,
    sequence_length=sequence_length,
    horizons=horizons,
    station_names=station_names,
    n_members=n_members_to_use,
    verbose=True
)

# Show sample of the predictions
print("\n" + "="*70)
print("ENSEMBLE PREDICTIONS SAMPLE")
print("="*70)
print(f"\nFirst few hours:")
print(ensemble_predictions_df.head())
print(f"\nLast few hours:")
print(ensemble_predictions_df.tail())


In [None]:
# Step 3: Calculate ensemble statistics from the 20 timeseries
print("\n" + "="*80)
print("STEP 3: Computing Ensemble Statistics")
print("="*80)

# The new ensemble_predictions_df has:
# - Index: hourly timestamps (241 rows: hour 0 to hour 240)
# - Columns: member_0, member_1, ..., member_19

# Calculate statistics across all ensemble members
ensemble_stats = pd.DataFrame({
    'mean': ensemble_predictions_df.mean(axis=1),
    'median': ensemble_predictions_df.median(axis=1),
    'std': ensemble_predictions_df.std(axis=1),
    'p05': ensemble_predictions_df.quantile(0.05, axis=1),
    'p10': ensemble_predictions_df.quantile(0.10, axis=1),
    'p25': ensemble_predictions_df.quantile(0.25, axis=1),
    'p75': ensemble_predictions_df.quantile(0.75, axis=1),
    'p90': ensemble_predictions_df.quantile(0.90, axis=1),
    'p95': ensemble_predictions_df.quantile(0.95, axis=1),
    'min': ensemble_predictions_df.min(axis=1),
    'max': ensemble_predictions_df.max(axis=1),
})

print(f"âœ“ Statistics calculated for {len(ensemble_predictions_df)} hourly timesteps")
print(f"  Statistics available: {list(ensemble_stats.columns)}")
print(f"  Number of ensemble members: {len(ensemble_predictions_df.columns)}")

# Show spread at key horizons
print("\n" + "="*70)
print("ENSEMBLE SPREAD AT KEY HORIZONS")
print("="*70)
print(f"{'Horizon':<12} {'Mean':>10} {'Std Dev':>10} {'IQR':>10} {'Range':>12}")
print("-"*70)

for hours_ahead in [0, 6, 12, 24, 48, 72, 120, 168, 240]:
    if hours_ahead < len(ensemble_stats):
        row = ensemble_stats.iloc[hours_ahead]
        iqr = row['p75'] - row['p25']
        rng = row['max'] - row['min']
        print(f"{hours_ahead}h ahead    {row['mean']:>10.3f}m {row['std']:>10.4f}m {iqr:>10.4f}m {rng:>10.4f}m")

print("="*70)

# Check if ensemble members are actually different
print("\nâœ“ Verification: Are ensemble members different?")
member_means = ensemble_predictions_df.mean(axis=0)
print(f"  Mean prediction by member (should vary if working correctly):")
print(f"  {member_means.values[:5].round(4)} ... {member_means.values[-3:].round(4)}")
print(f"  Range across members: {member_means.max() - member_means.min():.4f}m")


In [None]:
# Step 4: Visualize the 20 Ensemble Predictions (all 240 hours)
print("\n" + "="*80)
print("STEP 4: Visualizing Ensemble Predictions")
print("="*80)

# Configuration: Choose threshold set
# Options: 'fixed' (0.215, 0.33, 0.44, 0.535) or 'historical' (0.1366, 0.2582, 0.387, 0.6047)
THRESHOLD_SET = 'historical'  # Change to 'historical' to use old thresholds

# Define both threshold sets
FIXED_THRESHOLDS = {
    'green': (-float('inf'), 0.215),
    'light_blue': (0.215, 0.33),
    'dark_blue': (0.33, 0.44),
    'amber': (0.44, 0.535),
    'red': (0.535, float('inf'))
}

HISTORICAL_THRESHOLDS = {
    'green': (-float('inf'), 0.1366),
    'light_blue': (0.1366, 0.2582),
    'dark_blue': (0.2582, 0.387),
    'amber': (0.387, 0.6047),
    'red': (0.6047, float('inf'))
}

# Select the threshold set based on configuration
if THRESHOLD_SET == 'fixed':
    FLAG_THRESHOLDS = FIXED_THRESHOLDS
    print(f"Using FIXED thresholds: {[FIXED_THRESHOLDS['light_blue'][0], FIXED_THRESHOLDS['dark_blue'][0], FIXED_THRESHOLDS['amber'][0], FIXED_THRESHOLDS['red'][0]]}")
elif THRESHOLD_SET == 'historical':
    FLAG_THRESHOLDS = HISTORICAL_THRESHOLDS
    print(f"Using HISTORICAL thresholds: {[HISTORICAL_THRESHOLDS['light_blue'][0], HISTORICAL_THRESHOLDS['dark_blue'][0], HISTORICAL_THRESHOLDS['amber'][0], HISTORICAL_THRESHOLDS['red'][0]]}")
else:
    raise ValueError(f"THRESHOLD_SET must be 'fixed' or 'historical', got '{THRESHOLD_SET}'")

FLAG_COLORS = {
    'green': '#008001',
    'light_blue': '#02bfff',
    'dark_blue': '#000080',
    'amber': '#ffa503',
    'red': '#ff0000'
}

# Create a comprehensive visualization
fig, axes = plt.subplots(2, 1, figsize=(18, 12), height_ratios=[2, 1])

# ============= Plot 1: Fan Chart with all ensemble members =============
ax1 = axes[0]

# Make index timezone-naive for cleaner plotting
plot_df = ensemble_predictions_df.copy()
if hasattr(plot_df.index, 'tz') and plot_df.index.tz is not None:
    plot_df.index = plot_df.index.tz_localize(None)
    
plot_stats = ensemble_stats.copy()
if hasattr(plot_stats.index, 'tz') and plot_stats.index.tz is not None:
    plot_stats.index = plot_stats.index.tz_localize(None)

# Add flag color background bands using FLAG_THRESHOLDS
ax1.axhspan(-2, FLAG_THRESHOLDS['light_blue'][0], color=FLAG_COLORS['green'], alpha=0.15, zorder=0, label='Green Flag')
ax1.axhspan(FLAG_THRESHOLDS['light_blue'][0], FLAG_THRESHOLDS['dark_blue'][0], color=FLAG_COLORS['light_blue'], alpha=0.15, zorder=0, label='Light Blue')
ax1.axhspan(FLAG_THRESHOLDS['dark_blue'][0], FLAG_THRESHOLDS['amber'][0], color=FLAG_COLORS['dark_blue'], alpha=0.15, zorder=0, label='Dark Blue')
ax1.axhspan(FLAG_THRESHOLDS['amber'][0], FLAG_THRESHOLDS['red'][0], color=FLAG_COLORS['amber'], alpha=0.15, zorder=0, label='Amber')
ax1.axhspan(FLAG_THRESHOLDS['red'][0], 2, color=FLAG_COLORS['red'], alpha=0.15, zorder=0, label='Red Flag')

# Plot individual ensemble members (thin lines)
for col in plot_df.columns:
    ax1.plot(plot_df.index, plot_df[col], color='steelblue', alpha=0.3, linewidth=0.8, zorder=1)

# Plot uncertainty bands (filled regions)
ax1.fill_between(plot_stats.index, plot_stats['p05'], plot_stats['p95'], 
                  alpha=0.2, color='blue', label='5th-95th percentile', zorder=2)
ax1.fill_between(plot_stats.index, plot_stats['p25'], plot_stats['p75'], 
                  alpha=0.3, color='blue', label='25th-75th percentile (IQR)', zorder=3)

# Plot ensemble mean
ax1.plot(plot_stats.index, plot_stats['mean'], color='darkred', linewidth=2.5, 
         label='Ensemble Mean', zorder=4)

# Formatting
ax1.set_title(f'10-Day River Differential Forecast with Ensemble Uncertainty\n({n_members_to_use} rainfall scenarios)', 
              fontsize=14, fontweight='bold')
ax1.set_xlabel('Date/Time', fontsize=12)
ax1.set_ylabel('River Differential (m)', fontsize=12)
ax1.legend(loc='upper right', fontsize=9)
ax1.grid(True, alpha=0.3)

# Set reasonable y limits
y_min = min(plot_stats['min'].min(), -0.1)
y_max = max(plot_stats['max'].max(), 0.7)
ax1.set_ylim(y_min - 0.05, y_max + 0.05)

# Add vertical lines for day markers
for i in range(1, 11):
    day_time = plot_df.index[0] + pd.Timedelta(hours=24*i)
    if day_time <= plot_df.index[-1]:
        ax1.axvline(x=day_time, color='gray', linestyle='--', alpha=0.3)
        ax1.text(day_time, y_max, f'Day {i}', ha='center', va='bottom', fontsize=9, alpha=0.5)

# ============= Plot 2: Ensemble Spread Over Time =============
ax2 = axes[1]

# Plot standard deviation and IQR over time
ax2.fill_between(plot_stats.index, 0, plot_stats['std'], alpha=0.5, color='coral', label='Std Dev')
ax2.plot(plot_stats.index, plot_stats['p75'] - plot_stats['p25'], color='darkgreen', 
         linewidth=2, label='IQR (P75-P25)')
ax2.plot(plot_stats.index, plot_stats['max'] - plot_stats['min'], color='purple', 
         linewidth=1.5, linestyle='--', label='Full Range (Max-Min)')

ax2.set_title('Forecast Uncertainty Over Time', fontsize=12, fontweight='bold')
ax2.set_xlabel('Date/Time', fontsize=12)
ax2.set_ylabel('Spread (m)', fontsize=12)
ax2.legend(loc='upper right', fontsize=9)
ax2.grid(True, alpha=0.3)
ax2.set_ylim(0, None)

plt.tight_layout()
plt.savefig('ensemble_forecast.png', dpi=150, bbox_inches='tight')
plt.show()

print("\nâœ“ Plot saved as 'ensemble_forecast.png'")

In [None]:
# Step 5: Calculate FLAG PROBABILITIES from ensemble
print("\n" + "="*80)
print("STEP 5: Calculating Flag Probabilities Over Time")
print("="*80)

# Use the same thresholds as FLAG_THRESHOLDS for consistency
# FLAG_BOUNDARIES will use the same threshold set selected in cell 34 (THRESHOLD_SET)
FLAG_BOUNDARIES = FLAG_THRESHOLDS.copy()

def calculate_flag_probabilities_new(ensemble_df, flag_boundaries):
    """
    Calculate the probability of each flag color at each timestep.
    
    Args:
        ensemble_df: DataFrame where columns are ensemble members, rows are hours
        flag_boundaries: Dict of flag name -> (lower, upper) thresholds
    
    Returns:
        DataFrame with probability columns for each flag at each timestep
    """
    n_members = ensemble_df.shape[1]
    probs_df = pd.DataFrame(index=ensemble_df.index)
    
    for flag_name, (lower, upper) in flag_boundaries.items():
        # Count how many members predict each flag at each timestep
        if upper == float('inf'):
            count = (ensemble_df >= lower).sum(axis=1)
        elif lower == -float('inf'):
            count = (ensemble_df < upper).sum(axis=1)
        else:
            count = ((ensemble_df >= lower) & (ensemble_df < upper)).sum(axis=1)
        
        probs_df[flag_name] = count / n_members * 100  # Convert to percentage
    
    return probs_df

# Calculate flag probabilities for all 241 timesteps
flag_probabilities = calculate_flag_probabilities_new(
    ensemble_predictions_df, 
    FLAG_BOUNDARIES
)

print(f"âœ“ Flag probabilities calculated for {len(flag_probabilities)} timesteps")
print(f"  Flags tracked: {list(flag_probabilities.columns)}")

# Show flag probabilities at key horizons
print("\n" + "="*70)
print("FLAG PROBABILITIES AT KEY HORIZONS")
print("="*70)
print(f"{'Hour':<8} {'Green':>8} {'Lt Blue':>10} {'Dk Blue':>10} {'Amber':>8} {'Red':>8}")
print("-"*70)

for hours_ahead in [0, 6, 12, 24, 48, 72, 120, 168, 240]:
    if hours_ahead < len(flag_probabilities):
        row = flag_probabilities.iloc[hours_ahead]
        print(f"{hours_ahead:<8} {row['green']:>7.1f}% {row['light_blue']:>9.1f}% "
              f"{row['dark_blue']:>9.1f}% {row['amber']:>7.1f}% {row['red']:>7.1f}%")

print("="*70)

In [None]:
# Step 6: Visualize Flag Probabilities Over Time
print("\n" + "="*80)
print("STEP 6: Flag Probability Visualization")
print("="*80)

# Create stacked area chart of flag probabilities over time
fig, ax = plt.subplots(figsize=(18, 6))

# Make index timezone-naive for plotting
plot_probs = flag_probabilities.copy()
if hasattr(plot_probs.index, 'tz') and plot_probs.index.tz is not None:
    plot_probs.index = plot_probs.index.tz_localize(None)

# Order flags from low to high risk
flag_order = ['green', 'light_blue', 'dark_blue', 'amber', 'red']
colors = ['#008001', '#02bfff', '#000080', '#ffa503', '#ff0000']

# Create stacked area plot
ax.stackplot(plot_probs.index, 
             [plot_probs[flag] for flag in flag_order],
             labels=['Green', 'Light Blue', 'Dark Blue', 'Amber', 'Red'],
             colors=colors,
             alpha=0.8)

# Formatting
ax.set_title(f'Flag Probability Distribution Over 10-Day Forecast\n(based on {n_members_to_use} rainfall ensemble members)', 
             fontsize=14, fontweight='bold')
ax.set_xlabel('Date/Time', fontsize=12)
ax.set_ylabel('Probability (%)', fontsize=12)
ax.set_ylim(0, 100)
ax.legend(loc='upper right', fontsize=10)
ax.grid(True, alpha=0.3, axis='y')

# Add day markers
for i in range(1, 11):
    day_time = plot_probs.index[0] + pd.Timedelta(hours=24*i)
    if day_time <= plot_probs.index[-1]:
        ax.axvline(x=day_time, color='white', linestyle='-', linewidth=1, alpha=0.5)

plt.tight_layout()
plt.savefig('flag_probabilities.png', dpi=150, bbox_inches='tight')
plt.show()

print("\nâœ“ Plot saved as 'flag_probabilities.png'")


In [None]:
# Step 7: All Ensemble Members Spaghetti Plot with Rainfall
print("\n" + "="*80)
print("STEP 7: Spaghetti Plot - All Ensemble Members with Rainfall")
print("="*80)

fig, ax = plt.subplots(figsize=(20, 12))
ax_rain = ax.twinx()  # Secondary axis for rainfall

# Define flag colors
flag_colors = {
    'green': '#008001',
    'light_blue': '#02bfff',
    'dark_blue': '#000080',
    'amber': '#ffa503',
    'red': '#ff0000'
}

# Add flag boundaries as horizontal filled regions using FLAG_BOUNDARIES
# (FLAG_BOUNDARIES is defined in cell 35 and uses the same threshold set as FLAG_THRESHOLDS)
ax.axhspan(-4, FLAG_BOUNDARIES['light_blue'][0], color=flag_colors['green'], alpha=0.08, zorder=0)
ax.axhspan(FLAG_BOUNDARIES['light_blue'][0], FLAG_BOUNDARIES['dark_blue'][0], color=flag_colors['light_blue'], alpha=0.08, zorder=0)
ax.axhspan(FLAG_BOUNDARIES['dark_blue'][0], FLAG_BOUNDARIES['amber'][0], color=flag_colors['dark_blue'], alpha=0.08, zorder=0)
ax.axhspan(FLAG_BOUNDARIES['amber'][0], FLAG_BOUNDARIES['red'][0], color=flag_colors['amber'], alpha=0.08, zorder=0)
ax.axhspan(FLAG_BOUNDARIES['red'][0], 4, color=flag_colors['red'], alpha=0.08, zorder=0)

# ============= RAINFALL DATA PREPARATION =============
print("Preparing rainfall data...")

# 1. Historical rainfall (last 7 days) - aggregate across all stations
last_7_days = isis_df_featureless.iloc[-24*7:].copy()
if hasattr(last_7_days.index, 'tz') and last_7_days.index.tz is not None:
    last_7_days.index = last_7_days.index.tz_localize(None)

rainfall_cols = [col for col in last_7_days.columns if col != 'differential']
historical_rainfall_hourly = last_7_days[rainfall_cols].mean(axis=1)  # Average across stations
historical_rainfall_daily = historical_rainfall_hourly.resample('1D').sum()  # Daily totals

# 2. Forecast rainfall - calculate ensemble statistics
rainfall_forecast_naive = rainfall_forecast_all_members.copy()
if hasattr(rainfall_forecast_naive.index, 'tz') and rainfall_forecast_naive.index.tz is not None:
    rainfall_forecast_naive.index = rainfall_forecast_naive.index.tz_localize(None)

# For each timestamp, sum across all stations for each member, then get statistics
n_members_rain = 50  # All ECMWF ensemble members
station_names_list = list(station_coordinates.keys())

# Calculate average rainfall per ensemble member (mean across all stations)
member_totals = pd.DataFrame(index=rainfall_forecast_naive.index)
for member_idx in range(n_members_rain):
    member_cols = [f'{station}_member_{member_idx}' for station in station_names_list]
    existing_cols = [col for col in member_cols if col in rainfall_forecast_naive.columns]
    if existing_cols:
        member_totals[f'member_{member_idx}'] = rainfall_forecast_naive[existing_cols].mean(axis=1)

# Resample to daily for cleaner visualization
member_totals_daily = member_totals.resample('1D').sum()

# Calculate rainfall ensemble statistics
forecast_rain_median = member_totals_daily.median(axis=1)
forecast_rain_p10 = member_totals_daily.quantile(0.10, axis=1)
forecast_rain_p90 = member_totals_daily.quantile(0.90, axis=1)

# Error bars: distance from median to percentiles
error_lower = forecast_rain_median - forecast_rain_p10
error_upper = forecast_rain_p90 - forecast_rain_median

print(f"  Historical rainfall: {len(historical_rainfall_daily)} daily bars (avg across stations)")
print(f"  Forecast rainfall: {len(forecast_rain_median)} daily bars with ensemble spread (P10-P90)")

# ============= PLOT RAINFALL BARS =============
bar_width = 0.8  # Width in days (for daily bars)

# Historical rainfall bars (gray)
ax_rain.bar(historical_rainfall_daily.index, historical_rainfall_daily.values,
           width=bar_width, color='gray', alpha=0.4, label='Historical Rainfall', zorder=1)

# Forecast rainfall bars with error bars showing ensemble spread
ax_rain.bar(forecast_rain_median.index, forecast_rain_median.values,
           width=bar_width, color='cornflowerblue', alpha=0.5,
           yerr=[error_lower.values, error_upper.values],
           error_kw={'elinewidth': 1.5, 'capsize': 3, 'capthick': 1, 'alpha': 0.7, 'color': 'navy'},
           label='Forecast Rainfall (median Â± P10-P90)', zorder=2)

# ============= PLOT RIVER DIFFERENTIAL =============
# Plot historical differential as a solid black line
ax.plot(last_7_days.index, last_7_days['differential'].values, 
        color='black', linewidth=3, label='Historical Differential', zorder=100, alpha=0.9)

# Current time marker
current_time = last_7_days.index[-1]
ax.axvline(x=current_time, color='red', linestyle='--', linewidth=2.5, 
          alpha=0.8, label='Now', zorder=101)

# Make ensemble predictions timezone-naive for plotting
plot_ensemble = ensemble_predictions_df.copy()
if hasattr(plot_ensemble.index, 'tz') and plot_ensemble.index.tz is not None:
    plot_ensemble.index = plot_ensemble.index.tz_localize(None)

# Plot ALL ensemble member river forecasts (spaghetti plot)
n_members = len(plot_ensemble.columns)
print(f"Plotting {n_members} ensemble member trajectories...")

for idx, col in enumerate(plot_ensemble.columns):
    if idx == 0:
        ax.plot(plot_ensemble.index, plot_ensemble[col].values, 
                color='steelblue', linewidth=1.2, alpha=0.5, 
                label=f'Ensemble Predictions (n={n_members})', zorder=50)
    else:
        ax.plot(plot_ensemble.index, plot_ensemble[col].values, 
                color='steelblue', linewidth=1.2, alpha=0.5, zorder=50)

# Overlay the ensemble MEAN as a bold line
ax.plot(plot_ensemble.index, ensemble_stats['mean'].values, 
       color='darkviolet', linewidth=3, 
       label='Ensemble Mean', zorder=102, alpha=0.95)

# Also plot the ensemble MEDIAN
ax.plot(plot_ensemble.index, ensemble_stats['median'].values, 
       color='darkgreen', linewidth=2.5, linestyle='--',
       label='Ensemble Median', zorder=103, alpha=0.85)

# ============= FORMATTING =============
ax.set_xlabel('Date/Time', fontsize=14, fontweight='bold')
ax.set_ylabel('River Differential (m)', fontsize=14, fontweight='bold', color='black')
ax_rain.set_ylabel('Rainfall (mm/day, avg across stations)', fontsize=14, fontweight='bold', color='cornflowerblue')

ax.tick_params(axis='y', labelcolor='black')
ax_rain.tick_params(axis='y', labelcolor='cornflowerblue')

# Set y-limits
ax.set_ylim(-0.1, max(0.9, ensemble_stats['max'].max() + 0.1))
max_rain = max(historical_rainfall_daily.max(), forecast_rain_p90.max()) if len(forecast_rain_p90) > 0 else historical_rainfall_daily.max()
ax_rain.set_ylim(0, max_rain * 1.3)

ax.set_title(f'10-Day River Forecast - All {n_members} Ensemble Members\n'
             f'(Different rainfall scenarios produce different river predictions)', 
            fontsize=16, fontweight='bold', pad=20)

ax.grid(True, alpha=0.3, linestyle=':', linewidth=0.8)

# Combined legend
lines1, labels1 = ax.get_legend_handles_labels()
lines2, labels2 = ax_rain.get_legend_handles_labels()
ax.legend(lines1 + lines2, labels1 + labels2, fontsize=10, loc='upper left', framealpha=0.9)

# Add day markers
for i in range(1, 11):
    day_time = current_time + pd.Timedelta(days=i)
    if day_time <= plot_ensemble.index[-1]:
        ax.axvline(x=day_time, color='gray', linestyle=':', alpha=0.3, zorder=0)

plt.setp(ax.xaxis.get_majorticklabels(), rotation=45, ha='right')
plt.tight_layout()
plt.savefig('spaghetti_with_rainfall.png', dpi=150, bbox_inches='tight')
plt.show()

print("\nâœ“ Spaghetti plot with rainfall complete!")
print(f"  River: {n_members} ensemble forecasts (blue lines) + mean (purple) + median (green dashed)")
print(f"  Rainfall: Historical (gray bars) + Forecast with P10-P90 error bars (blue bars)")
print(f"  Saved as 'spaghetti_with_rainfall.png'")


In [None]:
# Step 8: Total Rainfall vs Final Differential (Scatter Plot)
print("\n" + "="*80)
print("STEP 8: Total Rainfall vs Final River Differential")
print("="*80)

# Calculate total rainfall for each ensemble member over the 10-day forecast
rainfall_forecast_naive = rainfall_forecast_all_members.copy()
station_names_list = list(station_coordinates.keys())
n_members = 50  # All ECMWF ensemble members

total_rainfall_by_member = []
final_differential_by_member = []
member_labels = []

for member_idx in range(n_members):
    # Get total rainfall for this member (sum across all stations and all hours)
    member_cols = [f'{station}_member_{member_idx}' for station in station_names_list]
    existing_cols = [col for col in member_cols if col in rainfall_forecast_naive.columns]
    
    if existing_cols:
        total_rain = rainfall_forecast_naive[existing_cols].sum().sum()  # Sum across stations and time
        total_rainfall_by_member.append(total_rain)
        
        # Get final differential for this member (at hour 240)
        final_diff = ensemble_predictions_df[f'member_{member_idx}'].iloc[-1]
        final_differential_by_member.append(final_diff)
        
        member_labels.append(member_idx)

# Create scatter plot
fig, ax = plt.subplots(figsize=(12, 8))

# Plot each point
scatter = ax.scatter(total_rainfall_by_member, final_differential_by_member, 
                     c=member_labels, cmap='viridis', s=150, alpha=0.8, edgecolors='black', linewidths=1)

# Add member labels to each point
for i, (x, y, label) in enumerate(zip(total_rainfall_by_member, final_differential_by_member, member_labels)):
    ax.annotate(f'{label}', (x, y), textcoords="offset points", xytext=(5, 5), 
                fontsize=9, alpha=0.8)

# Add trend line
z = np.polyfit(total_rainfall_by_member, final_differential_by_member, 1)
p = np.poly1d(z)
x_line = np.linspace(min(total_rainfall_by_member), max(total_rainfall_by_member), 100)
ax.plot(x_line, p(x_line), 'r--', linewidth=2, alpha=0.7, label=f'Trend line')

# Calculate correlation
correlation = np.corrcoef(total_rainfall_by_member, final_differential_by_member)[0, 1]

# Add flag threshold lines using FLAG_BOUNDARIES
ax.axhline(y=FLAG_BOUNDARIES['light_blue'][0], color='green', linestyle=':', alpha=0.5, label='Green/Light Blue threshold')
ax.axhline(y=FLAG_BOUNDARIES['dark_blue'][0], color='cyan', linestyle=':', alpha=0.5, label='Light Blue/Dark Blue threshold')
ax.axhline(y=FLAG_BOUNDARIES['amber'][0], color='blue', linestyle=':', alpha=0.5, label='Dark Blue/Amber threshold')
ax.axhline(y=FLAG_BOUNDARIES['red'][0], color='orange', linestyle=':', alpha=0.5, label='Amber/Red threshold')

# Formatting
ax.set_xlabel('Total 10-Day Rainfall (mm, all stations)', fontsize=14, fontweight='bold')
ax.set_ylabel('Final River Differential at Day 10 (m)', fontsize=14, fontweight='bold')
ax.set_title(f'Rainfall vs River Response: How Different Rainfall Totals Affect Final River Level\n'
             f'Correlation: r = {correlation:.3f}', fontsize=14, fontweight='bold')

ax.grid(True, alpha=0.3)
ax.legend(loc='upper left', fontsize=9)

# Add colorbar for member index
cbar = plt.colorbar(scatter, ax=ax, label='Ensemble Member')

plt.tight_layout()
plt.savefig('rainfall_vs_differential.png', dpi=150, bbox_inches='tight')
plt.show()

# Print summary statistics
print(f"\nâœ“ Scatter plot complete!")
print(f"  Correlation coefficient: r = {correlation:.3f}")
print(f"\nSummary by ensemble member:")
print(f"{'Member':<8} {'Total Rain (mm)':<18} {'Final Diff (m)':<15} {'Flag':<12}")
print("-" * 55)

for member, rain, diff in sorted(zip(member_labels, total_rainfall_by_member, final_differential_by_member), key=lambda x: x[1]):
    if diff >= FLAG_BOUNDARIES['red'][0]:
        flag = 'RED'
    elif diff >= FLAG_BOUNDARIES['amber'][0]:
        flag = 'AMBER'
    elif diff >= FLAG_BOUNDARIES['dark_blue'][0]:
        flag = 'DARK BLUE'
    elif diff >= FLAG_BOUNDARIES['light_blue'][0]:
        flag = 'LIGHT BLUE'
    else:
        flag = 'GREEN'
    print(f"{member:<8} {rain:<18.1f} {diff:<15.3f} {flag:<12}")


In [None]:
# Step 7: Historical what-if forecast using ACTUAL future rainfall on 2020-01-20
# 
# This cell answers: "What would the model have predicted on 2020-01-20 if it had
# perfect knowledge of the future rainfall?"
# It compares that prediction with what actually happened from Jan 10â€“30, 2020.

# 1. Define forecast time (t0)
forecast_time = pd.Timestamp('2020-01-10 00:00:00')
if hasattr(isis_df_featureless.index, 'tz') and isis_df_featureless.index.tz is not None:
    forecast_time = forecast_time.tz_localize(isis_df_featureless.index.tz)

print("\n" + "="*80)
print("STEP 7: Historical what-if forecast on 2020-01-20 using ACTUAL rainfall")
print("="*80)
print(f"Forecast time (t0): {forecast_time}")

# 2. Build historical data up to t0
historical_until_t0 = isis_df_featureless.loc[:forecast_time].copy()

# Require at least 30 days of history for features and sequence_length hours for LSTM
min_history_hours = max(720, sequence_length + 1)
if len(historical_until_t0) < min_history_hours:
    raise ValueError(f"Not enough history before {forecast_time}. "
                     f"Need {min_history_hours} hours, got {len(historical_until_t0)}.")

# 3. Build a "forecast" rainfall DataFrame using ACTUAL future rainfall from Jan 20â€“30, 2020
start_future = forecast_time
end_future = forecast_time + pd.Timedelta(hours=240)  # 10 days

rainfall_station_cols = [c for c in isis_df_featureless.columns if c != 'differential']
actual_future_rainfall = isis_df_featureless.loc[start_future:end_future, rainfall_station_cols].copy()

if actual_future_rainfall.empty:
    raise ValueError("Actual future rainfall data for Jan 20â€“30, 2020 is missing in isis_df_featureless.")

print(f"Using ACTUAL rainfall from {actual_future_rainfall.index.min()} to {actual_future_rainfall.index.max()} for what-if forecast.")

# 4. Run a single-member prediction where the rainfall forecast = actual future rainfall
clairvoyant_pred = predict_single_ensemble_member(
    model=best_model,
    scaler=scaler,
    historical_df=historical_until_t0,
    rainfall_forecast_df=actual_future_rainfall,
    X_train_columns=X_isis.columns,
    sequence_length=sequence_length,
    horizons=horizons,
    verbose=True
)

# 5. Extract actual differential for comparison over Jan 10â€“30, 2020
window_start = pd.Timestamp('2020-01-10 00:00:00')
window_end = pd.Timestamp('2020-01-30 23:00:00')
if hasattr(isis_df_featureless.index, 'tz') and isis_df_featureless.index.tz is not None:
    window_start = window_start.tz_localize(isis_df_featureless.index.tz)
    window_end = window_end.tz_localize(isis_df_featureless.index.tz)

actual_diff_window = isis_df_featureless['differential'].loc[window_start:window_end]

# Align clairvoyant prediction to the same plotting window (predictions start at t0)
clairvoyant_window = clairvoyant_pred.loc[forecast_time:window_end]

# 6. Plot comparison: 10 days before and 10 days after t0
plt.figure(figsize=(14, 6))

# Plot actual history
plt.plot(actual_diff_window.index, actual_diff_window.values,
         label='Actual differential', color='black', linewidth=2)

# Plot clairvoyant prediction (only from t0 onwards)
plt.plot(clairvoyant_window.index, clairvoyant_window.values,
         label='Model prediction (with actual future rainfall)',
         color='tab:blue', linewidth=2)

# Mark forecast time
ymin, ymax = plt.ylim()
plt.axvline(forecast_time, color='red', linestyle='--', linewidth=1.5,
            label='Forecast time (2020-01-20 00:00)')

plt.title('What-if forecast on 2020-01-20 with actual rainfall\nJan 10â€“30, 2020')
plt.xlabel('Time')
plt.ylabel('River differential (m)')
plt.legend()
plt.grid(True, alpha=0.3)
plt.tight_layout()

plt.show()

In [None]:
# Step 8: Historical what-if forecasts for EACH day in a specified month
# 
# For each day in the specified month, this cell asks:
# "What would the model have predicted at 00Z on this day if it had perfect
# knowledge of the future rainfall?"
#
# We:
# 1. Use historical data up to t0 (the day's 00Z)
# 2. Use ACTUAL future rainfall from t0 onward as the "forecast"
# 3. Generate a 10-day clairvoyant prediction for each t0
# 4. Plot all predictions and the actual differential on a single plot

print("\n" + "="*80)
print("STEP 8: Historical what-if forecasts for EACH day in specified month (00Z)")
print("="*80)

# Configuration - CHOOSE YOUR MONTH AND YEAR HERE
year = 2024
month = 10

# Determine number of days in the month
import calendar
num_days = calendar.monthrange(year, month)[1]
month_days = range(1, num_days + 1)

# Timezone handling
if hasattr(isis_df_featureless.index, 'tz') and isis_df_featureless.index.tz is not None:
    tz = isis_df_featureless.index.tz
else:
    tz = None

# History requirement: 30 days of features + LSTM sequence length
min_history_hours = max(720, sequence_length + 1)

# Rainfall station columns (all non-differential columns)
rainfall_station_cols = [c for c in isis_df_featureless.columns if c != 'differential']

clairvoyant_preds = {}
skipped = []

for d in month_days:
    # 1) Define t0 for this day
    t0 = pd.Timestamp(f"{year}-{month:02d}-{d:02d} 00:00:00")
    if tz is not None:
        t0 = t0.tz_localize(tz)

    # Make sure we actually have a data point at t0
    if t0 not in isis_df_featureless.index:
        skipped.append((t0, "no data at t0"))
        continue

    # 2) Historical data up to t0
    hist_until_t0 = isis_df_featureless.loc[:t0].copy()
    if len(hist_until_t0) < min_history_hours:
        skipped.append((t0, f"not enough history ({len(hist_until_t0)} < {min_history_hours})"))
        continue

    # 3) Build ACTUAL future rainfall from t0 to t0+240h (10 days)
    start_future = t0
    end_future = t0 + pd.Timedelta(hours=240)
    actual_future_rainfall = isis_df_featureless.loc[start_future:end_future, rainfall_station_cols].copy()
    if actual_future_rainfall.empty:
        skipped.append((t0, "no future rainfall data"))
        continue

    print(f"\n--- Running what-if forecast for {t0} ---")

    # 4) Run what-if prediction for this t0
    pred = predict_single_ensemble_member(
        model=best_model,
        scaler=scaler,
        historical_df=hist_until_t0,
        rainfall_forecast_df=actual_future_rainfall,
        X_train_columns=X_isis.columns,
        sequence_length=sequence_length,
        horizons=horizons,
        verbose=False
    )

    clairvoyant_preds[t0] = pred

print(f"\nâœ“ Generated {len(clairvoyant_preds)} what-if forecasts")
if skipped:
    print("Skipped t0 values:")
    for t0, reason in skipped:
        print(f"  - {t0}: {reason}")

# Comparison window: full month plus 10 days after last t0
window_start = pd.Timestamp(f'{year}-{month:02d}-01 00:00:00')
# Calculate end of window (10 days after last day of month)
last_day_of_month = pd.Timestamp(f'{year}-{month:02d}-{num_days:02d} 23:00:00')
window_end = last_day_of_month + pd.Timedelta(days=10)
if tz is not None:
    window_start = window_start.tz_localize(tz)
    window_end = window_end.tz_localize(tz)

actual_diff_window = isis_df_featureless['differential'].loc[window_start:window_end]

# Get month name for title
month_name = calendar.month_name[month]

plt.figure(figsize=(16, 7))

# Plot actual differential
plt.plot(actual_diff_window.index, actual_diff_window.values,
         label='Actual differential', color='black', linewidth=2)

# Plot all clairvoyant predictions (one per day)
if clairvoyant_preds:
    n = len(clairvoyant_preds)
    colors = plt.cm.viridis(np.linspace(0, 1, n))

    for (t0, color) in zip(sorted(clairvoyant_preds.keys()), colors):
        series = clairvoyant_preds[t0]
        # Restrict to plotting window
        series_window = series.loc[window_start:window_end]
        plt.plot(series_window.index, series_window.values,
                 color=color, alpha=0.5)

    # Add a single legend entry for all prediction curves
    plt.plot([], [], color='tab:blue', alpha=0.5,
             label=f'Model predictions (one per day, clairvoyant rainfall)')

# Mark each forecast time with a faint vertical line
for t0 in sorted(clairvoyant_preds.keys()):
    plt.axvline(t0, color='red', linestyle=':', alpha=0.2)

plt.title(f'What-if forecasts using ACTUAL rainfall\nOne 10-day forecast per day in {month_name} {year} (00Z starts)')
plt.xlabel('Time')
plt.ylabel('River differential (m)')
plt.legend()
plt.grid(True, alpha=0.3)
plt.tight_layout()

plt.show()