In [1]:
from load_data import load_data
from add_reward import add_reward_df

In [2]:
df = load_data()
df = add_reward_df(df)

In [3]:
df = df.reset_index()

In [4]:
min_values = df.describe().loc['min']
max_values = df.describe().loc['max']

In [5]:
min_observation= min_values.loc[~min_values.index.isin(['SepsisLabel',
                                                        'patient', 
                                                        'zeros_reward',
                                                        'ones_reward'])].min()

max_observation = max_values.loc[~max_values.index.isin(['SepsisLabel',
                                                         'patient',
                                                         'zeros_reward',
                                                         'ones_reward'])].max()

In [6]:
min_observation

-25.545204389483803

In [7]:
max_observation

335.0

In [8]:
import random
import json
import gym
from gym import spaces
import pandas as pd
import numpy as np


class SepsisEnv(gym.Env):
    """A Sepsis environment for OpenAI gym"""
    metadata = {'render.modes': ['human']}
    
    def __init__(self, df):
        super(SepsisEnv, self).__init__()

        self.df = df
        self.reward_range = (-2.0, 1.0)
        # Only two possible actions, 0 for non-sepsis,
        # 1 for sepsis.
        n_actions = 2
        self.action_space = spaces.Discrete(n=n_actions)
        # Observation space is a feature vector of 41 vital signs, 
        # lab values, and other demographic information.
        self.observation_space = spaces.Box(
            low=-25.545204389483803, high=335.0, shape=(1, 42), dtype=np.float16)

    def _next_observation(self):
        obs = np.array([
        self.df.loc[self.current_step, ~df.columns.isin(['SepsisLabel',
                            'patient', 
                            'zeros_reward',
                            'ones_reward'])].values 

        ])

        return obs


    def step(self, action):
        # Execute one time step within the environment
        self.current_step += 1
        
        if action == 0:
            reward = self.df.loc[self.current_step, ['zeros_reward']]
        else:
            reward = self.df.loc[self.current_step, ['ones_reward']]

        done = False

        obs = self._next_observation()

        return obs, reward, done, {}

    def reset(self):
        # Reset the state of the environment to an initial state

        # Set the current step to a random point within the data frame
        self.current_step = random.randint(
            0, len(self.df.loc[:, 'HR'].values))

        return self._next_observation()

    def render(self, mode='human', close=False):
        # Render the environment to the screen
        print('current step' ,self.current_step)




In [11]:
import gym
import json
import datetime as dt
from stable_baselines.common.policies import MlpPolicy
from stable_baselines.common.vec_env import DummyVecEnv
from stable_baselines import PPO2
import pandas as pd

In [10]:
env = DummyVecEnv([lambda: SepsisEnv(df)])
model = PPO2(MlpPolicy, env, verbose=1)
model.learn(total_timesteps=20000)
obs = env.reset()
for i in range(2000):
  action, _states = model.predict(obs)
  obs, rewards, done, info = env.step(action)
  env.render()





Instructions for updating:
Use keras.layers.flatten instead.


Instructions for updating:
Use tf.where in 2.0, which has the same broadcast rule as np.where


--------------------------------------
| approxkl           | 0.0005859758  |
| clipfrac           | 0.0           |
| explained_variance | -2.04         |
| fps                | 235           |
| n_updates          | 1             |
| policy_entropy     | 0.6925658     |
| policy_loss        | -0.0029517147 |
| serial_timesteps   | 128           |
| time_elapsed       | 1.19e-05      |
| total_timesteps    | 128           |
| value_loss         | 0.06827691    |
--------------------------------------
---------------------------------------
| approxkl           | 0.0007327878   |
| clipfrac           | 0.0            |
| explained_variance | -0.0984        |
| fps                | 574            |
| n_updates          | 2              |
| policy_entropy     | 0.687957       |
| policy_loss        | -0.00033928873 |
| serial_t

--------------------------------------
| approxkl           | 0.00076092605 |
| clipfrac           | 0.0           |
| explained_variance | -21.1         |
| fps                | 530           |
| n_updates          | 9             |
| policy_entropy     | 0.45643207    |
| policy_loss        | -0.005715921  |
| serial_timesteps   | 1152          |
| time_elapsed       | 2.18          |
| total_timesteps    | 1152          |
| value_loss         | 0.042638525   |
--------------------------------------
--------------------------------------
| approxkl           | 0.0005644725  |
| clipfrac           | 0.0           |
| explained_variance | -18.5         |
| fps                | 518           |
| n_updates          | 10            |
| policy_entropy     | 0.4242148     |
| policy_loss        | -0.0060725035 |
| serial_timesteps   | 1280          |
| time_elapsed       | 2.42          |
| total_timesteps    | 1280          |
| value_loss         | 0.02453394    |
-------------------------

-------------------------------------
| approxkl           | 0.0004320044 |
| clipfrac           | 0.005859375  |
| explained_variance | -8.11        |
| fps                | 590          |
| n_updates          | 26           |
| policy_entropy     | 0.16622108   |
| policy_loss        | -0.00820413  |
| serial_timesteps   | 3328         |
| time_elapsed       | 6.4          |
| total_timesteps    | 3328         |
| value_loss         | 0.005887539  |
-------------------------------------
---------------------------------------
| approxkl           | 4.4406985e-05  |
| clipfrac           | 0.0            |
| explained_variance | -5.65          |
| fps                | 525            |
| n_updates          | 27             |
| policy_entropy     | 0.16878814     |
| policy_loss        | -0.00057377154 |
| serial_timesteps   | 3456           |
| time_elapsed       | 6.62           |
| total_timesteps    | 3456           |
| value_loss         | 0.006950955    |
--------------------------

--------------------------------------
| approxkl           | 1.3219941e-05 |
| clipfrac           | 0.0           |
| explained_variance | -8.15         |
| fps                | 549           |
| n_updates          | 43            |
| policy_entropy     | 0.073386356   |
| policy_loss        | -0.001955037  |
| serial_timesteps   | 5504          |
| time_elapsed       | 10.5          |
| total_timesteps    | 5504          |
| value_loss         | 0.0036316258  |
--------------------------------------
---------------------------------------
| approxkl           | 2.7096276e-05  |
| clipfrac           | 0.0            |
| explained_variance | -5.28          |
| fps                | 549            |
| n_updates          | 44             |
| policy_entropy     | 0.06837256     |
| policy_loss        | -0.00061274786 |
| serial_timesteps   | 5632           |
| time_elapsed       | 10.8           |
| total_timesteps    | 5632           |
| value_loss         | 0.007045415    |
-------------

--------------------------------------
| approxkl           | 1.496384e-05  |
| clipfrac           | 0.0           |
| explained_variance | -8.75         |
| fps                | 511           |
| n_updates          | 59            |
| policy_entropy     | 0.024660956   |
| policy_loss        | -0.0002661359 |
| serial_timesteps   | 7552          |
| time_elapsed       | 14.5          |
| total_timesteps    | 7552          |
| value_loss         | 0.0016101969  |
--------------------------------------
---------------------------------------
| approxkl           | 2.3360752e-05  |
| clipfrac           | 0.0            |
| explained_variance | 0.0225         |
| fps                | 500            |
| n_updates          | 60             |
| policy_entropy     | 0.028699094    |
| policy_loss        | -0.00070426567 |
| serial_timesteps   | 7680           |
| time_elapsed       | 14.7           |
| total_timesteps    | 7680           |
| value_loss         | 0.009982053    |
-------------

--------------------------------------
| approxkl           | 4.073894e-07  |
| clipfrac           | 0.0           |
| explained_variance | -3.8          |
| fps                | 511           |
| n_updates          | 75            |
| policy_entropy     | 0.020147275   |
| policy_loss        | -0.0002537201 |
| serial_timesteps   | 9600          |
| time_elapsed       | 18.6          |
| total_timesteps    | 9600          |
| value_loss         | 0.0036189116  |
--------------------------------------
--------------------------------------
| approxkl           | 0.00041446424 |
| clipfrac           | 0.0078125     |
| explained_variance | -4.68         |
| fps                | 497           |
| n_updates          | 76            |
| policy_entropy     | 0.029360488   |
| policy_loss        | -0.007292578  |
| serial_timesteps   | 9728          |
| time_elapsed       | 18.8          |
| total_timesteps    | 9728          |
| value_loss         | 0.0042488165  |
-------------------------

-------------------------------------
| approxkl           | 2.504851e-05 |
| clipfrac           | 0.0          |
| explained_variance | -3.2         |
| fps                | 503          |
| n_updates          | 92           |
| policy_entropy     | 0.02453153   |
| policy_loss        | -0.000855789 |
| serial_timesteps   | 11776        |
| time_elapsed       | 22.8         |
| total_timesteps    | 11776        |
| value_loss         | 0.001584404  |
-------------------------------------
--------------------------------------
| approxkl           | 4.0215644e-05 |
| clipfrac           | 0.0           |
| explained_variance | -2.16         |
| fps                | 554           |
| n_updates          | 93            |
| policy_entropy     | 0.019817818   |
| policy_loss        | -0.0006225342 |
| serial_timesteps   | 11904         |
| time_elapsed       | 23            |
| total_timesteps    | 11904         |
| value_loss         | 0.001364381   |
--------------------------------------

-------------------------------------
| approxkl           | 5.396963e-08 |
| clipfrac           | 0.0          |
| explained_variance | -3.94        |
| fps                | 614          |
| n_updates          | 109          |
| policy_entropy     | 0.013502684  |
| policy_loss        | 6.93223e-05  |
| serial_timesteps   | 13952        |
| time_elapsed       | 27.1         |
| total_timesteps    | 13952        |
| value_loss         | 0.0005771639 |
-------------------------------------
--------------------------------------
| approxkl           | 5.6358687e-05 |
| clipfrac           | 0.001953125   |
| explained_variance | -5.67         |
| fps                | 584           |
| n_updates          | 110           |
| policy_entropy     | 0.016552113   |
| policy_loss        | -0.0007900002 |
| serial_timesteps   | 14080         |
| time_elapsed       | 27.3          |
| total_timesteps    | 14080         |
| value_loss         | 0.0007951219  |
--------------------------------------

---------------------------------------
| approxkl           | 5.873223e-10   |
| clipfrac           | 0.0            |
| explained_variance | -23.4          |
| fps                | 558            |
| n_updates          | 126            |
| policy_entropy     | 0.005697978    |
| policy_loss        | -3.1022355e-06 |
| serial_timesteps   | 16128          |
| time_elapsed       | 31             |
| total_timesteps    | 16128          |
| value_loss         | 0.002425821    |
---------------------------------------
--------------------------------------
| approxkl           | 0.00014237258 |
| clipfrac           | 0.001953125   |
| explained_variance | -11.9         |
| fps                | 517           |
| n_updates          | 127           |
| policy_entropy     | 0.008133255   |
| policy_loss        | -0.0022964033 |
| serial_timesteps   | 16256         |
| time_elapsed       | 31.2          |
| total_timesteps    | 16256         |
| value_loss         | 0.003222824   |
------------

---------------------------------------
| approxkl           | 1.947207e-09   |
| clipfrac           | 0.0            |
| explained_variance | -64.1          |
| fps                | 500            |
| n_updates          | 143            |
| policy_entropy     | 0.00618615     |
| policy_loss        | -1.3639452e-05 |
| serial_timesteps   | 18304          |
| time_elapsed       | 35.2           |
| total_timesteps    | 18304          |
| value_loss         | 0.0005497271   |
---------------------------------------
---------------------------------------
| approxkl           | 5.766973e-09   |
| clipfrac           | 0.0            |
| explained_variance | -36            |
| fps                | 522            |
| n_updates          | 144            |
| policy_entropy     | 0.0073080286   |
| policy_loss        | -4.6955654e-05 |
| serial_timesteps   | 18432          |
| time_elapsed       | 35.4           |
| total_timesteps    | 18432          |
| value_loss         | 0.0006914686   |


current step 146448
current step 146449
current step 146450
current step 146451
current step 146452
current step 146453
current step 146454
current step 146455
current step 146456
current step 146457
current step 146458
current step 146459
current step 146460
current step 146461
current step 146462
current step 146463
current step 146464
current step 146465
current step 146466
current step 146467
current step 146468
current step 146469
current step 146470
current step 146471
current step 146472
current step 146473
current step 146474
current step 146475
current step 146476
current step 146477
current step 146478
current step 146479
current step 146480
current step 146481
current step 146482
current step 146483
current step 146484
current step 146485
current step 146486
current step 146487
current step 146488
current step 146489
current step 146490
current step 146491
current step 146492
current step 146493
current step 146494
current step 146495
current step 146496
current step 146497


current step 146920
current step 146921
current step 146922
current step 146923
current step 146924
current step 146925
current step 146926
current step 146927
current step 146928
current step 146929
current step 146930
current step 146931
current step 146932
current step 146933
current step 146934
current step 146935
current step 146936
current step 146937
current step 146938
current step 146939
current step 146940
current step 146941
current step 146942
current step 146943
current step 146944
current step 146945
current step 146946
current step 146947
current step 146948
current step 146949
current step 146950
current step 146951
current step 146952
current step 146953
current step 146954
current step 146955
current step 146956
current step 146957
current step 146958
current step 146959
current step 146960
current step 146961
current step 146962
current step 146963
current step 146964
current step 146965
current step 146966
current step 146967
current step 146968
current step 146969


current step 147414
current step 147415
current step 147416
current step 147417
current step 147418
current step 147419
current step 147420
current step 147421
current step 147422
current step 147423
current step 147424
current step 147425
current step 147426
current step 147427
current step 147428
current step 147429
current step 147430
current step 147431
current step 147432
current step 147433
current step 147434
current step 147435
current step 147436
current step 147437
current step 147438
current step 147439
current step 147440
current step 147441
current step 147442
current step 147443
current step 147444
current step 147445
current step 147446
current step 147447
current step 147448
current step 147449
current step 147450
current step 147451
current step 147452
current step 147453
current step 147454
current step 147455
current step 147456
current step 147457
current step 147458
current step 147459
current step 147460
current step 147461
current step 147462
current step 147463


current step 147906
current step 147907
current step 147908
current step 147909
current step 147910
current step 147911
current step 147912
current step 147913
current step 147914
current step 147915
current step 147916
current step 147917
current step 147918
current step 147919
current step 147920
current step 147921
current step 147922
current step 147923
current step 147924
current step 147925
current step 147926
current step 147927
current step 147928
current step 147929
current step 147930
current step 147931
current step 147932
current step 147933
current step 147934
current step 147935
current step 147936
current step 147937
current step 147938
current step 147939
current step 147940
current step 147941
current step 147942
current step 147943
current step 147944
current step 147945
current step 147946
current step 147947
current step 147948
current step 147949
current step 147950
current step 147951
current step 147952
current step 147953
current step 147954
current step 147955
