In [1]:
import sys
from pathlib import Path

PROJECT_ROOT = Path.cwd().parent
sys.path.insert(0, str(PROJECT_ROOT))


In [2]:
import pandas as pd
import numpy as np

from sklearn.metrics import root_mean_squared_error
from statsmodels.tsa.ar_model import AutoReg

from src.data.dataloader import load_precipitation
from src.indices.spi import compute_spi
from src.splits.temporal import split_pre_post


In [3]:
with open("../data/curated/gauges.txt") as f:
    gauges = [line.strip() for line in f if line.strip()]

print(f"Number of gauges: {len(gauges)}")


Number of gauges: 18


In [4]:
def persistence_forecast(series):
    return series.shift(1)


In [5]:
def rmse(y_true, y_pred):
    return root_mean_squared_error(y_true, y_pred)


In [6]:
def extreme_rmse(y_true, y_pred, threshold=-1.5):
    mask = (
        (y_true < threshold) &
        y_true.notna() &
        y_pred.notna()
    )
    if mask.sum() < 5:  # guard against tiny samples
        return np.nan
    return rmse(y_true[mask], y_pred[mask])


In [7]:
gauge = gauges[0]  # pick the first gauge only

precip = load_precipitation(gauge)

spi3 = compute_spi(precip, scale=3)
spi6 = compute_spi(precip, scale=6)

for scale, spi in [("SPI-3", spi3), ("SPI-6", spi6)]:
    train, test = split_pre_post(spi, split_year=2000)

    if len(train) < 20 or len(test) < 20:
        print("Too short, skipping")
        continue

    print(scale, "train len:", len(train), "test len:", len(test))


SPI-3 train len: 250 test len: 240
SPI-6 train len: 247 test len: 240


In [8]:
# ---- Persistence only ----
p_train_pred = train.shift(1)
p_test_pred  = test.shift(1)

from sklearn.metrics import root_mean_squared_error

p_train_rmse = root_mean_squared_error(
    train.iloc[1:], p_train_pred.iloc[1:]
)
p_test_rmse = root_mean_squared_error(
    test.iloc[1:], p_test_pred.iloc[1:]
)

print(scale, "Persistence RMSE:", p_test_rmse)


SPI-6 Persistence RMSE: 0.6215794818483302


In [9]:
# ---- Extreme RMSE (Persistence) ----
y_true = test.iloc[1:]["SPI_6"]
y_pred = p_test_pred.iloc[1:]["SPI_6"]

mask = (y_true < -1.5) & y_true.notna() & y_pred.notna()

print("Extreme count:", (y_true < -1.5).sum())
print("Valid extreme count:", mask.sum())

if mask.sum() >= 5:
    rmse_ext = root_mean_squared_error(
        y_true[mask],
        y_pred[mask]
    )
    print(scale, "Persistence extreme RMSE:", rmse_ext)
else:
    print("Too few extremes")


Extreme count: 25
Valid extreme count: 25
SPI-6 Persistence extreme RMSE: 0.7499567711040134


In [10]:
from sklearn.metrics import root_mean_squared_error
from statsmodels.tsa.ar_model import AutoReg
import numpy as np

results = []

for gauge in gauges:
    try:
        # -----------------------------
        # Load data and compute SPI
        # -----------------------------
        precip = load_precipitation(gauge)

        spi3 = compute_spi(precip, scale=3)
        spi6 = compute_spi(precip, scale=6)

        for scale_name, spi in [("SPI_3", spi3), ("SPI_6", spi6)]:

            # -----------------------------
            # Temporal split
            # -----------------------------
            train, test = split_pre_post(spi, split_year=2000)

            # Guard against very short records
            if len(train) < 20 or len(test) < 20:
                continue

            # -----------------------------
            # Persistence baseline
            # -----------------------------
            p_train_pred = train.shift(1)
            p_test_pred  = test.shift(1)

            # Align for RMSE
            p_train_true = train.iloc[1:][scale_name]
            p_train_pred = p_train_pred.iloc[1:][scale_name]

            p_test_true = test.iloc[1:][scale_name]
            p_test_pred = p_test_pred.iloc[1:][scale_name]

            p_train_rmse = root_mean_squared_error(
                p_train_true, p_train_pred
            )
            p_test_rmse = root_mean_squared_error(
                p_test_true, p_test_pred
            )

            # Extreme RMSE (Persistence)
            mask_p = (
                (p_test_true < -1.5) &
                p_test_true.notna() &
                p_test_pred.notna()
            )

            p_ext_rmse = (
                root_mean_squared_error(
                    p_test_true[mask_p],
                    p_test_pred[mask_p],
                )
                if mask_p.sum() >= 5 else np.nan
            )

            # -----------------------------
            # AR(1) model (stationary)
            # -----------------------------
            ar1 = AutoReg(train[scale_name], lags=1, old_names=False).fit()

            # IMPORTANT: integer indexing avoids ambiguity
            ar_test_pred = ar1.predict(
                start=len(train),
                end=len(train) + len(test) - 1,
                dynamic=False
            )

            ar_test_pred.index = test.index[:len(ar_test_pred)]
            ar_test_true = test[scale_name].iloc[:len(ar_test_pred)]

            ar_test_rmse = root_mean_squared_error(
                ar_test_true, ar_test_pred
            )

            # Extreme RMSE (AR1)
            mask_ar = (
                (ar_test_true < -1.5) &
                ar_test_true.notna() &
                ar_test_pred.notna()
            )

            ar_ext_rmse = (
                root_mean_squared_error(
                    ar_test_true[mask_ar],
                    ar_test_pred[mask_ar],
                )
                if mask_ar.sum() >= 5 else np.nan
            )

            # -----------------------------
            # Store results
            # -----------------------------
            results.append({
                "gauge": gauge,
                "scale": scale_name,

                "persistence_test_rmse": p_test_rmse,
                "persistence_extreme_rmse": p_ext_rmse,

                "ar1_test_rmse": ar_test_rmse,
                "ar1_extreme_rmse": ar_ext_rmse,

                "ar1_rmse_inflation": ar_test_rmse / p_test_rmse
            })

    except Exception as e:
        print(f"Gauge {gauge} failed: {e}")


In [11]:
df = pd.DataFrame(results)
df.head()


Unnamed: 0,gauge,scale,persistence_test_rmse,persistence_extreme_rmse,ar1_test_rmse,ar1_extreme_rmse,ar1_rmse_inflation
0,3004,SPI_3,0.616729,0.634928,0.968496,1.713429,1.570375
1,3004,SPI_6,0.621579,0.749957,1.002406,2.044161,1.612676
2,3008,SPI_3,0.671802,1.033673,1.048878,1.990366,1.561291
3,3008,SPI_6,0.621432,0.724292,1.057742,2.118404,1.702104
4,3014,SPI_3,0.667378,0.857072,1.037127,1.847931,1.554032


In [12]:
numeric_cols = [
    "persistence_test_rmse",
    "persistence_extreme_rmse",
    "ar1_test_rmse",
    "ar1_extreme_rmse",
    "ar1_rmse_inflation",
]

df.groupby("scale")[numeric_cols].median()


Unnamed: 0_level_0,persistence_test_rmse,persistence_extreme_rmse,ar1_test_rmse,ar1_extreme_rmse,ar1_rmse_inflation
scale,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1
SPI_3,0.66959,0.93089,1.036804,1.924709,1.5567
SPI_6,0.632748,0.782897,1.060737,2.151054,1.666499
