In [1]:
import numpy as np
from tqdm import tqdm
import gymnasium as gym
from gymnasium.wrappers import FlattenObservation, RecordEpisodeStatistics
import torch
from torch.utils.tensorboard import SummaryWriter
from matplotlib import pyplot as plt
import matplotlib

from clinic_environment import ClinicEnv
from clinic_agent import ClinicDQNAgent, ReplayMemory, Transition

In [2]:
is_ipython = 'inline' in matplotlib.get_backend()
if is_ipython:
    from IPython import display

plt.ion()

<contextlib.ExitStack at 0x7c22d7af0170>

In [3]:
learning_rate = 1e-4
n_episodes = 40_000
start_epsilon = 1.0
epsilon_decay = start_epsilon / (n_episodes / 2)  # reduce the exploration over tim
final_epsilon = 0.1

In [4]:
writer = SummaryWriter()

In [5]:
clinic_capacity = np.array([1, 2])
clinic_travel_times = np.array([[0, 10], [10, 0]])
patient_times = np.array([30, 40, 50])
num_nurses = 3

clinic_env = ClinicEnv(clinic_capacity, clinic_travel_times, patient_times, num_nurses)
# clinic_env = RecordEpisodeStatistics(FlattenObservation(unwrapped_clinic_env))
# clinic_env.get_valid_actions = unwrapped_clinic_env.get_valid_actions

In [6]:
agent = ClinicDQNAgent(
    clinic_env, 
    learning_rate=learning_rate, 
    initial_epsilon=start_epsilon,
    epsilon_decay=epsilon_decay,
    final_epsilon=final_epsilon,
    n_iter=n_episodes,
    batch_size=256,
    device="cuda",
    writer=writer,
)

In [7]:
def play_episode(env, agent, randomize: bool = True, update_model: bool = True):
    obs, info = env.reset()
    done = False

    total_reward = 0
    while not done:
        action = agent.get_action(obs, randomize=randomize)
        next_obs, reward, terminated, truncated, info = env.step(action.item())

        if update_model:
            agent.update(obs, action, reward, terminated, next_obs)
        else:
            print(action.item())

        done = terminated or truncated
        obs = next_obs
        total_reward += reward

    return total_reward

In [None]:
for i in tqdm(range(n_episodes)):
    total_reward = play_episode(clinic_env, agent)
    writer.add_scalar("total reward", total_reward, i)

    agent.decay_epsilon()
    # agent.update_lr()

    # writer.add_scalar("Learning rate", agent.scheduler.get_last_lr()[0], i)

  7%|██████████▋                                                                                                                                      | 2947/40000 [01:17<40:20, 15.31it/s]

In [9]:
obs, info = clinic_env.reset()

obs_arr = gym.spaces.utils.flatten(clinic_env.nonterminal_normalized_observation_space, clinic_env.normalize_state(obs))
obs_tensor = torch.tensor(obs_arr, dtype=torch.float32, device="cuda").unsqueeze(0)
agent.policy_net(obs_tensor)

tensor([[0.0203, 0.0300, 0.0089, 0.0069, 0.0222, 0.0128]], device='cuda:0',
       grad_fn=<AddBackward0>)

In [10]:
obs, info = clinic_env.reset()

action = agent.get_action(obs, randomize=False)
print((action, obs))
obs, _, _, _, _ = clinic_env.step(action.item())

action = agent.get_action(obs, randomize=False)
print((action, obs))
obs, _, _, _, _ = clinic_env.step(action.item())

action = agent.get_action(obs, randomize=False)
print((action, obs))
obs, _, _, _, _ = clinic_env.step(action.item())


(tensor([[1]], device='cuda:0'), (0, {'nurse_turn': 0, 'nurses': ({'location': 1, 'operating_minutes_left': 0.0, 'traveling_minutes_left': 0.0, 'status': <NurseStatus.IDLE: 1>}, {'location': 1, 'operating_minutes_left': 0.0, 'traveling_minutes_left': 0.0, 'status': <NurseStatus.IDLE: 1>}, {'location': 1, 'operating_minutes_left': 0.0, 'traveling_minutes_left': 0.0, 'status': <NurseStatus.IDLE: 1>}), 'patients': ({'status': 1, 'treatment_time': 30.0, 'minutes_in_treatment': 0.0, 'treated_at': 0}, {'status': 1, 'treatment_time': 40.0, 'minutes_in_treatment': 0.0, 'treated_at': 0}, {'status': 1, 'treatment_time': 50.0, 'minutes_in_treatment': 0.0, 'treated_at': 0}), 'clinics': ({'capacity': 1.0, 'num_patients': 0.0}, {'capacity': 2.0, 'num_patients': 0.0})}))
(tensor([[4]], device='cuda:0'), (0, {'nurse_turn': 1, 'nurses': ({'location': 1, 'operating_minutes_left': 15.0, 'traveling_minutes_left': 0.0, 'status': <NurseStatus.IN_OPERATION: 2>}, {'location': 1, 'operating_minutes_left': 0.0,

In [55]:
1.0 / play_episode(clinic_env, agent, randomize=False, update_model=False)

1
4
5
0
0
0
0
0
4
1
1
0
0
0
4
0
0
4
4
2
4
5
0
0
0
0
4
3
5
0
0
0
0
0
5
2
0
5
0
4
0
0
0
0
0
0
0
0
0
3
0
0
0
0
4
0
0


19.0

In [16]:
torch.save(agent.policy_net.state_dict(), "policy_net.pt")
torch.save(agent.target_net.state_dict(), "target_net.pt")