# Build State Space

In [1]:
import pandas as pd
from collections import Counter

import numpy as np

from multiprocessing import Pool

from sklearn.impute import KNNImputer
from sklearn.impute import SimpleImputer

from sklearn.preprocessing import StandardScaler, MinMaxScaler
from sklearn.model_selection import train_test_split

import math

from tqdm import tqdm
import time

In [2]:
import warnings
warnings.filterwarnings('ignore')

In [3]:
physio_data = pd.read_csv('../icu_data/mimic_iv/physio_df_v5.csv')

In [4]:
physio_data

Unnamed: 0,subject_id,hadm_id,stay_id,time,icu_starttime,icu_endtime,los,discharge_fail,readmission,readmission_count,...,INR,M,discharge_action,Blood Pressure Systolic,Blood Pressure Diastolic,Blood Pressure Mean,Temperature C,SaO2,GCS score,weight
0,10000032,29079034,39553978,2180-07-23 23:50:47,2180-07-23 14:00:00,2180-07-23 23:50:47,0.410266,0,0,0,...,1.34,0,1,88.900000,54.100000,62.300000,37.203704,96.300000,14.666667,39.326426
1,10000690,25860671,37081114,2150-11-03 07:37:00,2150-11-02 19:37:00,2150-11-06 17:03:17,3.893252,0,0,0,...,1.25,0,0,110.833333,50.250000,65.833333,36.444444,96.615385,15.000000,55.156787
2,10000690,25860671,37081114,2150-11-03 19:37:00,2150-11-02 19:37:00,2150-11-06 17:03:17,3.893252,0,0,0,...,1.45,0,0,107.083333,51.083333,64.500000,36.833333,92.909091,13.000000,55.156787
3,10000690,25860671,37081114,2150-11-04 07:37:00,2150-11-02 19:37:00,2150-11-06 17:03:17,3.893252,0,0,0,...,1.59,0,0,118.250000,52.416667,68.000000,36.148148,97.000000,13.000000,55.300000
4,10000690,25860671,37081114,2150-11-04 19:37:00,2150-11-02 19:37:00,2150-11-06 17:03:17,3.893252,0,0,0,...,1.35,0,0,138.250000,61.083333,80.083333,36.185185,95.916667,15.000000,55.156787
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
849011,19999840,21033226,38978960,2164-09-17 16:35:15,2164-09-12 09:26:28,2164-09-17 16:35:15,5.297766,0,0,0,...,1.20,1,1,91.500000,49.000000,52.625000,37.111111,93.000000,3.000000,77.500000
849012,19999987,23865745,36195440,2145-11-03 10:59:00,2145-11-02 22:59:00,2145-11-04 21:29:30,1.937847,0,0,0,...,1.10,0,0,109.916667,73.000000,80.416667,37.574074,99.538462,7.833333,88.600000
849013,19999987,23865745,36195440,2145-11-03 22:59:00,2145-11-02 22:59:00,2145-11-04 21:29:30,1.937847,0,0,0,...,1.10,0,0,106.785714,65.285714,76.857143,37.296296,97.230769,8.666667,94.200000
849014,19999987,23865745,36195440,2145-11-04 10:59:00,2145-11-02 22:59:00,2145-11-04 21:29:30,1.937847,0,0,0,...,1.10,0,0,118.272727,71.636364,80.818182,37.569444,96.500000,12.833333,60.000000


In [5]:
physio_data = physio_data.drop_duplicates()

In [6]:
physio_data.info()

<class 'pandas.core.frame.DataFrame'>
Int64Index: 849013 entries, 0 to 849015
Data columns (total 54 columns):
 #   Column                          Non-Null Count   Dtype  
---  ------                          --------------   -----  
 0   subject_id                      849013 non-null  int64  
 1   hadm_id                         849013 non-null  int64  
 2   stay_id                         849013 non-null  int64  
 3   time                            849013 non-null  object 
 4   icu_starttime                   849013 non-null  object 
 5   icu_endtime                     849013 non-null  object 
 6   los                             849013 non-null  float64
 7   discharge_fail                  849013 non-null  int64  
 8   readmission                     849013 non-null  int64  
 9   readmission_count               849013 non-null  int64  
 10  death_in_ICU                    849013 non-null  int64  
 11  death_out_ICU                   849013 non-null  int64  
 12  age             

In [7]:
len(pd.unique(physio_data['stay_id']))

85042

In [8]:
m = physio_data[physio_data['discharge_action'] == 1]

In [9]:
duplicates = m[m.duplicated(subset=['stay_id'])]

In [10]:
duplicates

Unnamed: 0,subject_id,hadm_id,stay_id,time,icu_starttime,icu_endtime,los,discharge_fail,readmission,readmission_count,...,INR,M,discharge_action,Blood Pressure Systolic,Blood Pressure Diastolic,Blood Pressure Mean,Temperature C,SaO2,GCS score,weight
790966,19350933,21370988,38685898,2162-04-16 15:22:13,2162-04-15 15:22:13,2162-04-16 15:22:13,1.0,0,0,0,...,1.0,1,1,125.833333,64.5,79.0,36.388889,96.0,15.0,107.0


In [11]:
m[m['stay_id'] == 38685898]

Unnamed: 0,subject_id,hadm_id,stay_id,time,icu_starttime,icu_endtime,los,discharge_fail,readmission,readmission_count,...,INR,M,discharge_action,Blood Pressure Systolic,Blood Pressure Diastolic,Blood Pressure Mean,Temperature C,SaO2,GCS score,weight
790965,19350933,21370988,38685898,2162-04-16 15:22:13,2162-04-15 15:22:13,2162-04-16 15:22:13,1.0,0,0,0,...,1.0,1,1,125.833333,64.5,79.0,36.388889,96.0,15.0,106.775557
790966,19350933,21370988,38685898,2162-04-16 15:22:13,2162-04-15 15:22:13,2162-04-16 15:22:13,1.0,0,0,0,...,1.0,1,1,125.833333,64.5,79.0,36.388889,96.0,15.0,107.0


In [12]:
physio_data[physio_data['subject_id'] == 19350933]

Unnamed: 0,subject_id,hadm_id,stay_id,time,icu_starttime,icu_endtime,los,discharge_fail,readmission,readmission_count,...,INR,M,discharge_action,Blood Pressure Systolic,Blood Pressure Diastolic,Blood Pressure Mean,Temperature C,SaO2,GCS score,weight
790964,19350933,21370988,38685898,2162-04-16 03:22:13,2162-04-15 15:22:13,2162-04-16 15:22:13,1.0,0,0,0,...,1.0,1,0,128.25,59.916667,76.833333,36.777778,95.785714,15.0,106.775557
790965,19350933,21370988,38685898,2162-04-16 15:22:13,2162-04-15 15:22:13,2162-04-16 15:22:13,1.0,0,0,0,...,1.0,1,1,125.833333,64.5,79.0,36.388889,96.0,15.0,106.775557
790966,19350933,21370988,38685898,2162-04-16 15:22:13,2162-04-15 15:22:13,2162-04-16 15:22:13,1.0,0,0,0,...,1.0,1,1,125.833333,64.5,79.0,36.388889,96.0,15.0,107.0


In [13]:
physio_data = physio_data.drop(790966)

In [14]:
physio_data[physio_data['subject_id'] == 19350933]

Unnamed: 0,subject_id,hadm_id,stay_id,time,icu_starttime,icu_endtime,los,discharge_fail,readmission,readmission_count,...,INR,M,discharge_action,Blood Pressure Systolic,Blood Pressure Diastolic,Blood Pressure Mean,Temperature C,SaO2,GCS score,weight
790964,19350933,21370988,38685898,2162-04-16 03:22:13,2162-04-15 15:22:13,2162-04-16 15:22:13,1.0,0,0,0,...,1.0,1,0,128.25,59.916667,76.833333,36.777778,95.785714,15.0,106.775557
790965,19350933,21370988,38685898,2162-04-16 15:22:13,2162-04-15 15:22:13,2162-04-16 15:22:13,1.0,0,0,0,...,1.0,1,1,125.833333,64.5,79.0,36.388889,96.0,15.0,106.775557


In [15]:
physio_data = physio_data.reset_index(drop = True)

In [16]:
physio_data.info()

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 849012 entries, 0 to 849011
Data columns (total 54 columns):
 #   Column                          Non-Null Count   Dtype  
---  ------                          --------------   -----  
 0   subject_id                      849012 non-null  int64  
 1   hadm_id                         849012 non-null  int64  
 2   stay_id                         849012 non-null  int64  
 3   time                            849012 non-null  object 
 4   icu_starttime                   849012 non-null  object 
 5   icu_endtime                     849012 non-null  object 
 6   los                             849012 non-null  float64
 7   discharge_fail                  849012 non-null  int64  
 8   readmission                     849012 non-null  int64  
 9   readmission_count               849012 non-null  int64  
 10  death_in_ICU                    849012 non-null  int64  
 11  death_out_ICU                   849012 non-null  int64  
 12  age             

In [None]:
# physio_data['RR'] = np.nan
# physio_data['TV'] = np.nan

# for i in range(len(physio_data)):
#     if physio_data['Tidal Volume (spontaneous)'].iloc[i] != 0:
#         physio_data['TV'].iloc[i] = physio_data['Tidal Volume (spontaneous)'].iloc[i]
#         if physio_data['Respiratory Rate (spontaneous)'].iloc[i] != 0:
#             physio_data['RR'].iloc[i] = physio_data['Respiratory Rate (spontaneous)'].iloc[i]
#         else:
#             physio_data['RR'].iloc[i] = physio_data['Respiratory Rate'].iloc[i]
#     else:
#         if physio_data['Respiratory Rate (spontaneous)'].iloc[i] != 0:
#             physio_data['TV'].iloc[i] = physio_data['Tidal Volume (observed)'].iloc[i]
#             physio_data['RR'].iloc[i] = physio_data['Respiratory Rate (spontaneous)'].iloc[i]
#         else:
#             physio_data['TV'].iloc[i] = physio_data['Tidal Volume (spontaneous)'].iloc[i]
#             physio_data['RR'].iloc[i] = physio_data['Respiratory Rate (spontaneous)'].iloc[i]

In [17]:
# initialize RR and TV as NaN
physio_data['RR'] = np.nan
physio_data['TV'] = np.nan

# condition 1: Tidal Volume (spontaneous) not equal to 0
mask_spont_tv = physio_data['Tidal Volume (spontaneous)'] != 0
physio_data.loc[mask_spont_tv, 'TV'] = physio_data.loc[mask_spont_tv, 'Tidal Volume (spontaneous)']

# condition 2: Respiratory Rate (spontaneous) not equal to 0
mask_spont_rr = physio_data['Respiratory Rate (spontaneous)'] != 0
physio_data.loc[mask_spont_tv & mask_spont_rr, 'RR'] = physio_data.loc[mask_spont_tv & mask_spont_rr, 'Respiratory Rate (spontaneous)']
physio_data.loc[mask_spont_tv & ~mask_spont_rr, 'RR'] = physio_data.loc[mask_spont_tv & ~mask_spont_rr, 'Respiratory Rate']

# condition 3: Tidal Volume (spontaneous) equal to 0
mask_observed_tv = physio_data['Tidal Volume (spontaneous)'] == 0
physio_data.loc[mask_observed_tv & mask_spont_rr, 'TV'] = physio_data.loc[mask_observed_tv & mask_spont_rr, 'Tidal Volume (observed)']
physio_data.loc[mask_observed_tv & mask_spont_rr, 'RR'] = physio_data.loc[mask_observed_tv & mask_spont_rr, 'Respiratory Rate (spontaneous)']

# condition 4: Tidal Volume (spontaneous) and Respiratory Rate (spontaneous) both equal to 0
physio_data.loc[mask_observed_tv & ~mask_spont_rr, 'TV'] = physio_data.loc[mask_observed_tv & ~mask_spont_rr, 'Tidal Volume (spontaneous)']
physio_data.loc[mask_observed_tv & ~mask_spont_rr, 'RR'] = physio_data.loc[mask_observed_tv & ~mask_spont_rr, 'Respiratory Rate (spontaneous)']

In [18]:
# physio_data.info()

In [19]:
physio_data_1 = physio_data.drop(columns = ['Respiratory Rate', 'Respiratory Rate (spontaneous)', 'Respiratory Rate (Set)', 'Respiratory Rate (Total)',
                                            'Tidal Volume (spontaneous)', 'Tidal Volume (set)', 'Tidal Volume (observed)']).copy()

In [20]:
# physio_data_1.info()

In [21]:
# physio_data_1.columns

In [22]:
physio_data_1[physio_data_1['subject_id'] == 16133115].tail(50)

Unnamed: 0,subject_id,hadm_id,stay_id,time,icu_starttime,icu_endtime,los,discharge_fail,readmission,readmission_count,...,discharge_action,Blood Pressure Systolic,Blood Pressure Diastolic,Blood Pressure Mean,Temperature C,SaO2,GCS score,weight,RR,TV
519368,16133115,24673862,37990758,2121-03-22 07:23:49,2121-02-18 14:53:49,2121-03-22 23:19:48,32.351377,0,1,6,...,0,103.0,47.0,66.0,36.555556,99.0,15.0,99.5,23.0,0.000284
519369,16133115,24673862,37990758,2121-03-22 08:08:49,2121-02-18 14:53:49,2121-03-22 23:19:48,32.351377,0,1,6,...,0,109.0,49.0,71.0,37.0,91.0,15.0,99.5,23.0,0.000284
519370,16133115,24673862,37990758,2121-03-22 08:53:49,2121-02-18 14:53:49,2121-03-22 23:19:48,32.351377,0,1,6,...,0,109.0,49.0,71.0,37.0,91.0,15.0,99.5,23.0,0.000284
519371,16133115,24673862,37990758,2121-03-22 09:38:49,2121-02-18 14:53:49,2121-03-22 23:19:48,32.351377,0,1,6,...,0,106.0,48.0,67.0,37.0,100.0,15.0,99.5,23.0,0.000284
519372,16133115,24673862,37990758,2121-03-22 10:23:49,2121-02-18 14:53:49,2121-03-22 23:19:48,32.351377,0,1,6,...,0,96.0,41.0,59.0,37.0,100.0,15.0,99.5,23.0,0.000284
519373,16133115,24673862,37990758,2121-03-22 11:08:49,2121-02-18 14:53:49,2121-03-22 23:19:48,32.351377,0,1,6,...,0,97.0,44.0,62.0,37.0,97.0,15.0,99.5,23.0,0.000284
519374,16133115,24673862,37990758,2121-03-22 11:53:49,2121-02-18 14:53:49,2121-03-22 23:19:48,32.351377,0,1,6,...,0,97.0,44.0,62.0,37.0,97.0,15.0,99.5,23.0,0.000284
519375,16133115,24673862,37990758,2121-03-22 12:38:49,2121-02-18 14:53:49,2121-03-22 23:19:48,32.351377,0,1,6,...,0,96.0,45.0,63.0,36.777778,99.0,15.0,99.5,23.0,0.000284
519376,16133115,24673862,37990758,2121-03-22 13:23:49,2121-02-18 14:53:49,2121-03-22 23:19:48,32.351377,0,1,6,...,0,107.0,47.0,69.0,36.777778,100.0,15.0,99.5,23.0,0.000284
519377,16133115,24673862,37990758,2121-03-22 14:08:49,2121-02-18 14:53:49,2121-03-22 23:19:48,32.351377,0,1,6,...,0,103.0,43.0,63.0,36.777778,100.0,15.0,99.5,23.0,0.000284


In [23]:
Counter(physio_data_1[physio_data_1['subject_id'] == 16133115]['readmission_count'])

Counter({6: 1036, 3: 641, 5: 156, 2: 128, 0: 118, 4: 44, 1: 14})

## Modify the DataFrame for RL training

- Extract the ICU stay cases with less than 15 days

In [24]:
physio_data_2 = physio_data_1[~physio_data_1['subject_id'].isin(pd.unique(physio_data_1[physio_data_1['los'] > 15.0]['subject_id']))].copy()
# physio_data_2 = physio_data_1.copy()

In [25]:
# physio_data_2 = physio_data_2[physio_data_2['los'] >= 1.0].copy()

In [26]:
physio_data_2 = physio_data_2.reset_index(drop = True)

In [27]:
physio_data_2.info()

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 581977 entries, 0 to 581976
Data columns (total 49 columns):
 #   Column                    Non-Null Count   Dtype  
---  ------                    --------------   -----  
 0   subject_id                581977 non-null  int64  
 1   hadm_id                   581977 non-null  int64  
 2   stay_id                   581977 non-null  int64  
 3   time                      581977 non-null  object 
 4   icu_starttime             581977 non-null  object 
 5   icu_endtime               581977 non-null  object 
 6   los                       581977 non-null  float64
 7   discharge_fail            581977 non-null  int64  
 8   readmission               581977 non-null  int64  
 9   readmission_count         581977 non-null  int64  
 10  death_in_ICU              581977 non-null  int64  
 11  death_out_ICU             581977 non-null  int64  
 12  age                       581977 non-null  int64  
 13  Heart Rate                581977 non-null  f

- Only keep the cases **firstly** admitted to the ICU (**no requirement for readmission count now**).

In [28]:
# physio_data_2.head(50)

In [29]:
# physio_data_3 = physio_data_2[physio_data_2['readmission_count'] == 0].copy()
physio_data_3 = physio_data_2.copy()

In [30]:
physio_data_3.info()

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 581977 entries, 0 to 581976
Data columns (total 49 columns):
 #   Column                    Non-Null Count   Dtype  
---  ------                    --------------   -----  
 0   subject_id                581977 non-null  int64  
 1   hadm_id                   581977 non-null  int64  
 2   stay_id                   581977 non-null  int64  
 3   time                      581977 non-null  object 
 4   icu_starttime             581977 non-null  object 
 5   icu_endtime               581977 non-null  object 
 6   los                       581977 non-null  float64
 7   discharge_fail            581977 non-null  int64  
 8   readmission               581977 non-null  int64  
 9   readmission_count         581977 non-null  int64  
 10  death_in_ICU              581977 non-null  int64  
 11  death_out_ICU             581977 non-null  int64  
 12  age                       581977 non-null  int64  
 13  Heart Rate                581977 non-null  f

In [31]:
physio_data_3 = physio_data_3.reset_index(drop = True)

In [32]:
physio_data_3

Unnamed: 0,subject_id,hadm_id,stay_id,time,icu_starttime,icu_endtime,los,discharge_fail,readmission,readmission_count,...,discharge_action,Blood Pressure Systolic,Blood Pressure Diastolic,Blood Pressure Mean,Temperature C,SaO2,GCS score,weight,RR,TV
0,10000032,29079034,39553978,2180-07-23 23:50:47,2180-07-23 14:00:00,2180-07-23 23:50:47,0.410266,0,0,0,...,1,88.900000,54.100000,62.300000,37.203704,96.300000,14.666667,39.326426,5.960000,0.000485
1,10000690,25860671,37081114,2150-11-03 07:37:00,2150-11-02 19:37:00,2150-11-06 17:03:17,3.893252,0,0,0,...,0,110.833333,50.250000,65.833333,36.444444,96.615385,15.000000,55.156787,9.300000,0.000524
2,10000690,25860671,37081114,2150-11-03 19:37:00,2150-11-02 19:37:00,2150-11-06 17:03:17,3.893252,0,0,0,...,0,107.083333,51.083333,64.500000,36.833333,92.909091,13.000000,55.156787,8.500000,0.000445
3,10000690,25860671,37081114,2150-11-04 07:37:00,2150-11-02 19:37:00,2150-11-06 17:03:17,3.893252,0,0,0,...,0,118.250000,52.416667,68.000000,36.148148,97.000000,13.000000,55.300000,7.900000,0.000375
4,10000690,25860671,37081114,2150-11-04 19:37:00,2150-11-02 19:37:00,2150-11-06 17:03:17,3.893252,0,0,0,...,0,138.250000,61.083333,80.083333,36.185185,95.916667,15.000000,55.156787,14.800000,0.000337
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
581972,19999840,21033226,38978960,2164-09-17 16:35:15,2164-09-12 09:26:28,2164-09-17 16:35:15,5.297766,0,0,0,...,1,91.500000,49.000000,52.625000,37.111111,93.000000,3.000000,77.500000,24.666667,0.000635
581973,19999987,23865745,36195440,2145-11-03 10:59:00,2145-11-02 22:59:00,2145-11-04 21:29:30,1.937847,0,0,0,...,0,109.916667,73.000000,80.416667,37.574074,99.538462,7.833333,88.600000,19.615385,0.000473
581974,19999987,23865745,36195440,2145-11-03 22:59:00,2145-11-02 22:59:00,2145-11-04 21:29:30,1.937847,0,0,0,...,0,106.785714,65.285714,76.857143,37.296296,97.230769,8.666667,94.200000,9.666667,0.000641
581975,19999987,23865745,36195440,2145-11-04 10:59:00,2145-11-02 22:59:00,2145-11-04 21:29:30,1.937847,0,0,0,...,0,118.272727,71.636364,80.818182,37.569444,96.500000,12.833333,60.000000,9.666667,0.000641


In [33]:
Counter(physio_data_3['readmission_count'])

Counter({0: 398871,
         1: 101361,
         2: 41570,
         3: 20216,
         4: 10856,
         5: 5877,
         6: 1653,
         7: 278,
         8: 208,
         13: 130,
         32: 103,
         27: 70,
         28: 59,
         34: 54,
         33: 49,
         9: 46,
         19: 46,
         10: 44,
         12: 43,
         24: 42,
         31: 42,
         14: 40,
         15: 36,
         30: 35,
         16: 32,
         26: 31,
         29: 27,
         35: 25,
         23: 21,
         11: 17,
         20: 17,
         21: 17,
         22: 17,
         17: 15,
         25: 15,
         18: 14})

In [34]:
Counter(physio_data_3['readmission'])

Counter({0: 398871, 1: 183106})

In [35]:
# drop_patient_list = []

# for i in tqdm(range(len(physio_data_3))):
#     if physio_data_3['readmission_count'].iloc[i] >= 6:
#         drop_patient_list.append(physio_data_3['subject_id'].iloc[i])

In [36]:
# len(drop_patient_list)

In [37]:
drop_patient_list = physio_data_3.loc[physio_data_3['readmission_count'] >= 6, 'subject_id'].tolist()

In [38]:
physio_data_3 = physio_data_3[~physio_data_3['subject_id'].isin(drop_patient_list)].copy()

In [39]:
physio_data_3 = physio_data_3.reset_index(drop = True)

In [40]:
physio_data_3.info()

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 572692 entries, 0 to 572691
Data columns (total 49 columns):
 #   Column                    Non-Null Count   Dtype  
---  ------                    --------------   -----  
 0   subject_id                572692 non-null  int64  
 1   hadm_id                   572692 non-null  int64  
 2   stay_id                   572692 non-null  int64  
 3   time                      572692 non-null  object 
 4   icu_starttime             572692 non-null  object 
 5   icu_endtime               572692 non-null  object 
 6   los                       572692 non-null  float64
 7   discharge_fail            572692 non-null  int64  
 8   readmission               572692 non-null  int64  
 9   readmission_count         572692 non-null  int64  
 10  death_in_ICU              572692 non-null  int64  
 11  death_out_ICU             572692 non-null  int64  
 12  age                       572692 non-null  int64  
 13  Heart Rate                572692 non-null  f

In [41]:
Counter(physio_data_3['readmission_count'])

Counter({0: 398457, 1: 101001, 2: 41104, 3: 19197, 4: 9398, 5: 3535})

- Mark the decision epoch in the dataset

In [42]:
physio_data_3.head(50)

Unnamed: 0,subject_id,hadm_id,stay_id,time,icu_starttime,icu_endtime,los,discharge_fail,readmission,readmission_count,...,discharge_action,Blood Pressure Systolic,Blood Pressure Diastolic,Blood Pressure Mean,Temperature C,SaO2,GCS score,weight,RR,TV
0,10000032,29079034,39553978,2180-07-23 23:50:47,2180-07-23 14:00:00,2180-07-23 23:50:47,0.410266,0,0,0,...,1,88.9,54.1,62.3,37.203704,96.3,14.666667,39.326426,5.96,0.000485
1,10000690,25860671,37081114,2150-11-03 07:37:00,2150-11-02 19:37:00,2150-11-06 17:03:17,3.893252,0,0,0,...,0,110.833333,50.25,65.833333,36.444444,96.615385,15.0,55.156787,9.3,0.000524
2,10000690,25860671,37081114,2150-11-03 19:37:00,2150-11-02 19:37:00,2150-11-06 17:03:17,3.893252,0,0,0,...,0,107.083333,51.083333,64.5,36.833333,92.909091,13.0,55.156787,8.5,0.000445
3,10000690,25860671,37081114,2150-11-04 07:37:00,2150-11-02 19:37:00,2150-11-06 17:03:17,3.893252,0,0,0,...,0,118.25,52.416667,68.0,36.148148,97.0,13.0,55.3,7.9,0.000375
4,10000690,25860671,37081114,2150-11-04 19:37:00,2150-11-02 19:37:00,2150-11-06 17:03:17,3.893252,0,0,0,...,0,138.25,61.083333,80.083333,36.185185,95.916667,15.0,55.156787,14.8,0.000337
5,10000690,25860671,37081114,2150-11-05 07:37:00,2150-11-02 19:37:00,2150-11-06 17:03:17,3.893252,0,0,0,...,0,122.75,59.083333,75.0,36.111111,96.25,15.0,55.3,11.6,0.000303
6,10000690,25860671,37081114,2150-11-05 19:37:00,2150-11-02 19:37:00,2150-11-06 17:03:17,3.893252,0,0,0,...,0,142.461538,81.153846,91.230769,37.018519,94.230769,13.0,55.3,14.516667,0.00044
7,10000690,25860671,37081114,2150-11-06 07:37:00,2150-11-02 19:37:00,2150-11-06 17:03:17,3.893252,0,0,0,...,0,124.416667,66.5,80.416667,36.62963,98.416667,13.666667,55.3,12.5,0.000339
8,10000690,25860671,37081114,2150-11-06 17:03:17,2150-11-02 19:37:00,2150-11-06 17:03:17,3.893252,0,0,0,...,1,115.666667,59.333333,72.444444,36.518519,93.555556,15.0,55.156787,6.6,0.000569
9,10000980,26913865,39765666,2189-06-27 20:38:27,2189-06-27 08:42:00,2189-06-27 20:38:27,0.497535,0,0,0,...,1,142.454545,83.272727,97.545455,36.740741,98.909091,15.0,76.022019,27.0,0.000528


In [43]:
# physio_data_3['epoch'] = physio_data_3.groupby(['stay_id']).cumcount() + 1
physio_data_3['epoch'] = physio_data_3.groupby(['stay_id', 'readmission_count']).cumcount() + 1

In [44]:
physio_data_3[['stay_id', 'discharge_action', 'epoch']].head(50)

Unnamed: 0,stay_id,discharge_action,epoch
0,39553978,1,1
1,37081114,0,1
2,37081114,0,2
3,37081114,0,3
4,37081114,0,4
5,37081114,0,5
6,37081114,0,6
7,37081114,0,7
8,37081114,1,8
9,39765666,1,1


- Delete the cases decease during the ICU stay (**no requirement for death_out_ICU now**)

In [45]:
physio_data_3[['stay_id', 'death_in_ICU', 'discharge_action', 'epoch']].head(50)

Unnamed: 0,stay_id,death_in_ICU,discharge_action,epoch
0,39553978,0,1,1
1,37081114,0,0,1
2,37081114,0,0,2
3,37081114,0,0,3
4,37081114,0,0,4
5,37081114,0,0,5
6,37081114,0,0,6
7,37081114,0,0,7
8,37081114,0,1,8
9,39765666,0,1,1


In [46]:
physio_data_3['id_delete'] = 0.0
condition_1 = (physio_data_3['readmission_count'] == 0) & (physio_data_3['death_in_ICU'] == 1)


physio_data_3.loc[condition_1, 'id_delete'] = 1.0

In [47]:
physio_data_3

Unnamed: 0,subject_id,hadm_id,stay_id,time,icu_starttime,icu_endtime,los,discharge_fail,readmission,readmission_count,...,Blood Pressure Diastolic,Blood Pressure Mean,Temperature C,SaO2,GCS score,weight,RR,TV,epoch,id_delete
0,10000032,29079034,39553978,2180-07-23 23:50:47,2180-07-23 14:00:00,2180-07-23 23:50:47,0.410266,0,0,0,...,54.100000,62.300000,37.203704,96.300000,14.666667,39.326426,5.960000,0.000485,1,0.0
1,10000690,25860671,37081114,2150-11-03 07:37:00,2150-11-02 19:37:00,2150-11-06 17:03:17,3.893252,0,0,0,...,50.250000,65.833333,36.444444,96.615385,15.000000,55.156787,9.300000,0.000524,1,0.0
2,10000690,25860671,37081114,2150-11-03 19:37:00,2150-11-02 19:37:00,2150-11-06 17:03:17,3.893252,0,0,0,...,51.083333,64.500000,36.833333,92.909091,13.000000,55.156787,8.500000,0.000445,2,0.0
3,10000690,25860671,37081114,2150-11-04 07:37:00,2150-11-02 19:37:00,2150-11-06 17:03:17,3.893252,0,0,0,...,52.416667,68.000000,36.148148,97.000000,13.000000,55.300000,7.900000,0.000375,3,0.0
4,10000690,25860671,37081114,2150-11-04 19:37:00,2150-11-02 19:37:00,2150-11-06 17:03:17,3.893252,0,0,0,...,61.083333,80.083333,36.185185,95.916667,15.000000,55.156787,14.800000,0.000337,4,0.0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
572687,19999840,21033226,38978960,2164-09-17 16:35:15,2164-09-12 09:26:28,2164-09-17 16:35:15,5.297766,0,0,0,...,49.000000,52.625000,37.111111,93.000000,3.000000,77.500000,24.666667,0.000635,11,1.0
572688,19999987,23865745,36195440,2145-11-03 10:59:00,2145-11-02 22:59:00,2145-11-04 21:29:30,1.937847,0,0,0,...,73.000000,80.416667,37.574074,99.538462,7.833333,88.600000,19.615385,0.000473,1,0.0
572689,19999987,23865745,36195440,2145-11-03 22:59:00,2145-11-02 22:59:00,2145-11-04 21:29:30,1.937847,0,0,0,...,65.285714,76.857143,37.296296,97.230769,8.666667,94.200000,9.666667,0.000641,2,0.0
572690,19999987,23865745,36195440,2145-11-04 10:59:00,2145-11-02 22:59:00,2145-11-04 21:29:30,1.937847,0,0,0,...,71.636364,80.818182,37.569444,96.500000,12.833333,60.000000,9.666667,0.000641,3,0.0


In [48]:
physio_data_3 = physio_data_3[physio_data_3['id_delete'] != 1].copy()

In [49]:
# physio_data_4 = physio_data_3[physio_data_3['death_in_ICU'] == 0].copy()
physio_data_4 = physio_data_3.copy()

In [50]:
physio_data_4 = physio_data_4.reset_index(drop = True)

In [51]:
physio_data_4.info()

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 529907 entries, 0 to 529906
Data columns (total 51 columns):
 #   Column                    Non-Null Count   Dtype  
---  ------                    --------------   -----  
 0   subject_id                529907 non-null  int64  
 1   hadm_id                   529907 non-null  int64  
 2   stay_id                   529907 non-null  int64  
 3   time                      529907 non-null  object 
 4   icu_starttime             529907 non-null  object 
 5   icu_endtime               529907 non-null  object 
 6   los                       529907 non-null  float64
 7   discharge_fail            529907 non-null  int64  
 8   readmission               529907 non-null  int64  
 9   readmission_count         529907 non-null  int64  
 10  death_in_ICU              529907 non-null  int64  
 11  death_out_ICU             529907 non-null  int64  
 12  age                       529907 non-null  int64  
 13  Heart Rate                529907 non-null  f

In [52]:
# physio_data_4.to_csv('../data_output/physio_data_sta.csv', index = False)

In [53]:
physio_data_4[['stay_id', 'death_in_ICU', 'discharge_action', 'epoch']].head(50)

Unnamed: 0,stay_id,death_in_ICU,discharge_action,epoch
0,39553978,0,1,1
1,37081114,0,0,1
2,37081114,0,0,2
3,37081114,0,0,3
4,37081114,0,0,4
5,37081114,0,0,5
6,37081114,0,0,6
7,37081114,0,0,7
8,37081114,0,1,8
9,39765666,0,1,1


- Preparation for the RL training

In [54]:
physio_data_4.columns

Index(['subject_id', 'hadm_id', 'stay_id', 'time', 'icu_starttime',
       'icu_endtime', 'los', 'discharge_fail', 'readmission',
       'readmission_count', 'death_in_ICU', 'death_out_ICU', 'age',
       'Heart Rate', 'Arterial O2 pressure', 'Hemoglobin',
       'Arterial CO2 Pressure', 'PH (Venous)', 'Hematocrit (serum)', 'WBC',
       'Chloride (serum)', 'Creatinine (serum)', 'Glucose (serum)',
       'Magnesium', 'Sodium (serum)', 'PH (Arterial)', 'Inspired O2 Fraction',
       'Arterial Base Excess', 'BUN', 'Ionized Calcium', 'Total Bilirubin',
       'Glucose (whole blood)', 'Potassium (serum)', 'HCO3 (serum)',
       'Platelet Count', 'Prothrombin time', 'PTT', 'INR', 'M',
       'discharge_action', 'Blood Pressure Systolic',
       'Blood Pressure Diastolic', 'Blood Pressure Mean', 'Temperature C',
       'SaO2', 'GCS score', 'weight', 'RR', 'TV', 'epoch', 'id_delete'],
      dtype='object')

In [55]:
len(physio_data_4.columns)

51

In [56]:
def compute_qsofa(row):
    score = 0
    
    if row['RR'] >= 22:
        score += 1
        
    if row['Blood Pressure Systolic'] <= 100:
        score += 1
        
    if row['GCS score'] < 15:
        score += 1
    
    return score

physio_data_4['qSOFA'] = physio_data_4.apply(compute_qsofa, axis = 1)

In [57]:
physio_data_4

Unnamed: 0,subject_id,hadm_id,stay_id,time,icu_starttime,icu_endtime,los,discharge_fail,readmission,readmission_count,...,Blood Pressure Mean,Temperature C,SaO2,GCS score,weight,RR,TV,epoch,id_delete,qSOFA
0,10000032,29079034,39553978,2180-07-23 23:50:47,2180-07-23 14:00:00,2180-07-23 23:50:47,0.410266,0,0,0,...,62.300000,37.203704,96.300000,14.666667,39.326426,5.960000,0.000485,1,0.0,2
1,10000690,25860671,37081114,2150-11-03 07:37:00,2150-11-02 19:37:00,2150-11-06 17:03:17,3.893252,0,0,0,...,65.833333,36.444444,96.615385,15.000000,55.156787,9.300000,0.000524,1,0.0,0
2,10000690,25860671,37081114,2150-11-03 19:37:00,2150-11-02 19:37:00,2150-11-06 17:03:17,3.893252,0,0,0,...,64.500000,36.833333,92.909091,13.000000,55.156787,8.500000,0.000445,2,0.0,1
3,10000690,25860671,37081114,2150-11-04 07:37:00,2150-11-02 19:37:00,2150-11-06 17:03:17,3.893252,0,0,0,...,68.000000,36.148148,97.000000,13.000000,55.300000,7.900000,0.000375,3,0.0,1
4,10000690,25860671,37081114,2150-11-04 19:37:00,2150-11-02 19:37:00,2150-11-06 17:03:17,3.893252,0,0,0,...,80.083333,36.185185,95.916667,15.000000,55.156787,14.800000,0.000337,4,0.0,0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
529902,19999828,25744818,36075953,2149-01-10 13:11:02,2149-01-08 18:12:00,2149-01-10 13:11:02,1.790995,0,0,0,...,98.000000,36.888889,98.250000,15.000000,67.900000,4.700000,0.000583,4,0.0,0
529903,19999987,23865745,36195440,2145-11-03 10:59:00,2145-11-02 22:59:00,2145-11-04 21:29:30,1.937847,0,0,0,...,80.416667,37.574074,99.538462,7.833333,88.600000,19.615385,0.000473,1,0.0,1
529904,19999987,23865745,36195440,2145-11-03 22:59:00,2145-11-02 22:59:00,2145-11-04 21:29:30,1.937847,0,0,0,...,76.857143,37.296296,97.230769,8.666667,94.200000,9.666667,0.000641,2,0.0,1
529905,19999987,23865745,36195440,2145-11-04 10:59:00,2145-11-02 22:59:00,2145-11-04 21:29:30,1.937847,0,0,0,...,80.818182,37.569444,96.500000,12.833333,60.000000,9.666667,0.000641,3,0.0,1


In [58]:
name_first = ['subject_id', 'hadm_id', 'stay_id', 'time', 'discharge_action', 'epoch', 
              'icu_starttime', 'icu_endtime', 'los', 
              'discharge_fail', 
              'readmission', 'readmission_count', 'death_in_ICU', 'death_out_ICU', 'qSOFA']

name_second = ['age', 'M', 'weight', 'Heart Rate', 'Arterial O2 pressure', 'Hemoglobin',
               'Arterial CO2 Pressure', 'PH (Venous)', 'Hematocrit (serum)', 'WBC',
               'Chloride (serum)', 'Creatinine (serum)', 'Glucose (serum)',
               'Magnesium', 'Sodium (serum)', 'PH (Arterial)', 'Inspired O2 Fraction',
               'Arterial Base Excess', 'BUN', 'Ionized Calcium', 'Total Bilirubin',
               'Glucose (whole blood)', 'Potassium (serum)', 'HCO3 (serum)',
               'Platelet Count', 'Prothrombin time', 'PTT', 'INR', 
               'Blood Pressure Systolic', 'Blood Pressure Diastolic', 'Blood Pressure Mean', 'Temperature C',
               'SaO2', 'GCS score', 'RR', 'TV', 'readmission_count']

In [59]:
[len(name_first), len(name_second)]

[15, 37]

In [60]:
rl_cont_state_table = physio_data_4[name_second].copy()
state_id_table = physio_data_4[name_first].copy()

- First cost: mortality risk

In [61]:
state_id_table['death'] = 0

condition = (state_id_table['death_in_ICU'] == 1) | (state_id_table['death_out_ICU'] == 1)
state_id_table.loc[condition, 'death'] = 1

In [62]:
state_id_table['mortality_costs'] = 0

# condition_1 = (state_id_table['discharge_action'] == 1) & (state_id_table['death_out_ICU'] == 1)
# state_id_table.loc[condition_1, 'mortality_costs'] = 100

# condition_2 = (state_id_table['discharge_action'] == 1) & (state_id_table['death_in_ICU'] == 1)
# state_id_table.loc[condition_2, 'mortality_costs'] = 100

# condition_3 = (state_id_table['discharge_action'] == 1) & (state_id_table['death_out_ICU'] != 1)
# state_id_table.loc[condition_3, 'mortality_costs'] = 0

condition_1 = (state_id_table['discharge_action'] == 1) & (state_id_table['death'] == 1)
state_id_table.loc[condition_1, 'mortality_costs'] = 1

condition_2 = (state_id_table['discharge_action'] == 1) & (state_id_table['death'] != 1)
state_id_table.loc[condition_2, 'mortality_costs'] = 0

# condition_3 = (state_id_table['discharge_action'] == 1) & (state_id_table['death'] != 1) & (state_id_table['readmission'] == 1)
# state_id_table.loc[condition_3, 'mortality_costs'] = 0

In [63]:
Counter(state_id_table['mortality_costs'])

Counter({0: 522457, 1: 7450})

In [64]:
state_id_table[['mortality_costs']].describe().T

Unnamed: 0,count,mean,std,min,25%,50%,75%,max
mortality_costs,529907.0,0.014059,0.117735,0.0,0.0,0.0,0.0,1.0


In [65]:
Counter(state_id_table['discharge_fail'])

Counter({0: 440912, 1: 88995})

In [66]:
state_id_table[['stay_id', 'discharge_action', 'discharge_fail', 'mortality_costs']].head(50)

Unnamed: 0,stay_id,discharge_action,discharge_fail,mortality_costs
0,39553978,1,0,0
1,37081114,0,0,0
2,37081114,0,0,0
3,37081114,0,0,0
4,37081114,0,0,0
5,37081114,0,0,0
6,37081114,0,0,0
7,37081114,0,0,0
8,37081114,1,0,0
9,39765666,1,0,0


In [67]:
len(pd.unique(state_id_table['stay_id']))

73180

- Second cost: discharge failure


In [68]:
state_id_table['discharge_fail_costs'] = 0

condition_1 = (state_id_table['discharge_action'] == 1) & (state_id_table['discharge_fail'] == 1)
state_id_table.loc[condition_1, 'discharge_fail_costs'] = 1

condition_2 = (state_id_table['discharge_action'] == 1) & (state_id_table['discharge_fail'] != 1)
state_id_table.loc[condition_2, 'discharge_fail_costs'] = 0

In [69]:
state_id_table[['discharge_fail_costs']].describe().T

Unnamed: 0,count,mean,std,min,25%,50%,75%,max
discharge_fail_costs,529907.0,0.01889,0.136137,0.0,0.0,0.0,0.0,1.0


- Third cost: Length-of-Stay (LOS) in ICU

In [70]:
state_id_table.info()

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 529907 entries, 0 to 529906
Data columns (total 18 columns):
 #   Column                Non-Null Count   Dtype  
---  ------                --------------   -----  
 0   subject_id            529907 non-null  int64  
 1   hadm_id               529907 non-null  int64  
 2   stay_id               529907 non-null  int64  
 3   time                  529907 non-null  object 
 4   discharge_action      529907 non-null  int64  
 5   epoch                 529907 non-null  int64  
 6   icu_starttime         529907 non-null  object 
 7   icu_endtime           529907 non-null  object 
 8   los                   529907 non-null  float64
 9   discharge_fail        529907 non-null  int64  
 10  readmission           529907 non-null  int64  
 11  readmission_count     529907 non-null  int64  
 12  death_in_ICU          529907 non-null  int64  
 13  death_out_ICU         529907 non-null  int64  
 14  qSOFA                 529907 non-null  int64  
 15  

In [71]:
state_id_table['time'] = pd.to_datetime(state_id_table['time'])
state_id_table['icu_starttime'] = pd.to_datetime(state_id_table['icu_starttime'])
state_id_table['icu_endtime'] = pd.to_datetime(state_id_table['icu_endtime'])

In [72]:
# state_id_table['los_costs'] = 0
# r = 0

# for i in range(len(state_id_table)):

#     if state_id_table['discharge_action'].iloc[i] == 0:
#         # r = ((state_id_table['time'].iloc[i+1]) - (state_id_table['time'].iloc[i]))/pd.Timedelta(hours = 1)
#         r = 12.0 * (0.5 ** (state_id_table['readmission_count'].iloc[i]))
#         state_id_table['los_costs'].iloc[i] = r

#     else:
#         r = 0
#         state_id_table['los_costs'].iloc[i] = r

In [73]:
state_id_table['los_costs'] = 0.0

discharge_action_zero_mask = state_id_table['discharge_action'] == 0
state_id_table.loc[discharge_action_zero_mask, 'los_costs'] = 12.0 * (0.5 ** np.minimum(state_id_table.loc[discharge_action_zero_mask, 'readmission_count'], 4))

In [74]:
# # Calculate time differences, handling the last row appropriately
# state_id_table['time_diff'] = state_id_table['time'].diff().shift(-1)  # Shift to get next - current
# state_id_table['time_diff'] = state_id_table['time_diff'].fillna(pd.Timedelta(0))  # Handle last row

# # Calculate los_costs using vectorized operations
# state_id_table['los_costs'] = np.where(
#     state_id_table['discharge_action'] == 0,
#     state_id_table['time_diff'] / pd.Timedelta(hours=1),
#     0
# )

In [75]:
state_id_table[['stay_id', 'time', 'epoch', 'discharge_action', 'discharge_fail', 'mortality_costs', 'discharge_fail_costs', 'los_costs']].head(50)

Unnamed: 0,stay_id,time,epoch,discharge_action,discharge_fail,mortality_costs,discharge_fail_costs,los_costs
0,39553978,2180-07-23 23:50:47,1,1,0,0,0,0.0
1,37081114,2150-11-03 07:37:00,1,0,0,0,0,12.0
2,37081114,2150-11-03 19:37:00,2,0,0,0,0,12.0
3,37081114,2150-11-04 07:37:00,3,0,0,0,0,12.0
4,37081114,2150-11-04 19:37:00,4,0,0,0,0,12.0
5,37081114,2150-11-05 07:37:00,5,0,0,0,0,12.0
6,37081114,2150-11-05 19:37:00,6,0,0,0,0,12.0
7,37081114,2150-11-06 07:37:00,7,0,0,0,0,12.0
8,37081114,2150-11-06 17:03:17,8,1,0,0,0,0.0
9,39765666,2189-06-27 20:38:27,1,1,0,0,0,0.0


In [76]:
state_id_table[['los_costs']].describe().T

Unnamed: 0,count,mean,std,min,25%,50%,75%,max
los_costs,529907.0,7.972219,4.797476,0.0,3.0,12.0,12.0,12.0


In [77]:
Counter(state_id_table['los_costs'])

Counter({12.0: 292531,
         6.0: 92902,
         0.0: 73180,
         3.0: 39614,
         1.5: 18865,
         0.75: 12815})

In [78]:
scaler = MinMaxScaler()
state_id_table['los_costs_scaled'] = 0
state_id_table[['los_costs_scaled']] = scaler.fit_transform(state_id_table[['los_costs']])

In [79]:
state_id_table[['los_costs_scaled']].describe().T

Unnamed: 0,count,mean,std,min,25%,50%,75%,max
los_costs_scaled,529907.0,0.664352,0.39979,0.0,0.25,1.0,1.0,1.0


In [80]:
Counter(state_id_table['los_costs_scaled'])

Counter({1.0: 292531,
         0.5: 92902,
         0.0: 73180,
         0.25: 39614,
         0.125: 18865,
         0.0625: 12815})

- Identify safe action space

In [81]:
Counter(state_id_table['qSOFA'])

Counter({1: 249014, 0: 196028, 2: 77631, 3: 7234})

In [82]:
stats = state_id_table.groupby('qSOFA')['discharge_action'].agg([
    'count',
    'mean',
    'std',
    'median',
    'min',
    'max'
]).round(2)

In [83]:
print(stats)

        count  mean   std  median  min  max
qSOFA                                      
0      196028  0.22  0.42     0.0    0    1
1      249014  0.10  0.30     0.0    0    1
2       77631  0.06  0.24     0.0    0    1
3        7234  0.06  0.23     0.0    0    1


In [84]:
stats = state_id_table.groupby('discharge_action')['qSOFA'].agg([
    'count',
    'mean',
    'std',
    'median',
    'min',
    'max'
]).round(2)

In [85]:
print(stats)

                   count  mean   std  median  min  max
discharge_action                                      
0                 456727  0.85  0.73     1.0    0    3
1                  73180  0.49  0.65     0.0    0    3


In [86]:
stats = state_id_table.groupby('qSOFA')['readmission_count'].agg([
    'count',
    'mean',
    'std',
    'median',
    'min',
    'max'
]).round(2)

In [87]:
print(stats)

        count  mean   std  median  min  max
qSOFA                                      
0      196028  0.45  0.87     0.0    0    5
1      249014  0.53  0.96     0.0    0    5
2       77631  0.85  1.19     0.0    0    5
3        7234  1.30  1.31     1.0    0    5


In [88]:
stats = state_id_table.groupby('readmission_count')['qSOFA'].agg([
    'count',
    'mean',
    'std',
    'median',
    'min',
    'max'
]).round(2)

In [89]:
print(stats)

                    count  mean   std  median  min  max
readmission_count                                      
0                  355672  0.73  0.69     1.0    0    3
1                  101001  0.89  0.77     1.0    0    3
2                   41104  0.98  0.80     1.0    0    3
3                   19197  1.02  0.81     1.0    0    3
4                    9398  1.13  0.84     1.0    0    3
5                    3535  1.18  0.83     1.0    0    3


In [90]:
safe_condition = (state_id_table['qSOFA'] == 0) | (state_id_table['qSOFA'] == 1)
unsafe_condition = (state_id_table['qSOFA'] == 2) | (state_id_table['qSOFA'] == 3)

state_id_table.loc[safe_condition, 'safe_action'] = 1.0
state_id_table.loc[unsafe_condition, 'safe_action'] = 0.0

In [91]:
# condition = (state_id_table['qSOFA'] <= 1)

# state_id_table['qSOFA_greedy_action'] = 0.0
# state_id_table.loc[condition, 'qSOFA_greedy_action'] = 1.0

In [92]:
state_id_table

Unnamed: 0,subject_id,hadm_id,stay_id,time,discharge_action,epoch,icu_starttime,icu_endtime,los,discharge_fail,...,readmission_count,death_in_ICU,death_out_ICU,qSOFA,death,mortality_costs,discharge_fail_costs,los_costs,los_costs_scaled,safe_action
0,10000032,29079034,39553978,2180-07-23 23:50:47,1,1,2180-07-23 14:00:00,2180-07-23 23:50:47,0.410266,0,...,0,0,0,2,0,0,0,0.0,0.0,0.0
1,10000690,25860671,37081114,2150-11-03 07:37:00,0,1,2150-11-02 19:37:00,2150-11-06 17:03:17,3.893252,0,...,0,0,0,0,0,0,0,12.0,1.0,1.0
2,10000690,25860671,37081114,2150-11-03 19:37:00,0,2,2150-11-02 19:37:00,2150-11-06 17:03:17,3.893252,0,...,0,0,0,1,0,0,0,12.0,1.0,1.0
3,10000690,25860671,37081114,2150-11-04 07:37:00,0,3,2150-11-02 19:37:00,2150-11-06 17:03:17,3.893252,0,...,0,0,0,1,0,0,0,12.0,1.0,1.0
4,10000690,25860671,37081114,2150-11-04 19:37:00,0,4,2150-11-02 19:37:00,2150-11-06 17:03:17,3.893252,0,...,0,0,0,0,0,0,0,12.0,1.0,1.0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
529902,19999828,25744818,36075953,2149-01-10 13:11:02,1,4,2149-01-08 18:12:00,2149-01-10 13:11:02,1.790995,0,...,0,0,0,0,0,0,0,0.0,0.0,1.0
529903,19999987,23865745,36195440,2145-11-03 10:59:00,0,1,2145-11-02 22:59:00,2145-11-04 21:29:30,1.937847,0,...,0,0,0,1,0,0,0,12.0,1.0,1.0
529904,19999987,23865745,36195440,2145-11-03 22:59:00,0,2,2145-11-02 22:59:00,2145-11-04 21:29:30,1.937847,0,...,0,0,0,1,0,0,0,12.0,1.0,1.0
529905,19999987,23865745,36195440,2145-11-04 10:59:00,0,3,2145-11-02 22:59:00,2145-11-04 21:29:30,1.937847,0,...,0,0,0,1,0,0,0,12.0,1.0,1.0


In [93]:
rl_cont_state_table.info()

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 529907 entries, 0 to 529906
Data columns (total 37 columns):
 #   Column                    Non-Null Count   Dtype  
---  ------                    --------------   -----  
 0   age                       529907 non-null  int64  
 1   M                         529907 non-null  int64  
 2   weight                    529907 non-null  float64
 3   Heart Rate                529907 non-null  float64
 4   Arterial O2 pressure      529907 non-null  float64
 5   Hemoglobin                529907 non-null  float64
 6   Arterial CO2 Pressure     529907 non-null  float64
 7   PH (Venous)               529907 non-null  float64
 8   Hematocrit (serum)        529907 non-null  float64
 9   WBC                       529907 non-null  float64
 10  Chloride (serum)          529907 non-null  float64
 11  Creatinine (serum)        529907 non-null  float64
 12  Glucose (serum)           529907 non-null  float64
 13  Magnesium                 529907 non-null  f

In [94]:
rl_cont_state_table.rename(columns = {'RR': 'Respiratory Rate', 'TV': 'Tidal Volume'}, inplace = True)

In [95]:
rl_cont_state_table['age'] = rl_cont_state_table['age'].astype(float)
rl_cont_state_table['M'] = rl_cont_state_table['M'].astype(float)
rl_cont_state_table['readmission_count'] = rl_cont_state_table['readmission_count'].astype(float)

In [96]:
rl_cont_state_table.info()

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 529907 entries, 0 to 529906
Data columns (total 37 columns):
 #   Column                    Non-Null Count   Dtype  
---  ------                    --------------   -----  
 0   age                       529907 non-null  float64
 1   M                         529907 non-null  float64
 2   weight                    529907 non-null  float64
 3   Heart Rate                529907 non-null  float64
 4   Arterial O2 pressure      529907 non-null  float64
 5   Hemoglobin                529907 non-null  float64
 6   Arterial CO2 Pressure     529907 non-null  float64
 7   PH (Venous)               529907 non-null  float64
 8   Hematocrit (serum)        529907 non-null  float64
 9   WBC                       529907 non-null  float64
 10  Chloride (serum)          529907 non-null  float64
 11  Creatinine (serum)        529907 non-null  float64
 12  Glucose (serum)           529907 non-null  float64
 13  Magnesium                 529907 non-null  f

In [97]:
state_id_table['discharge_action'] = state_id_table['discharge_action'].astype(float)
state_id_table['discharge_fail'] = state_id_table['discharge_fail'].astype(float)
state_id_table['mortality_costs'] = state_id_table['mortality_costs'].astype(float)
state_id_table['discharge_fail_costs'] = state_id_table['discharge_fail_costs'].astype(float)
state_id_table['los_costs'] = state_id_table['los_costs'].astype(float)
state_id_table['los_costs_scaled'] = state_id_table['los_costs_scaled'].astype(float)

In [98]:
state_id_table.info()

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 529907 entries, 0 to 529906
Data columns (total 21 columns):
 #   Column                Non-Null Count   Dtype         
---  ------                --------------   -----         
 0   subject_id            529907 non-null  int64         
 1   hadm_id               529907 non-null  int64         
 2   stay_id               529907 non-null  int64         
 3   time                  529907 non-null  datetime64[ns]
 4   discharge_action      529907 non-null  float64       
 5   epoch                 529907 non-null  int64         
 6   icu_starttime         529907 non-null  datetime64[ns]
 7   icu_endtime           529907 non-null  datetime64[ns]
 8   los                   529907 non-null  float64       
 9   discharge_fail        529907 non-null  float64       
 10  readmission           529907 non-null  int64         
 11  readmission_count     529907 non-null  int64         
 12  death_in_ICU          529907 non-null  int64         
 13 

In [99]:
# scaler = StandardScaler()
scaler = MinMaxScaler()

In [100]:
var_list = rl_cont_state_table.columns.tolist()

In [101]:
rl_cont_state_table_scaled = rl_cont_state_table.copy()
rl_cont_state_table_scaled[var_list] = scaler.fit_transform(rl_cont_state_table[var_list])
rl_cont_state_table_scaled['M'] = rl_cont_state_table['M'].copy()
rl_cont_state_table_scaled['readmission_count_original'] = rl_cont_state_table['readmission_count'].copy()

In [102]:
rl_cont_state_table_scaled.info()

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 529907 entries, 0 to 529906
Data columns (total 38 columns):
 #   Column                      Non-Null Count   Dtype  
---  ------                      --------------   -----  
 0   age                         529907 non-null  float64
 1   M                           529907 non-null  float64
 2   weight                      529907 non-null  float64
 3   Heart Rate                  529907 non-null  float64
 4   Arterial O2 pressure        529907 non-null  float64
 5   Hemoglobin                  529907 non-null  float64
 6   Arterial CO2 Pressure       529907 non-null  float64
 7   PH (Venous)                 529907 non-null  float64
 8   Hematocrit (serum)          529907 non-null  float64
 9   WBC                         529907 non-null  float64
 10  Chloride (serum)            529907 non-null  float64
 11  Creatinine (serum)          529907 non-null  float64
 12  Glucose (serum)             529907 non-null  float64
 13  Magnesium     

In [103]:
rl_cont_state_table_scaled

Unnamed: 0,age,M,weight,Heart Rate,Arterial O2 pressure,Hemoglobin,Arterial CO2 Pressure,PH (Venous),Hematocrit (serum),WBC,...,Blood Pressure Systolic,Blood Pressure Diastolic,Blood Pressure Mean,Temperature C,SaO2,GCS score,Respiratory Rate,Tidal Volume,readmission_count,readmission_count_original
0,0.465753,0.0,0.140851,0.608696,0.388503,0.456281,0.395322,0.483544,0.469474,0.505747,...,0.215960,0.380167,0.264902,0.636311,0.491650,0.972222,0.117709,0.549241,0.0,0.0
1,0.931507,0.0,0.277019,0.405100,0.530778,0.502513,0.541053,0.541772,0.484211,0.285441,...,0.434746,0.321214,0.318746,0.301226,0.518596,1.000000,0.184120,0.603056,0.0,0.0
2,0.931507,0.0,0.277019,0.451087,0.275928,0.502513,0.513450,0.481013,0.484211,0.285441,...,0.397340,0.333974,0.298427,0.472855,0.201942,0.833333,0.168213,0.493565,0.0,0.0
3,0.931507,0.0,0.278251,0.307971,0.622549,0.552764,0.559298,0.703797,0.561404,0.231801,...,0.508728,0.354391,0.351763,0.170461,0.551456,0.833333,0.156283,0.395556,0.0,0.0
4,0.931507,0.0,0.277019,0.448370,0.637126,0.552764,0.585965,0.521519,0.561404,0.231801,...,0.708229,0.487099,0.535899,0.186807,0.458900,1.000000,0.293478,0.342778,0.0,0.0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
529902,0.383562,0.0,0.386633,0.480978,0.577246,0.623116,0.381287,0.455696,0.701754,0.442529,...,0.543641,0.895432,0.808928,0.497373,0.658252,1.000000,0.092656,0.684213,0.0,0.0
529903,0.534247,0.0,0.564689,0.568980,0.493413,0.874372,0.573099,0.556962,0.880702,0.492337,...,0.425603,0.669573,0.540979,0.799766,0.768335,0.402778,0.389224,0.531667,0.0,0.0
529904,0.534247,0.0,0.612858,0.538880,0.493413,0.874372,0.573099,0.577215,0.880702,0.492337,...,0.394371,0.551448,0.486736,0.677175,0.571173,0.472222,0.191410,0.765278,0.0,0.0
529905,0.534247,0.0,0.318679,0.769022,0.277844,0.814070,0.385965,0.632911,0.824561,0.442529,...,0.508955,0.648692,0.547097,0.797723,0.508738,0.819444,0.191410,0.765278,0.0,0.0


In [104]:
Counter(rl_cont_state_table_scaled['M'])

Counter({1.0: 304934, 0.0: 224973})

In [105]:
Counter(rl_cont_state_table_scaled['readmission_count'])

Counter({0.0: 355672,
         0.2: 101001,
         0.4: 41104,
         0.6000000000000001: 19197,
         0.8: 9398,
         1.0: 3535})

In [106]:
condition = (rl_cont_state_table_scaled['readmission_count'] == 0.6000000000000001)

rl_cont_state_table_scaled.loc[condition, 'readmission_count'] = 0.6

In [107]:
Counter(rl_cont_state_table_scaled['readmission_count'])

Counter({0.0: 355672,
         0.2: 101001,
         0.4: 41104,
         0.6: 19197,
         0.8: 9398,
         1.0: 3535})

In [108]:
Counter(rl_cont_state_table_scaled['readmission_count_original'])

Counter({0.0: 355672,
         1.0: 101001,
         2.0: 41104,
         3.0: 19197,
         4.0: 9398,
         5.0: 3535})

In [109]:
# train_subject_id, test_subject_id = train_test_split(pd.unique(state_id_table['subject_id']), test_size = 0.15, random_state = 1000)

In [111]:
train_subject_id, temp_subject_id = train_test_split(
    pd.unique(state_id_table['subject_id']), 
    test_size = 0.25,  
    random_state = 5
)

val_subject_id, test_subject_id = train_test_split(
    temp_subject_id,
    test_size = 0.50, 
    random_state = 5
)

In [112]:
# [len(train_subject_id), len(test_subject_id)]

In [113]:
[len(train_subject_id), len(val_subject_id), len(test_subject_id)]

[38867, 6478, 6478]

In [114]:
[len(pd.unique(state_id_table['subject_id'])), 
 len(pd.unique(state_id_table['stay_id']))]

[51823, 73180]

In [115]:
state_id_table

Unnamed: 0,subject_id,hadm_id,stay_id,time,discharge_action,epoch,icu_starttime,icu_endtime,los,discharge_fail,...,readmission_count,death_in_ICU,death_out_ICU,qSOFA,death,mortality_costs,discharge_fail_costs,los_costs,los_costs_scaled,safe_action
0,10000032,29079034,39553978,2180-07-23 23:50:47,1.0,1,2180-07-23 14:00:00,2180-07-23 23:50:47,0.410266,0.0,...,0,0,0,2,0,0.0,0.0,0.0,0.0,0.0
1,10000690,25860671,37081114,2150-11-03 07:37:00,0.0,1,2150-11-02 19:37:00,2150-11-06 17:03:17,3.893252,0.0,...,0,0,0,0,0,0.0,0.0,12.0,1.0,1.0
2,10000690,25860671,37081114,2150-11-03 19:37:00,0.0,2,2150-11-02 19:37:00,2150-11-06 17:03:17,3.893252,0.0,...,0,0,0,1,0,0.0,0.0,12.0,1.0,1.0
3,10000690,25860671,37081114,2150-11-04 07:37:00,0.0,3,2150-11-02 19:37:00,2150-11-06 17:03:17,3.893252,0.0,...,0,0,0,1,0,0.0,0.0,12.0,1.0,1.0
4,10000690,25860671,37081114,2150-11-04 19:37:00,0.0,4,2150-11-02 19:37:00,2150-11-06 17:03:17,3.893252,0.0,...,0,0,0,0,0,0.0,0.0,12.0,1.0,1.0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
529902,19999828,25744818,36075953,2149-01-10 13:11:02,1.0,4,2149-01-08 18:12:00,2149-01-10 13:11:02,1.790995,0.0,...,0,0,0,0,0,0.0,0.0,0.0,0.0,1.0
529903,19999987,23865745,36195440,2145-11-03 10:59:00,0.0,1,2145-11-02 22:59:00,2145-11-04 21:29:30,1.937847,0.0,...,0,0,0,1,0,0.0,0.0,12.0,1.0,1.0
529904,19999987,23865745,36195440,2145-11-03 22:59:00,0.0,2,2145-11-02 22:59:00,2145-11-04 21:29:30,1.937847,0.0,...,0,0,0,1,0,0.0,0.0,12.0,1.0,1.0
529905,19999987,23865745,36195440,2145-11-04 10:59:00,0.0,3,2145-11-02 22:59:00,2145-11-04 21:29:30,1.937847,0.0,...,0,0,0,1,0,0.0,0.0,12.0,1.0,1.0


In [116]:
id_table_train = state_id_table[state_id_table['subject_id'].isin(train_subject_id.tolist())].copy()
mv_train_index = id_table_train.index
id_index_list = mv_train_index.tolist()

rl_table_train = rl_cont_state_table.loc[id_index_list].copy()
rl_table_train_scaled = rl_cont_state_table_scaled.loc[id_index_list].copy()

In [117]:
id_table_val = state_id_table[state_id_table['subject_id'].isin(val_subject_id.tolist())].copy()
mv_val_index = id_table_val.index
id_index_list = mv_val_index.tolist()

rl_table_val = rl_cont_state_table.loc[id_index_list].copy()
rl_table_val_scaled = rl_cont_state_table_scaled.loc[id_index_list].copy()

In [118]:
id_table_test = state_id_table[state_id_table['subject_id'].isin(test_subject_id.tolist())].copy()
mv_test_index = id_table_test.index
id_index_list = mv_test_index.tolist()

rl_table_test = rl_cont_state_table.loc[id_index_list].copy()
rl_table_test_scaled = rl_cont_state_table_scaled.loc[id_index_list].copy()

In [119]:
id_table_train

Unnamed: 0,subject_id,hadm_id,stay_id,time,discharge_action,epoch,icu_starttime,icu_endtime,los,discharge_fail,...,readmission_count,death_in_ICU,death_out_ICU,qSOFA,death,mortality_costs,discharge_fail_costs,los_costs,los_costs_scaled,safe_action
0,10000032,29079034,39553978,2180-07-23 23:50:47,1.0,1,2180-07-23 14:00:00,2180-07-23 23:50:47,0.410266,0.0,...,0,0,0,2,0,0.0,0.0,0.0,0.0,0.0
1,10000690,25860671,37081114,2150-11-03 07:37:00,0.0,1,2150-11-02 19:37:00,2150-11-06 17:03:17,3.893252,0.0,...,0,0,0,0,0,0.0,0.0,12.0,1.0,1.0
2,10000690,25860671,37081114,2150-11-03 19:37:00,0.0,2,2150-11-02 19:37:00,2150-11-06 17:03:17,3.893252,0.0,...,0,0,0,1,0,0.0,0.0,12.0,1.0,1.0
3,10000690,25860671,37081114,2150-11-04 07:37:00,0.0,3,2150-11-02 19:37:00,2150-11-06 17:03:17,3.893252,0.0,...,0,0,0,1,0,0.0,0.0,12.0,1.0,1.0
4,10000690,25860671,37081114,2150-11-04 19:37:00,0.0,4,2150-11-02 19:37:00,2150-11-06 17:03:17,3.893252,0.0,...,0,0,0,0,0,0.0,0.0,12.0,1.0,1.0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
529898,19999625,25304202,31070865,2139-10-11 18:21:28,1.0,2,2139-10-10 19:18:00,2139-10-11 18:21:28,0.960741,0.0,...,0,0,0,1,0,0.0,0.0,0.0,0.0,1.0
529899,19999828,25744818,36075953,2149-01-09 06:12:00,0.0,1,2149-01-08 18:12:00,2149-01-10 13:11:02,1.790995,0.0,...,0,0,0,1,0,0.0,0.0,12.0,1.0,1.0
529900,19999828,25744818,36075953,2149-01-09 18:12:00,0.0,2,2149-01-08 18:12:00,2149-01-10 13:11:02,1.790995,0.0,...,0,0,0,0,0,0.0,0.0,12.0,1.0,1.0
529901,19999828,25744818,36075953,2149-01-10 06:12:00,0.0,3,2149-01-08 18:12:00,2149-01-10 13:11:02,1.790995,0.0,...,0,0,0,0,0,0.0,0.0,12.0,1.0,1.0


In [120]:
rl_table_train

Unnamed: 0,age,M,weight,Heart Rate,Arterial O2 pressure,Hemoglobin,Arterial CO2 Pressure,PH (Venous),Hematocrit (serum),WBC,...,INR,Blood Pressure Systolic,Blood Pressure Diastolic,Blood Pressure Mean,Temperature C,SaO2,GCS score,Respiratory Rate,Tidal Volume,readmission_count
0,52.0,0.0,39.326426,96.500000,92.100000,9.04,36.400000,7.361,28.08,13.25,...,1.34,88.900000,54.100000,62.300000,37.203704,96.300000,14.666667,5.96,0.000485,0.0
1,86.0,0.0,55.156787,77.769231,121.800000,9.50,42.630000,7.384,28.50,7.50,...,1.25,110.833333,50.250000,65.833333,36.444444,96.615385,15.000000,9.30,0.000524,0.0
2,86.0,0.0,55.156787,82.000000,68.600000,9.50,41.450000,7.360,28.50,7.50,...,1.45,107.083333,51.083333,64.500000,36.833333,92.909091,13.000000,8.50,0.000445,0.0
3,86.0,0.0,55.300000,68.833333,140.957143,10.00,43.410000,7.448,30.70,6.10,...,1.59,118.250000,52.416667,68.000000,36.148148,97.000000,13.000000,7.90,0.000375,0.0
4,86.0,0.0,55.156787,81.750000,144.000000,10.00,44.550000,7.376,30.70,6.10,...,1.35,138.250000,61.083333,80.083333,36.185185,95.916667,15.000000,14.80,0.000337,0.0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
529898,81.0,1.0,50.394071,68.400000,92.933333,10.40,37.533333,7.298,30.00,7.70,...,1.20,112.100000,55.300000,69.000000,37.037037,99.400000,11.000000,7.55,0.000425,0.0
529899,46.0,0.0,67.900000,103.000000,116.600000,11.50,39.600000,7.360,36.80,13.10,...,1.20,104.357143,75.000000,84.714286,36.986111,97.846154,14.000000,4.00,0.000514,0.0
529900,46.0,0.0,67.900000,97.083333,126.800000,11.50,33.600000,7.350,36.80,13.10,...,1.20,106.833333,73.666667,84.583333,36.666667,99.416667,15.000000,12.30,0.000394,0.0
529901,46.0,0.0,67.700000,86.833333,100.000000,10.70,38.400000,7.350,34.70,11.60,...,1.20,109.333333,74.083333,85.166667,36.805556,97.416667,15.000000,16.10,0.000433,0.0


In [121]:
id_table_test

Unnamed: 0,subject_id,hadm_id,stay_id,time,discharge_action,epoch,icu_starttime,icu_endtime,los,discharge_fail,...,readmission_count,death_in_ICU,death_out_ICU,qSOFA,death,mortality_costs,discharge_fail_costs,los_costs,los_costs_scaled,safe_action
20,10002013,23581541,39060235,2160-05-18 22:00:53,0.0,1,2160-05-18 10:00:53,2160-05-19 17:33:33,1.314352,0.0,...,0,0,0,1,0,0.0,0.0,12.0,1.0,1.0
21,10002013,23581541,39060235,2160-05-19 10:00:53,0.0,2,2160-05-18 10:00:53,2160-05-19 17:33:33,1.314352,0.0,...,0,0,0,0,0,0.0,0.0,12.0,1.0,1.0
22,10002013,23581541,39060235,2160-05-19 17:33:33,1.0,3,2160-05-18 10:00:53,2160-05-19 17:33:33,1.314352,0.0,...,0,0,0,0,0,0.0,0.0,0.0,0.0,1.0
336,10004401,29988601,32773003,2144-01-27 10:28:04,0.0,1,2144-01-26 22:28:04,2144-02-06 13:44:15,10.636238,1.0,...,0,0,0,1,0,0.0,0.0,12.0,1.0,1.0
337,10004401,29988601,32773003,2144-01-27 22:28:04,0.0,2,2144-01-26 22:28:04,2144-02-06 13:44:15,10.636238,1.0,...,0,0,0,1,0,0.0,0.0,12.0,1.0,1.0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
529640,19997293,28847872,31877557,2124-01-02 01:48:04,0.0,8,2123-12-29 01:48:04,2124-01-03 23:16:34,5.894792,0.0,...,0,0,0,1,0,0.0,0.0,12.0,1.0,1.0
529641,19997293,28847872,31877557,2124-01-02 13:48:04,0.0,9,2123-12-29 01:48:04,2124-01-03 23:16:34,5.894792,0.0,...,0,0,0,1,0,0.0,0.0,12.0,1.0,1.0
529642,19997293,28847872,31877557,2124-01-03 01:48:04,0.0,10,2123-12-29 01:48:04,2124-01-03 23:16:34,5.894792,0.0,...,0,0,0,1,0,0.0,0.0,12.0,1.0,1.0
529643,19997293,28847872,31877557,2124-01-03 13:48:04,0.0,11,2123-12-29 01:48:04,2124-01-03 23:16:34,5.894792,0.0,...,0,0,0,1,0,0.0,0.0,12.0,1.0,1.0


In [122]:
rl_table_test

Unnamed: 0,age,M,weight,Heart Rate,Arterial O2 pressure,Hemoglobin,Arterial CO2 Pressure,PH (Venous),Hematocrit (serum),WBC,...,INR,Blood Pressure Systolic,Blood Pressure Diastolic,Blood Pressure Mean,Temperature C,SaO2,GCS score,Respiratory Rate,Tidal Volume,readmission_count
20,53.0,0.0,95.798630,95.555556,156.128571,11.233333,45.000000,7.370167,30.5,19.200000,...,1.1,116.666667,68.222222,84.777778,37.100000,97.333333,9.333333,17.222222,0.000534,0.0
21,53.0,0.0,104.100000,94.000000,124.822857,10.900000,45.000000,7.382000,31.2,17.900000,...,1.1,106.083333,56.416667,70.090909,37.571429,96.636364,15.000000,14.083333,0.000572,0.0
22,53.0,0.0,96.000000,92.666667,168.280000,10.900000,45.000000,7.380000,31.2,17.900000,...,1.1,112.500000,63.500000,75.000000,36.555556,95.666667,15.000000,14.333333,0.000574,0.0
336,82.0,1.0,75.840582,70.400000,86.666667,8.433333,42.333333,7.340000,25.3,23.266667,...,1.4,96.400000,37.000000,50.600000,36.788889,95.000000,15.000000,18.000000,0.000474,0.0
337,82.0,1.0,76.000000,78.916667,111.000000,8.500000,46.600000,7.340000,25.9,23.050000,...,1.4,111.000000,40.916667,59.333333,36.259259,97.500000,11.000000,15.000000,0.000412,0.0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
529640,76.0,1.0,107.300000,82.000000,110.500000,9.300000,39.500000,7.345000,28.9,7.600000,...,1.3,108.166667,75.750000,83.583333,36.537037,99.000000,14.000000,14.500000,0.000545,0.0
529641,76.0,1.0,107.300000,81.750000,110.500000,10.100000,39.500000,7.345000,31.5,11.000000,...,1.3,109.000000,68.333333,77.000000,36.361111,98.000000,14.000000,14.500000,0.000545,0.0
529642,76.0,1.0,107.300000,82.090909,110.500000,10.100000,39.500000,7.345000,31.5,11.000000,...,1.3,109.333333,69.250000,78.916667,36.592593,98.250000,14.000000,14.500000,0.000545,0.0
529643,76.0,1.0,107.093071,82.250000,110.500000,9.800000,39.500000,7.345000,31.1,12.300000,...,1.3,108.166667,74.250000,82.416667,36.481481,97.909091,14.000000,14.500000,0.000545,0.0


- Divide the dataset according to the discharge outcome (success/fail).

In [123]:
# # state_id_cont = state_id_table[state_id_table['discharge_action'] == 0]
# # state_id_disch = state_id_table[state_id_table['discharge_action'] == 1]

# state_id_success = state_id_table[state_id_table['discharge_fail'] == 0]
# state_id_failure = state_id_table[state_id_table['discharge_fail'] == 1]

In [124]:
# id_table_success = state_id_success[state_id_success['subject_id'].isin(train_subject_id.tolist())].copy()
# id_table_failure = state_id_failure[state_id_failure['subject_id'].isin(train_subject_id.tolist())].copy()

# mv_train_index_success = id_table_success.index
# mv_train_index_failure = id_table_failure.index

# id_index_list_success = mv_train_index_success.tolist()
# id_index_list_failure = mv_train_index_failure.tolist()

# rl_table_train_success = rl_cont_state_table.loc[id_index_list_success].copy()
# rl_table_train_failure = rl_cont_state_table.loc[id_index_list_failure].copy()

# rl_table_train_success_scaled = rl_cont_state_table_scaled.loc[id_index_list_success].copy()
# rl_table_train_failure_scaled = rl_cont_state_table_scaled.loc[id_index_list_failure].copy()

In [125]:
# id_table_success

In [126]:
# id_table_cont = state_id_cont[state_id_cont['subject_id'].isin(train_subject_id.tolist())].copy()
# id_table_disch = state_id_disch[state_id_disch['subject_id'].isin(train_subject_id.tolist())].copy()

# mv_train_index_cont = id_table_cont.index
# mv_train_index_disch = id_table_disch.index

# id_index_list_cont = mv_train_index_cont.tolist()
# id_index_list_disch = mv_train_index_disch.tolist()

# rl_table_train_cont = rl_cont_state_table.loc[id_index_list_cont].copy()
# rl_table_train_disch = rl_cont_state_table.loc[id_index_list_disch].copy()

# rl_table_train_cont_scaled = rl_cont_state_table_scaled.loc[id_index_list_cont].copy()
# rl_table_train_disch_scaled = rl_cont_state_table_scaled.loc[id_index_list_disch].copy()

- Data output

In [127]:
# state_id_table.to_csv('../data_output/state_id_table_v5.csv', index = False)
# rl_cont_state_table.to_csv('../data_output/rl_cont_state_table_v5.csv', index = False)
# rl_cont_state_table_scaled.to_csv('../data_output/rl_cont_state_table_scaled_v5.csv', index = False)

# id_table_train.to_csv('../data_output/id_table_train_v5.csv', index = False)
# rl_table_train.to_csv('../data_output/rl_table_train_v5.csv', index = False)
# rl_table_train_scaled.to_csv('../data_output/rl_table_train_scaled_v5.csv', index = False)

# id_table_cont.to_csv('../data_output/id_table_cont_v5.csv', index = False)
# rl_table_train_cont.to_csv('../data_output/rl_table_train_cont_v5.csv', index = False)
# rl_table_train_cont_scaled.to_csv('../data_output/rl_table_train_cont_scaled_v5.csv', index = False)

# id_table_disch.to_csv('../data_output/id_table_disch_v5.csv', index = False)
# rl_table_train_disch.to_csv('../data_output/rl_table_train_disch_v5.csv', index = False)
# rl_table_train_disch_scaled.to_csv('../data_output/rl_table_train_disch_scaled_v5.csv', index = False)

In [128]:
# id_table_test.to_csv('../data_output/id_table_test_v5.csv', index = False)
# rl_table_test.to_csv('../data_output/rl_table_test_v5.csv', index = False)
# rl_table_test_scaled.to_csv('../data_output/rl_table_test_scaled_v5.csv', index = False)

In [129]:
state_id_table.to_csv('../data_output/state_id_table_v15.csv', index = False)
rl_cont_state_table.to_csv('../data_output/rl_cont_state_table_v15.csv', index = False)
rl_cont_state_table_scaled.to_csv('../data_output/rl_cont_state_table_scaled_v15.csv', index = False)

In [130]:
id_table_train.to_csv('../data_output/id_table_train_v15.csv', index = False)
rl_table_train.to_csv('../data_output/rl_table_train_v15.csv', index = False)
rl_table_train_scaled.to_csv('../data_output/rl_table_train_scaled_v15.csv', index = False)

In [131]:
id_table_val.to_csv('../data_output/id_table_val_v15.csv', index = False)
rl_table_val.to_csv('../data_output/rl_table_val_v15.csv', index = False)
rl_table_val_scaled.to_csv('../data_output/rl_table_val_scaled_v15.csv', index = False)

In [132]:
id_table_test.to_csv('../data_output/id_table_test_v15.csv', index = False)
rl_table_test.to_csv('../data_output/rl_table_test_v15.csv', index = False)
rl_table_test_scaled.to_csv('../data_output/rl_table_test_scaled_v15.csv', index = False)

In [134]:
rl_cont_state_table_scaled

Unnamed: 0,age,M,weight,Heart Rate,Arterial O2 pressure,Hemoglobin,Arterial CO2 Pressure,PH (Venous),Hematocrit (serum),WBC,...,Blood Pressure Systolic,Blood Pressure Diastolic,Blood Pressure Mean,Temperature C,SaO2,GCS score,Respiratory Rate,Tidal Volume,readmission_count,readmission_count_original
0,0.465753,0.0,0.140851,0.608696,0.388503,0.456281,0.395322,0.483544,0.469474,0.505747,...,0.215960,0.380167,0.264902,0.636311,0.491650,0.972222,0.117709,0.549241,0.0,0.0
1,0.931507,0.0,0.277019,0.405100,0.530778,0.502513,0.541053,0.541772,0.484211,0.285441,...,0.434746,0.321214,0.318746,0.301226,0.518596,1.000000,0.184120,0.603056,0.0,0.0
2,0.931507,0.0,0.277019,0.451087,0.275928,0.502513,0.513450,0.481013,0.484211,0.285441,...,0.397340,0.333974,0.298427,0.472855,0.201942,0.833333,0.168213,0.493565,0.0,0.0
3,0.931507,0.0,0.278251,0.307971,0.622549,0.552764,0.559298,0.703797,0.561404,0.231801,...,0.508728,0.354391,0.351763,0.170461,0.551456,0.833333,0.156283,0.395556,0.0,0.0
4,0.931507,0.0,0.277019,0.448370,0.637126,0.552764,0.585965,0.521519,0.561404,0.231801,...,0.708229,0.487099,0.535899,0.186807,0.458900,1.000000,0.293478,0.342778,0.0,0.0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
529902,0.383562,0.0,0.386633,0.480978,0.577246,0.623116,0.381287,0.455696,0.701754,0.442529,...,0.543641,0.895432,0.808928,0.497373,0.658252,1.000000,0.092656,0.684213,0.0,0.0
529903,0.534247,0.0,0.564689,0.568980,0.493413,0.874372,0.573099,0.556962,0.880702,0.492337,...,0.425603,0.669573,0.540979,0.799766,0.768335,0.402778,0.389224,0.531667,0.0,0.0
529904,0.534247,0.0,0.612858,0.538880,0.493413,0.874372,0.573099,0.577215,0.880702,0.492337,...,0.394371,0.551448,0.486736,0.677175,0.571173,0.472222,0.191410,0.765278,0.0,0.0
529905,0.534247,0.0,0.318679,0.769022,0.277844,0.814070,0.385965,0.632911,0.824561,0.442529,...,0.508955,0.648692,0.547097,0.797723,0.508738,0.819444,0.191410,0.765278,0.0,0.0


In [133]:
summary_stats = rl_cont_state_table.describe().T

summary_stats.insert(0, "Category", "Clinical Information")

latex_table = summary_stats.to_latex(
    index = True,
    columns = ["mean", "std", "min", "max"],
    header = ["Mean", "SD", "Min", "Max"],
    float_format = "%.6f",  
    column_format = "llcccc",  
    caption = "Summary statistics of the study samples.",
    label = "tab:summary_stats",
    longtable = False,
    escape = False  
)

print(latex_table)

\begin{table}
\centering
\caption{Summary statistics of the study samples.}
\label{tab:summary_stats}
\begin{tabular}{llcccc}
\toprule
{} &       Mean &        SD &        Min &        Max \\
\midrule
age                      &  63.557636 & 15.984381 &  18.000000 &  91.000000 \\
M                        &   0.575448 &  0.494275 &   0.000000 &   1.000000 \\
weight                   &  80.170375 & 19.881834 &  22.951755 & 139.207385 \\
Heart Rate               &  85.395203 & 16.069415 &  40.500000 & 132.500000 \\
Arterial O2 pressure     & 112.238637 & 33.211877 &  11.000000 & 219.750000 \\
Hemoglobin               &   9.779127 &  1.770941 &   4.500000 &  14.450000 \\
Arterial CO2 Pressure    &  40.578616 &  6.397797 &  19.500000 &  62.250000 \\
PH (Venous)              &   7.374831 &  0.052287 &   7.170000 &   7.565000 \\
Hematocrit (serum)       &  29.897038 &  5.064912 &  14.700000 &  43.200000 \\
WBC                      &  10.954266 &  4.822153 &   0.050000 &  26.150000 \\
Chloride 