In [1]:
# 17 — Setup forecasting & explainability

from pathlib import Path
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import matplotlib.pyplot as plt

# Paths
DATA_ROOT   = Path("../data")
PROC_DIR    = DATA_ROOT / "processed"
STATS_DIR   = DATA_ROOT / "stats"
MODEL_DIR   = DATA_ROOT / "models"
FORECAST_DIR= DATA_ROOT / "forecasts" / "extrapolated"
RESULTS_DIR = DATA_ROOT / "results"
FIG_DIR     = RESULTS_DIR / "figures"
TAB_DIR     = RESULTS_DIR / "tables"
for d in [FORECAST_DIR, FIG_DIR, TAB_DIR]:
    d.mkdir(parents=True, exist_ok=True)

# Features (order must match your arrays)
FEATURES = ["bh_mass","bh_acc","stellar_mass","sfr","halo_mass","vel_disp"]
F = len(FEATURES)

# Load arrays & stats
ids   = np.load(PROC_DIR / "ids.npy")
snaps = np.load(PROC_DIR / "snapshots.npy").astype(int)
Xraw  = np.load(PROC_DIR / "features.npy").astype(np.float32)
mean  = np.load(STATS_DIR / "feat_mean.npy")
std   = np.load(STATS_DIR / "feat_std.npy")

def to_norm(x): return (x - mean) / (std + 1e-8)
def to_real(z): return z * std + mean

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Rebuild LSTM exactly as trained
class BHSequenceLSTM(nn.Module):
    def __init__(self, input_size, hidden=128, layers=2, dropout=0.1, output_size=None):
        super().__init__()
        if output_size is None: output_size = input_size
        self.lstm = nn.LSTM(input_size, hidden, num_layers=layers, batch_first=True,
                            dropout=dropout if layers > 1 else 0.0)
        self.head = nn.Sequential(
            nn.Linear(hidden, hidden), nn.ReLU(), nn.Dropout(dropout),
            nn.Linear(hidden, output_size)
        )
    def forward(self, x):              # x: [B,W,F]
        h,_ = self.lstm(x)             # [B,W,H]
        return self.head(h[:, -1, :])  # [B,F]

# Load checkpoint
CKPT_PATH = MODEL_DIR / "blackhole_evolution_lstm.pt"
ckpt = torch.load(CKPT_PATH, map_location="cpu")
model = BHSequenceLSTM(input_size=F, hidden=128, layers=2, dropout=0.1, output_size=F).to(device)
model.load_state_dict(ckpt["state_dict"])
model.eval()

# Use trained window if saved; else default to 8
WINDOW = int(ckpt.get("window", 8))
print(f"[OK] Loaded model. WINDOW={WINDOW}, FEATURES={FEATURES}")


[OK] Loaded model. WINDOW=8, FEATURES=['bh_mass', 'bh_acc', 'stellar_mass', 'sfr', 'halo_mass', 'vel_disp']


In [None]:
# 18 — FAST: one-step forecast for all subhalos (batched)

from pathlib import Path
import numpy as np
import pandas as pd
import torch
import matplotlib.pyplot as plt

# assumes Section 17 already ran (model, device, FEATURES, WINDOW, to_norm, to_real, etc.)
K = 1  # one future step
csv_long = PROC_DIR / "black_hole_evolution_tng100.csv"
df_long = pd.read_csv(csv_long)

# Collect last WINDOW slices for every eligible subhalo
sid_list, last_snap_list, Xw_real_list = [], [], []
for sid in np.unique(ids):
    sub = df_long[df_long["subhalo_id"] == sid].sort_values("snapshot")
    if len(sub) < WINDOW:
        continue
    Xw_real = sub[FEATURES].iloc[-WINDOW:].to_numpy(dtype=np.float32)  # [W, F]
    sid_list.append(int(sid))
    last_snap_list.append(int(sub["snapshot"].iloc[-1]))
    Xw_real_list.append(Xw_real)

if not Xw_real_list:
    print("[WARN] No subhalos had ≥ WINDOW steps; nothing to forecast.")
else:
    # Batch predict next step (normalized → real)
    Xw_real_batch = np.stack(Xw_real_list, axis=0)              # [N, W, F]
    Xw_norm_batch = (Xw_real_batch - mean) / (std + 1e-8)       # normalize
    Xw_t = torch.from_numpy(Xw_norm_batch).float().to(device)   # [N, W, F]

    with torch.no_grad():
        y_norm = model(Xw_t)                                    # [N, F]
    y_real = y_norm.cpu().numpy() * std + mean                  # back to physical
    y_real = np.maximum(y_real, 0.0)                            # clamp non-negatives

    # Write per-subhalo CSVs and aggregate table
    FORECAST_DIR.mkdir(parents=True, exist_ok=True)
    out_paths = []
    rows = []
    for i, sid in enumerate(sid_list):
        next_snap = last_snap_list[i] + 1
        pred = y_real[i]                                        # [F]
        out_df = pd.DataFrame([pred], columns=FEATURES)
        out_df.insert(0, "snapshot", next_snap)
        out_df.insert(0, "subhalo_id", sid)
        p = FORECAST_DIR / f"forecast_subhalo_{sid}_K1.csv"
        out_df.to_csv(p, index=False)
        out_paths.append(p)
        r = {"subhalo_id": sid, "snapshot": next_snap}
        r.update({f: float(pred[j]) for j, f in enumerate(FEATURES)})
        rows.append(r)

    agg_df = pd.DataFrame(rows)
    agg_path = TAB_DIR / "forecast_one_step_aggregate.csv"
    agg_df.to_csv(agg_path, index=False)

    print(f"[OK] Wrote {len(out_paths)} one-step forecasts → {FORECAST_DIR}")
    print(f"[OK] Saved aggregate table → {agg_path}")

    # Optional quick bar plot of mean predicted change vs last step (sanity)
    # (Compute mean predicted values per feature)
    mean_pred = agg_df[FEATURES].mean(axis=0).values
    plt.figure(figsize=(8,4))
    plt.bar(FEATURES, mean_pred)
    plt.xticks(rotation=30)
    plt.ylabel("Mean predicted next-step value (physical units)")
    plt.title("One-step forecast (mean over subhalos)")
    figp = FIG_DIR / "forecast_one_step_mean.png"
    plt.tight_layout(); plt.savefig(figp, dpi=200)
    print(f"[OK] Saved figure → {figp}")


In [None]:
# 19 — Permutation importance on one-step TEST windows (physical-unit ΔRMSE)

# Build test split by subhalo to avoid leakage
rng = np.random.default_rng(42)
unique_ids = np.unique(ids); rng.shuffle(unique_ids)
n = len(unique_ids)
test_ids = set(unique_ids[int(0.85*n):])

# Construct test windows/targets (normalized)
Xw_test, y_test = [], []
for sid in unique_ids:
    if sid not in test_ids: continue
    m = (ids == sid)
    s = snaps[m]; x = to_norm(Xraw[m])
    order = np.argsort(s); x = x[order]
    if len(x) <= WINDOW: continue
    for t in range(WINDOW, len(x)):
        Xw_test.append(x[t-WINDOW:t, :]); y_test.append(x[t, :])
Xw_test = np.stack(Xw_test, axis=0).astype(np.float32)
y_test  = np.stack(y_test,  axis=0).astype(np.float32)

@torch.no_grad()
def predict_windows_norm(model, Xw):
    xb = torch.from_numpy(Xw).float().to(device)
    return model(xb).cpu().numpy()

# Baseline (no permutation)
P_base = predict_windows_norm(model, Xw_test)          # normalized
P_base_r, y_test_r = to_real(P_base), to_real(y_test)  # physical
rmse_base = np.sqrt(np.mean((P_base_r - y_test_r)**2, axis=0))  # per target feature

# Permute each input feature across samples (keep time structure)
results = []
for j_in, name in enumerate(FEATURES):
    Xw_perm = Xw_test.copy()
    for t in range(WINDOW):
        perm = rng.permutation(Xw_perm.shape[0])
        Xw_perm[:, t, j_in] = Xw_perm[perm, t, j_in]
    P_perm  = predict_windows_norm(model, Xw_perm)
    P_perm_r= to_real(P_perm)
    rmse_perm = np.sqrt(np.mean((P_perm_r - y_test_r)**2, axis=0))
    delta = rmse_perm - rmse_base
    results.append({"input_feature": name, **{f"ΔRMSE→{tgt}": delta[k] for k,tgt in enumerate(FEATURES)}})

imp_df = pd.DataFrame(results)
imp_df["ΔRMSE_mean"] = imp_df[[c for c in imp_df.columns if c.startswith("ΔRMSE→")]].mean(axis=1)
imp_df.sort_values("ΔRMSE_mean", ascending=False, inplace=True)
imp_df.to_csv(TAB_DIR / "permutation_importance_test_ranked.csv", index=False)

plt.figure(figsize=(8,4))
plt.bar(imp_df["input_feature"], imp_df["ΔRMSE_mean"])
plt.xticks(rotation=30); plt.ylabel("Mean ΔRMSE (physical units)")
plt.title("Permutation Importance (Test)")
fig_imp = FIG_DIR / "permutation_importance_test.png"
plt.tight_layout(); plt.savefig(fig_imp, dpi=200)

print(f"[OK] Saved importance table → {TAB_DIR / 'permutation_importance_test_ranked.csv'}")
print(f"[OK] Saved figure → {fig_imp}")
