In [1]:
from pathlib import Path
import numpy as np, pandas as pd
import torch, torch.nn as nn
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt

# paths (consistent with earlier nbs)
DATA_DIR   = Path("../data/processed")
FIG_DIR    = Path("../reports/figures"); FIG_DIR.mkdir(parents=True, exist_ok=True)
TABLE_DIR  = Path("../reports/tables");  TABLE_DIR.mkdir(parents=True, exist_ok=True)
MODEL_DIR  = Path("../models")

FEATURES  = ["bh_mass","bh_acc","stellar_mass","sfr","halo_mass","vel_disp"]
NAME_MAP  = {
    "bh_mass":"Black Hole Mass","bh_acc":"BH Accretion Rate",
    "stellar_mass":"Stellar Mass","sfr":"Star Formation Rate",
    "halo_mass":"Halo Mass","vel_disp":"Velocity Dispersion",
}
HORIZONS  = [1,3,5]
T_IN      = 8

# load standardization stats
stats_df = pd.read_csv(DATA_DIR/"standardization_stats.csv", index_col=0)
STATS = {k: {"mean": float(stats_df.loc[k,"mean"]), "std": float(stats_df.loc[k,"std"])} for k in stats_df.index}

# load test split (already z-scored in prior eval, but we'll normalize here to be safe)
test = pd.read_parquet(DATA_DIR/"test.parquet").sort_values(["subhalo_id","snapshot"]).reset_index(drop=True)

# normalization
for f in FEATURES:
    mu, sd = STATS[f]["mean"], STATS[f]["std"]
    test[f] = (test[f]-mu)/sd

# build sequences identical to 03
def split_tracks(df):
    tracks=[]
    for tid,g in df.groupby("subhalo_id"):
        snaps = g["snapshot"].values
        X = g[FEATURES].values.astype(np.float32)
        tracks.append((int(tid), snaps, X))
    return tracks

def make_sequences(tracks, t_in=8, horizons=[1,3,5]):
    Xs, Ys = [], []
    for tid, snaps, X in tracks:
        if len(snaps) < t_in+max(horizons): continue
        i=0
        while i + t_in + max(horizons) <= len(snaps):
            seg_snaps = snaps[i:i+t_in+max(horizons)]
            if np.all(np.diff(seg_snaps)==1):
                xin = X[i:i+t_in]
                youts=[]
                for h in horizons:
                    youts.append(X[i+t_in-1+h])
                Xs.append(xin); Ys.append(np.stack(youts,0))
                i += 1
            else:
                brk = np.where(np.diff(seg_snaps)!=1)[0][0]
                i = i + brk + 1
    Xs = np.stack(Xs).astype(np.float32)
    Ys = np.stack(Ys).astype(np.float32)
    return Xs, Ys

Xt, Yt = make_sequences(split_tracks(test), t_in=T_IN, horizons=HORIZONS)

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# model definition must match 03
class LSTMForecaster(nn.Module):
    def __init__(self, in_dim, hidden=128, layers=2, dropout=0.1, horizons=3, out_dim=6):
        super().__init__()
        self.lstm = nn.LSTM(in_dim, hidden, num_layers=layers, dropout=dropout if layers>1 else 0.0, batch_first=True)
        self.head = nn.Sequential(nn.Linear(hidden, hidden), nn.ReLU(), nn.Linear(hidden, horizons*out_dim))
        self.horizons=horizons; self.out_dim=out_dim
    def forward(self, x):
        _, (hn, _) = self.lstm(x)
        h = hn[-1]
        y = self.head(h).view(-1, self.horizons, self.out_dim)
        return y

model = LSTMForecaster(in_dim=len(FEATURES), horizons=len(HORIZONS), out_dim=len(FEATURES)).to(DEVICE)
ckpt = torch.load(MODEL_DIR/"lstm_blackhole_best.pt", map_location=DEVICE)
model.load_state_dict(ckpt["model"]); model.eval();

Xt_t = torch.from_numpy(Xt).to(DEVICE)  # [N,T,F]
Yt_t = torch.from_numpy(Yt).to(DEVICE)  # [N,H,F]

Xt.shape, Yt.shape


((1079, 8, 6), (1079, 3, 6))

In [2]:
@torch.no_grad()
def rmse_overall(model, X, Y):
    P = model(X).cpu().numpy()
    return float(np.sqrt(((P - Y.cpu().numpy())**2).mean()))

base_rmse = rmse_overall(model, Xt_t, Yt_t)

rng = np.random.default_rng(0)
importances = []
for j, feat in enumerate(FEATURES):
    Xperm = Xt.copy()
    # permute feature j across samples but preserve time structure within a sample
    perm_idx = rng.permutation(Xperm.shape[0])
    Xperm[:, :, j] = Xperm[perm_idx, :, j]
    rmse_perm = rmse_overall(model, torch.from_numpy(Xperm).to(DEVICE), Yt_t)
    importances.append({"feature": feat, "delta_rmse": rmse_perm - base_rmse})

imp_df = pd.DataFrame(importances).sort_values("delta_rmse", ascending=False)
imp_df.to_csv(TABLE_DIR/"permutation_importance_overall.csv", index=False)
imp_df


Unnamed: 0,feature,delta_rmse
0,bh_mass,0.346268
1,bh_acc,0.02626
2,stellar_mass,0.005203
3,sfr,0.000587
4,halo_mass,0.000313
5,vel_disp,0.00013


In [3]:
fig, ax = plt.subplots(figsize=(7.8,4.8))
ax.bar(range(len(imp_df)), imp_df["delta_rmse"].values)
ax.set_xticks(range(len(imp_df)))
ax.set_xticklabels([NAME_MAP[f] for f in imp_df["feature"]], rotation=30, ha="right")
ax.set_ylabel("Δ RMSE (z-scored)")
ax.set_title("Permutation Feature Importance (Overall)")
plt.tight_layout()
fig.savefig(FIG_DIR/"perm_importance_overall.png", dpi=300, bbox_inches="tight")
plt.close(fig)
FIG_DIR/"perm_importance_overall.png"


PosixPath('../reports/figures/perm_importance_overall.png')

In [4]:
@torch.no_grad()
def overall_change_for_perturb(model, X, Y, j, ksig=1.0):
    # add +ksig to feature j across the sequence
    Xp = X.clone()
    Xm = X.clone()
    Xp[:,:,j] += ksig
    Xm[:,:,j] -= ksig
    Pp = model(Xp).cpu().numpy()
    Pm = model(Xm).cpu().numpy()
    # quantify effect size as average absolute change in predictions
    eff = float(np.mean(np.abs(Pp - Pm)/2.0))
    return eff

sens = []
for j, feat in enumerate(FEATURES):
    eff = overall_change_for_perturb(model, Xt_t, Yt_t, j, ksig=1.0)
    sens.append({"feature": feat, "effect": eff})

sens_df = pd.DataFrame(sens).sort_values("effect", ascending=False)
sens_df.to_csv(TABLE_DIR/"sensitivity_plusminus1sigma.csv", index=False)
sens_df


Unnamed: 0,feature,effect
4,halo_mass,0.010336
3,sfr,0.008526
2,stellar_mass,0.00746
1,bh_acc,0.007085
0,bh_mass,0.007018
5,vel_disp,0.006758


In [5]:
fig, ax = plt.subplots(figsize=(7.8,4.8))
ax.bar(range(len(sens_df)), sens_df["effect"].values)
ax.set_xticks(range(len(sens_df)))
ax.set_xticklabels([NAME_MAP[f] for f in sens_df["feature"]], rotation=30, ha="right")
ax.set_ylabel("Avg |Δ prediction| (z-scored)")
ax.set_title("Sensitivity to ±1σ Input Perturbations")
plt.tight_layout()
fig.savefig(FIG_DIR/"sensitivity_pm1sigma.png", dpi=300, bbox_inches="tight")
plt.close(fig)
FIG_DIR/"sensitivity_pm1sigma.png"


PosixPath('../reports/figures/sensitivity_pm1sigma.png')