# Dermatology Triage â€” Environment Smoke Test (Colab)
**What this notebook does**
- Implements a small Gym-style `ClinicEnv` with synthetic patients (no images).
- Tests the environment with random actions.
- Renders frames as RGB arrays and saves a short `random_demo.mp4`.
- Shows basic observation & reward sanity checks.

**Notes**
- This notebook is for *testing the environment only* (first mandatory step).
- Later you'll run training scripts (PPO/DQN/etc.) using the same env.


## Install dependencies

In [1]:
# Install dependencies
!pip install --quiet gym==0.26.5 imageio matplotlib numpy

[31mERROR: Could not find a version that satisfies the requirement gym==0.26.5 (from versions: 0.0.2, 0.0.3, 0.0.4, 0.0.5, 0.0.6, 0.0.7, 0.1.0, 0.1.1, 0.1.2, 0.1.3, 0.1.4, 0.1.5, 0.1.6, 0.1.7, 0.2.0, 0.2.1, 0.2.2, 0.2.3, 0.2.4, 0.2.5, 0.2.6, 0.2.7, 0.2.8, 0.2.9, 0.2.10, 0.2.11, 0.2.12, 0.3.0, 0.4.0, 0.4.1, 0.4.2, 0.4.3, 0.4.4, 0.4.5, 0.4.6, 0.4.8, 0.4.9, 0.4.10, 0.5.0, 0.5.1, 0.5.2, 0.5.3, 0.5.4, 0.5.5, 0.5.6, 0.5.7, 0.6.0, 0.7.0, 0.7.1, 0.7.2, 0.7.3, 0.7.4, 0.8.0.dev0, 0.8.0, 0.8.1, 0.8.2, 0.9.0, 0.9.1, 0.9.2, 0.9.3, 0.9.4, 0.9.5, 0.9.6, 0.9.7, 0.10.0, 0.10.1, 0.10.2, 0.10.3, 0.10.4, 0.10.5, 0.10.8, 0.10.9, 0.10.11, 0.11.0, 0.12.0, 0.12.1, 0.12.4, 0.12.5, 0.12.6, 0.13.0, 0.13.1, 0.14.0, 0.15.3, 0.15.4, 0.15.6, 0.15.7, 0.16.0, 0.17.0, 0.17.1, 0.17.2, 0.17.3, 0.18.0, 0.18.3, 0.19.0, 0.20.0, 0.21.0, 0.22.0, 0.23.0, 0.23.1, 0.24.0, 0.24.1, 0.25.0, 0.25.1, 0.25.2, 0.26.0, 0.26.1, 0.26.2)[0m[31m
[0m[31mERROR: No matching distribution found for gym==0.26.5[0m[31m
[0m

## Imports & helpers

In [2]:
import gym
import numpy as np
import random
from gym import spaces
import matplotlib.pyplot as plt
from typing import Tuple, Dict, Any
import imageio
from IPython.display import HTML, display
import base64
import os

Gym has been unmaintained since 2022 and does not support NumPy 2.0 amongst other critical functionality.
Please upgrade to Gymnasium, the maintained drop-in replacement of Gym, or contact the authors of your software and request that they upgrade.
See the migration guide at https://gymnasium.farama.org/introduction/migration_guide/ for additional information.
  return datetime.utcnow().replace(tzinfo=utc)


## ClinicEnv implementation

In [9]:
class ClinicEnv(gym.Env):
    """
    Minimal Gym env for Dermatology Triage (smoke-test).
    Observation:
        [age_norm, duration_norm, fever_flag, infection_flag,
         symptom_embed_0..7 (8 dims),
         room_avail (0/1), queue_len_norm, time_of_day_norm]
        => total dim = 1+1+1+1+8+1+1+1 = 14
    Actions (discrete, 8):
        0 send_doctor
        1 send_nurse
        2 remote_advice
        3 escalate_priority
        4 defer_patient
        5 idle
        6 open_room
        7 close_room
    Reward:
        +1 correct triage (mild/mod)
        +2 correct doctor for severe
        +3 correct: critical + prioritized fast
        -1.5 incorrect triage
        -0.01 * total_wait increment (penalty per step)
        -0.05 * num_open_rooms (resource cost)
    """
    metadata = {"render_modes": ["rgb_array"], "render_fps": 6}

    def __init__(self, seed: int = 0, max_steps: int = 500):
        super().__init__()
        self.seed(seed)
        self.max_steps = max_steps
        # observation space
        obs_low = np.array([0.0]*14, dtype=np.float32)
        obs_high = np.array([1.0]*14, dtype=np.float32)
        self.observation_space = spaces.Box(obs_low, obs_high, dtype=np.float32)
        # action space
        self.action_space = spaces.Discrete(8)
        # internal state
        self.step_count = 0
        self.num_open_rooms = 1
        self.queue = []  # patients waiting (list of patient dicts)
        self.current_patient = None
        self.total_wait = 0.0
        self.last_render = None
        self.reset()

    def seed(self, seed=None):
        self._seed = seed
        random.seed(seed)
        np.random.seed(seed)
        return [seed]

    def _sample_patient(self):
        # severity: 0=mild,1=moderate,2=severe,3=critical (hidden)
        severity = np.random.choice([0,1,2,3], p=[0.4,0.35,0.2,0.05])
        age_norm = np.clip(np.random.normal(0.5, 0.15), 0.0, 1.0)
        duration_norm = np.clip(np.random.exponential(0.5), 0.0, 1.0)
        fever_flag = 1.0 if (np.random.rand() < (0.05 + 0.2*severity)) else 0.0
        infection_flag = 1.0 if (np.random.rand() < (0.05 + 0.25*severity)) else 0.0
        # symptom embedding correlated with severity
        base = 0.2 + 0.25*severity
        symptom_embed = np.clip(np.random.normal(loc=base, scale=0.08, size=(8,)), 0.0, 1.0)
        patient = {
            "severity": int(severity),
            "age_norm": float(age_norm),
            "duration_norm": float(duration_norm),
            "fever_flag": float(fever_flag),
            "infection_flag": float(infection_flag),
            "symptom_embed": symptom_embed,
            "wait_time": 0.0
        }
        return patient

    def _form_observation(self, patient):
        vec = [
            patient["age_norm"],
            patient["duration_norm"],
            patient["fever_flag"],
            patient["infection_flag"],
        ]
        vec += list(patient["symptom_embed"])
        vec += [1.0 if self.num_open_rooms>0 else 0.0,  # room_avail
                np.clip(len(self.queue)/10.0, 0.0, 1.0),  # queue_len_norm
                np.clip(self.step_count / self.max_steps, 0.0, 1.0)]  # time_of_day_norm
        return np.array(vec, dtype=np.float32)

    def reset(self, seed=None, options=None):
        if seed is not None:
            self.seed(seed)
        self.step_count = 0
        self.num_open_rooms = 1
        self.queue = [self._sample_patient() for _ in range(3)]  # warm start queue
        self.current_patient = None
        self.total_wait = 0.0
        # spawn initial current patient
        self._maybe_spawn_next()
        obs = self._form_observation(self.current_patient)
        return obs

    def _maybe_spawn_next(self):
        if self.current_patient is None and len(self.queue) > 0:
            self.current_patient = self.queue.pop(0)
        elif self.current_patient is None:
            # if no queued patients, create one
            self.current_patient = self._sample_patient()

    def step(self, action: int) -> Tuple[np.ndarray, float, bool, Dict[str,Any]]:
        """
        Apply action to current patient and progress environment by 1 step.
        """
        assert self.action_space.contains(action), "Invalid action"
        self.step_count += 1
        done = False
        info = {}
        patient = self.current_patient
        # default reward components
        reward = 0.0

        # Determine correct triage for reward logic (ground truth)
        correct_action = None
        if patient["severity"] == 0:  # mild
            correct_action = 2  # remote_advice
        elif patient["severity"] == 1:  # moderate
            correct_action = 1  # nurse
        elif patient["severity"] == 2:  # severe
            correct_action = 0  # doctor
        else:  # critical
            correct_action = 3  # escalate_priority (doctor + priority)

        # Reward triage match
        if action == correct_action:
            if patient["severity"] == 0:
                reward += 1.0
            elif patient["severity"] == 1:
                reward += 1.0
            elif patient["severity"] == 2:
                reward += 2.0
            else:
                # critical handled correctly
                # bonus if wait was small
                if patient["wait_time"] < 5.0:
                    reward += 3.0
                else:
                    reward += 2.0
        else:
            reward += -1.5  # penalty for incorrect triage

        # action side-effects
        if action == 6:  # open_room
            self.num_open_rooms += 1
        if action == 7 and self.num_open_rooms > 0:  # close_room
            self.num_open_rooms -= 1
        if action == 4:  # defer_patient -> patient goes to end of queue
            patient["wait_time"] += 1.0
            self.queue.append(patient)
            self.current_patient = None
        else:
            # treat patient -> spawn next (non-defer actions treat patient and remove)
            self.current_patient = None

        # wait penalty: increment wait time for everyone in queue
        wait_increment = 0.01 * len(self.queue)
        for p in self.queue:
            p["wait_time"] += 1.0
        self.total_wait += wait_increment
        reward += -0.01 * wait_increment
        # resource cost
        reward += -0.05 * self.num_open_rooms

        # spawn next patient in queue to be current
        self._maybe_spawn_next()

        # observation
        obs = self._form_observation(self.current_patient)
        # termination
        if self.step_count >= self.max_steps:
            done = True
        # info include some bookkeeping for evaluation
        info["current_severity"] = int(patient["severity"])
        info["correct_action"] = int(correct_action)
        info["num_open_rooms"] = int(self.num_open_rooms)
        info["queue_length"] = len(self.queue)
        return obs, float(reward), done, info

    def render(self, mode="rgb_array"):
      """
      Pure NumPy 240x360 RGB render (Colab-safe, SB3-safe)

      Shows:
      - Severity bar (color changes by severity)
      - Text-like indicators (simple colored bars)
      """
      H, W = 240, 360
      canvas = np.ones((H, W, 3), dtype=np.uint8) * 255  # white

      # --- Severity bar ---
      sev = self.current_patient["severity"] if self.current_patient else 0
      sev_norm = sev / 3.0
      color = np.array([
          int(255 * sev_norm),          # R increases with severity
          int(180 * (1 - sev_norm)),    # G decreases with severity
          60                             # constant B
      ], dtype=np.uint8)

      # Draw bar
      canvas[20:200, 20:60] = color

      # --- Draw queue length indicator (blue bars) ---
      q_len = len(self.queue)
      q_h = min(q_len * 15, 150)
      canvas[20:20+q_h, 80:100] = [80, 80, 255]

      # --- Draw open room indicator (green bars) ---
      r = self.num_open_rooms
      r_h = min(r * 20, 150)
      canvas[20:20+r_h, 120:140] = [50, 200, 50]

      # Save frame
      self.last_render = canvas
      return canvas


## Sanity-check: create env and print obs shape

In [14]:
# create env and sample observation
env = ClinicEnv(seed=42, max_steps=200)
obs = env.reset()
print("Observation shape:", obs.shape)
print("Sample observation (first 6 dims):", obs[:6])
print("Action space:", env.action_space)
print("Observation space:", env.observation_space)

Observation shape: (15,)
Sample observation (first 6 dims): [0.5295292 0.7461227 0.        0.        1.        0.9637095]
Action space: Discrete(8)
Observation space: Box(0.0, 1.0, (14,), float32)


## Run a single random episode and print steps

In [15]:
obs = env.reset()
done = False
total_reward = 0.0
step = 0
while not done and step < 50:
    action = env.action_space.sample()
    obs, reward, done, info = env.step(action)
    total_reward += reward
    if step < 5:  # print first few step infos
        print(f"step {step}: action {action}, reward {reward:.3f}, info {info}")
    step += 1
print("Total reward (partial run):", total_reward)

step 0: action 3, reward -1.550, info {'current_severity': 2, 'correct_action': 0, 'num_open_rooms': 1, 'queue_length': 1}
step 1: action 3, reward -1.550, info {'current_severity': 0, 'correct_action': 2, 'num_open_rooms': 1, 'queue_length': 0}
step 2: action 1, reward -1.550, info {'current_severity': 0, 'correct_action': 2, 'num_open_rooms': 1, 'queue_length': 0}
step 3: action 6, reward -1.600, info {'current_severity': 0, 'correct_action': 2, 'num_open_rooms': 2, 'queue_length': 0}
step 4: action 3, reward 2.900, info {'current_severity': 3, 'correct_action': 3, 'num_open_rooms': 2, 'queue_length': 0}
Total reward (partial run): -67.40089999999998


## Random-play demo: produce frames and save mp4

In [16]:
# Run 3 random episodes, capture frames, save as mp4
out_fname = "random_demo.mp4"
writer = imageio.get_writer(out_fname, fps=6)

n_episodes = 3
max_frames_per_episode = 80
for ep in range(n_episodes):
    obs = env.reset()
    done = False
    frames = 0
    while not done and frames < max_frames_per_episode:
        frame = env.render(mode="rgb_array")
        writer.append_data(frame)

        action = env.action_space.sample()
        obs, reward, done, info = env.step(action)
        frames += 1
print("Finished episodes; closing writer.")
writer.close()




Finished episodes; closing writer.


## Display the MP4 inline

In [17]:
# Display video inline in Colab
from IPython.display import HTML
mp4 = open('random_demo.mp4','rb').read()
data_url = "data:video/mp4;base64," + base64.b64encode(mp4).decode()
HTML(f"""
<video width=640 controls>
  <source src="{data_url}" type="video/mp4">
</video>
""")


## Quick evaluation function

In [18]:
# small evaluate function to run random policy for multiple episodes and collect stats
def evaluate_random(env, episodes=20, max_steps=200):
    stats = {"episode_reward":[],"avg_queue_len":[],"avg_rooms":[]}
    for e in range(episodes):
        obs = env.reset()
        done=False
        total_reward = 0.0
        qsum = 0
        rsum = 0
        steps = 0
        while not done and steps < max_steps:
            action = env.action_space.sample()
            obs, reward, done, info = env.step(action)
            total_reward += reward
            qsum += info.get("queue_length",0)
            rsum += info.get("num_open_rooms",0)
            steps += 1
        stats["episode_reward"].append(total_reward)
        stats["avg_queue_len"].append(qsum/steps)
        stats["avg_rooms"].append(rsum/steps)
    return stats

stats = evaluate_random(env, episodes=10, max_steps=100)
import pandas as pd
df = pd.DataFrame(stats)
print(df.describe())


       episode_reward  avg_queue_len  avg_rooms
count       10.000000      10.000000   10.00000
mean      -120.766400       0.015000    2.33300
std          7.364786       0.008498    1.49903
min       -133.751400       0.010000    0.43000
25%       -125.189000       0.010000    1.55500
50%       -119.501300       0.010000    1.79500
75%       -114.526600       0.017500    2.75500
max       -112.650900       0.030000    5.15000


  return datetime.utcnow().replace(tzinfo=utc)
