In [None]:
# train_runner_multi_my.py
# Multi-experiment runner (45 features by default).
# Requires: requiredFile_{FEATURE_SET}.pkl generated by prepare_data_my_approach.py

import os, pickle, random
import numpy as np
import pandas as pd
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt

FEATURE_SET = 45   # set to 43 or 45
EPOCHS      = 100
N_FLUID     = 5
N_VASO      = 5
N_ACTIONS   = N_FLUID * N_VASO

OUTDIR = "runs"
os.makedirs(OUTDIR, exist_ok=True)

EXPERIMENTS = [
    dict(name=f"exp1_baseline_{FEATURE_SET}",   gamma=0.99, lr=1e-5,  seed=42),
    dict(name=f"exp2_gamma095_{FEATURE_SET}",   gamma=0.95, lr=1e-5,  seed=42),
    dict(name=f"exp3_higher_lr_{FEATURE_SET}",  gamma=0.99, lr=5e-5,  seed=42),
    dict(name=f"exp4_seed7_{FEATURE_SET}",      gamma=0.99, lr=1e-5,  seed=7),
]

from ID3QNE_deepQnet import Dist_DQN

def set_seed(seed):
    random.seed(seed); np.random.seed(seed)
    try:
        import torch
        torch.manual_seed(seed); torch.cuda.manual_seed_all(seed)
    except Exception:
        pass

def load_data(feature_set: int):
    pkl = f"requiredFile_{feature_set}.pkl"
    if not os.path.exists(pkl):
        raise FileNotFoundError(
            f"{pkl} not found. Run: python3 prepare_data_my_approach.py --csv finalData.csv --vaso vaso.csv --feature-set {feature_set}"
        )
    with open(pkl, "rb") as f:
        obj = pickle.load(f)
    return dict(
        Xtr=obj["Xtrain"],
        Atr=obj["Actiontrain"],
        Rtr=obj["Rewardtrain"],
        Xntr=obj["Xnext_train"],
        Dtr=obj["Done_train"].astype(np.int32),
        Xte=obj["Xtest"],
        Survival_test=obj.get("Survival_test", None),
        ExpectedReturn_test=obj.get("ExpectedReturn_test", None)
    )

def model_train_adapter(model, S, A, R, Snext, Done, epoch):
    S     = np.asarray(S, dtype=float)
    Snext = np.asarray(Snext, dtype=float)
    A     = np.asarray(A, dtype=np.int64)
    R     = np.asarray(R, dtype=float)
    Done  = np.asarray(Done, dtype=np.int64)
    # placeholders for expected unpacking
    A_next   = A.copy()
    bloc_num = np.zeros_like(A, dtype=np.int64)
    SOFAS    = np.zeros_like(A, dtype=np.float32)
    batch8 = (S, Snext, A, A_next, R, Done, bloc_num, SOFAS)
    try:
        return float(model.train(batch8, epoch))
    except TypeError:
        return float(model.train(batch8))

def predict_actions(model, X):
    preds = np.zeros((X.shape[0],), dtype=np.int64)
    for i in range(X.shape[0]):
        preds[i] = int(model.get_action(X[i]))
    return preds

def action_to_bins(action, n_fluid=5):
    fluid = int(action % n_fluid)
    vaso  = int(action // n_fluid)
    return fluid, vaso

def plot_training_loss(exp_dir, exp_name, losses):
    fig = plt.figure(figsize=(12,6))
    ax = fig.add_subplot(111)
    ax.plot(range(1, len(losses)+1), losses)
    ax.set_title(f"{exp_name} – Training Loss Over Time")
    ax.set_xlabel("Epoch"); ax.set_ylabel("Loss")
    fig.tight_layout()
    fp = os.path.join(exp_dir, f"{exp_name}_training_loss.png")
    fig.savefig(fp, dpi=160); plt.close(fig); return fp

def plot_train_action_distribution(exp_dir, exp_name, Atr):
    counts = np.array([np.sum(Atr == i) for i in range(N_ACTIONS)], dtype=float)
    fig = plt.figure(figsize=(12,6))
    ax = fig.add_subplot(111)
    ax.bar(np.arange(N_ACTIONS), counts)
    ax.set_title(f"{exp_name} – Training Action Distribution (Clinician)")
    ax.set_xlabel("Action (0..24)"); ax.set_ylabel("Count")
    fig.tight_layout()
    fp = os.path.join(exp_dir, f"{exp_name}_train_action_dist.png")
    fig.savefig(fp, dpi=160); plt.close(fig); return fp

def plot_policy_combined(exp_dir, exp_name, actions_pred):
    fb = np.array([action_to_bins(a)[0] for a in actions_pred], dtype=int)
    vb = np.array([action_to_bins(a)[1] for a in actions_pred], dtype=int)
    def prop(bins, k):
        c = np.array([np.sum(bins == i) for i in range(k)], dtype=float)
        return c / c.sum() if c.sum() > 0 else c
    fluid_prop = prop(fb, N_FLUID)
    vaso_prop  = prop(vb, N_VASO)
    heat = np.zeros((N_VASO, N_FLUID), dtype=float)
    for f, v in zip(fb, vb): heat[v, f] += 1.0
    if heat.sum() > 0: heat /= heat.sum()

    fig = plt.figure(figsize=(18,5.5))
    ax1 = fig.add_subplot(1,3,1); ax1.bar(np.arange(N_FLUID), fluid_prop); ax1.set_ylim(0,1)
    ax1.set_title(f"{exp_name}\nFluid Distribution (Agent on Test)")
    ax1.set_xlabel("Fluid Bin (0..4)"); ax1.set_ylabel("Proportion")

    ax2 = fig.add_subplot(1,3,2); ax2.bar(np.arange(N_VASO), vaso_prop); ax2.set_ylim(0,1)
    ax2.set_title(f"{exp_name}\nVasopressor Distribution (Agent on Test)")
    ax2.set_xlabel("Vasopressor Bin (0..4)"); ax2.set_ylabel("Proportion")

    ax3 = fig.add_subplot(1,3,3)
    im = ax3.imshow(heat, origin="lower", cmap="coolwarm", vmin=0, vmax=max(1e-12, heat.max()))
    ax3.set_xticks(range(N_FLUID)); ax3.set_yticks(range(N_VASO))
    ax3.set_xlabel("Fluid Bin"); ax3.set_ylabel("Vasopressor Bin")
    ax3.set_title(f"{exp_name}\n5×5 Action Heatmap (Agent on Test)")
    cb = fig.colorbar(im, ax=ax3, fraction=0.046, pad=0.04); cb.set_label("Proportion")
    fig.tight_layout()
    fp = os.path.join(exp_dir, f"{exp_name}_policy_combined.png")
    fig.savefig(fp, dpi=160); plt.close(fig); return fp

def plot_survival_by_agent_action(exp_dir, exp_name, actions_pred, survival_flags):
    if survival_flags is None: return None
    a = np.asarray(actions_pred).reshape(-1)
    s = np.asarray(survival_flags).reshape(-1)
    if a.shape[0] != s.shape[0]: return None
    rates = []
    for k in range(N_ACTIONS):
        m = (a == k)
        rates.append(np.mean(s[m]) if m.sum() > 0 else np.nan)
    rates = np.array(rates)
    fig = plt.figure(figsize=(12,6))
    ax = fig.add_subplot(111)
    ax.bar(np.arange(N_ACTIONS), np.nan_to_num(rates, nan=0.0))
    ax.set_ylim(0,1); ax.set_title(f"{exp_name} – Survival by Agent Action (Test)")
    ax.set_xlabel("Agent Action (0..24)"); ax.set_ylabel("Survival Rate")
    fig.tight_layout()
    fp = os.path.join(exp_dir, f"{exp_name}_survival_by_action.png")
    fig.savefig(fp, dpi=160); plt.close(fig); return fp

def rowwise_expected_return(R):
    R = np.asarray(R).reshape(-1)
    if np.nanmax(np.abs(R)) >= 20:   # account for +/-24 terminal scaling
        Rn = (R + 24.0) / 48.0
        return float(np.nanmean(Rn))
    return float(np.nanmean(R))

def compute_survival(flags):
    if flags is None: return np.nan
    f = np.asarray(flags).reshape(-1)
    return float(np.nanmean(f))

def main():
    data = load_data(FEATURE_SET)
    Xtr, Atr, Rtr, Xntr, Dtr = data["Xtr"], data["Atr"], data["Rtr"], data["Xntr"], data["Dtr"]
    Xte = data["Xte"]; survival_flags = data["Survival_test"]; exp_return_te = data["ExpectedReturn_test"]
    state_dim = Xtr.shape[1]

    rows = []
    for cfg in EXPERIMENTS:
        name, gamma, lr, seed = cfg["name"], cfg["gamma"], cfg["lr"], cfg["seed"]
        set_seed(seed)
        exp_dir = os.path.join(OUTDIR, name); os.makedirs(exp_dir, exist_ok=True)
        print(f"\n=== {name} | gamma={gamma}, lr={lr}, seed={seed}, epochs={EPOCHS} ===")

        # model
        model = None
        try:
            model = Dist_DQN(state_dim=state_dim, num_actions=N_ACTIONS, gamma=gamma, lr=lr, seed=seed)
        except TypeError:
            model = Dist_DQN(state_dim, N_ACTIONS, gamma, lr, seed)

        # train
        losses = []
        with open(os.path.join(exp_dir, f"{name}_training_log.txt"), "w") as lg:
            for ep in range(EPOCHS):
                loss = model_train_adapter(model, Xtr, Atr, Rtr, Xntr, Dtr, ep)
                losses.append(loss); line = f"Epoch {ep+1} | Loss: {loss:.4f}"
                print(line); lg.write(line + "\n")

        # plots
        plot_training_loss(exp_dir, name, losses)
        plot_train_action_distribution(exp_dir, name, Atr)
        actions_pred = predict_actions(model, Xte)
        plot_policy_combined(exp_dir, name, actions_pred)
        plot_survival_by_agent_action(exp_dir, name, actions_pred, survival_flags)

        # summary row
        exp_ret = float(np.nanmean(exp_return_te)) if exp_return_te is not None else rowwise_expected_return(Rtr)
        surv    = compute_survival(survival_flags)
        rows.append(dict(experiment=name, epochs=EPOCHS, gamma=gamma, lr=f"{lr:.1E}", seed=seed,
                         exp_return_test=exp_ret, survival_test_rowwise=surv))

    df = pd.DataFrame(rows)
    df.to_csv(os.path.join(OUTDIR, f"summary_metrics_{FEATURE_SET}.csv"), index=False)

    # small PNG snapshot
    try:
        fig, ax = plt.subplots(figsize=(18,4)); ax.axis("off")
        ax.set_title(f"summary_metrics_{FEATURE_SET}", fontsize=16, pad=12)
        table = ax.table(cellText=df.values, colLabels=df.columns, cellLoc='center', loc='center')
        table.auto_set_font_size(False); table.set_fontsize(10); table.scale(1,1.6)
        fig.tight_layout()
        fig.savefig(os.path.join(OUTDIR, f"summary_metrics_{FEATURE_SET}.png"), dpi=160)
        plt.close(fig)
    except Exception:
        pass

    print(f"\nDone. See runs/summary_metrics_{FEATURE_SET}.csv and plots in runs/*_{FEATURE_SET}/")

if __name__ == "__main__":
    main()
