In [1]:
from load_data import load_data
from add_reward import add_reward_df

In [2]:
df = load_data()
df = add_reward_df(df)

In [3]:
df = df.reset_index()

In [4]:
min_values = df.describe().loc['min']
max_values = df.describe().loc['max']

In [5]:
min_observation= min_values.loc[~min_values.index.isin(['SepsisLabel',
                                                        'patient', 
                                                        'zeros_reward',
                                                        'ones_reward'])].min()

max_observation = max_values.loc[~max_values.index.isin(['SepsisLabel',
                                                         'patient',
                                                         'zeros_reward',
                                                         'ones_reward'])].max()

In [6]:
min_observation

-25.545204389483803

In [7]:
max_observation

335.0

In [8]:
import random
import json
import gym
from gym import spaces
import pandas as pd
import numpy as np


class SepsisEnv(gym.Env):
    """A Sepsis environment for OpenAI gym"""
    metadata = {'render.modes': ['human']}
    
    def __init__(self, df):
        super(SepsisEnv, self).__init__()

        self.df = df
        self.reward_range = (-2.0, 1.0)
        # Only two possible actions, 0 for non-sepsis,
        # 1 for sepsis.
        n_actions = 2
        self.action_space = spaces.Discrete(n=n_actions)
        # Observation space is a feature vector of 41 vital signs, 
        # lab values, and other demographic information.
        self.observation_space = spaces.Box(
            low=-25.545204389483803, high=335.0, shape=(1, 42), dtype=np.float16)

    def _next_observation(self):
        obs = np.array([
        self.df.loc[self.current_step, ~df.columns.isin(['SepsisLabel',
                            'patient', 
                            'zeros_reward',
                            'ones_reward'])].values 

        ])

        return obs


    def step(self, action):
        # Execute one time step within the environment
        self.current_step += 1
        
        if action == 0:
            reward = self.df.loc[self.current_step, ['zeros_reward']]
        else:
            reward = self.df.loc[self.current_step, ['ones_reward']]

        done = False

        obs = self._next_observation()

        return obs, reward, done, {}

    def reset(self):
        # Reset the state of the environment to an initial state
        # Set the current step to a random point within the data frame
#         self.current_step = random.randint(
#             0, len(self.df.loc[:, 'HR'].values))
        self.current_step = 0

        return self._next_observation()

    def render(self, mode='human', close=False):
        # Render the environment to the screen
        print('current step' ,self.current_step)




In [9]:
import gym
import json
import datetime as dt
from stable_baselines.common.policies import MlpPolicy
from stable_baselines.common.vec_env import DummyVecEnv
from stable_baselines import PPO2
import pandas as pd

  _np_qint8 = np.dtype([("qint8", np.int8, 1)])
  _np_quint8 = np.dtype([("quint8", np.uint8, 1)])
  _np_qint16 = np.dtype([("qint16", np.int16, 1)])
  _np_quint16 = np.dtype([("quint16", np.uint16, 1)])
  _np_qint32 = np.dtype([("qint32", np.int32, 1)])
  np_resource = np.dtype([("resource", np.ubyte, 1)])
  _np_qint8 = np.dtype([("qint8", np.int8, 1)])
  _np_quint8 = np.dtype([("quint8", np.uint8, 1)])
  _np_qint16 = np.dtype([("qint16", np.int16, 1)])
  _np_quint16 = np.dtype([("quint16", np.uint16, 1)])
  _np_qint32 = np.dtype([("qint32", np.int32, 1)])
  np_resource = np.dtype([("resource", np.ubyte, 1)])


## Muli-layer Perception Model 

In [33]:
env = DummyVecEnv([lambda: SepsisEnv(df)])
model = PPO2(MlpPolicy, env, verbose=1)
model.learn(total_timesteps=20000)
reward_list = []
for i in range(500): 
    obs = env.reset()
    action, _states = model.predict(obs)
    obs, rewards, done, info = env.step(action)
    reward_list.append(rewards)

--------------------------------------
| approxkl           | 0.0030465946  |
| clipfrac           | 0.0           |
| explained_variance | -6.22         |
| fps                | 247           |
| n_updates          | 1             |
| policy_entropy     | 0.6905858     |
| policy_loss        | -0.0132151125 |
| serial_timesteps   | 128           |
| time_elapsed       | 1.17e-05      |
| total_timesteps    | 128           |
| value_loss         | 0.0394027     |
--------------------------------------
--------------------------------------
| approxkl           | 0.0020915268  |
| clipfrac           | 0.0           |
| explained_variance | -4.47         |
| fps                | 526           |
| n_updates          | 2             |
| policy_entropy     | 0.6740954     |
| policy_loss        | -0.0040575666 |
| serial_timesteps   | 256           |
| time_elapsed       | 0.518         |
| total_timesteps    | 256           |
| value_loss         | 0.04418189    |
-------------------------

--------------------------------------
| approxkl           | 0.0015226941  |
| clipfrac           | 0.0           |
| explained_variance | -7.64         |
| fps                | 522           |
| n_updates          | 14            |
| policy_entropy     | 0.526848      |
| policy_loss        | -0.0043677427 |
| serial_timesteps   | 1792          |
| time_elapsed       | 3.58          |
| total_timesteps    | 1792          |
| value_loss         | 0.024510322   |
--------------------------------------
--------------------------------------
| approxkl           | 0.00023695099 |
| clipfrac           | 0.0           |
| explained_variance | 0.0598        |
| fps                | 491           |
| n_updates          | 15            |
| policy_entropy     | 0.49055883    |
| policy_loss        | 0.0012263645  |
| serial_timesteps   | 1920          |
| time_elapsed       | 3.83          |
| total_timesteps    | 1920          |
| value_loss         | 3.303101      |
-------------------------

--------------------------------------
| approxkl           | 8.724536e-05  |
| clipfrac           | 0.0           |
| explained_variance | -3.75         |
| fps                | 519           |
| n_updates          | 31            |
| policy_entropy     | 0.17336532    |
| policy_loss        | -0.0017316388 |
| serial_timesteps   | 3968          |
| time_elapsed       | 7.88          |
| total_timesteps    | 3968          |
| value_loss         | 0.006179292   |
--------------------------------------
-------------------------------------
| approxkl           | 0.0003427327 |
| clipfrac           | 0.001953125  |
| explained_variance | -2.08        |
| fps                | 543          |
| n_updates          | 32           |
| policy_entropy     | 0.1558189    |
| policy_loss        | -0.005548742 |
| serial_timesteps   | 4096         |
| time_elapsed       | 8.13         |
| total_timesteps    | 4096         |
| value_loss         | 0.009351094  |
-------------------------------------

---------------------------------------
| approxkl           | 5.8099217e-06  |
| clipfrac           | 0.0            |
| explained_variance | -0.572         |
| fps                | 528            |
| n_updates          | 48             |
| policy_entropy     | 0.08506122     |
| policy_loss        | -0.00066616037 |
| serial_timesteps   | 6144           |
| time_elapsed       | 12.2           |
| total_timesteps    | 6144           |
| value_loss         | 0.003046819    |
---------------------------------------
---------------------------------------
| approxkl           | 7.191309e-07   |
| clipfrac           | 0.0            |
| explained_variance | -2.88          |
| fps                | 567            |
| n_updates          | 49             |
| policy_entropy     | 0.068581976    |
| policy_loss        | -0.00018969818 |
| serial_timesteps   | 6272           |
| time_elapsed       | 12.4           |
| total_timesteps    | 6272           |
| value_loss         | 0.0017128773   |


--------------------------------------
| approxkl           | 9.9558514e-05 |
| clipfrac           | 0.001953125   |
| explained_variance | -11.3         |
| fps                | 520           |
| n_updates          | 65            |
| policy_entropy     | 0.055474237   |
| policy_loss        | -0.0044465815 |
| serial_timesteps   | 8320          |
| time_elapsed       | 16.5          |
| total_timesteps    | 8320          |
| value_loss         | 0.0035133262  |
--------------------------------------
---------------------------------------
| approxkl           | 2.1753307e-07  |
| clipfrac           | 0.0            |
| explained_variance | 0.0162         |
| fps                | 520            |
| n_updates          | 66             |
| policy_entropy     | 0.05344635     |
| policy_loss        | -5.7102006e-05 |
| serial_timesteps   | 8448           |
| time_elapsed       | 16.7           |
| total_timesteps    | 8448           |
| value_loss         | 2.6761994      |
-------------

----------------------------------------
| approxkl           | 8.322248e-08    |
| clipfrac           | 0.0             |
| explained_variance | -1.42           |
| fps                | 487             |
| n_updates          | 82              |
| policy_entropy     | 0.030875854     |
| policy_loss        | -0.000106689695 |
| serial_timesteps   | 10496           |
| time_elapsed       | 20.8            |
| total_timesteps    | 10496           |
| value_loss         | 0.0021314102    |
----------------------------------------
-------------------------------------
| approxkl           | 5.012696e-06 |
| clipfrac           | 0.0          |
| explained_variance | 0.0607       |
| fps                | 521          |
| n_updates          | 83           |
| policy_entropy     | 0.044533502  |
| policy_loss        | -6.57856e-05 |
| serial_timesteps   | 10624        |
| time_elapsed       | 21.1         |
| total_timesteps    | 10624        |
| value_loss         | 0.58647084   |
-----------

---------------------------------------
| approxkl           | 3.702001e-07   |
| clipfrac           | 0.0            |
| explained_variance | 0.0182         |
| fps                | 564            |
| n_updates          | 99             |
| policy_entropy     | 0.035632446    |
| policy_loss        | -5.1140436e-05 |
| serial_timesteps   | 12672          |
| time_elapsed       | 25.2           |
| total_timesteps    | 12672          |
| value_loss         | 3.1869402      |
---------------------------------------
-------------------------------------
| approxkl           | 9.146965e-05 |
| clipfrac           | 0.001953125  |
| explained_variance | -3.27        |
| fps                | 551          |
| n_updates          | 100          |
| policy_entropy     | 0.02517091   |
| policy_loss        | -0.001820255 |
| serial_timesteps   | 12800        |
| time_elapsed       | 25.4         |
| total_timesteps    | 12800        |
| value_loss         | 0.009585283  |
------------------------

---------------------------------------
| approxkl           | 2.4768906e-10  |
| clipfrac           | 0.0            |
| explained_variance | 0.00542        |
| fps                | 494            |
| n_updates          | 116            |
| policy_entropy     | 0.013109433    |
| policy_loss        | -3.7741847e-07 |
| serial_timesteps   | 14848          |
| time_elapsed       | 29.5           |
| total_timesteps    | 14848          |
| value_loss         | 0.6430007      |
---------------------------------------
---------------------------------------
| approxkl           | 2.1734749e-10  |
| clipfrac           | 0.0            |
| explained_variance | -2.6           |
| fps                | 503            |
| n_updates          | 117            |
| policy_entropy     | 0.011981013    |
| policy_loss        | -4.8986403e-06 |
| serial_timesteps   | 14976          |
| time_elapsed       | 29.7           |
| total_timesteps    | 14976          |
| value_loss         | 0.0095793065   |


---------------------------------------
| approxkl           | 6.613744e-12   |
| clipfrac           | 0.0            |
| explained_variance | -0.000833      |
| fps                | 474            |
| n_updates          | 132            |
| policy_entropy     | 0.008464232    |
| policy_loss        | -1.0826625e-08 |
| serial_timesteps   | 16896          |
| time_elapsed       | 33.6           |
| total_timesteps    | 16896          |
| value_loss         | 2.7572923      |
---------------------------------------
--------------------------------------
| approxkl           | 0.00016190679 |
| clipfrac           | 0.00390625    |
| explained_variance | -5.67         |
| fps                | 502           |
| n_updates          | 133           |
| policy_entropy     | 0.015165213   |
| policy_loss        | -0.0015574781 |
| serial_timesteps   | 17024         |
| time_elapsed       | 33.9          |
| total_timesteps    | 17024         |
| value_loss         | 0.0024020902  |
------------

--------------------------------------
| approxkl           | 7.3081935e-10 |
| clipfrac           | 0.0           |
| explained_variance | -5.68         |
| fps                | 532           |
| n_updates          | 149           |
| policy_entropy     | 0.010359919   |
| policy_loss        | -9.786105e-06 |
| serial_timesteps   | 19072         |
| time_elapsed       | 38.1          |
| total_timesteps    | 19072         |
| value_loss         | 0.00392071    |
--------------------------------------
---------------------------------------
| approxkl           | 3.1495992e-09  |
| clipfrac           | 0.0            |
| explained_variance | -3.12          |
| fps                | 534            |
| n_updates          | 150            |
| policy_entropy     | 0.010975818    |
| policy_loss        | -3.8600992e-06 |
| serial_timesteps   | 19200          |
| time_elapsed       | 38.4           |
| total_timesteps    | 19200          |
| value_loss         | 0.002622829    |
-------------

In [34]:
sum(reward_list)

array([-0.90000015], dtype=float32)

## Random Model

In [35]:
reward_list = []
for i in range(500): 
    obs = env.reset()
    action = np.random.choice([0,1], size=1)
    obs, rewards, done, info = env.step(action)
    reward_list.append(rewards)



In [36]:
sum(reward_list)

array([-12.250029], dtype=float32)