In [None]:
import numpy as np
import torch
import pickle
from ID3QNE_deepQnet import Dist_DQN

# ====== CONFIG ======
device = 'mps' if torch.mps.is_available() else 'cpu'
epochs = 100  # change here if needed

# ====== Load Processed Patient Data ======
with open('requiredFile.pkl', 'rb') as f:
    MIMICtable = pickle.load(f)

X = MIMICtable['X_train']
y = MIMICtable['y_train']
Xnext = MIMICtable['Xnext_train']
Action = MIMICtable['Action_train']

# Convert to NumPy arrays
X = X.to_numpy()
y = y.to_numpy()
Xnext = Xnext.to_numpy()
Action = Action.to_numpy()

# Quick checks
print("Action stats:", np.min(Action), np.max(Action), "Unique:", np.unique(Action))
print("Reward stats:", np.unique(y, return_counts=True))

# ====== Initialize Model ======
model = Dist_DQN(state_dim=X.shape[1], n_actions=len(np.unique(Action)))

# ====== Prepare Batch ======
batchs = {
    'state': X,
    'next_state': Xnext,
    'action': Action,
    'reward': y
}

# Clear log at start (optional): open("training_log.txt", "w").close()

model.to(device)

# ====== Training Loop ======
for epoch in range(epochs):
    loss = model.train_model(batchs, epoch)
    print(f"Epoch {epoch + 1}/{epochs} | Loss: {loss:.4f}")

    # Log each epoch to file
    with open("training_log.txt", "a") as log_file:
        log_file.write(f"Epoch {epoch + 1} | Loss: {loss:.4f}\n")

print("Training complete.")
