In [1]:
# ==== PyTorch seq2seq with ROBUST preprocessing (median impute + safe standardize + safe MAE) ====
import duckdb, pandas as pd, numpy as np, datetime as dt, math
import torch, torch.nn as nn
from sklearn.metrics import mean_absolute_error

DB = "../ProjectMain/db/data.duckdb"
DELIVERY_DATE = pd.Timestamp("2025-08-18").date()   # change target date if needed
W_PAST = 672
H_FUT = 24
HOLDOUT_DAYS = 60
HIDDEN = 128
EPOCHS = 30
LR = 1e-2
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
rng = np.random.default_rng(42)
torch.manual_seed(42)

# ---------- helpers ----------
def fit_standardizer(X2d):
    """Return (mean,std) with std guarded (==0 -> 1). X2d: (N, F)"""
    mu = np.nanmean(X2d, axis=0)
    sd = np.nanstd(X2d, axis=0)
    sd = np.where(sd <= 1e-8, 1.0, sd)
    return mu, sd

def apply_standardizer(X, mu, sd):
    return (X - mu) / sd

# robust 3D imputer
def impute_inplace_3d(X, med_vec):
    """X: (N, T, F) or (N, W, F); med_vec: (F,). Fills NaNs in-place with per-feature medians."""
    mask = np.isnan(X)
    if mask.any():
        med_broadcast = np.broadcast_to(med_vec, X.shape)
        X[mask] = med_broadcast[mask]

# sanity checker
def assert_finite(name, arr):
    if not np.isfinite(arr).all():
        bad = int(np.isnan(arr).sum() + np.isinf(arr).sum())
        raise RuntimeError(f"{name} has {bad} non-finite values")

# ---------- ensure weather-enhanced views (hist + fcst) exist & helper transforms ----------
con_ensure = duckdb.connect(DB)
# Historical weather enhanced (to backfill when forecast is missing during training)
con_ensure.execute("""
CREATE OR REPLACE VIEW wx_hist_enh AS
WITH agg AS (
  SELECT
    OperatingDTM, "interval" AS Interval,
    (hist_temp_the_woodlands_tx + hist_temp_katy_tx + hist_temp_friendswood_tx + hist_temp_baytown_tx + hist_temp_houston_tx)/5.0 AS temp_avg,
    (hist_hum_the_woodlands_tx  + hist_hum_katy_tx  + hist_hum_friendswood_tx  + hist_hum_baytown_tx  + hist_hum_houston_tx)/5.0  AS hum_avg,
    GREATEST(hist_temp_the_woodlands_tx, hist_temp_katy_tx, hist_temp_friendswood_tx, hist_temp_baytown_tx, hist_temp_houston_tx)
      - LEAST(hist_temp_the_woodlands_tx, hist_temp_katy_tx, hist_temp_friendswood_tx, hist_temp_baytown_tx, hist_temp_houston_tx) AS temp_spread,
    GREATEST(hist_hum_the_woodlands_tx, hist_hum_katy_tx, hist_hum_friendswood_tx, hist_hum_baytown_tx, hist_hum_houston_tx)
      - LEAST(hist_hum_the_woodlands_tx, hist_hum_katy_tx, hist_hum_friendswood_tx, hist_hum_baytown_tx, hist_hum_houston_tx) AS hum_spread
  FROM vw_historical_weather_by_city
)
SELECT
  a.*,
  a.temp_avg - LAG(a.temp_avg) OVER (ORDER BY OperatingDTM, Interval) AS temp_avg_ramp1,
  a.hum_avg  - LAG(a.hum_avg)  OVER (ORDER BY OperatingDTM, Interval) AS hum_avg_ramp1
FROM agg a;
""")
con_ensure.close()

# Signed log transforms for spikes
sgn = np.sign
abs_ = np.abs

def sgn_log1p(y):
    return sgn(y) * np.log1p(abs_(y))

def sgn_expm1(z):
    return sgn(z) * np.expm1(abs_(z))

# Hour-of-day weights (emphasize 17-21 local hours)
HOUR_WEIGHTS = np.ones(24, dtype=np.float32)
# Reweight hours to reduce early overshoot and emphasize true peak
# index 0->H1, ..., 23->H24
HOUR_WEIGHTS[:] = 1.0
HOUR_WEIGHTS[15] = 1.5     # H16
HOUR_WEIGHTS[16:18] = 2.0  # H17-18
HOUR_WEIGHTS[18:20] = 5.0  # H19-20 (peak focus)
HOUR_WEIGHTS[20] = 3.0     # H21

# ---------- pull data ----------
con = duckdb.connect(DB)
df_lags = con.execute("""
SELECT
  ts, OperatingDTM, Interval, hb_houston,
  p_lag1,p_lag2,p_lag3,p_lag6,p_lag12,p_lag24,p_lag48,p_lag72,p_lag168,
  dp1,dp24,p_roll24_mean,p_roll24_std,p_roll72_mean,p_roll168_mean
FROM vw_master_spine_lags
ORDER BY ts
""").df()
df_lags["ts"] = pd.to_datetime(df_lags["ts"])

df_future = con.execute("""
SELECT
  f.OperatingDTM AS delivery_date,
  f.Interval     AS delivery_interval,
  f.hb_houston   AS target,
  -- base future features
  f.wz_southcentral, f.wz_east, f.wz_west, f.wz_northcentral, f.wz_farwest, f.wz_north, f.wz_southern, f.wz_coast,
  f.lz_north, f.lz_west, f.lz_south, f.lz_houston,
  f.cal_hour AS cal_hour_f, f.cal_dow AS cal_dow_f, f.cal_is_weekend AS cal_is_weekend_f,
  f.cal_sin_hour AS cal_sin_hour_f, f.cal_cos_hour AS cal_cos_hour_f,
  f.cal_sin_dow  AS cal_sin_dow_f,  f.cal_cos_dow  AS cal_cos_dow_f,
  -- load-enhanced joins
  l.net_wz, l.net_lz, l.net_wz_ramp1, l.net_lz_ramp1, l.lz_houston_ramp1, l.wz_spread, l.lz_spread,
  -- weather: COALESCE forecast -> historical
  COALESCE(wf.temp_avg,      wh.temp_avg)      AS temp_avg,
  COALESCE(wf.temp_spread,   wh.temp_spread)   AS temp_spread,
  COALESCE(wf.temp_avg_ramp1,wh.temp_avg_ramp1)AS temp_avg_ramp1,
  COALESCE(wf.hum_avg,       wh.hum_avg)       AS hum_avg,
  COALESCE(wf.hum_spread,    wh.hum_spread)    AS hum_spread,
  COALESCE(wf.hum_avg_ramp1, wh.hum_avg_ramp1) AS hum_avg_ramp1,
  CASE WHEN wf.temp_avg IS NULL THEN 1 ELSE 0 END AS is_weather_proxy,
  -- additional ramps & day-over-day deltas to help with peaks (computed over time order)
  (l.net_wz      - LAG(l.net_wz,      3)  OVER (ORDER BY f.OperatingDTM, f.Interval)) AS net_wz_ramp3,
  (l.net_wz      - LAG(l.net_wz,      6)  OVER (ORDER BY f.OperatingDTM, f.Interval)) AS net_wz_ramp6,
  (l.net_lz      - LAG(l.net_lz,      3)  OVER (ORDER BY f.OperatingDTM, f.Interval)) AS net_lz_ramp3,
  (l.net_lz      - LAG(l.net_lz,      6)  OVER (ORDER BY f.OperatingDTM, f.Interval)) AS net_lz_ramp6,
  (l.lz_houston  - LAG(l.lz_houston,  3)  OVER (ORDER BY f.OperatingDTM, f.Interval)) AS lz_houston_ramp3,
  (l.lz_houston  - LAG(l.lz_houston,  6)  OVER (ORDER BY f.OperatingDTM, f.Interval)) AS lz_houston_ramp6,
  (l.net_wz      - LAG(l.net_wz,     24)  OVER (ORDER BY f.OperatingDTM, f.Interval)) AS net_wz_dod,
  (l.net_lz      - LAG(l.net_lz,     24)  OVER (ORDER BY f.OperatingDTM, f.Interval)) AS net_lz_dod,
  (l.lz_houston  - LAG(l.lz_houston, 24)  OVER (ORDER BY f.OperatingDTM, f.Interval)) AS lz_houston_dod
FROM vw_master_spine_ts f
LEFT JOIN load_fcst_enh l
  ON l.OperatingDTM = f.OperatingDTM AND l.Interval = f.Interval
LEFT JOIN wx_fcst_enh  wf
  ON wf.OperatingDTM = f.OperatingDTM AND wf.Interval = f.Interval
LEFT JOIN wx_hist_enh  wh
  ON wh.OperatingDTM = f.OperatingDTM AND wh.Interval = f.Interval
ORDER BY f.OperatingDTM, f.Interval
""").df()
con.close()

df_future["delivery_date"] = pd.to_datetime(df_future["delivery_date"]).dt.date
df_future["delivery_interval"] = df_future["delivery_interval"].astype(int)

# ---------- assemble samples (day-ahead) ----------
enc_cols = [
    "hb_houston","p_lag1","p_lag2","p_lag3","p_lag6","p_lag12","p_lag24","p_lag48","p_lag72","p_lag168",
    "dp1","dp24","p_roll24_mean","p_roll24_std","p_roll72_mean","p_roll168_mean"
]
dec_future_cols = [
    "cal_hour_f","cal_dow_f","cal_is_weekend_f","cal_sin_hour_f","cal_cos_hour_f","cal_sin_dow_f","cal_cos_dow_f",
    "net_wz","net_lz","net_wz_ramp1","net_lz_ramp1","lz_houston_ramp1","wz_spread","lz_spread",
    "temp_avg","temp_spread","temp_avg_ramp1","hum_avg","hum_spread","hum_avg_ramp1",
    # NEW: multi-horizon ramps + day-over-day deltas
    "net_wz_ramp3","net_wz_ramp6","net_lz_ramp3","net_lz_ramp6",
    "lz_houston_ramp3","lz_houston_ramp6","net_wz_dod","net_lz_dod","lz_houston_dod",
    # Identify when weather came from historical backfill instead of forecast
    "is_weather_proxy"
]

# base rows = end-of-day with price
base_rows = df_lags[(df_lags["hb_houston"].notna()) & (df_lags["Interval"]==24)][["ts","OperatingDTM","hb_houston"]].copy()
base_rows["base_date"] = pd.to_datetime(base_rows["OperatingDTM"]).dt.date
base_rows["delivery_date"] = base_rows["base_date"] + dt.timedelta(days=1)

# only use delivery days with 24 actuals for training/val
has24 = df_future.groupby("delivery_date")["target"].apply(lambda s: s.notna().sum()==24)
valid_delivery_dates = set(has24[has24].index)

df_lags_idx = df_lags.set_index("ts").sort_index()

samples = []
for _, r in base_rows.iterrows():
    ddate = r["delivery_date"]
    ts0 = r["ts"]
    # encoder window
    start = ts0 - pd.Timedelta(hours=W_PAST-1)
    enc_df = df_lags_idx.loc[start:ts0][enc_cols]
    if enc_df.shape[0] != W_PAST: 
        continue
    if enc_df.isna().any().any():
        continue  # strict: skip any encoder NaNs
    # decoder future rows
    fdf = df_future[(df_future["delivery_date"]==ddate)].sort_values("delivery_interval")
    if fdf.shape[0] != H_FUT:
        continue
    X_enc = enc_df.values.astype(np.float32)
    X_dec = fdf[dec_future_cols].values.astype(np.float32)
    y     = fdf["target"].values.astype(np.float32)
    y0    = np.float32(r["hb_houston"])
    samples.append((X_enc, X_dec, y, y0, ddate))

if len(samples) == 0:
    raise RuntimeError("No samples assembled; check data coverage or reduce W_PAST.")

# split by date
dates = sorted({s[4] for s in samples})
cut_date = dates[-1] - dt.timedelta(days=HOLDOUT_DAYS) if len(dates)>HOLDOUT_DAYS else dates[0]
train_idx = [i for i,s in enumerate(samples) if s[4] <= cut_date and s[4] in valid_delivery_dates]
val_idx   = [i for i,s in enumerate(samples) if s[4] >  cut_date and s[4] in valid_delivery_dates]

def stack(idxs):
    Xe = np.stack([samples[i][0] for i in idxs], axis=0)
    Xd = np.stack([samples[i][1] for i in idxs], axis=0)
    Y  = np.stack([samples[i][2] for i in idxs], axis=0)
    Y0 = np.stack([samples[i][3] for i in idxs], axis=0)
    return Xe, Xd, Y, Y0

Xe_tr, Xd_tr, Y_tr, Y0_tr = stack(train_idx)
Xe_va, Xd_va, Y_va, Y0_va = (None, None, None, None)
if len(val_idx) > 0:
    Xe_va, Xd_va, Y_va, Y0_va = stack(val_idx)

# ---------- diagnose + drop all-NaN decoder features, then impute (robust) ----------
# Flatten decoder train to 2D to inspect per-feature missingness
F_orig = Xd_tr.shape[2]
Xd_tr_2d_raw = Xd_tr.reshape(-1, F_orig)
nan_counts = np.isnan(Xd_tr_2d_raw).sum(axis=0)
all_nan_cols = nan_counts == Xd_tr_2d_raw.shape[0]

if all_nan_cols.any():
    dropped_cols = [dec_future_cols[i] for i,flag in enumerate(all_nan_cols) if flag]
    print(f"Dropping {int(all_nan_cols.sum())} decoder feature(s) with all-NaN in TRAIN: {dropped_cols}")
else:
    dropped_cols = []

# Keep mask to be reused at prediction time
dec_keep_mask = ~all_nan_cols

# Apply mask to decoder tensors
Xd_tr = Xd_tr[:, :, dec_keep_mask]
if Xe_va is not None:
    Xd_va = Xd_va[:, :, dec_keep_mask]

# Recompute medians on KEPT features
dec_med = np.nanmedian(Xd_tr.reshape(-1, Xd_tr.shape[2]), axis=0)
# Guard: if a median is still NaN for some reason, set to 0
dec_med = np.where(np.isnan(dec_med), 0.0, dec_med)

# Encoder median (rare but safe)
enc_med = np.nanmedian(Xe_tr.reshape(-1, Xe_tr.shape[2]), axis=0)
enc_med = np.where(np.isnan(enc_med), 0.0, enc_med)

# Impute in-place (3D-safe)
impute_inplace_3d(Xe_tr, enc_med)
impute_inplace_3d(Xd_tr, dec_med)
if Xe_va is not None:
    impute_inplace_3d(Xe_va, enc_med)
    impute_inplace_3d(Xd_va, dec_med)

# Targets should be complete for training/val; assert
assert np.isfinite(Y_tr).all(), "Y_tr has non-finite values"
if Y_va is not None:
    assert np.isfinite(Y_va).all(), "Y_va has non-finite values"

# ---------- safe standardization ----------
enc_mu, enc_sd = fit_standardizer(Xe_tr.reshape(-1, Xe_tr.shape[2]))
dec_mu, dec_sd = fit_standardizer(Xd_tr.reshape(-1, Xd_tr.shape[2]))
# ---- target transform (spike-aware) then standardize ----
Y_tr_t  = sgn_log1p(Y_tr)
y_mu, y_sd = fit_standardizer(Y_tr_t.reshape(-1,1))
y_mu = float(np.atleast_1d(y_mu)[0]); y_sd = float(np.atleast_1d(y_sd)[0])

Xe_tr_n = apply_standardizer(Xe_tr, enc_mu, enc_sd)
Xd_tr_n = apply_standardizer(Xd_tr, dec_mu, dec_sd)
Y_tr_n  = apply_standardizer(Y_tr_t, y_mu, y_sd).reshape(Y_tr.shape)
Y0_tr_t = sgn_log1p(Y0_tr.reshape(-1,1)).reshape(-1,1)
Y0_tr_n = apply_standardizer(Y0_tr_t, y_mu, y_sd).reshape(-1)

if Xe_va is not None:
    Xe_va_n = apply_standardizer(Xe_va, enc_mu, enc_sd)
    Xd_va_n = apply_standardizer(Xd_va, dec_mu, dec_sd)
    Y_va_t  = sgn_log1p(Y_va)
    Y_va_n  = apply_standardizer(Y_va_t,  y_mu,  y_sd).reshape(Y_va.shape)
    Y0_va_t = sgn_log1p(Y0_va.reshape(-1,1))
    Y0_va_n = apply_standardizer(Y0_va_t, y_mu, y_sd).reshape(-1)


# Final safety: ensure everything is finite
assert_finite("Xe_tr_n", Xe_tr_n)
assert_finite("Xd_tr_n", Xd_tr_n)
assert_finite("Y_tr_n",  Y_tr_n)
assert_finite("Y0_tr_n", Y0_tr_n)
if Xe_va is not None:
    assert_finite("Xe_va_n", Xe_va_n)
    assert_finite("Xd_va_n", Xd_va_n)
    assert_finite("Y_va_n",  Y_va_n)
    assert_finite("Y0_va_n", Y0_va_n)

# ---------- torch datasets ----------
class Seq2SeqDataset(torch.utils.data.Dataset):
    def __init__(self, Xe, Xd, Y, Y0):
        self.Xe = torch.from_numpy(Xe).float()
        self.Xd = torch.from_numpy(Xd).float()
        self.Y  = torch.from_numpy(Y).float()
        self.Y0 = torch.from_numpy(Y0).float()
    def __len__(self): return self.Xe.shape[0]
    def __getitem__(self, i): return self.Xe[i], self.Xd[i], self.Y[i], self.Y0[i]

bs = 32
train_loader = torch.utils.data.DataLoader(Seq2SeqDataset(Xe_tr_n, Xd_tr_n, Y_tr_n, Y0_tr_n), batch_size=bs, shuffle=True)
val_loader = None if Xe_va is None else torch.utils.data.DataLoader(Seq2SeqDataset(Xe_va_n, Xd_va_n, Y_va_n, Y0_va_n), batch_size=bs, shuffle=False)

# ---------- model ----------
class Encoder(nn.Module):
    def __init__(self, in_dim, hidden):
        super().__init__()
        self.rnn = nn.GRU(in_dim, hidden, batch_first=True)
    def forward(self, x):  # x: (B,W,E)
        _, h = self.rnn(x)
        return h  # (1,B,H)

class Decoder(nn.Module):
    def __init__(self, in_dim, hidden):
        super().__init__()
        self.rnn = nn.GRU(in_dim + 1, hidden, batch_first=True)  # + prev y
        self.out = nn.Linear(hidden, 1)
    def forward(self, future_feats, h0, y0, teacher=None, tf_prob=0.5):
        B, T, D = future_feats.shape
        y_prev = y0.view(B,1)
        h = h0
        outs = []
        for t in range(T):
            x_t = torch.cat([future_feats[:,t,:], y_prev], dim=1).unsqueeze(1)  # (B,1,D+1)
            o, h = self.rnn(x_t, h)
            y_t = self.out(o[:, -1, :]).squeeze(1)
            outs.append(y_t)
            if (teacher is not None) and (rng.random() < tf_prob):
                y_prev = teacher[:,t].view(B,1)
            else:
                y_prev = y_t.view(B,1)
        return torch.stack(outs, dim=1)  # (B,24)

class Seq2Seq(nn.Module):
    def __init__(self, enc_in, dec_in, hidden):
        super().__init__()
        self.enc = Encoder(enc_in, hidden)
        self.dec = Decoder(dec_in, hidden)
    def forward(self, x_enc, x_dec, y0, y_teacher=None, tf_prob=0.5):
        h = self.enc(x_enc)
        return self.dec(x_dec, h, y0, y_teacher, tf_prob)

enc_in = Xe_tr_n.shape[2]
dec_in = Xd_tr_n.shape[2]
model = Seq2Seq(enc_in, dec_in, HIDDEN).to(DEVICE)
opt = torch.optim.Adam(model.parameters(), lr=LR)
loss_fn = nn.HuberLoss()

def huber_weighted(yhat, y, w, delta=1.0):
    d = torch.abs(yhat - y)
    q = torch.clamp(d, max=delta)
    l = 0.5 * q**2 + delta * (d - q)
    return (l * w).mean()

def train_epoch():
    model.train()
    tot = 0.0
    for xenc, xdec, y, y0 in train_loader:
        xenc, xdec, y, y0 = xenc.to(DEVICE), xdec.to(DEVICE), y.to(DEVICE), y0.to(DEVICE)
        opt.zero_grad()
        yhat = model(xenc, xdec, y0, y_teacher=y, tf_prob=0.5)
        # time-step weights (B,24)
        w = torch.tensor(HOUR_WEIGHTS, device=y.device).unsqueeze(0).expand_as(y)
        loss = huber_weighted(yhat, y, w)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        opt.step()
        tot += loss.item() * xenc.size(0)
    return tot / len(train_loader.dataset)

@torch.no_grad()
def eval_epoch():
    if val_loader is None: return None
    model.eval()
    tot = 0.0
    for xenc, xdec, y, y0 in val_loader:
        xenc, xdec, y, y0 = xenc.to(DEVICE), xdec.to(DEVICE), y.to(DEVICE), y0.to(DEVICE)
        yhat = model(xenc, xdec, y0, y_teacher=None, tf_prob=0.0)
        w = torch.tensor(HOUR_WEIGHTS, device=y.device).unsqueeze(0).expand_as(y)
        loss = huber_weighted(yhat, y, w)
        tot += loss.item() * xenc.size(0)
    return tot / len(val_loader.dataset)

for ep in range(1, EPOCHS+1):
    tr = train_epoch()
    va = eval_epoch()
    print(f"Epoch {ep:02d}  train_loss={tr:.4f}" + ("" if va is None else f"  val_loss={va:.4f}"))

# ---- evaluate holdout MAE in real units (SAFE) ----
if val_loader is not None:
    model.eval()
    preds, trues = [], []
    with torch.no_grad():
        for xenc, xdec, y, y0 in val_loader:
            xenc, xdec, y, y0 = xenc.to(DEVICE), xdec.to(DEVICE), y.to(DEVICE), y0.to(DEVICE)
            yhat_n = model(xenc, xdec, y0, y_teacher=None, tf_prob=0.0).detach().cpu().numpy()
            y_n    = y.detach().cpu().numpy()
            # invert scaling + inverse spike transform
            yhat_t = (yhat_n * y_sd + y_mu)
            yt_t   = (y_n    * y_sd + y_mu)
            yhat = sgn_expm1(yhat_t).reshape(-1)
            yt   = sgn_expm1(yt_t).reshape(-1)
            preds.append(yhat)
            trues.append(yt)
    preds = np.concatenate(preds)
    trues = np.concatenate(trues)
    mask = np.isfinite(preds) & np.isfinite(trues)
    dropped = int((~mask).sum())
    if dropped:
        print(f"Dropped {dropped} non-finite pairs from holdout scoring.")
    print({"holdout_MAE": float(mean_absolute_error(trues[mask], preds[mask]))})

# ---- optional LightGBM residual booster (train on TRAIN residuals) ----
USE_LGBM_RESIDUAL = True
booster = None
if USE_LGBM_RESIDUAL:
    try:
        import lightgbm as lgb
        # Build a non-shuffled loader to preserve order
        train_loader_eval = torch.utils.data.DataLoader(
            Seq2SeqDataset(Xe_tr_n, Xd_tr_n, Y_tr_n, Y0_tr_n), batch_size=64, shuffle=False)
        base_preds = []
        with torch.no_grad():
            model.eval()
            for xenc, xdec, y, y0 in train_loader_eval:
                xenc, xdec, y, y0 = xenc.to(DEVICE), xdec.to(DEVICE), y.to(DEVICE), y0.to(DEVICE)
                yhat_n = model(xenc, xdec, y0, y_teacher=None, tf_prob=0.0).detach().cpu().numpy()
                base_preds.append(yhat_n)
        base_preds = np.concatenate(base_preds, axis=0)  # (N_train, 24)
        # inverse transforms -> true $/MWh
        base_preds_t = (base_preds * y_sd + y_mu)
        base_preds_true = sgn_expm1(base_preds_t)
        y_true_tr = Y_tr  # already true units
        # flatten
        y_base_flat = base_preds_true.reshape(-1)
        y_true_flat = y_true_tr.reshape(-1)
        X_flat = Xd_tr.reshape(-1, Xd_tr.shape[2])  # kept raw decoder features
        mask = np.isfinite(y_base_flat) & np.isfinite(y_true_flat) & np.isfinite(X_flat).all(axis=1)
        X_flat = X_flat[mask]
        y_resid_flat = (y_true_flat - y_base_flat)[mask]
        # Fit booster
        booster = lgb.LGBMRegressor(
            n_estimators=600, learning_rate=0.03, subsample=0.8, colsample_bytree=0.8,
            max_depth=-1, reg_alpha=0.0, reg_lambda=0.0, min_child_samples=20, random_state=42)
        booster.fit(X_flat, y_resid_flat)
        print("Residual booster trained (LightGBM).")
    except Exception as e:
        print(f"LightGBM residual booster skipped: {e}")

# ---- predict requested DELIVERY_DATE ----

# find base row (previous day EOD with price)
prev_day = DELIVERY_DATE - dt.timedelta(days=1)
base_row = df_lags[(pd.to_datetime(df_lags["OperatingDTM"]).dt.date == prev_day) & (df_lags["Interval"]==24) & df_lags["hb_houston"].notna()]
if base_row.empty:
    raise RuntimeError(f"No base EOD price for {prev_day}")
ts0 = pd.to_datetime(base_row.iloc[0]["ts"])
y0  = np.float32(base_row.iloc[0]["hb_houston"])

# encoder window
win = df_lags.set_index("ts").loc[ts0 - pd.Timedelta(hours=W_PAST-1): ts0][enc_cols]
if win.shape[0] != W_PAST or win.isna().any().any():
    raise RuntimeError("Insufficient/NaN history for encoder window.")
Xe_pred = win.values.astype(np.float32)[None, ...]

# decoder future features (24 rows)
frows = df_future[(df_future["delivery_date"]==DELIVERY_DATE)].sort_values("delivery_interval")
if frows.shape[0] != H_FUT:
    raise RuntimeError("Spine missing 24 rows for delivery date.")
Xd_pred_full = frows[dec_future_cols].values.astype(np.float32)[None, ...]

# Apply the same keep mask used in training (drop columns that were all-NaN in train)
if 'dec_keep_mask' in globals():
    Xd_pred = Xd_pred_full[:, :, dec_keep_mask]
else:
    Xd_pred = Xd_pred_full

# impute + standardize pred using TRAIN stats
m = np.isnan(Xd_pred)
if m.any():
    Xd_pred[m] = np.broadcast_to(dec_med, Xd_pred.shape)[m]
Xe_pred_n = apply_standardizer(Xe_pred, enc_mu, enc_sd)
Xd_pred_n = apply_standardizer(Xd_pred, dec_mu, dec_sd)
# compute normalized scalar y0_n using spike-aware transform
y0_t = sgn_log1p(y0)
y0_n = np.float32(((y0_t - y_mu) / y_sd)).ravel()[0]

model.eval()
with torch.no_grad():
    yhat_n = model(torch.from_numpy(Xe_pred_n).to(DEVICE),
                   torch.from_numpy(Xd_pred_n).to(DEVICE),
                   torch.tensor([y0_n], dtype=torch.float32, device=DEVICE),
                   y_teacher=None, tf_prob=0.0).detach().cpu().numpy()[0]
yhat_t = (yhat_n * y_sd + y_mu).reshape(-1)
yhat = sgn_expm1(yhat_t)
# --- optional: add residual booster only on peak hours (H18–H21) ---
if booster is not None:
    Xd_pred_kept = Xd_pred.copy()  # (1,24,F_kept)
    resid_pred = booster.predict(Xd_pred_kept.reshape(24, Xd_pred_kept.shape[2]))
    hour_mask = np.zeros(24, dtype=np.float32)
    hour_mask[17:21] = 1.0  # apply only to intervals 18,19,20,21
    yhat = yhat + resid_pred * hour_mask

# replace any residual NaN with the day's median
if np.isnan(yhat).any():
    day_med = float(np.nanmedian(yhat))
    yhat = np.nan_to_num(yhat, nan=day_med)

forecast = pd.DataFrame({
    "OperatingDTM": [DELIVERY_DATE]*H_FUT,
    "Interval": list(range(1, H_FUT+1)),
    "hb_houston_pred": yhat.tolist()
})
forecast


Epoch 01  train_loss=0.3712  val_loss=0.2919
Epoch 02  train_loss=0.2066  val_loss=0.3048
Epoch 03  train_loss=0.1673  val_loss=0.2532
Epoch 04  train_loss=0.1640  val_loss=0.3110
Epoch 05  train_loss=0.1563  val_loss=0.2003
Epoch 06  train_loss=0.1477  val_loss=0.3941
Epoch 07  train_loss=0.1426  val_loss=1.0057
Epoch 08  train_loss=0.1434  val_loss=0.2375
Epoch 09  train_loss=0.1211  val_loss=0.3437
Epoch 10  train_loss=0.1160  val_loss=0.2721
Epoch 11  train_loss=0.1147  val_loss=0.1732
Epoch 12  train_loss=0.1046  val_loss=1.3323
Epoch 13  train_loss=0.1142  val_loss=0.2677
Epoch 14  train_loss=0.1034  val_loss=0.1915
Epoch 15  train_loss=0.1139  val_loss=0.1490
Epoch 16  train_loss=0.1015  val_loss=0.2421
Epoch 17  train_loss=0.1008  val_loss=0.2169
Epoch 18  train_loss=0.0946  val_loss=0.1399
Epoch 19  train_loss=0.0827  val_loss=0.2168
Epoch 20  train_loss=0.0862  val_loss=0.1368
Epoch 21  train_loss=0.0849  val_loss=0.2953
Epoch 22  train_loss=0.0754  val_loss=0.1221
Epoch 23  



Unnamed: 0,OperatingDTM,Interval,hb_houston_pred
0,2025-08-18,1,31.504337
1,2025-08-18,2,28.639776
2,2025-08-18,3,27.46483
3,2025-08-18,4,27.254873
4,2025-08-18,5,26.091288
5,2025-08-18,6,24.651674
6,2025-08-18,7,26.005545
7,2025-08-18,8,25.442276
8,2025-08-18,9,21.579998
9,2025-08-18,10,18.756807


In [2]:
# ==== PyTorch seq2seq with ROBUST preprocessing (median impute + safe standardize + safe MAE) ====
import duckdb, pandas as pd, numpy as np, datetime as dt, math
import torch, torch.nn as nn
from sklearn.metrics import mean_absolute_error

DB = "../ProjectMain/db/data.duckdb"
DELIVERY_DATE = pd.Timestamp("2025-08-18").date()   # change target date if needed
W_PAST = 168
H_FUT = 24
HOLDOUT_DAYS = 60
HIDDEN = 128
EPOCHS = 30
LR = 1e-2
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
rng = np.random.default_rng(42)
torch.manual_seed(42)

# ---------- helpers ----------
def fit_standardizer(X2d):
    """Return (mean,std) with std guarded (==0 -> 1). X2d: (N, F)"""
    mu = np.nanmean(X2d, axis=0)
    sd = np.nanstd(X2d, axis=0)
    sd = np.where(sd <= 1e-8, 1.0, sd)
    return mu, sd

def apply_standardizer(X, mu, sd):
    return (X - mu) / sd

# robust 3D imputer
def impute_inplace_3d(X, med_vec):
    """X: (N, T, F) or (N, W, F); med_vec: (F,). Fills NaNs in-place with per-feature medians."""
    mask = np.isnan(X)
    if mask.any():
        med_broadcast = np.broadcast_to(med_vec, X.shape)
        X[mask] = med_broadcast[mask]

# sanity checker
def assert_finite(name, arr):
    if not np.isfinite(arr).all():
        bad = int(np.isnan(arr).sum() + np.isinf(arr).sum())
        raise RuntimeError(f"{name} has {bad} non-finite values")

# ---------- ensure weather-enhanced views (hist + fcst) exist & helper transforms ----------
con_ensure = duckdb.connect(DB)
# Historical weather enhanced (to backfill when forecast is missing during training)
con_ensure.execute("""
CREATE OR REPLACE VIEW wx_hist_enh AS
WITH agg AS (
  SELECT
    OperatingDTM, "interval" AS Interval,
    (hist_temp_the_woodlands_tx + hist_temp_katy_tx + hist_temp_friendswood_tx + hist_temp_baytown_tx + hist_temp_houston_tx)/5.0 AS temp_avg,
    (hist_hum_the_woodlands_tx  + hist_hum_katy_tx  + hist_hum_friendswood_tx  + hist_hum_baytown_tx  + hist_hum_houston_tx)/5.0  AS hum_avg,
    GREATEST(hist_temp_the_woodlands_tx, hist_temp_katy_tx, hist_temp_friendswood_tx, hist_temp_baytown_tx, hist_temp_houston_tx)
      - LEAST(hist_temp_the_woodlands_tx, hist_temp_katy_tx, hist_temp_friendswood_tx, hist_temp_baytown_tx, hist_temp_houston_tx) AS temp_spread,
    GREATEST(hist_hum_the_woodlands_tx, hist_hum_katy_tx, hist_hum_friendswood_tx, hist_hum_baytown_tx, hist_hum_houston_tx)
      - LEAST(hist_hum_the_woodlands_tx, hist_hum_katy_tx, hist_hum_friendswood_tx, hist_hum_baytown_tx, hist_hum_houston_tx) AS hum_spread
  FROM vw_historical_weather_by_city
)
SELECT
  a.*,
  a.temp_avg - LAG(a.temp_avg) OVER (ORDER BY OperatingDTM, Interval) AS temp_avg_ramp1,
  a.hum_avg  - LAG(a.hum_avg)  OVER (ORDER BY OperatingDTM, Interval) AS hum_avg_ramp1
FROM agg a;
""")
con_ensure.close()

# Signed log transforms for spikes
sgn = np.sign
abs_ = np.abs

def sgn_log1p(y):
    return sgn(y) * np.log1p(abs_(y))

def sgn_expm1(z):
    return sgn(z) * np.expm1(abs_(z))

# Hour-of-day weights (emphasize 17-21 local hours)
HOUR_WEIGHTS = np.ones(24, dtype=np.float32)
# Reweight hours to reduce early overshoot and emphasize true peak
# index 0->H1, ..., 23->H24
HOUR_WEIGHTS[:] = 1.0
HOUR_WEIGHTS[15] = 1.5     # H16
HOUR_WEIGHTS[16:18] = 2.0  # H17-18
HOUR_WEIGHTS[18:20] = 5.0  # H19-20 (peak focus)
HOUR_WEIGHTS[20] = 3.0     # H21

# ---------- pull data ----------
con = duckdb.connect(DB)
df_lags = con.execute("""
SELECT
  ts, OperatingDTM, Interval, hb_houston,
  p_lag1,p_lag2,p_lag3,p_lag6,p_lag12,p_lag24,p_lag48,p_lag72,p_lag168,
  dp1,dp24,p_roll24_mean,p_roll24_std,p_roll72_mean,p_roll168_mean
FROM vw_master_spine_lags
ORDER BY ts
""").df()
df_lags["ts"] = pd.to_datetime(df_lags["ts"])

df_future = con.execute("""
SELECT
  f.OperatingDTM AS delivery_date,
  f.Interval     AS delivery_interval,
  f.hb_houston   AS target,
  -- base future features
  f.wz_southcentral, f.wz_east, f.wz_west, f.wz_northcentral, f.wz_farwest, f.wz_north, f.wz_southern, f.wz_coast,
  f.lz_north, f.lz_west, f.lz_south, f.lz_houston,
  f.cal_hour AS cal_hour_f, f.cal_dow AS cal_dow_f, f.cal_is_weekend AS cal_is_weekend_f,
  f.cal_sin_hour AS cal_sin_hour_f, f.cal_cos_hour AS cal_cos_hour_f,
  f.cal_sin_dow  AS cal_sin_dow_f,  f.cal_cos_dow  AS cal_cos_dow_f,
  -- load-enhanced joins
  l.net_wz, l.net_lz, l.net_wz_ramp1, l.net_lz_ramp1, l.lz_houston_ramp1, l.wz_spread, l.lz_spread,
  -- weather: COALESCE forecast -> historical
  COALESCE(wf.temp_avg,      wh.temp_avg)      AS temp_avg,
  COALESCE(wf.temp_spread,   wh.temp_spread)   AS temp_spread,
  COALESCE(wf.temp_avg_ramp1,wh.temp_avg_ramp1)AS temp_avg_ramp1,
  COALESCE(wf.hum_avg,       wh.hum_avg)       AS hum_avg,
  COALESCE(wf.hum_spread,    wh.hum_spread)    AS hum_spread,
  COALESCE(wf.hum_avg_ramp1, wh.hum_avg_ramp1) AS hum_avg_ramp1,
  CASE WHEN wf.temp_avg IS NULL THEN 1 ELSE 0 END AS is_weather_proxy,
  -- additional ramps & day-over-day deltas to help with peaks (computed over time order)
  (l.net_wz      - LAG(l.net_wz,      3)  OVER (ORDER BY f.OperatingDTM, f.Interval)) AS net_wz_ramp3,
  (l.net_wz      - LAG(l.net_wz,      6)  OVER (ORDER BY f.OperatingDTM, f.Interval)) AS net_wz_ramp6,
  (l.net_lz      - LAG(l.net_lz,      3)  OVER (ORDER BY f.OperatingDTM, f.Interval)) AS net_lz_ramp3,
  (l.net_lz      - LAG(l.net_lz,      6)  OVER (ORDER BY f.OperatingDTM, f.Interval)) AS net_lz_ramp6,
  (l.lz_houston  - LAG(l.lz_houston,  3)  OVER (ORDER BY f.OperatingDTM, f.Interval)) AS lz_houston_ramp3,
  (l.lz_houston  - LAG(l.lz_houston,  6)  OVER (ORDER BY f.OperatingDTM, f.Interval)) AS lz_houston_ramp6,
  (l.net_wz      - LAG(l.net_wz,     24)  OVER (ORDER BY f.OperatingDTM, f.Interval)) AS net_wz_dod,
  (l.net_lz      - LAG(l.net_lz,     24)  OVER (ORDER BY f.OperatingDTM, f.Interval)) AS net_lz_dod,
  (l.lz_houston  - LAG(l.lz_houston, 24)  OVER (ORDER BY f.OperatingDTM, f.Interval)) AS lz_houston_dod
FROM vw_master_spine_ts f
LEFT JOIN load_fcst_enh l
  ON l.OperatingDTM = f.OperatingDTM AND l.Interval = f.Interval
LEFT JOIN wx_fcst_enh  wf
  ON wf.OperatingDTM = f.OperatingDTM AND wf.Interval = f.Interval
LEFT JOIN wx_hist_enh  wh
  ON wh.OperatingDTM = f.OperatingDTM AND wh.Interval = f.Interval
ORDER BY f.OperatingDTM, f.Interval
""").df()
con.close()

df_future["delivery_date"] = pd.to_datetime(df_future["delivery_date"]).dt.date
df_future["delivery_interval"] = df_future["delivery_interval"].astype(int)

# ---------- assemble samples (day-ahead) ----------
enc_cols = [
    "hb_houston","p_lag1","p_lag2","p_lag3","p_lag6","p_lag12","p_lag24","p_lag48","p_lag72","p_lag168",
    "dp1","dp24","p_roll24_mean","p_roll24_std","p_roll72_mean","p_roll168_mean"
]
dec_future_cols = [
    "cal_hour_f","cal_dow_f","cal_is_weekend_f","cal_sin_hour_f","cal_cos_hour_f","cal_sin_dow_f","cal_cos_dow_f",
    "net_wz","net_lz","net_wz_ramp1","net_lz_ramp1","lz_houston_ramp1","wz_spread","lz_spread",
    "temp_avg","temp_spread","temp_avg_ramp1","hum_avg","hum_spread","hum_avg_ramp1",
    # NEW: multi-horizon ramps + day-over-day deltas
    "net_wz_ramp3","net_wz_ramp6","net_lz_ramp3","net_lz_ramp6",
    "lz_houston_ramp3","lz_houston_ramp6","net_wz_dod","net_lz_dod","lz_houston_dod",
    # Identify when weather came from historical backfill instead of forecast
    "is_weather_proxy"
]

# base rows = end-of-day with price
base_rows = df_lags[(df_lags["hb_houston"].notna()) & (df_lags["Interval"]==24)][["ts","OperatingDTM","hb_houston"]].copy()
base_rows["base_date"] = pd.to_datetime(base_rows["OperatingDTM"]).dt.date
base_rows["delivery_date"] = base_rows["base_date"] + dt.timedelta(days=1)

# only use delivery days with 24 actuals for training/val
has24 = df_future.groupby("delivery_date")["target"].apply(lambda s: s.notna().sum()==24)
valid_delivery_dates = set(has24[has24].index)

df_lags_idx = df_lags.set_index("ts").sort_index()

samples = []
for _, r in base_rows.iterrows():
    ddate = r["delivery_date"]
    ts0 = r["ts"]
    # encoder window
    start = ts0 - pd.Timedelta(hours=W_PAST-1)
    enc_df = df_lags_idx.loc[start:ts0][enc_cols]
    if enc_df.shape[0] != W_PAST: 
        continue
    if enc_df.isna().any().any():
        continue  # strict: skip any encoder NaNs
    # decoder future rows
    fdf = df_future[(df_future["delivery_date"]==ddate)].sort_values("delivery_interval")
    if fdf.shape[0] != H_FUT:
        continue
    X_enc = enc_df.values.astype(np.float32)
    X_dec = fdf[dec_future_cols].values.astype(np.float32)
    y     = fdf["target"].values.astype(np.float32)
    y0    = np.float32(r["hb_houston"])
    samples.append((X_enc, X_dec, y, y0, ddate))

if len(samples) == 0:
    raise RuntimeError("No samples assembled; check data coverage or reduce W_PAST.")

# split by date
dates = sorted({s[4] for s in samples})
cut_date = dates[-1] - dt.timedelta(days=HOLDOUT_DAYS) if len(dates)>HOLDOUT_DAYS else dates[0]
train_idx = [i for i,s in enumerate(samples) if s[4] <= cut_date and s[4] in valid_delivery_dates]
val_idx   = [i for i,s in enumerate(samples) if s[4] >  cut_date and s[4] in valid_delivery_dates]

def stack(idxs):
    Xe = np.stack([samples[i][0] for i in idxs], axis=0)
    Xd = np.stack([samples[i][1] for i in idxs], axis=0)
    Y  = np.stack([samples[i][2] for i in idxs], axis=0)
    Y0 = np.stack([samples[i][3] for i in idxs], axis=0)
    return Xe, Xd, Y, Y0

Xe_tr, Xd_tr, Y_tr, Y0_tr = stack(train_idx)
Xe_va, Xd_va, Y_va, Y0_va = (None, None, None, None)
if len(val_idx) > 0:
    Xe_va, Xd_va, Y_va, Y0_va = stack(val_idx)

# ---------- diagnose + drop all-NaN decoder features, then impute (robust) ----------
# Flatten decoder train to 2D to inspect per-feature missingness
F_orig = Xd_tr.shape[2]
Xd_tr_2d_raw = Xd_tr.reshape(-1, F_orig)
nan_counts = np.isnan(Xd_tr_2d_raw).sum(axis=0)
all_nan_cols = nan_counts == Xd_tr_2d_raw.shape[0]

if all_nan_cols.any():
    dropped_cols = [dec_future_cols[i] for i,flag in enumerate(all_nan_cols) if flag]
    print(f"Dropping {int(all_nan_cols.sum())} decoder feature(s) with all-NaN in TRAIN: {dropped_cols}")
else:
    dropped_cols = []

# Keep mask to be reused at prediction time
dec_keep_mask = ~all_nan_cols

# Apply mask to decoder tensors
Xd_tr = Xd_tr[:, :, dec_keep_mask]
if Xe_va is not None:
    Xd_va = Xd_va[:, :, dec_keep_mask]

# Recompute medians on KEPT features
dec_med = np.nanmedian(Xd_tr.reshape(-1, Xd_tr.shape[2]), axis=0)
# Guard: if a median is still NaN for some reason, set to 0
dec_med = np.where(np.isnan(dec_med), 0.0, dec_med)

# Encoder median (rare but safe)
enc_med = np.nanmedian(Xe_tr.reshape(-1, Xe_tr.shape[2]), axis=0)
enc_med = np.where(np.isnan(enc_med), 0.0, enc_med)

# Impute in-place (3D-safe)
impute_inplace_3d(Xe_tr, enc_med)
impute_inplace_3d(Xd_tr, dec_med)
if Xe_va is not None:
    impute_inplace_3d(Xe_va, enc_med)
    impute_inplace_3d(Xd_va, dec_med)

# Targets should be complete for training/val; assert
assert np.isfinite(Y_tr).all(), "Y_tr has non-finite values"
if Y_va is not None:
    assert np.isfinite(Y_va).all(), "Y_va has non-finite values"

# ---------- safe standardization ----------
enc_mu, enc_sd = fit_standardizer(Xe_tr.reshape(-1, Xe_tr.shape[2]))
dec_mu, dec_sd = fit_standardizer(Xd_tr.reshape(-1, Xd_tr.shape[2]))
# ---- target transform (spike-aware) then standardize ----
Y_tr_t  = sgn_log1p(Y_tr)
y_mu, y_sd = fit_standardizer(Y_tr_t.reshape(-1,1))
y_mu = float(np.atleast_1d(y_mu)[0]); y_sd = float(np.atleast_1d(y_sd)[0])

Xe_tr_n = apply_standardizer(Xe_tr, enc_mu, enc_sd)
Xd_tr_n = apply_standardizer(Xd_tr, dec_mu, dec_sd)
Y_tr_n  = apply_standardizer(Y_tr_t, y_mu, y_sd).reshape(Y_tr.shape)
Y0_tr_t = sgn_log1p(Y0_tr.reshape(-1,1)).reshape(-1,1)
Y0_tr_n = apply_standardizer(Y0_tr_t, y_mu, y_sd).reshape(-1)

if Xe_va is not None:
    Xe_va_n = apply_standardizer(Xe_va, enc_mu, enc_sd)
    Xd_va_n = apply_standardizer(Xd_va, dec_mu, dec_sd)
    Y_va_t  = sgn_log1p(Y_va)
    Y_va_n  = apply_standardizer(Y_va_t,  y_mu,  y_sd).reshape(Y_va.shape)
    Y0_va_t = sgn_log1p(Y0_va.reshape(-1,1))
    Y0_va_n = apply_standardizer(Y0_va_t, y_mu, y_sd).reshape(-1)


# Final safety: ensure everything is finite
assert_finite("Xe_tr_n", Xe_tr_n)
assert_finite("Xd_tr_n", Xd_tr_n)
assert_finite("Y_tr_n",  Y_tr_n)
assert_finite("Y0_tr_n", Y0_tr_n)
if Xe_va is not None:
    assert_finite("Xe_va_n", Xe_va_n)
    assert_finite("Xd_va_n", Xd_va_n)
    assert_finite("Y_va_n",  Y_va_n)
    assert_finite("Y0_va_n", Y0_va_n)

# ---------- torch datasets ----------
class Seq2SeqDataset(torch.utils.data.Dataset):
    def __init__(self, Xe, Xd, Y, Y0):
        self.Xe = torch.from_numpy(Xe).float()
        self.Xd = torch.from_numpy(Xd).float()
        self.Y  = torch.from_numpy(Y).float()
        self.Y0 = torch.from_numpy(Y0).float()
    def __len__(self): return self.Xe.shape[0]
    def __getitem__(self, i): return self.Xe[i], self.Xd[i], self.Y[i], self.Y0[i]

bs = 32
train_loader = torch.utils.data.DataLoader(Seq2SeqDataset(Xe_tr_n, Xd_tr_n, Y_tr_n, Y0_tr_n), batch_size=bs, shuffle=True)
val_loader = None if Xe_va is None else torch.utils.data.DataLoader(Seq2SeqDataset(Xe_va_n, Xd_va_n, Y_va_n, Y0_va_n), batch_size=bs, shuffle=False)

# ---------- model ----------
class Encoder(nn.Module):
    def __init__(self, in_dim, hidden):
        super().__init__()
        self.rnn = nn.GRU(in_dim, hidden, batch_first=True)
    def forward(self, x):  # x: (B,W,E)
        _, h = self.rnn(x)
        return h  # (1,B,H)

class Decoder(nn.Module):
    def __init__(self, in_dim, hidden):
        super().__init__()
        self.rnn = nn.GRU(in_dim + 1, hidden, batch_first=True)  # + prev y
        self.out = nn.Linear(hidden, 1)
    def forward(self, future_feats, h0, y0, teacher=None, tf_prob=0.5):
        B, T, D = future_feats.shape
        y_prev = y0.view(B,1)
        h = h0
        outs = []
        for t in range(T):
            x_t = torch.cat([future_feats[:,t,:], y_prev], dim=1).unsqueeze(1)  # (B,1,D+1)
            o, h = self.rnn(x_t, h)
            y_t = self.out(o[:, -1, :]).squeeze(1)
            outs.append(y_t)
            if (teacher is not None) and (rng.random() < tf_prob):
                y_prev = teacher[:,t].view(B,1)
            else:
                y_prev = y_t.view(B,1)
        return torch.stack(outs, dim=1)  # (B,24)

class Seq2Seq(nn.Module):
    def __init__(self, enc_in, dec_in, hidden):
        super().__init__()
        self.enc = Encoder(enc_in, hidden)
        self.dec = Decoder(dec_in, hidden)
    def forward(self, x_enc, x_dec, y0, y_teacher=None, tf_prob=0.5):
        h = self.enc(x_enc)
        return self.dec(x_dec, h, y0, y_teacher, tf_prob)

enc_in = Xe_tr_n.shape[2]
dec_in = Xd_tr_n.shape[2]
model = Seq2Seq(enc_in, dec_in, HIDDEN).to(DEVICE)
opt = torch.optim.Adam(model.parameters(), lr=LR)
loss_fn = nn.HuberLoss()

def huber_weighted(yhat, y, w, delta=1.0):
    d = torch.abs(yhat - y)
    q = torch.clamp(d, max=delta)
    l = 0.5 * q**2 + delta * (d - q)
    return (l * w).mean()

def train_epoch():
    model.train()
    tot = 0.0
    for xenc, xdec, y, y0 in train_loader:
        xenc, xdec, y, y0 = xenc.to(DEVICE), xdec.to(DEVICE), y.to(DEVICE), y0.to(DEVICE)
        opt.zero_grad()
        yhat = model(xenc, xdec, y0, y_teacher=y, tf_prob=0.5)
        # time-step weights (B,24)
        w = torch.tensor(HOUR_WEIGHTS, device=y.device).unsqueeze(0).expand_as(y)
        loss = huber_weighted(yhat, y, w)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        opt.step()
        tot += loss.item() * xenc.size(0)
    return tot / len(train_loader.dataset)

@torch.no_grad()
def eval_epoch():
    if val_loader is None: return None
    model.eval()
    tot = 0.0
    for xenc, xdec, y, y0 in val_loader:
        xenc, xdec, y, y0 = xenc.to(DEVICE), xdec.to(DEVICE), y.to(DEVICE), y0.to(DEVICE)
        yhat = model(xenc, xdec, y0, y_teacher=None, tf_prob=0.0)
        w = torch.tensor(HOUR_WEIGHTS, device=y.device).unsqueeze(0).expand_as(y)
        loss = huber_weighted(yhat, y, w)
        tot += loss.item() * xenc.size(0)
    return tot / len(val_loader.dataset)

for ep in range(1, EPOCHS+1):
    tr = train_epoch()
    va = eval_epoch()
    print(f"Epoch {ep:02d}  train_loss={tr:.4f}" + ("" if va is None else f"  val_loss={va:.4f}"))

# ---- evaluate holdout MAE in real units (SAFE) ----
if val_loader is not None:
    model.eval()
    preds, trues = [], []
    with torch.no_grad():
        for xenc, xdec, y, y0 in val_loader:
            xenc, xdec, y, y0 = xenc.to(DEVICE), xdec.to(DEVICE), y.to(DEVICE), y0.to(DEVICE)
            yhat_n = model(xenc, xdec, y0, y_teacher=None, tf_prob=0.0).detach().cpu().numpy()
            y_n    = y.detach().cpu().numpy()
            # invert scaling + inverse spike transform
            yhat_t = (yhat_n * y_sd + y_mu)
            yt_t   = (y_n    * y_sd + y_mu)
            yhat = sgn_expm1(yhat_t).reshape(-1)
            yt   = sgn_expm1(yt_t).reshape(-1)
            preds.append(yhat)
            trues.append(yt)
    preds = np.concatenate(preds)
    trues = np.concatenate(trues)
    mask = np.isfinite(preds) & np.isfinite(trues)
    dropped = int((~mask).sum())
    if dropped:
        print(f"Dropped {dropped} non-finite pairs from holdout scoring.")
    print({"holdout_MAE": float(mean_absolute_error(trues[mask], preds[mask]))})

# ---- optional LightGBM residual booster (train on TRAIN residuals) ----
USE_LGBM_RESIDUAL = True
booster = None
if USE_LGBM_RESIDUAL:
    try:
        import lightgbm as lgb
        # Build a non-shuffled loader to preserve order
        train_loader_eval = torch.utils.data.DataLoader(
            Seq2SeqDataset(Xe_tr_n, Xd_tr_n, Y_tr_n, Y0_tr_n), batch_size=64, shuffle=False)
        base_preds = []
        with torch.no_grad():
            model.eval()
            for xenc, xdec, y, y0 in train_loader_eval:
                xenc, xdec, y, y0 = xenc.to(DEVICE), xdec.to(DEVICE), y.to(DEVICE), y0.to(DEVICE)
                yhat_n = model(xenc, xdec, y0, y_teacher=None, tf_prob=0.0).detach().cpu().numpy()
                base_preds.append(yhat_n)
        base_preds = np.concatenate(base_preds, axis=0)  # (N_train, 24)
        # inverse transforms -> true $/MWh
        base_preds_t = (base_preds * y_sd + y_mu)
        base_preds_true = sgn_expm1(base_preds_t)
        y_true_tr = Y_tr  # already true units
        # flatten
        y_base_flat = base_preds_true.reshape(-1)
        y_true_flat = y_true_tr.reshape(-1)
        X_flat = Xd_tr.reshape(-1, Xd_tr.shape[2])  # kept raw decoder features
        mask = np.isfinite(y_base_flat) & np.isfinite(y_true_flat) & np.isfinite(X_flat).all(axis=1)
        X_flat = X_flat[mask]
        y_resid_flat = (y_true_flat - y_base_flat)[mask]
        # Fit booster
        booster = lgb.LGBMRegressor(
            n_estimators=600, learning_rate=0.03, subsample=0.8, colsample_bytree=0.8,
            max_depth=-1, reg_alpha=0.0, reg_lambda=0.0, min_child_samples=20, random_state=42)
        booster.fit(X_flat, y_resid_flat)
        print("Residual booster trained (LightGBM).")
    except Exception as e:
        print(f"LightGBM residual booster skipped: {e}")

# ---- predict requested DELIVERY_DATE ----

# find base row (previous day EOD with price)
prev_day = DELIVERY_DATE - dt.timedelta(days=1)
base_row = df_lags[(pd.to_datetime(df_lags["OperatingDTM"]).dt.date == prev_day) & (df_lags["Interval"]==24) & df_lags["hb_houston"].notna()]
if base_row.empty:
    raise RuntimeError(f"No base EOD price for {prev_day}")
ts0 = pd.to_datetime(base_row.iloc[0]["ts"])
y0  = np.float32(base_row.iloc[0]["hb_houston"])

# encoder window
win = df_lags.set_index("ts").loc[ts0 - pd.Timedelta(hours=W_PAST-1): ts0][enc_cols]
if win.shape[0] != W_PAST or win.isna().any().any():
    raise RuntimeError("Insufficient/NaN history for encoder window.")
Xe_pred = win.values.astype(np.float32)[None, ...]

# decoder future features (24 rows)
frows = df_future[(df_future["delivery_date"]==DELIVERY_DATE)].sort_values("delivery_interval")
if frows.shape[0] != H_FUT:
    raise RuntimeError("Spine missing 24 rows for delivery date.")
Xd_pred_full = frows[dec_future_cols].values.astype(np.float32)[None, ...]

# Apply the same keep mask used in training (drop columns that were all-NaN in train)
if 'dec_keep_mask' in globals():
    Xd_pred = Xd_pred_full[:, :, dec_keep_mask]
else:
    Xd_pred = Xd_pred_full

# impute + standardize pred using TRAIN stats
m = np.isnan(Xd_pred)
if m.any():
    Xd_pred[m] = np.broadcast_to(dec_med, Xd_pred.shape)[m]
Xe_pred_n = apply_standardizer(Xe_pred, enc_mu, enc_sd)
Xd_pred_n = apply_standardizer(Xd_pred, dec_mu, dec_sd)
# compute normalized scalar y0_n using spike-aware transform
y0_t = sgn_log1p(y0)
y0_n = np.float32(((y0_t - y_mu) / y_sd)).ravel()[0]

model.eval()
with torch.no_grad():
    yhat_n = model(torch.from_numpy(Xe_pred_n).to(DEVICE),
                   torch.from_numpy(Xd_pred_n).to(DEVICE),
                   torch.tensor([y0_n], dtype=torch.float32, device=DEVICE),
                   y_teacher=None, tf_prob=0.0).detach().cpu().numpy()[0]
yhat_t = (yhat_n * y_sd + y_mu).reshape(-1)
yhat = sgn_expm1(yhat_t)
# --- optional: add residual booster only on peak hours (H18–H21) ---
if booster is not None:
    Xd_pred_kept = Xd_pred.copy()  # (1,24,F_kept)
    resid_pred = booster.predict(Xd_pred_kept.reshape(24, Xd_pred_kept.shape[2]))
    hour_mask = np.zeros(24, dtype=np.float32)
    hour_mask[17:21] = 1.0  # apply only to intervals 18,19,20,21
    yhat = yhat + resid_pred * hour_mask

# replace any residual NaN with the day's median
if np.isnan(yhat).any():
    day_med = float(np.nanmedian(yhat))
    yhat = np.nan_to_num(yhat, nan=day_med)

forecast = pd.DataFrame({
    "OperatingDTM": [DELIVERY_DATE]*H_FUT,
    "Interval": list(range(1, H_FUT+1)),
    "hb_houston_pred": yhat.tolist()
})
forecast


Epoch 01  train_loss=0.3602  val_loss=0.4567
Epoch 02  train_loss=0.2156  val_loss=0.5112
Epoch 03  train_loss=0.1802  val_loss=0.3709
Epoch 04  train_loss=0.1574  val_loss=0.4478
Epoch 05  train_loss=0.1605  val_loss=0.2136
Epoch 06  train_loss=0.1471  val_loss=0.2455
Epoch 07  train_loss=0.1507  val_loss=0.2474
Epoch 08  train_loss=0.1323  val_loss=0.3780
Epoch 09  train_loss=0.1327  val_loss=0.2690
Epoch 10  train_loss=0.1202  val_loss=0.2851
Epoch 11  train_loss=0.1288  val_loss=0.4964
Epoch 12  train_loss=0.1281  val_loss=0.4486
Epoch 13  train_loss=0.1328  val_loss=0.1486
Epoch 14  train_loss=0.1413  val_loss=0.1517
Epoch 15  train_loss=0.1189  val_loss=0.2193
Epoch 16  train_loss=0.1006  val_loss=0.1624
Epoch 17  train_loss=0.1199  val_loss=0.1646
Epoch 18  train_loss=0.1102  val_loss=0.1511
Epoch 19  train_loss=0.1063  val_loss=0.2288
Epoch 20  train_loss=0.0901  val_loss=0.1795
Epoch 21  train_loss=0.0831  val_loss=0.2677
Epoch 22  train_loss=0.1042  val_loss=0.7247
Epoch 23  



Unnamed: 0,OperatingDTM,Interval,hb_houston_pred
0,2025-08-18,1,28.624001
1,2025-08-18,2,27.802418
2,2025-08-18,3,25.580601
3,2025-08-18,4,22.805994
4,2025-08-18,5,24.711779
5,2025-08-18,6,23.401142
6,2025-08-18,7,22.960243
7,2025-08-18,8,23.690546
8,2025-08-18,9,18.475698
9,2025-08-18,10,17.273062
