In [2]:
import warnings
warnings.filterwarnings("ignore", message="Gym has been unmaintained since")

import gymnasium as gym
import numpy as np
from gymnasium import spaces


class HospitalEnv(gym.Env):
    def __init__(self, num_doctors=3, max_steps=100):
        super().__init__()

        self.num_doctors = num_doctors
        self.max_steps = max_steps

        # Observation space
        # [free_doctors, doctor_timers..., max_wait_red, max_wait_yellow, max_wait_green,
        #  len_red, len_yellow, len_green, total_queue_length]
        obs_size = 1 + self.num_doctors + 7
        obs_low = np.zeros(obs_size, dtype=np.float32)
        obs_high = np.array(
            [self.num_doctors] +                # free doctors
            [200] * self.num_doctors +          # doctor timers
            [200, 200, 200] +                   # max waits
            [10, 20, 30, 50],                   # queue lengths & total queue
            dtype=np.float32
        )
        self.observation_space = spaces.Box(low=obs_low, high=obs_high, dtype=np.float32)

        # Action space: RED=0, YELLOW=1, GREEN=2
        self.action_space = spaces.Discrete(3)

        # Queues + timers
        self.red_queue = []
        self.yellow_queue = []
        self.green_queue = []
        self.doctor_timers = np.zeros(self.num_doctors)
        self.current_step = 0

        # Reward coefficients
        self.C = {"red": 10.0, "yellow": 1.0, "green": 0.2}

        # Critical wait thresholds
        self.thresholds = {"red": 5, "yellow": 15, "green": 30}
        self.critical_penalty_values = {"red": 100, "yellow": 20, "green": 5}
        self.critical_penalty = {"red": 0, "yellow": 0, "green": 0}

        # Service time ranges (minutes)
        self.service_ranges = {
            "red": (8, 15),
            "yellow": (4, 8),
            "green": (2, 3)
        }

        # Arrival rates (Poisson λ per minute)
        self.arrival_lambda = {"red": 3, "yellow": 2, "green": 1}

        # Queue limit
        self.max_queue_length = 30

    # ---------------------------
    # Reset
    # ---------------------------
    def reset(self, seed=None, options=None):
        super().reset(seed=seed)
        self.red_queue = []
        self.yellow_queue = []
        self.green_queue = []
        self.doctor_timers[:] = 0
        self.current_step = 0
        self.critical_penalty = {"red": 0, "yellow": 0, "green": 0}
        return self._get_obs(), {}

    # ---------------------------
    # Observation
    # ---------------------------
    def _get_obs(self):
        total_queue = len(self.red_queue) + len(self.yellow_queue) + len(self.green_queue)
        obs = np.array([
            np.sum(self.doctor_timers == 0),          # free doctors
            *self.doctor_timers,                      # doctor timers
            max(self.red_queue) if self.red_queue else 0,
            max(self.yellow_queue) if self.yellow_queue else 0,
            max(self.green_queue) if self.green_queue else 0,
            len(self.red_queue),
            len(self.yellow_queue),
            len(self.green_queue),
            total_queue
        ], dtype=np.float32)
        return obs

    # ---------------------------
    # Sample service time
    # ---------------------------
    def _sample_service_time(self, action):
        queue_map = {0: "red", 1: "yellow", 2: "green"}
        low, high = self.service_ranges[queue_map[action]]
        return np.random.randint(low, high + 1)

    # ---------------------------
    # Add arrivals
    # ---------------------------
    def _add_arrivals(self, service_time):
        new_red = np.random.poisson(self.arrival_lambda["red"] * service_time)
        new_yellow = np.random.poisson(self.arrival_lambda["yellow"] * service_time)
        new_green = np.random.poisson(self.arrival_lambda["green"] * service_time)

        total_current = len(self.red_queue) + len(self.yellow_queue) + len(self.green_queue)
        total_new = new_red + new_yellow + new_green
        available_space = max(0, self.max_queue_length - total_current)

        # Scale arrivals if exceeding max queue length
        if total_new > available_space:
            factor = available_space / total_new
            new_red = int(new_red * factor)
            new_yellow = int(new_yellow * factor)
            new_green = int(new_green * factor)

        self.red_queue.extend([0] * new_red)
        self.yellow_queue.extend([0] * new_yellow)
        self.green_queue.extend([0] * new_green)

    # ---------------------------
    # Step
    # ---------------------------
    def step(self, action):
        self.current_step += 1

        # Find free doctor
        free_doctors = np.where(self.doctor_timers == 0)[0]
        if len(free_doctors) == 0:
            min_timer = min(t for t in self.doctor_timers if t > 0)
            self.doctor_timers = np.maximum(0, self.doctor_timers - min_timer)
            free_doctors = np.where(self.doctor_timers == 0)[0]

        doctor = free_doctors[0]

        queue_map = {0: self.red_queue, 1: self.yellow_queue, 2: self.green_queue}
        queue_name_map = {0: "red", 1: "yellow", 2: "green"}
        queue = queue_map[action]
        cat_name = queue_name_map[action]

        # Serve a patient
        if len(queue) == 0:
            service_time = 0
        else:
            wait_time = queue.pop(0)
            service_time = self._sample_service_time(action)
            self.doctor_timers[doctor] = service_time

            # Increment waiting time for all other patients by 1-5 minutes randomly
            self.red_queue = [w + np.random.randint(1, 6) for w in self.red_queue]
            self.yellow_queue = [w + np.random.randint(1, 6) for w in self.yellow_queue]
            self.green_queue = [w + np.random.randint(1, 6) for w in self.green_queue]

            # Reduce timers of other doctors
            for i in range(self.num_doctors):
                if i != doctor:
                    self.doctor_timers[i] = max(0, self.doctor_timers[i] - service_time)

        # Add arrivals
        self._add_arrivals(service_time)

        # Reward
        MaxWait_Red = max(self.red_queue) if self.red_queue else 0
        MaxWait_Yellow = max(self.yellow_queue) if self.yellow_queue else 0
        MaxWait_Green = max(self.green_queue) if self.green_queue else 0

        reward = 10 - (
            self.C["red"] * MaxWait_Red +
            self.C["yellow"] * MaxWait_Yellow +
            self.C["green"] * MaxWait_Green
        )

        # Critical penalties
        for q_name, q_list in [("red", self.red_queue),
                               ("yellow", self.yellow_queue),
                               ("green", self.green_queue)]:
            if any(w > self.thresholds[q_name] for w in q_list):
                self.critical_penalty[q_name] += self.critical_penalty_values[q_name]

        reward -= sum(self.critical_penalty.values())

        truncated = self.current_step >= self.max_steps
        return self._get_obs(), reward, False, truncated, {}


In [3]:
import warnings
warnings.filterwarnings("ignore", message="Gym has been unmaintained since")

import sys
sys.path.append("../env")  # to read hospital_env.py


import gymnasium as gym
import numpy as np
import torch

from stable_baselines3 import DQN
from stable_baselines3.common.evaluation import evaluate_policy
from stable_baselines3.common.monitor import Monitor
from stable_baselines3.common.vec_env import DummyVecEnv

# ------------------------------
# Set seeds for reproducibility
# ------------------------------
SEED = 42
np.random.seed(SEED)
torch.manual_seed(SEED)

# --- Create wrapped environment ---
def make_env():
    env = HospitalEnv()
    env.reset(seed=SEED)   # seed the environment here
    env = Monitor(env)      # important for SB3 logging
    return env

env = DummyVecEnv([make_env])

# --- Create DQN agent ---
model = DQN(
    "MlpPolicy",       # Fully connected NN
    env,
    learning_rate=5e-4,
    gamma=0.95,
    batch_size=64,
    buffer_size=50000,
    exploration_initial_eps=1.0,
    exploration_final_eps=0.1,
    exploration_fraction=0.1,   # epsilon decay
    target_update_interval=1000,
    verbose=1,
    seed=SEED                  # seed SB3 agent
)

# --- Train agent ---
model.learn(total_timesteps=50000)

# --- Save trained model ---
model.save("../models/dqn_hospital_sb3")
print("Model saved to models/dqn_hospital_sb3.zip")

# --- Evaluation ---
eval_env = DummyVecEnv([make_env])
mean_reward, std_reward = evaluate_policy(model, eval_env, n_eval_episodes=50)
print(f"Mean reward: {mean_reward:.2f} ± {std_reward:.2f}")

Gym has been unmaintained since 2022 and does not support NumPy 2.0 amongst other critical functionality.
Please upgrade to Gymnasium, the maintained drop-in replacement of Gym, or contact the authors of your software and request that they upgrade.
Users of this version of Gym should be able to simply replace 'import gym' with 'import gymnasium as gym' in the vast majority of cases.
See the migration guide at https://gymnasium.farama.org/introduction/migration_guide/ for additional information.


Using cpu device
----------------------------------
| rollout/            |          |
|    ep_len_mean      | 100      |
|    ep_rew_mean      | 1e+03    |
|    exploration_rate | 0.928    |
| time/               |          |
|    episodes         | 4        |
|    fps              | 651      |
|    time_elapsed     | 0        |
|    total_timesteps  | 400      |
| train/              |          |
|    learning_rate    | 0.0005   |
|    loss             | 7.62     |
|    n_updates        | 74       |
----------------------------------
----------------------------------
| rollout/            |          |
|    ep_len_mean      | 100      |
|    ep_rew_mean      | 1e+03    |
|    exploration_rate | 0.856    |
| time/               |          |
|    episodes         | 8        |
|    fps              | 560      |
|    time_elapsed     | 1        |
|    total_timesteps  | 800      |
| train/              |          |
|    learning_rate    | 0.0005   |
|    loss             | 0.00704  |
|  



Mean reward: 1000.00 ± 0.00
