In [2]:
# Filter tensorflow version warnings
import os
# https://stackoverflow.com/questions/40426502/is-there-a-way-to-suppress-the-messages-tensorflow-prints/40426709
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'  # or any {'0', '1', '2'}
import warnings
# https://stackoverflow.com/questions/15777951/how-to-suppress-pandas-future-warning
warnings.simplefilter(action='ignore', category=FutureWarning)
warnings.simplefilter(action='ignore', category=Warning)
import tensorflow as tf
tf.get_logger().setLevel('INFO')
tf.autograph.set_verbosity(0)
import logging
tf.get_logger().setLevel(logging.ERROR)

import gym
import re
import altair as alt
from stable_baselines.common.policies import MlpPolicy, MlpLstmPolicy, MlpLnLstmPolicy
from stable_baselines.deepq import DQN, MlpPolicy as DQN_MlpPolicy, LnMlpPolicy
from stable_baselines.common.vec_env import DummyVecEnv
from stable_baselines import PPO2, A2C


from env.SepsisEnv import SepsisEnv
from load_data import load_data
from add_reward import add_reward_df, add_end_episode_df
import pandas as pd
import numpy as np

from tqdm import tqdm

In [3]:
df = load_data()
df = add_reward_df(df)
df = add_end_episode_df(df)

In [4]:
df = df.reset_index()

In [54]:
total_timesteps = 20_000
iterations = 50_000

In [55]:
def train_model(env, model, total_timesteps, iterations):
    model.learn(total_timesteps=total_timesteps)
    reward_list = []
    obs = env.reset()
    patient_list = []
    patient_count = 0
    for _ in tqdm(range(iterations)):
        action, _states = model.predict(obs)
        obs, rewards, done, info = env.step(action)
        reward_list.append(rewards[0])
        patient_list.append(patient_count)
        if done:
            patient_count += 1           
            obs = env.reset()
        # env.render()
    model_name = re.sub(r'\W+', '', str(model.__class__).split('.')[-1])
    policy_name = re.sub(r'\W+', '', str(model.policy).split('.')[-1])
    print('Model: ', model_name)
    print('Policy: ', policy_name)
    print('Total patients: ', patient_count)
    print('Total reward:', sum(reward_list))
    
    name = model_name + ' ' + policy_name
    return (name, reward_list), patient_list

In [56]:
def train_baseline_models(df, iterations, constant=False):
    reward_list = []
    patient_list = []
    env = DummyVecEnv([lambda: SepsisEnv(df)])
    obs = env.reset()
    patient_count = 0
    for _ in tqdm(range(iterations)): 
        if constant:
            obs, rewards, done, info = env.step(np.array([0]))            
        else:
            action = np.random.choice([0,1], size=1)
            obs, rewards, done, info = env.step(action)
        reward_list.append(rewards[0])
        patient_list.append(patient_count)
        if done:
            patient_count += 1
            obs = env.reset()
    if constant:
        name = 'Constant'
        print(f'Model: {name}')
    else:
        name = 'Random'
        print(f'Model: {name}')
    print('Total patients: ', patient_count)
    print('Total reward:', sum(reward_list))
    
    return (name, reward_list)

In [57]:
env = DummyVecEnv([lambda: SepsisEnv(df)])

In [58]:
models = [
    PPO2(MlpPolicy, env, verbose=0),
    PPO2(MlpLstmPolicy, env, nminibatches=1, verbose=0),
    PPO2(MlpLnLstmPolicy, env, nminibatches=1, verbose=0),
    A2C(MlpPolicy, env, lr_schedule='constant'),
    A2C(MlpLstmPolicy, env, lr_schedule='constant'),
    DQN(env=env,
        policy=DQN_MlpPolicy,
        learning_rate=1e-3,
        buffer_size=50000,
        exploration_fraction=0.1,
        exploration_final_eps=0.02,
        ),
    DQN(env=env,
        policy=LnMlpPolicy,
        learning_rate=1e-3,
        buffer_size=50000,
        exploration_fraction=0.1,
        exploration_final_eps=0.02,
        )
]

In [59]:
rewards = []
patients = []
for model in models:
    env = DummyVecEnv([lambda: SepsisEnv(df)])
    reward_list = train_model(env=env, model=model, total_timesteps=total_timesteps, iterations=iterations)
    rewards.append(reward_list[0])
    patients.append(reward_list[1])

100%|██████████| 50000/50000 [01:49<00:00, 455.73it/s]


Model:  PPO2
Policy:  MlpPolicy
Total patients:  1316
Total reward: -1137.4277982078493


100%|██████████| 50000/50000 [02:01<00:00, 412.75it/s]


Model:  PPO2
Policy:  MlpLstmPolicy
Total patients:  1316
Total reward: -1146.8500203303993


100%|██████████| 50000/50000 [02:09<00:00, 385.34it/s]


Model:  PPO2
Policy:  MlpLnLstmPolicy
Total patients:  1316
Total reward: -1141.0055759213865


100%|██████████| 50000/50000 [01:52<00:00, 444.38it/s]


Model:  A2C
Policy:  MlpPolicy
Total patients:  1316
Total reward: -1134.566686861217


100%|██████████| 50000/50000 [02:01<00:00, 410.23it/s]


Model:  A2C
Policy:  MlpLstmPolicy
Total patients:  1316
Total reward: -1194.2555766105652


100%|██████████| 50000/50000 [01:50<00:00, 451.28it/s]


Model:  DQN
Policy:  MlpPolicy
Total patients:  1316
Total reward: -788.1666802763939


100%|██████████| 50000/50000 [01:54<00:00, 436.26it/s]

Model:  DQN
Policy:  LnMlpPolicy
Total patients:  1316
Total reward: -918.9055719338357





In [60]:
reward_list = train_baseline_models(df=df, iterations=iterations, constant=False)
rewards.append(reward_list)
reward_list = train_baseline_models(df=df, iterations=iterations, constant=True)
rewards.append(reward_list)

100%|██████████| 50000/50000 [01:28<00:00, 565.45it/s]
  0%|          | 118/50000 [00:00<01:24, 588.75it/s]

Model: Random
Total patients:  1316
Total reward: -1365.5666901543736


100%|██████████| 50000/50000 [01:25<00:00, 584.26it/s]

Model: Constant
Total patients:  1316
Total reward: -1138.8889092504978





In [61]:
reward_df = pd.DataFrame(dict(rewards))

reward_df['patients'] = patients[0]

reward_df = reward_df.groupby('patients').sum()

pivot = pd.melt(reward_df.cumsum().reset_index(),id_vars='patients', value_vars=[c for c in reward_df.columns])

pivot.value = pivot.value.round()

In [62]:
alt.data_transformers.disable_max_rows()

DataTransformerRegistry.enable('default')

In [64]:
alt.Chart(pivot).mark_line().encode(
    x=alt.X('patients', axis=alt.Axis(title='Patients Observed')),
    y=alt.Y('value', axis=alt.Axis(title='Cumulative Rewards')),
    color=alt.Color('variable', legend=alt.Legend(title='Model and Policy'))
)

## Interactive Charts

In [65]:
highlight = alt.selection(type='single', on='mouseover',
                          fields=['variable'], nearest=True)

alt.Chart(pivot).mark_line().encode(
    x='patients',
    y='value',
    color='variable'
)

base = alt.Chart(pivot).mark_line().encode(
    x=alt.X('patients', axis=alt.Axis(title='Patients Observed')),
    y=alt.Y('value', axis=alt.Axis(title='Cumulative Rewards')),
    color=alt.Color('variable', legend=alt.Legend(title='Model and Policy'))
)
points = base.mark_circle().encode(
    opacity=alt.value(0)
).add_selection(
    highlight
).properties(
    width=600
)

lines = base.mark_line().encode(
    size=alt.condition(~highlight, alt.value(1), alt.value(3))
)

points + lines


In [71]:
# Create a selection that chooses the nearest point & selects based on x-value
nearest = alt.selection(type='single', nearest=True, on='mouseover',
                        fields=['patients'], empty='none')

# The basic line
line = alt.Chart(pivot).mark_line(interpolate='basis').encode(
    x=alt.X('patients', axis=alt.Axis(title='Patients Observed')),
    y=alt.Y('value', axis=alt.Axis(title='Cumulative Rewards')),
    color=alt.Color('variable', legend=alt.Legend(title='Model and Policy'))
)

# Transparent selectors across the chart. This is what tells us
# the x-value of the cursor
selectors = alt.Chart(pivot).mark_point().encode(
    x='patients:Q',
    opacity=alt.value(0),
).add_selection(
    nearest
)

# Draw points on the line, and highlight based on selection
points = line.mark_point().encode(
    opacity=alt.condition(nearest, alt.value(1), alt.value(0))
)

# Draw text labels near the points, and highlight based on selection
text = line.mark_text(align='left', dx=5, dy=-5).encode(
    text=alt.condition(nearest, 'value:Q', alt.value(' '))
)

# Draw a rule at the location of the selection
rules = alt.Chart(pivot).mark_rule(color='gray').encode(
    x='patients:Q',
).transform_filter(
    nearest
)

# Put the five layers into a chart and bind the data
alt.layer(
    line, selectors, points, rules, text
).properties(
    width=600, height=300
)