In [None]:
# analyze_agent_policy_combined.py
import pickle, numpy as np, torch, matplotlib.pyplot as plt
from matplotlib import cm, colors
from mpl_toolkits.mplot3d import Axes3D  # noqa
from ID3QNE_deepQnet import Dist_DQN

with open("requiredFile.pkl", "rb") as f:
    D = pickle.load(f)

X_test = D["X_test"].to_numpy()
state_dim = X_test.shape[1]; n_actions = D.get("nbins",25)

model = Dist_DQN(state_dim=state_dim, n_actions=n_actions)
model.q_net.load_state_dict(torch.load("model.pt", map_location="cpu"))
model.q_net.eval()

with torch.no_grad():
    q_vals = model.q_net(torch.tensor(X_test, dtype=torch.float32))
    actions = q_vals.argmax(1).cpu().numpy()

fluid = actions % 5
vaso  = actions // 5
N = len(actions)

fluid_prop = np.bincount(fluid, minlength=5).astype(float) / max(N,1)
vaso_prop  = np.bincount(vaso,  minlength=5).astype(float) / max(N,1)

heat = np.zeros((5,5), float)
for f,v in zip(fluid,vaso): heat[v,f]+=1.0
heat_prop = heat / max(N,1)

plt.close("all")
fig = plt.figure(figsize=(15,4.8), constrained_layout=True)

ax1 = fig.add_subplot(1,3,1)
ax1.bar(range(5), fluid_prop)
ax1.set_xticks(range(5)); ax1.set_xticklabels([f"F{i}" for i in range(5)])
ax1.set_ylim(0,1); ax1.set_xlabel("Fluid Bin (0..4)", fontsize=9)
ax1.set_ylabel("Proportion", fontsize=9); ax1.tick_params(labelsize=9)
ax1.set_title("Agent Policy\nFluid Distribution (Test, Proportion)", pad=10, fontsize=11)
ax1.grid(True, axis="y", alpha=0.3)

ax2 = fig.add_subplot(1,3,2)
ax2.bar(range(5), vaso_prop)
ax2.set_xticks(range(5)); ax2.set_xticklabels([f"V{i}" for i in range(5)])
ax2.set_ylim(0,1); ax2.set_xlabel("Vasopressor Bin (0..4)", fontsize=9)
ax2.set_ylabel("Proportion", fontsize=9); ax2.tick_params(labelsize=9)
ax2.set_title("Agent Policy\nVasopressor Distribution (Test, Proportion)", pad=10, fontsize=11)
ax2.grid(True, axis="y", alpha=0.3)

ax3 = fig.add_subplot(1,3,3, projection="3d")
_x = np.arange(5); _y = np.arange(5)
_xx,_yy = np.meshgrid(_x,_y)
x = _xx.ravel(); y=_yy.ravel(); z=np.zeros_like(x,dtype=float)
dx=np.full_like(x,0.6,float); dy=np.full_like(y,0.6,float)
dz = heat_prop.ravel()
cmap = cm.get_cmap("coolwarm")
norm = colors.Normalize(vmin=dz.min(), vmax=dz.max() if dz.max()>0 else 1.0)
bar_colors = cmap(norm(dz))
ax3.bar3d(x-0.3, y-0.3, z, dx, dy, dz, color=bar_colors, shade=True)
ax3.set_xticks(range(5)); ax3.set_xticklabels([f"F{i}" for i in range(5)], fontsize=8)
ax3.set_yticks(range(5)); ax3.set_yticklabels([f"V{i}" for i in range(5)], fontsize=8)
ax3.set_zlabel("Proportion", fontsize=9)
ax3.set_xlabel("Fluid Bin", fontsize=9); ax3.set_ylabel("Vasopressor Bin", fontsize=9)
ax3.set_title("Agent Policy\n5×5 Action Heatmap (Test, Proportion)", pad=10, fontsize=11)
mappable = cm.ScalarMappable(norm=norm, cmap=cmap); mappable.set_array([])
cb = fig.colorbar(mappable, ax=ax3, fraction=0.046, pad=0.08)
cb.set_label("Proportion", fontsize=9)

fig.savefig("policy_combined_like_paper_proportional.png", dpi=220)
plt.show()
print("Saved policy_combined_like_paper_proportional.png")
