In [1]:
from pathlib import Path
import json, math, time, random
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import matplotlib.pyplot as plt

SEED = 1337
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed_all(SEED)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

DATA_DIR   = Path("../data/processed")
FIG_DIR    = Path("../reports/figures"); FIG_DIR.mkdir(parents=True, exist_ok=True)
TABLE_DIR  = Path("../reports/tables");  TABLE_DIR.mkdir(parents=True, exist_ok=True)
MODEL_DIR  = Path("../models");          MODEL_DIR.mkdir(parents=True, exist_ok=True)

TRAIN_FP  = DATA_DIR / "train.parquet"
VAL_FP    = DATA_DIR / "val.parquet"
TEST_FP   = DATA_DIR / "test.parquet"
STATS_CSV = DATA_DIR / "standardization_stats.csv"
assert TRAIN_FP.exists() and VAL_FP.exists() and TEST_FP.exists(), "missing processed parquet files; run preprocessing."
assert STATS_CSV.exists(), "missing standardization_stats.csv; run preprocessing."

FEATURES = ["bh_mass","bh_acc","stellar_mass","sfr","halo_mass","vel_disp"]
NAME_MAP = {
    "bh_mass":"Black Hole Mass",
    "bh_acc":"BH Accretion Rate",
    "stellar_mass":"Stellar Mass",
    "sfr":"Star Formation Rate",
    "halo_mass":"Halo Mass",
    "vel_disp":"Velocity Dispersion",
}

plt.rcParams.update({
    "figure.figsize": (7.5,5.2),
    "axes.titlesize": 15,
    "axes.labelsize": 13,
    "xtick.labelsize": 11,
    "ytick.labelsize": 11,
    "legend.fontsize": 11
})

stats_df = pd.read_csv(STATS_CSV, index_col=0)
STATS = {k: {"mean": float(stats_df.loc[k,"mean"]), "std": float(stats_df.loc[k,"std"])} for k in stats_df.index}

CFG = {
    "T_IN": 8,
    "HORIZONS": [1,3,5],
    "BATCH": 256,
    "HIDDEN": 128,
    "LAYERS": 2,
    "DROPOUT": 0.10,
    "EPOCHS": 100,
    "LR": 1e-3,
    "WD": 1e-5,
    "GRAD_CLIP": 1.0,
    "PATIENCE": 12,
    "SEED": SEED,
}

DEVICE, len(FEATURES), list(stats_df.index), CFG


(device(type='cpu'),
 6,
 ['bh_mass', 'bh_acc', 'stellar_mass', 'sfr', 'halo_mass', 'vel_disp'],
 {'T_IN': 8,
  'HORIZONS': [1, 3, 5],
  'BATCH': 256,
  'HIDDEN': 128,
  'LAYERS': 2,
  'DROPOUT': 0.1,
  'EPOCHS': 100,
  'LR': 0.001,
  'WD': 1e-05,
  'GRAD_CLIP': 1.0,
  'PATIENCE': 12,
  'SEED': 1337})

In [2]:
class BHTimeSeriesDataset(Dataset):
    """
    Black Hole Evolution dataset for sequence modeling.
    Produces input sequence (T_IN snapshots) and targets at given horizons.
    """
    def __init__(self, df: pd.DataFrame, features, horizons, stats, t_in):
        self.df = df.copy()
        self.features = features
        self.horizons = horizons
        self.stats = stats
        self.t_in = t_in

        for f in self.features:
            mu, sigma = stats[f]["mean"], stats[f]["std"]
            self.df[f] = (self.df[f] - mu) / sigma

        # group by subhalo
        self.groups = [g for _, g in self.df.groupby("subhalo_id")]

        self.samples = []
        for g in self.groups:
            snaps = g["snapshot"].to_numpy()
            for i in range(len(snaps) - t_in - max(horizons) + 1):
                x_idx = slice(i, i + t_in)
                y_idx = [i + t_in - 1 + h for h in horizons]
                if y_idx[-1] < len(snaps):
                    self.samples.append((g.iloc[x_idx], g.iloc[y_idx]))

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        x_df, y_df = self.samples[idx]
        x = x_df[self.features].to_numpy(dtype=np.float32)
        y = np.stack([y_df[self.features].to_numpy(dtype=np.float32) for y_df in [y_df]], axis=0)
        return torch.from_numpy(x), torch.from_numpy(y.squeeze(0))


def make_loaders(train_fp, val_fp, test_fp, features, horizons, stats, t_in, batch):
    train = pd.read_parquet(train_fp)
    val   = pd.read_parquet(val_fp)
    test  = pd.read_parquet(test_fp)

    dsets = {
        "train": BHTimeSeriesDataset(train, features, horizons, stats, t_in),
        "val":   BHTimeSeriesDataset(val, features, horizons, stats, t_in),
        "test":  BHTimeSeriesDataset(test, features, horizons, stats, t_in),
    }

    loaders = {
        k: DataLoader(v, batch_size=batch, shuffle=(k=="train"), drop_last=True)
        for k,v in dsets.items()
    }
    return dsets, loaders


dsets, loaders = make_loaders(
    TRAIN_FP, VAL_FP, TEST_FP,
    FEATURES, CFG["HORIZONS"], STATS, CFG["T_IN"], CFG["BATCH"]
)

for k,v in dsets.items():
    print(f"{k}: {len(v)} samples")

batch_x, batch_y = next(iter(loaders["train"]))
print("Input shape:", batch_x.shape, "Target shape:", batch_y.shape)


train: 5193 samples
val: 1112 samples
test: 1107 samples
Input shape: torch.Size([256, 8, 6]) Target shape: torch.Size([256, 3, 6])


In [3]:
class LSTMForecaster(nn.Module):
    def __init__(self, in_dim, hidden=128, layers=2, dropout=0.1, horizons=3, out_dim=6):
        super().__init__()
        self.lstm = nn.LSTM(
            input_size=in_dim,
            hidden_size=hidden,
            num_layers=layers,
            dropout=dropout if layers > 1 else 0.0,
            batch_first=True,
        )
        self.head = nn.Sequential(
            nn.Linear(hidden, hidden),
            nn.ReLU(),
            nn.Linear(hidden, horizons * out_dim),
        )
        self.horizons = horizons
        self.out_dim = out_dim

    def forward(self, x):
        # x: [B, T_in, F]
        _, (hn, _) = self.lstm(x)     
        h = hn[-1]                    
        y = self.head(h)              
        y = y.view(-1, self.horizons, self.out_dim)  
        return y

model = LSTMForecaster(
    in_dim=len(FEATURES),
    hidden=CFG["HIDDEN"],
    layers=CFG["LAYERS"],
    dropout=CFG["DROPOUT"],
    horizons=len(CFG["HORIZONS"]),
    out_dim=len(FEATURES),
).to(DEVICE)

n_params = sum(p.numel() for p in model.parameters())
print(f"Model: LSTMForecaster | params: {n_params:,} | device: {DEVICE}")


Model: LSTMForecaster | params: 220,562 | device: cpu


In [6]:
import pandas as pd
from sklearn.model_selection import train_test_split
from pathlib import Path

OUT_CSV = "../data/black_hole_evolution_tng100.csv"

if not Path(OUT_CSV).exists():
    raise FileNotFoundError(f"Processed dataset not found: {OUT_CSV}. "
                            "Run the preprocessing notebook first.")

df = pd.read_csv(OUT_CSV)

all_ids = df["subhalo_id"].unique()

train_ids, val_ids = train_test_split(
    all_ids,
    test_size=0.2,
    random_state=42
)

train_df = df[df["subhalo_id"].isin(train_ids)].reset_index(drop=True)
val_df   = df[df["subhalo_id"].isin(val_ids)].reset_index(drop=True)

print(f"Train BHs: {len(train_ids)}, Val BHs: {len(val_ids)}")
print(f"Train rows: {len(train_df)}, Val rows: {len(val_df)}")


Train BHs: 2000, Val BHs: 500
Train rows: 29931, Val rows: 7481


In [7]:
from torch.utils.data import DataLoader


assert 'BHTimeSeriesDataset' in globals(), "Run Cell 2 first to define BHTimeSeriesDataset."

dtr = DataLoader(
    BHTimeSeriesDataset(train_df, FEATURES, CFG["HORIZONS"], STATS, CFG["T_IN"]),
    batch_size=CFG["BATCH"], shuffle=True, drop_last=False
)
dva = DataLoader(
    BHTimeSeriesDataset(val_df, FEATURES, CFG["HORIZONS"], STATS, CFG["T_IN"]),
    batch_size=CFG["BATCH"], shuffle=False, drop_last=False
)


def run_epoch(loader, train_mode=True):
    model.train() if train_mode else model.eval()
    total, n = 0.0, 0
    with torch.set_grad_enabled(train_mode):
        for xb, yb in loader:
            xb = xb.to(DEVICE)            
            yb = yb.to(DEVICE)            
            if train_mode:
                opt.zero_grad(set_to_none=True)
            yp = model(xb)                
            loss = nn.functional.mse_loss(yp, yb)
            if train_mode:
                loss.backward()
                nn.utils.clip_grad_norm_(model.parameters(), CFG["GRAD_CLIP"])
                opt.step()
            bs = xb.size(0)
            total += loss.item() * bs
            n += bs
    return total / max(n, 1)

opt   = torch.optim.AdamW(model.parameters(), lr=CFG["LR"], weight_decay=CFG["WD"])
sched = torch.optim.lr_scheduler.ReduceLROnPlateau(opt, mode="min", factor=0.5, patience=4, verbose=False)

best_val   = float("inf")
best_path  = MODEL_DIR / "lstm_blackhole_best.pt"
history    = {"epoch": [], "train": [], "val": [], "lr": []}
pat_count  = 0

for epoch in range(1, CFG["EPOCHS"] + 1):
    tr_loss = run_epoch(dtr, train_mode=True)
    va_loss = run_epoch(dva, train_mode=False)
    sched.step(va_loss)

    
    improved = va_loss < best_val - 1e-6
    if improved:
        best_val = va_loss
        pat_count = 0
        torch.save(
            {"model": model.state_dict(),
             "config": {"T_IN": CFG["T_IN"], "HORIZONS": CFG["HORIZONS"], "FEATURES": FEATURES}},
            best_path
        )
    else:
        pat_count += 1

    history["epoch"].append(epoch)
    history["train"].append(tr_loss)
    history["val"].append(va_loss)
    history["lr"].append(opt.param_groups[0]["lr"])

    print(f"Epoch {epoch:03d} | Train {tr_loss:.5f} | Val {va_loss:.5f} | LR {opt.param_groups[0]['lr']:.2e}"
          + ("  [*]" if improved else ""))

    if pat_count >= CFG["PATIENCE"]:
        print(f"Early stopping at epoch {epoch} (no val improvement for {CFG['PATIENCE']} epochs).")
        break


hist_df = pd.DataFrame(history)
hist_df.to_csv(TABLE_DIR / "training_history.csv", index=False)

fig, ax = plt.subplots(figsize=(7.5,5.2))
ax.plot(hist_df["epoch"], hist_df["train"], label="Train")
ax.plot(hist_df["epoch"], hist_df["val"],   label="Validation")
ax.set_xlabel("Epoch"); ax.set_ylabel("MSE Loss")
ax.set_title("LSTM Convergence")
ax.legend()
plt.tight_layout()
fig.savefig(FIG_DIR / "lstm_convergence.png", dpi=300, bbox_inches="tight")
plt.close(fig)

best_path.as_posix(), best_val


Epoch 001 | Train 1.18640 | Val 0.51763 | LR 1.00e-03  [*]
Epoch 002 | Train 0.73178 | Val 0.41406 | LR 1.00e-03  [*]
Epoch 003 | Train 0.57950 | Val 0.38757 | LR 1.00e-03  [*]
Epoch 004 | Train 0.51872 | Val 0.33042 | LR 1.00e-03  [*]
Epoch 005 | Train 0.49525 | Val 0.31351 | LR 1.00e-03  [*]
Epoch 006 | Train 0.46796 | Val 0.30968 | LR 1.00e-03  [*]
Epoch 007 | Train 0.44961 | Val 0.31010 | LR 1.00e-03
Epoch 008 | Train 0.42623 | Val 0.27826 | LR 1.00e-03  [*]
Epoch 009 | Train 0.41541 | Val 0.29092 | LR 1.00e-03
Epoch 010 | Train 0.40129 | Val 0.30047 | LR 1.00e-03
Epoch 011 | Train 0.39719 | Val 0.29348 | LR 1.00e-03
Epoch 012 | Train 0.38667 | Val 0.30301 | LR 1.00e-03
Epoch 013 | Train 0.38329 | Val 0.29396 | LR 5.00e-04
Epoch 014 | Train 0.36608 | Val 0.27748 | LR 5.00e-04  [*]
Epoch 015 | Train 0.36115 | Val 0.27544 | LR 5.00e-04  [*]
Epoch 016 | Train 0.35560 | Val 0.27984 | LR 5.00e-04
Epoch 017 | Train 0.35816 | Val 0.29175 | LR 5.00e-04
Epoch 018 | Train 0.34696 | Val 0.275

('../models/lstm_blackhole_best.pt', 0.275090250948156)

In [8]:
test_df = pd.read_parquet(TEST_FP)


assert 'BHTimeSeriesDataset' in globals(), "Run Cell 2 first to define BHTimeSeriesDataset."
dte = DataLoader(
    BHTimeSeriesDataset(test_df, FEATURES, CFG["HORIZONS"], STATS, CFG["T_IN"]),
    batch_size=CFG["BATCH"], shuffle=False, drop_last=False
)


ckpt = torch.load(MODEL_DIR/"lstm_blackhole_best.pt", map_location=DEVICE)
model.load_state_dict(ckpt["model"])
model.eval()


sums = np.zeros((len(CFG["HORIZONS"]), len(FEATURES)), dtype=np.float64)
counts = np.zeros_like(sums)

with torch.no_grad():
    for xb, yb in dte:
        xb, yb = xb.to(DEVICE), yb.to(DEVICE)     
        yp = model(xb)                            
        err2 = (yp - yb).pow(2).cpu().numpy()     
        sums  += err2.sum(axis=0)
        counts+= err2.shape[0]

rmse = np.sqrt(sums / np.maximum(counts,1))
rmse_df = pd.DataFrame(rmse, columns=FEATURES, index=[f"H{h}" for h in CFG["HORIZONS"]])
rmse_df.to_csv(TABLE_DIR/"rmse_test_lstm.csv")
rmse_df


Unnamed: 0,bh_mass,bh_acc,stellar_mass,sfr,halo_mass,vel_disp
H1,228.613199,141.974056,0.663351,0.288987,0.824576,1.404387
H3,266.098185,96.950083,0.629554,0.089996,0.925593,1.499788
H5,290.478655,97.48536,0.686128,0.060783,1.090862,1.546049


In [9]:
def collect_preds(loader):
    Y_list, P_list = [], []
    with torch.no_grad():
        for xb, yb in loader:
            xb = xb.to(DEVICE); yb = yb.to(DEVICE)
            yp = model(xb)
            Y_list.append(yb.cpu().numpy())
            P_list.append(yp.cpu().numpy())
    Y = np.concatenate(Y_list, axis=0)  
    P = np.concatenate(P_list, axis=0)
    return Y, P

Yt, Pt = collect_preds(dte)
h1 = 0


for j, feat in enumerate(FEATURES):
    y = Yt[:,h1,j].ravel()
    p = Pt[:,h1,j].ravel()
    lims = np.percentile(np.concatenate([y,p]), [1,99])
    fig, ax = plt.subplots(figsize=(5.8,5.2))
    ax.scatter(y, p, s=6, alpha=0.3)
    ax.plot(lims, lims, lw=2)
    ax.set_xlabel(f"True {NAME_MAP[feat]} (z-scored)")
    ax.set_ylabel(f"Predicted {NAME_MAP[feat]} (z-scored)")
    ax.set_title(f"Parity (H=1) — {NAME_MAP[feat]}")
    plt.tight_layout()
    fig.savefig(FIG_DIR/f"parity_H1_{feat}.png", dpi=300, bbox_inches="tight")
    plt.close(fig)


rmse_overall = [float(np.sqrt(((Yt[:,i,:]-Pt[:,i,:])**2).mean())) for i,_ in enumerate(CFG["HORIZONS"])]
fig, ax = plt.subplots(figsize=(6.6,4.8))
ax.plot(CFG["HORIZONS"], rmse_overall, marker="o")
ax.set_xlabel("Forecast Horizon (Δ snapshots)")
ax.set_ylabel("RMSE (z-scored, mean over features)")
ax.set_title("Error vs Horizon — LSTM")
plt.tight_layout()
fig.savefig(FIG_DIR/"rmse_vs_horizon_lstm.png", dpi=300, bbox_inches="tight")
plt.close(fig)

FIG_DIR, TABLE_DIR


(PosixPath('../reports/figures'), PosixPath('../reports/tables'))

In [10]:
def baseline_persistence(X, Y):
    last = X[:,-1:,:]                             
    P = np.repeat(last, repeats=len(CFG["HORIZONS"]), axis=1)
    rmse = np.sqrt(((P - Y)**2).mean(axis=0))     
    return rmse

def baseline_ridge(X, Y, lam=1e-2):
    B, T, F = X.shape
    H = Y.shape[1]
    Phi = X.reshape(B, T*F)
    Phi_ = np.concatenate([Phi, np.ones((B,1),dtype=np.float32)], axis=1)  # bias
    rmse = np.zeros((H,F), dtype=np.float32)
    A = Phi_.T @ Phi_ + lam*np.eye(Phi_.shape[1], dtype=np.float32)
    A_inv = np.linalg.pinv(A)
    for h in range(H):
        Yh = Y[:,h,:]                              
        W = np.zeros((F, Phi_.shape[1]), dtype=np.float32)
        for j in range(F):
            b = Phi_.T @ Yh[:,j:j+1]
            w = (A_inv @ b).ravel()
            W[j,:] = w
        P = (Phi_ @ W.T)                           
        rmse[h,:] = np.sqrt(((P - Yh)**2).mean(axis=0))
    return rmse


Xt_list, Yt_list = [], []
for xb, yb in dte:
    Xt_list.append(xb.numpy()); Yt_list.append(yb.numpy())
Xt = np.concatenate(Xt_list, axis=0)
Yt_np = np.concatenate(Yt_list, axis=0)

rmse_lstm = pd.read_csv(TABLE_DIR/"rmse_test_lstm.csv", index_col=0).values
rmse_pers = baseline_persistence(Xt, Yt_np)
rmse_ridge= baseline_ridge(Xt, Yt_np, lam=1e-2)

pd.DataFrame(rmse_pers, columns=FEATURES, index=[f"H{h}" for h in CFG["HORIZONS"]]).to_csv(TABLE_DIR/"rmse_test_persistence.csv")
pd.DataFrame(rmse_ridge, columns=FEATURES, index=[f"H{h}" for h in CFG["HORIZONS"]]).to_csv(TABLE_DIR/"rmse_test_ridge.csv")


mean_lstm  = rmse_lstm.mean(axis=1)
mean_pers  = rmse_pers.mean(axis=1)
mean_ridge = rmse_ridge.mean(axis=1)

fig, ax = plt.subplots(figsize=(6.6,4.8))
x = np.arange(len(CFG["HORIZONS"]))
ax.plot(x, mean_lstm,  marker="o", label="LSTM")
ax.plot(x, mean_pers,  marker="o", label="Persistence")
ax.plot(x, mean_ridge, marker="o", label="Ridge")
ax.set_xticks(x); ax.set_xticklabels([f"H{h}" for h in CFG["HORIZONS"]])
ax.set_ylabel("RMSE (z-scored, mean over features)")
ax.set_title("Model Comparison vs Horizon")
ax.legend()
plt.tight_layout()
fig.savefig(FIG_DIR/"model_comparison_rmse_vs_horizon.png", dpi=300, bbox_inches="tight")
plt.close(fig)

cmp_tbl = pd.DataFrame({
    "Horizon": [f"H{h}" for h in CFG["HORIZONS"]],
    "LSTM": mean_lstm,
    "Persistence": mean_pers,
    "Ridge": mean_ridge
})
cmp_tbl.to_csv(TABLE_DIR/"rmse_overall_by_horizon.csv", index=False)

FIG_DIR, TABLE_DIR


(PosixPath('../reports/figures'), PosixPath('../reports/tables'))