In [1]:
import numpy as np
import pandas as pd
from collections import Counter
from sklearn.model_selection import train_test_split
import tensorflow as tf
from tensorflow.keras.callbacks import EarlyStopping
from tensorflow.keras import regularizers
tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR)
tf.autograph.set_verbosity(0)
from tqdm import tqdm
tqdm.pandas()
import warnings
warnings.filterwarnings("ignore", category=FutureWarning)
import os

# Diagnosis & Parameters

In [2]:
diagnosis = 'death'

In [3]:
parameters = {
    'dx_offset' : 60,
    'pos_examples' : 10000,
    'neg_examples' : 10000,
    'ratio' : .5,           #pos/all
    'seq_length' : 100,
    'vocab_size' : 1000,
    'window' : [60*4, 60*0],
    'baseline' : [0, 360],
}

print(parameters['window'])

data_path = '/Users/tobymanders/Documents/insight_project/data/'
image_path = '/Users/tobymanders/Documents/insight_project/figures/'

[240, 0]


# Functions

In [4]:
def save_fig(fig_name, tight_layout=True):
    path = os.path.join(image_path, fig_name + '.png')
    print("Saving figure", fig_name)
    if tight_layout:
        plt.tight_layout()
    plt.savefig(path, format='png', dpi=300)

In [5]:
def check_col_complete(dataframe):
    len_ = len(dataframe)

    incomplete = []
    for col in dataframe.columns:
        complete = dataframe[col].count() == len_
        if complete:
            print('{: >23}: COMPLETE'.format(col))
        else:
            missing = len_ - dataframe[col].count()
            incomplete.append(col)
            print('{: >23}: MISSING {} VALUES ({:.1f}%)'.format(col, missing, missing*100/len_))
    
    print('\nTotal number of items:', len_)

    return incomplete

In [6]:
def tokenize_text(x):
    tokenized = [word_to_ID[word] if word in word_to_ID else 
                                         word_to_ID['<unk>'] for word in x.split(';')]
    return tokenized
            

In [7]:
def pad_text(x):
    rem = parameters['seq_length'] - len(x)
    if rem > 0:
        pad = [word_to_ID['<unk>']]*abs(rem)
        x.extend(pad)
        return x
    else:
        return x[-parameters['seq_length']:]
    

# Load Corpus

In [8]:
dx = pd.read_csv(data_path + 'diagnosis.csv')
dx.sample(5)

Unnamed: 0,diagnosisid,patientunitstayid,activeupondischarge,diagnosisoffset,diagnosisstring,icd9code,diagnosispriority
1725766,38724593,3034620,False,54298,renal|electrolyte imbalance|hypophosphatemia,"275.3, E83.30",Other
669810,13945610,987893,False,841,pulmonary|disorders of acid base|respiratory a...,"276.2, E87.2",Other
2619217,45071904,3213367,False,3864,renal|disorder of kidney|acute renal failure,"584.9, N17.9",Other
397287,9937872,496091,False,1876,renal|disorder of acid base|metabolic acidosis...,"276.2, E87.2",Other
543547,11400434,737622,False,16053,neurologic|post-neurosurgery|post craniotomy,,Primary


In [9]:
dx[dx.diagnosisstring.str.contains('acute')].diagnosisstring.value_counts().nlargest(15)

pulmonary|respiratory failure|acute respiratory failure                                                      97836
renal|disorder of kidney|acute renal failure                                                                 65313
pulmonary|respiratory failure|acute respiratory distress                                                     26774
cardiovascular|chest pain / ASHD|acute coronary syndrome|acute myocardial infarction (no ST elevation)       15686
pulmonary|disorders of the airways|acute COPD exacerbation                                                   13088
hematology|bleeding and red blood cell disorders|anemia|acute blood loss anemia                              10727
cardiovascular|chest pain / ASHD|acute coronary syndrome                                                      9808
cardiovascular|ventricular disorders|acute pulmonary edema                                                    8177
cardiovascular|chest pain / ASHD|acute coronary syndrome|acute myocardial infarc

In [10]:
dx.diagnosisstring.value_counts().nlargest(15)

pulmonary|respiratory failure|acute respiratory failure            97836
renal|disorder of kidney|acute renal failure                       65313
endocrine|glucose metabolism|diabetes mellitus                     44491
neurologic|altered mental status / pain|change in mental status    41034
pulmonary|pulmonary infections|pneumonia                           39729
cardiovascular|ventricular disorders|congestive heart failure      37467
cardiovascular|vascular disorders|hypertension                     37328
cardiovascular|shock / hypotension|hypotension                     33766
pulmonary|respiratory failure|hypoxemia                            33515
cardiovascular|shock / hypotension|sepsis                          32509
hematology|bleeding and red blood cell disorders|anemia            31667
cardiovascular|arrhythmias|atrial fibrillation                     29475
pulmonary|respiratory failure|acute respiratory distress           26774
neurologic|altered mental status / pain|pain       

In [11]:
d = pd.read_csv(data_path + 'patient.csv')    
d = d[['patientunitstayid', 'unitdischargestatus', 'unitdischargeoffset', 'age']]
dx = d.merge(dx, on='patientunitstayid')
dx.head()

Unnamed: 0,patientunitstayid,unitdischargestatus,unitdischargeoffset,age,diagnosisid,activeupondischarge,diagnosisoffset,diagnosisstring,icd9code,diagnosispriority
0,141168,Expired,3596,70,4222318,False,72,cardiovascular|chest pain / ASHD|coronary arte...,"414.00, I25.10",Other
1,141168,Expired,3596,70,3370568,True,118,cardiovascular|ventricular disorders|cardiomyo...,,Other
2,141168,Expired,3596,70,4160941,False,72,pulmonary|disorders of the airways|COPD,"491.20, J44.9",Other
3,141168,Expired,3596,70,4103261,True,118,pulmonary|disorders of the airways|COPD,"491.20, J44.9",Other
4,141168,Expired,3596,70,3545241,True,118,cardiovascular|ventricular disorders|congestiv...,"428.0, I50.9",Other


In [12]:
# Remove patients with short stays and pts < 16 y/o
stay_min = parameters['window'][0] + parameters['baseline'][1]
dx = dx[dx.unitdischargeoffset>stay_min]
dx['age'] = dx.age.apply(lambda x: np.float(89) if x=='> 89' else np.float(x));
dx = dx[dx.age>15]

In [13]:
dx.sample(5)

Unnamed: 0,patientunitstayid,unitdischargestatus,unitdischargeoffset,age,diagnosisid,activeupondischarge,diagnosisoffset,diagnosisstring,icd9code,diagnosispriority
1607373,2873139,Alive,38955,42.0,30944381,False,28333,cardiovascular|ventricular disorders|congestiv...,"428.0, I50.9",Primary
2197306,3094887,Alive,12950,51.0,42885979,False,9636,neurologic|altered mental status / pain|pain,,Other
823082,1103855,Alive,13432,42.0,15314792,False,2340,pulmonary|pulmonary infections|pneumonia|commu...,"486, J18.9",Primary
1633545,2898361,Alive,3691,44.0,31190701,True,17,burns/trauma|trauma - CNS|spinal cord injury|c...,"952.00, S14.1",Primary
1795711,3043496,Alive,4410,58.0,37212486,False,1533,infectious diseases|chest/pulmonary infections...,"486, J18.9",Other


In [14]:
cols = ['patientunitstayid', 'diagnosisoffset']

if diagnosis=='death':
    dx_nonevents = dx[dx.unitdischargestatus=='Alive'].copy()[cols].sample(frac=1)
    dx_events = dx[dx.unitdischargestatus=='Expired'].copy()[cols]
    dx_events.sort_values(by='diagnosisoffset', inplace=True)
    dx_events.drop_duplicates('patientunitstayid', keep='last', inplace=True)
    dx_nonevents.drop_duplicates('patientunitstayid', keep='last', inplace=True)
else:
    dx_nonevents = dx[~dx.diagnosisstring.str.contains(diagnosis)].copy()[cols]
    dx_events = dx[dx.diagnosisstring.str.contains(diagnosis)].copy()[cols]
    dx_events.sort_values(by='diagnosisoffset', inplace=True)
    dx_events = dx_events[dx_events.diagnosisoffset>parameters['dx_offset']]
    dx_events.drop_duplicates('patientunitstayid', keep='first', inplace=True)
    dx_nonevents.drop_duplicates('patientunitstayid', keep='first', inplace=True)


In [15]:
print(f'Positive patients: {len(dx_events)}\nNegative patients: {len(dx_nonevents)}')

Positive patients: 8392
Negative patients: 154570


In [16]:
pos_samples = len(dx_events)
neg_samples = len(dx_nonevents)
print(f"Positive examples: {len(dx_events)}")
print(f"Negative examples: {len(dx_nonevents)}")

Positive examples: 8392
Negative examples: 154570


In [17]:
if pos_samples < parameters['pos_examples']:
    parameters['pos_examples'] = pos_samples
    
parameters['neg_examples'] = int((parameters['pos_examples'] - parameters['ratio'] * 
                                 parameters['pos_examples']) / parameters['ratio'])
if neg_samples < parameters['neg_examples']:
    parameters['neg_examples'] = neg_samples

In [18]:
# Sample patient events
# import pickle as pkl

# special_neg_list = pkl.load(open(data_path + 'saved_data/' + 'special_pts.pkl', 'rb'))

In [19]:
dx_events = dx_events.sample(parameters['pos_examples'])

# dx_nonevents_special = dx_nonevents[dx_nonevents.patientunitstayid.isin(special_neg_list)]

# dx_nonevents = pd.concat([dx_nonevents_special, dx_nonevents.sample(parameters['neg_examples'])])

dx_nonevents = dx_nonevents.sample(parameters['neg_examples'])


print(f'Events: {len(dx_events)}\nNonevents: {len(dx_nonevents)}')

Events: 8392
Nonevents: 8392


In [20]:
dx_nonevents.drop_duplicates(inplace=True)

In [21]:
# Create a dictionary of all events and timestamps
all_events = pd.concat([dx_events, dx_nonevents])

# Load Data and Extract Features

## PMH Diagnoses

In [22]:
pmh = pd.read_csv(data_path + 'pastHistory.csv')

In [23]:
pmh = pmh[['patientunitstayid', 'pasthistoryoffset', 'pasthistorypath', 'pasthistoryvalue', 'pasthistoryvaluetext']]
pmh = pmh[~(pmh['pasthistoryvalue'].isin(['Performed', 'No Health Problems', 'clinical diagnosis', '']))]
pmh['dx'] = pmh['pasthistoryvalue'].apply(lambda x: x.split('  -')[0].split(' -')[0]
                                          .split('-')[0].split(' requiring')[0])

In [24]:
print(f'Past medical diagnoses: {len(pmh)}\nUnique diagnoses: {len(pmh.dx.unique())}')

Past medical diagnoses: 846771
Unique diagnoses: 166


In [25]:
# TODO: Replace PMH diagnoses with ICD-9 Codes
pmh.dx.value_counts().nlargest(5)

hypertension                  156386
COPD                           47591
CHF                            47534
insulin dependent diabetes     42150
atrial fibrillation            37422
Name: dx, dtype: int64

## Treatments

In [26]:
treatment = pd.read_csv(data_path + 'treatment.csv')

In [27]:
treatment.drop(['treatmentid', 'activeupondischarge'], axis=1, inplace=True)

In [28]:
treatment['treatment'] = treatment['treatmentstring'].apply(lambda x: x.split('|')[-1])

In [29]:
most_common_tx = list(treatment.treatment.value_counts().nlargest(500).index)

In [30]:
treatment = treatment[treatment.treatment.isin(most_common_tx)]

In [31]:
treatment = treatment[treatment.patientunitstayid.isin(all_events.patientunitstayid)]

In [32]:
treatment.drop('treatmentstring', axis=1, inplace=True)

In [33]:
len(treatment.treatment.unique())

500

In [34]:
treatment.head()

Unnamed: 0,patientunitstayid,treatmentoffset,treatment
552,242544,12433,non-invasive ventilation
553,242544,5736,stress ulcer prophylaxis
554,242544,12433,compression stockings
555,242544,31,oxygen therapy (> 60%)
556,242544,31,vasodilator


## ICU Diagnoses

In [35]:
# TODO: REPLACE NANS WITH ACTUAL VALUES
dx.drop(['diagnosisid', 'activeupondischarge', 'diagnosispriority'], axis=1, inplace=True)
dx.drop_duplicates(inplace=True)

dx['icd9'] = dx['icd9code'].apply(lambda x: str(x).split(',')[0].split('.')[0])

dx.dropna(subset=['icd9code'], axis=0, inplace=True)

print(f'Total diagnoses: {len(dx)}')
print(f'Unique diagnoses: {len(dx.icd9.unique())}')

Total diagnoses: 2217731
Unique diagnoses: 398


In [36]:
# TODO: check medication dependent, procedural coronary i... category

dx_to_icd = {
    'hypertension' : '401',
    'CHF' : '428', 
    'COPD' : '496',
    'insulin dependent diabetes' : '250',
    'renal insufficiency' : '585',
    'atrial fibrillation' : '427',
    'MI' : '410',
    'medication dependent' : '304', #questionable
    'renal failure' :'586',
    'hypothyroidism' : '244',
    'asthma' : '493', 
    'peripheral vascular disease' : '443',
    'stroke' : '434',
    'procedural coronary intervention' : '036', #questionable
    'home oxygen' : '093', #questionable
    'CABG' : '414',
    'peptic ulcer disease' : '533',
    'dementia' : '294',
    'DVT' : '453',
    'respiratory failure' : '518',
    'mechanical ventilation' : '096',
    'pneumonia' : '486',
    'pneumonitis' : '507',
    'hypotension' : '458',
    'cardiovascular symptoms' : '785',
    'unknown cause' : '799',
    'sepsis' : '038',
    'adverse effects' : '995',
    'pleurisy' : '511',
    'gi hemorrhage' : '578',
    'chronic bronchitis' : '491', 
    'malnutrition' : '263'
}

In [37]:
treatment.columns

Index(['patientunitstayid', 'treatmentoffset', 'treatment'], dtype='object')

In [38]:
dx.columns

Index(['patientunitstayid', 'unitdischargestatus', 'unitdischargeoffset',
       'age', 'diagnosisoffset', 'diagnosisstring', 'icd9code', 'icd9'],
      dtype='object')

In [39]:
# Merge treatments with diagnoses
tx_merge = treatment.copy()
tx_merge['id'] = tx_merge.pop('patientunitstayid')
tx_merge['offset'] = tx_merge.pop('treatmentoffset')
tx_merge['string'] = tx_merge.pop('treatment')

dx_merge = dx.copy()[['patientunitstayid', 'diagnosisoffset', 'icd9']]
dx_merge['id'] = dx_merge.pop('patientunitstayid')
dx_merge['offset'] = dx_merge.pop('diagnosisoffset')
dx_merge['string'] = dx_merge.pop('icd9')

merge = pd.concat([tx_merge, dx_merge])
merge = merge[merge.id.isin(all_events.patientunitstayid)]

In [40]:
len(merge)

959889

In [41]:
if diagnosis=='death':
    last_offset = merge.sort_values(by='offset')
    last_offset.drop_duplicates(subset=['id'], keep='last', inplace=True)
    last_offset.set_index('id', inplace=True)
    all_events_offset = pd.concat([dx_events, dx_nonevents])[['diagnosisoffset', 'patientunitstayid']]
    last_offset2 = pd.concat([all_events_offset.set_index('patientunitstayid'), last_offset], axis=1)
    last_offset2['last'] = last_offset2[['diagnosisoffset', 'offset']].max(axis=1)
    dx_offset_dict = pd.Series(last_offset2['last'].values,index=last_offset2.index).to_dict()

    dx_events['diagnosisoffset'] = dx_events.patientunitstayid.apply(lambda x: dx_offset_dict[x])
    dx_nonevents['diagnosisoffset'] = dx_nonevents.patientunitstayid.apply(lambda x: dx_offset_dict[x])
    
    print('offsets updated.')

offsets updated.


In [42]:
# dx_offset_dict = pd.Series(all_events.diagnosisoffset.values,index=all_events.patientunitstayid).to_dict()

In [43]:
df = pd.DataFrame(columns=['patientunitstayid', 'dx_seq'])


for event_df in [dx_events, dx_nonevents]:
    for patientunitstayid in event_df.patientunitstayid:
            
        dx_offset = int(event_df[event_df['patientunitstayid']==patientunitstayid].diagnosisoffset - parameters['window'][1])
        
        pmh_seq = list(pmh[pmh['patientunitstayid'] == patientunitstayid].sort_values(by='pasthistoryoffset').dx)
        
        seq = list(merge[(merge['id'] == patientunitstayid) & (merge['offset'] <= dx_offset)]
                   .sort_values(by='offset').string)
                
        for i, d in enumerate(pmh_seq):
            try:
                pmh_seq[i] = dx_to_icd[d]
            except:
                continue
            
        df = df.append({'patientunitstayid' : patientunitstayid,
                        'dx_seq' : ';'.join(pmh_seq + seq)}, ignore_index=True)

In [44]:
tot_seq_len = []
for pt in df.dx_seq:
    tot_seq_len.append(len(pt.split(';')))

In [45]:
np.mean(tot_seq_len)

63.89948760724499

In [46]:
allcodes = ';'.join(df['dx_seq'].values).split(';')

In [47]:
counter = sorted(Counter(allcodes).items(), 
                 key=lambda x: -x[1])[:parameters['vocab_size']-1] # sort by frequency

# Remove 'words' appearing less than n times.
min_word_ct = 20
counter2 = []
for word, count in counter:
    if count > min_word_ct:
        counter2.append((word, count))
    else:
        pass

counter2.append(('<unk>', 1))

words, _ = list(zip(*counter2))
word_to_ID = dict(zip(words, range(len(words))))

In [None]:
# Save word_to_ID
import pickle as pkl
pkl.dump(word_to_ID, open('../data/samples/word_to_ID.pkl', 'wb'))

In [49]:
df['dx_seq_tok'] = df['dx_seq'].apply(tokenize_text).apply(pad_text)

In [50]:
np.array(list(df['dx_seq_tok']))

array([[ 74, 509,   1, ..., 905, 905, 905],
       [ 50,   2,  50, ..., 905, 905, 905],
       [273, 273, 273, ..., 905, 905, 905],
       ...,
       [ 37,   2,  37, ..., 905, 905, 905],
       [  2,   5,  17, ..., 905, 905, 905],
       [  2,  86,   9, ..., 905, 905, 905]])

## Periodic Vitals

In [51]:
# LOAD VITALS. THIS IS LONG
vitals1 = pd.read_csv(data_path + 'vitalPeriodic.csv')
vitals1.sample(5)

Unnamed: 0,vitalperiodicid,patientunitstayid,observationoffset,temperature,sao2,heartrate,respiration,cvp,etco2,systemicsystolic,systemicdiastolic,systemicmean,pasystolic,padiastolic,pamean,st1,st2,st3,icp
69547830,1110298806,1602245,15451,,100.0,77.0,18.0,,,92.0,59.0,69.0,,,,,,,
97412845,1522511204,2416950,5031,,93.0,85.0,19.0,,,,,,,,,0.3,-0.2,-0.6,
111378293,1691877517,2785347,31954,,100.0,108.0,17.0,,,,,,,,,,,,
18770031,287732002,482732,92,,96.0,70.0,11.0,,,,,,,,,-0.05,0.05,0.1,
141897911,2072067299,3237502,9425,36.3,100.0,64.0,16.0,,,,,,,,,0.2,0.6,0.3,


In [52]:
# # Reduce the size of the dataframe.
vitals1 = vitals1[vitals1['patientunitstayid'].isin(all_events.patientunitstayid)]

In [53]:
check_col_complete(vitals1);

        vitalperiodicid: COMPLETE
      patientunitstayid: COMPLETE
      observationoffset: COMPLETE
            temperature: MISSING 15985694 VALUES (86.9%)
                   sao2: MISSING 1567210 VALUES (8.5%)
              heartrate: MISSING 135811 VALUES (0.7%)
            respiration: MISSING 2113480 VALUES (11.5%)
                    cvp: MISSING 14739888 VALUES (80.1%)
                  etco2: MISSING 17419084 VALUES (94.7%)
       systemicsystolic: MISSING 12943893 VALUES (70.4%)
      systemicdiastolic: MISSING 12943959 VALUES (70.4%)
           systemicmean: MISSING 12899297 VALUES (70.1%)
             pasystolic: MISSING 17849297 VALUES (97.0%)
            padiastolic: MISSING 17849305 VALUES (97.0%)
                 pamean: MISSING 17841816 VALUES (97.0%)
                    st1: MISSING 11641470 VALUES (63.3%)
                    st2: MISSING 11212181 VALUES (60.9%)
                    st3: MISSING 11805125 VALUES (64.2%)
                    icp: MISSING 17985648 VALUES 

In [54]:
vitals1.head()

Unnamed: 0,vitalperiodicid,patientunitstayid,observationoffset,temperature,sao2,heartrate,respiration,cvp,etco2,systemicsystolic,systemicdiastolic,systemicmean,pasystolic,padiastolic,pamean,st1,st2,st3,icp
0,37376747,141168,2059,,,92.0,,30.0,,,,,,,,,,,
1,37404957,141168,1289,,,118.0,,,,,,,,,,,,,
2,37385871,141168,1794,,91.0,78.0,,,,,,,,,,,,,
3,37401664,141168,1374,,90.0,118.0,,,,,,,,,,,,,
4,37377404,141168,2039,,98.0,92.0,,33.0,,,,,,,,,,,


In [55]:
vitals1['event_offset'] = vitals1['patientunitstayid'].apply(lambda x: dx_offset_dict[x])

In [56]:
vitals1['window'] = vitals1['observationoffset'].between(
    vitals1['event_offset'] - parameters['window'][0], vitals1['event_offset'] - parameters['window'][1])
vitals1['baseline'] = vitals1['observationoffset'].between(*parameters['baseline'])
vitals1['window2'] = vitals1['observationoffset']<(parameters['window'][1])

In [57]:
vitals1 = vitals1[vitals1.window | vitals1.baseline | vitals1.window2]

In [58]:
vitals1.columns

Index(['vitalperiodicid', 'patientunitstayid', 'observationoffset',
       'temperature', 'sao2', 'heartrate', 'respiration', 'cvp', 'etco2',
       'systemicsystolic', 'systemicdiastolic', 'systemicmean', 'pasystolic',
       'padiastolic', 'pamean', 'st1', 'st2', 'st3', 'icp', 'event_offset',
       'window', 'baseline', 'window2'],
      dtype='object')

In [59]:
vitals1.drop(['etco2', 'pasystolic', 'padiastolic', 'pamean', 'icp'], axis=1, inplace=True)

In [60]:
# Extract features
def create_feature(feature, period, operation, df):
    new_feat = df[df[period]==True][feature].groupby(df['patientunitstayid']).transform(operation)
    return new_feat

In [61]:
vitals1['hr_min_w1'] = create_feature('heartrate', "window", 'min', vitals1)
vitals1['hr_max_w1'] = create_feature('heartrate', "window", 'max', vitals1)
vitals1['hr_mean_bl'] = create_feature('heartrate', "baseline", 'mean', vitals1)

vitals1['resp_min_w1'] = create_feature('respiration', "window", 'min', vitals1)
vitals1['resp_max_w1'] = create_feature('respiration', "window", 'max', vitals1)
vitals1['resp_mean_bl'] = create_feature('respiration', "baseline", 'mean', vitals1)

vitals1['sao2_min_w1'] = create_feature('sao2', "window", 'min', vitals1)

In [62]:
vitals1 = vitals1.groupby(vitals1['patientunitstayid'].values).transform('max')

In [63]:
vitals1['hr_chg_bl'] = vitals1['hr_max_w1'] - vitals1['hr_mean_bl']
vitals1['resp_chg_bl'] = vitals1['resp_max_w1'] - vitals1['resp_mean_bl']

In [64]:
len(vitals1)

1528209

In [65]:
vitals1.drop_duplicates(inplace=True)

In [66]:
vitals1[['patientunitstayid', 'hr_min_w1', 'hr_max_w1', ]].sample(5)

Unnamed: 0,patientunitstayid,hr_min_w1,hr_max_w1
41473968,1058516,71.0,91.0
100686498,2521804,73.0,92.0
15365663,397519,70.0,98.0
82239941,1806575,52.0,80.0
32824844,914030,101.0,114.0


In [67]:
vitals1['label'] = vitals1['patientunitstayid'].isin(dx_events.patientunitstayid)

In [68]:
vitals1.groupby('label')['hr_max_w1'].describe()

Unnamed: 0_level_0,count,mean,std,min,25%,50%,75%,max
label,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1
False,7666.0,94.259066,20.468484,0.0,80.0,92.0,106.0,213.0
True,7905.0,106.279317,24.978268,0.0,89.0,105.0,122.0,243.0


In [69]:
len(vitals1)

16460

In [70]:
vitals1.columns

Index(['vitalperiodicid', 'patientunitstayid', 'observationoffset',
       'temperature', 'sao2', 'heartrate', 'respiration', 'cvp',
       'systemicsystolic', 'systemicdiastolic', 'systemicmean', 'st1', 'st2',
       'st3', 'event_offset', 'window', 'baseline', 'window2', 'hr_min_w1',
       'hr_max_w1', 'hr_mean_bl', 'resp_min_w1', 'resp_max_w1', 'resp_mean_bl',
       'sao2_min_w1', 'hr_chg_bl', 'resp_chg_bl', 'label'],
      dtype='object')

In [71]:
vitals_columns = ['patientunitstayid', 'hr_min_w1',
       'hr_max_w1', 'hr_mean_bl', 'resp_min_w1', 'resp_max_w1', 'resp_mean_bl',
       'sao2_min_w1', 'hr_chg_bl', 'resp_chg_bl', 'label']
vitals1 = vitals1[vitals_columns]
vitals1.columns

Index(['patientunitstayid', 'hr_min_w1', 'hr_max_w1', 'hr_mean_bl',
       'resp_min_w1', 'resp_max_w1', 'resp_mean_bl', 'sao2_min_w1',
       'hr_chg_bl', 'resp_chg_bl', 'label'],
      dtype='object')

In [72]:
check_col_complete(vitals1);

      patientunitstayid: COMPLETE
              hr_min_w1: MISSING 889 VALUES (5.4%)
              hr_max_w1: MISSING 889 VALUES (5.4%)
             hr_mean_bl: MISSING 141 VALUES (0.9%)
            resp_min_w1: MISSING 2277 VALUES (13.8%)
            resp_max_w1: MISSING 2277 VALUES (13.8%)
           resp_mean_bl: MISSING 1293 VALUES (7.9%)
            sao2_min_w1: MISSING 1473 VALUES (8.9%)
              hr_chg_bl: MISSING 1008 VALUES (6.1%)
            resp_chg_bl: MISSING 2466 VALUES (15.0%)
                  label: COMPLETE

Total number of items: 16460


## Labs

In [73]:
labs = pd.read_csv(data_path + 'lab.csv')

In [74]:
len(labs)

39132531

In [75]:
labs = labs[['patientunitstayid', 'labresultoffset', 'labtypeid', 'labname', 'labresult']]

In [76]:
labs = labs[labs['labtypeid']!=2]

In [77]:
labs_to_keep = ['bedside glucose',
 'potassium',
 'sodium',
 'glucose',
 'Hgb',
 'chloride',
 'Hct',
 'creatinine',
 'BUN',
 'calcium',
 'bicarbonate',
 'platelets x 1000',
 'WBC x 1000',
 'RBC',
 'MCV',
 'MCHC',
 'MCH',
 'RDW',
 'anion gap',
 'MPV',
 'magnesium',
 '-lymphs',
 '-monos',
 '-eos',
 '-polys',
 '-basos',
 'albumin',
 'AST (SGOT)',
 'ALT (SGPT)',
 'total protein',
 'alkaline phos.',
 'total bilirubin',
 'phosphate',
 'paO2',
 'paCO2',
 'pH',
 'PT - INR',
 'HCO3',
'lactate', 'albumin', 'total bilirubin', 'urinary specific gravity', 'FiO2'
]   ###Note platelets

In [78]:
labs = labs[labs.labname.isin(labs_to_keep)]

In [79]:
len(labs)

35165816

In [80]:
labs = labs[labs.patientunitstayid.isin(dx_offset_dict)]

In [81]:
len(labs)

3919164

In [82]:
labs['event_offset'] = labs['patientunitstayid'].apply(lambda x: dx_offset_dict[x])

In [83]:
labs["window"] = labs['labresultoffset'].between(
    labs['event_offset'] - parameters['window'][0], labs['event_offset'] - parameters['window'][1])
labs["baseline"] = labs['labresultoffset'].between(*parameters['baseline'])
labs['window2'] = labs['labresultoffset']<parameters['window'][1]

In [84]:
labs.sample(10)

Unnamed: 0,patientunitstayid,labresultoffset,labtypeid,labname,labresult,event_offset,window,baseline,window2
28406075,2725251,23426,1,BUN,13.0,26677.0,False,False,False
14482412,1153574,1454,3,Hct,37.7,22.0,False,False,False
21415957,1774325,501,1,AST (SGOT),1629.0,777.0,False,False,False
38538328,3337984,2275,1,BUN,25.0,4434.0,False,False,False
17253866,1558329,-248,3,-basos,0.5,200.0,False,False,True
36103631,3166045,1382,1,creatinine,0.88,967.0,False,False,False
26791250,2599024,41,3,Hct,46.7,500.0,False,True,False
29421750,2782681,11659,3,MCH,30.4,14492.0,False,False,False
28326127,2721222,6401,3,MCV,90.0,1241.0,False,False,False
33638649,3057548,10211,1,alkaline phos.,80.0,22499.0,False,False,False


In [85]:
labs = labs[labs.window|labs.baseline|labs.window2]

In [86]:
len(labs)

1032942

In [87]:
labs.sample(10)

Unnamed: 0,patientunitstayid,labresultoffset,labtypeid,labname,labresult,event_offset,window,baseline,window2
16438009,1450449,-36,3,WBC x 1000,12.0,87.0,True,False,True
22538003,1828415,-137,7,FiO2,100.0,9.0,True,False,True
5083998,529632,14570,1,calcium,8.3,14721.0,True,False,False
29551013,2789230,46,7,paO2,68.0,5715.0,False,True,False
12813525,1084184,-1165,7,paO2,88.0,9176.0,False,False,True
28038485,2705992,-223,1,BUN,46.0,5941.0,False,False,True
38072265,3245672,-6,3,WBC x 1000,0.3,955.0,False,False,True
33245802,3036192,-1898,1,glucose,122.0,5444.0,False,False,True
7899117,817898,604,3,Hct,30.4,617.0,True,False,False
35057376,3132909,-61,1,glucose,237.0,18276.0,False,False,True


In [88]:
# Extract features
def create_lab(lab, period, operation, df):
    new_feat = df[(df[period]==True) & (df['labname'].str.contains(lab))]['labresult'].groupby(df['patientunitstayid']).transform(operation)
    return new_feat

In [89]:
labs['gluc_mean_bl'] = create_lab('glucose', "baseline", 'mean', labs)
labs['gluc_max_w1'] = create_lab('glucose', "window", 'max', labs)
labs['gluc_min_w1'] = create_lab('glucose', "window", 'min', labs)

labs['k_mean_bl'] = create_lab('potassium', "baseline", 'mean', labs)
labs['k_min_w2'] = create_lab('potassium', "window2", 'min', labs)
labs['k_max_w2'] = create_lab('potassium', "window2", 'max', labs)

labs['na_mean_bl'] = create_lab('sodium', "baseline", 'mean', labs)
labs['na_min_w2'] = create_lab('sodium', "window2", 'min', labs)
labs['na_max_w2'] = create_lab('sodium', "window2", 'max', labs)

labs['hgb_mean_bl'] = create_lab('Hgb', "baseline", 'mean', labs)
labs['hgb_min_w2'] = create_lab('Hgb', "window2", 'min', labs)

labs['cl_mean_bl'] = create_lab('chloride', "baseline", 'mean', labs)
labs['cl_min_w2'] = create_lab('chloride', "window2", 'min', labs)
labs['cl_max_w2'] = create_lab('chloride', "window2", 'max', labs)

labs['hct_min_w2'] = create_lab('Hct', "window2", 'min', labs)

labs['crt_min_w2'] = create_lab('creatinine', "window2", 'min', labs)
labs['crt_mean_w2'] = create_lab('creatinine', "window2", 'mean', labs)
labs['crt_max_w1'] = create_lab('creatinine', "window", 'max', labs)

labs['bun_mean_w1'] = create_lab('BUN', "window", 'mean', labs)

labs['ca_min_w2'] = create_lab('calcium', "window2", 'min', labs)
labs['ca_max_w2'] = create_lab('calcium', "window2", 'max', labs)

labs['bicarb_mean_w1'] = create_lab('bicarbonate', "window", 'mean', labs)
labs['wbc_max_w2'] = create_lab('WBC x 1000', "window", 'max', labs)
labs['angap_max_w1'] = create_lab('anion gap', "window", 'max', labs)
labs['angap_max_w2'] = create_lab('anion gap', "window2", 'max', labs)


labs['hco3_min_w2'] = create_lab('HCO3', "window2", 'min', labs)
labs['hco3_max_w2'] = create_lab('HCO3', "window2", 'max', labs)
labs['pao2_min_w2'] = create_lab('paO2', "window2", 'min', labs)
labs['paco2_max_w2'] = create_lab('paCO2', "window2", 'mean', labs)

labs['ph_min_w2'] = create_lab('pH', "window2", 'min', labs)
labs['ph_max_w2'] = create_lab('pH', "window2", 'max', labs)
labs['inr_max_w2'] = create_lab('PT - INR', "window2", 'max', labs)
labs['lymphs_max_w2'] = create_lab('-lymphs', "window2", 'max', labs)
labs['lact_max_w2'] = create_lab('lactate', "window2", 'max', labs)
labs['alb_min_w2'] = create_lab('albumin', "window2", 'min', labs)
labs['tbili_max_w2'] = create_lab('total bilirubin', "window2", 'max', labs)
labs['usg_max_w2'] = create_lab('urinary specific gravity', "window2", 'max', labs)
labs['fio2_mean_w2'] = create_lab('FiO2', "window2", 'mean', labs)
labs['plt_min_w2'] = create_lab('platelets x 1000', 'window2', 'min', labs)
labs['rbc_max_w2'] = create_lab('RBC', "window2", 'max', labs)

In [90]:
check_col_complete(labs);

      patientunitstayid: COMPLETE
        labresultoffset: COMPLETE
              labtypeid: COMPLETE
                labname: COMPLETE
              labresult: MISSING 3796 VALUES (0.4%)
           event_offset: COMPLETE
                 window: COMPLETE
               baseline: COMPLETE
                window2: COMPLETE
           gluc_mean_bl: MISSING 1005086 VALUES (97.3%)
            gluc_max_w1: MISSING 1018103 VALUES (98.6%)
            gluc_min_w1: MISSING 1018103 VALUES (98.6%)
              k_mean_bl: MISSING 1020939 VALUES (98.8%)
               k_min_w2: MISSING 1005877 VALUES (97.4%)
               k_max_w2: MISSING 1005877 VALUES (97.4%)
             na_mean_bl: MISSING 1021513 VALUES (98.9%)
              na_min_w2: MISSING 1006425 VALUES (97.4%)
              na_max_w2: MISSING 1006425 VALUES (97.4%)
            hgb_mean_bl: MISSING 1022695 VALUES (99.0%)
             hgb_min_w2: MISSING 1007953 VALUES (97.6%)
             cl_mean_bl: MISSING 1023363 VALUES (99.1%)
    

In [91]:
labs = labs.groupby(labs['patientunitstayid'].values).transform('max')

In [92]:
check_col_complete(labs);

      patientunitstayid: COMPLETE
        labresultoffset: COMPLETE
              labtypeid: COMPLETE
                labname: COMPLETE
              labresult: MISSING 13 VALUES (0.0%)
           event_offset: COMPLETE
                 window: COMPLETE
               baseline: COMPLETE
                window2: COMPLETE
           gluc_mean_bl: MISSING 202768 VALUES (19.6%)
            gluc_max_w1: MISSING 389190 VALUES (37.7%)
            gluc_min_w1: MISSING 389190 VALUES (37.7%)
              k_mean_bl: MISSING 373983 VALUES (36.2%)
               k_min_w2: MISSING 144903 VALUES (14.0%)
               k_max_w2: MISSING 144903 VALUES (14.0%)
             na_mean_bl: MISSING 391065 VALUES (37.9%)
              na_min_w2: MISSING 143313 VALUES (13.9%)
              na_max_w2: MISSING 143313 VALUES (13.9%)
            hgb_mean_bl: MISSING 431281 VALUES (41.8%)
             hgb_min_w2: MISSING 148706 VALUES (14.4%)
             cl_mean_bl: MISSING 421081 VALUES (40.8%)
              cl_m

In [93]:
labs['bun_crt_rat'] = labs['bun_mean_w1'] / labs['crt_mean_w2']
labs['crt_chg'] = labs['crt_max_w1'] / labs['crt_min_w2']
labs['hgb_chg'] = labs['hgb_mean_bl'] - labs['hgb_min_w2']
labs['k_chg'] = labs['k_max_w2'] - labs['k_mean_bl']
labs['pao2_fio2_rat'] = labs['pao2_min_w2'] / labs['fio2_mean_w2']

In [94]:
labs.columns

Index(['patientunitstayid', 'labresultoffset', 'labtypeid', 'labname',
       'labresult', 'event_offset', 'window', 'baseline', 'window2',
       'gluc_mean_bl', 'gluc_max_w1', 'gluc_min_w1', 'k_mean_bl', 'k_min_w2',
       'k_max_w2', 'na_mean_bl', 'na_min_w2', 'na_max_w2', 'hgb_mean_bl',
       'hgb_min_w2', 'cl_mean_bl', 'cl_min_w2', 'cl_max_w2', 'hct_min_w2',
       'crt_min_w2', 'crt_mean_w2', 'crt_max_w1', 'bun_mean_w1', 'ca_min_w2',
       'ca_max_w2', 'bicarb_mean_w1', 'wbc_max_w2', 'angap_max_w1',
       'angap_max_w2', 'hco3_min_w2', 'hco3_max_w2', 'pao2_min_w2',
       'paco2_max_w2', 'ph_min_w2', 'ph_max_w2', 'inr_max_w2', 'lymphs_max_w2',
       'lact_max_w2', 'alb_min_w2', 'tbili_max_w2', 'usg_max_w2',
       'fio2_mean_w2', 'plt_min_w2', 'rbc_max_w2', 'bun_crt_rat', 'crt_chg',
       'hgb_chg', 'k_chg', 'pao2_fio2_rat'],
      dtype='object')

In [95]:
len(labs)

1032942

In [96]:
labs_cols_to_keep = ['patientunitstayid', 
       'gluc_mean_bl', 'gluc_max_w1', 'gluc_min_w1', 'k_mean_bl', 'k_min_w2',
       'k_max_w2', 'na_mean_bl', 'na_min_w2', 'na_max_w2', 'hgb_mean_bl',
       'hgb_min_w2', 'cl_mean_bl', 'cl_min_w2', 'cl_max_w2', 'hct_min_w2',
       'crt_min_w2', 'crt_mean_w2', 'crt_max_w1', 'bun_mean_w1', 'ca_min_w2',
       'ca_max_w2', 'bicarb_mean_w1', 'wbc_max_w2', 'angap_max_w1',
       'angap_max_w2', 'hco3_min_w2', 'hco3_max_w2', 'pao2_min_w2',
       'paco2_max_w2', 'ph_min_w2', 'ph_max_w2', 'inr_max_w2', 'lymphs_max_w2',
       'lact_max_w2', 'alb_min_w2', 'tbili_max_w2', 'usg_max_w2',
       'fio2_mean_w2', 'plt_min_w2', 'rbc_max_w2',
       'bun_crt_rat', 'crt_chg', 'hgb_chg', 'k_chg', 'pao2_fio2_rat',
       ]
labs = labs[labs_cols_to_keep]

In [97]:
len(labs)

1032942

In [98]:
labs.drop_duplicates(inplace=True)

In [99]:
check_col_complete(labs);

      patientunitstayid: COMPLETE
           gluc_mean_bl: MISSING 4235 VALUES (26.1%)
            gluc_max_w1: MISSING 7389 VALUES (45.5%)
            gluc_min_w1: MISSING 7389 VALUES (45.5%)
              k_mean_bl: MISSING 7528 VALUES (46.3%)
               k_min_w2: MISSING 5058 VALUES (31.1%)
               k_max_w2: MISSING 5058 VALUES (31.1%)
             na_mean_bl: MISSING 7758 VALUES (47.8%)
              na_min_w2: MISSING 5045 VALUES (31.1%)
              na_max_w2: MISSING 5045 VALUES (31.1%)
            hgb_mean_bl: MISSING 8405 VALUES (51.7%)
             hgb_min_w2: MISSING 5056 VALUES (31.1%)
             cl_mean_bl: MISSING 8195 VALUES (50.4%)
              cl_min_w2: MISSING 5298 VALUES (32.6%)
              cl_max_w2: MISSING 5298 VALUES (32.6%)
             hct_min_w2: MISSING 4997 VALUES (30.8%)
             crt_min_w2: MISSING 5325 VALUES (32.8%)
            crt_mean_w2: MISSING 5325 VALUES (32.8%)
             crt_max_w1: MISSING 11778 VALUES (72.5%)
           

In [100]:
len(labs.dropna())

38

## Aperiodic Vitals

In [101]:
vitals2 = pd.read_csv(data_path + 'vitalAperiodic.csv')
vitals2.sample(5)

Unnamed: 0,vitalaperiodicid,patientunitstayid,observationoffset,noninvasivesystolic,noninvasivediastolic,noninvasivemean,paop,cardiacoutput,cardiacinput,svr,svri,pvr,pvri
14120241,286466451,2077877,966,128.0,67.0,77.0,,,,,,,
9452455,201989994,1353592,29,68.0,55.0,60.0,,,,,,,
9464546,199145642,1356344,250,138.0,75.0,91.0,,,,,,,
10065086,217502352,1549100,608,110.0,58.0,70.0,,,,,,,
8495003,186595707,1162039,443,74.0,44.0,52.0,,,,,,,


In [102]:
# # Reduce the size of the dataframe.
vitals2 = vitals2[vitals2['patientunitstayid'].isin(dx_offset_dict)]

In [103]:
check_col_complete(vitals2);

       vitalaperiodicid: COMPLETE
      patientunitstayid: COMPLETE
      observationoffset: COMPLETE
    noninvasivesystolic: MISSING 335374 VALUES (10.0%)
   noninvasivediastolic: MISSING 334848 VALUES (9.9%)
        noninvasivemean: MISSING 311040 VALUES (9.2%)
                   paop: MISSING 3364185 VALUES (99.8%)
          cardiacoutput: MISSING 3263455 VALUES (96.8%)
           cardiacinput: MISSING 3351104 VALUES (99.4%)
                    svr: MISSING 3159304 VALUES (93.7%)
                   svri: MISSING 3343067 VALUES (99.2%)
                    pvr: MISSING 3367087 VALUES (99.9%)
                   pvri: MISSING 3367112 VALUES (99.9%)

Total number of items: 3369979


In [104]:
vitals2.head()

Unnamed: 0,vitalaperiodicid,patientunitstayid,observationoffset,noninvasivesystolic,noninvasivediastolic,noninvasivemean,paop,cardiacoutput,cardiacinput,svr,svri,pvr,pvri
0,4295739,141168,349,,,79.0,,,,,,,
1,4295737,141168,123,106.0,68.0,81.0,,,,,,,
2,4295741,141168,1398,,,27.0,,,,,,,
3,4295740,141168,441,,,62.0,,,,,,,
4,4295738,141168,138,111.0,62.0,82.0,,,,,,,


In [105]:
vitals2['event_offset'] = vitals2['patientunitstayid'].apply(lambda x: dx_offset_dict[x] if x in dx_offset_dict else 0)

In [106]:
vitals2["window"] = vitals2['observationoffset'].between(
    vitals2['event_offset'] - parameters['window'][0], vitals2['event_offset'] - parameters['window'][1])
vitals2["baseline"] = vitals2['observationoffset'].between(*parameters['baseline'])

In [107]:
vitals2 = vitals2[vitals2.window | vitals2.baseline]

In [108]:
vitals2.drop(['vitalaperiodicid', 'paop', 'cardiacoutput', 'cardiacinput', 'svr', 'svri', 'pvr', 'pvri'], axis=1, inplace=True)

In [109]:
vitals2.sample(5)

Unnamed: 0,patientunitstayid,observationoffset,noninvasivesystolic,noninvasivediastolic,noninvasivemean,event_offset,window,baseline
19527771,2886590,98,174.0,87.0,109.0,50.0,False,True
2146558,349313,2303,86.0,64.0,69.0,2401.0,True,False
4972923,768101,3011,130.0,62.0,78.0,3223.0,True,False
18497530,2793461,5183,124.0,60.0,69.0,5196.0,True,False
24669797,3330658,137,88.0,62.0,69.0,143.0,True,True


In [110]:
# Extract features

vitals2['noninv_mean_w1'] = create_feature('noninvasivemean', "window", 'mean', vitals2)
vitals2['noninv_mean_bl'] = create_feature('noninvasivemean', "baseline", 'mean', vitals2)

vitals2['noninv_syst_max_wl'] = create_feature('noninvasivesystolic', "window", 'max', vitals2)
vitals2['noninv_syst_min_wl'] = create_feature('noninvasivesystolic', "window", 'min', vitals2)

vitals2['noninv_dias_max_wl'] = create_feature('noninvasivediastolic', "window", 'max', vitals2)
vitals2['noninv_dias_min_wl'] = create_feature('noninvasivediastolic', "window", 'min', vitals2)

In [111]:
vitals2[vitals2.baseline].sample(20)

Unnamed: 0,patientunitstayid,observationoffset,noninvasivesystolic,noninvasivediastolic,noninvasivemean,event_offset,window,baseline,noninv_mean_w1,noninv_mean_bl,noninv_syst_max_wl,noninv_syst_min_wl,noninv_dias_max_wl,noninv_dias_min_wl
16132486,2460014,193,103.0,64.0,76.0,29.0,False,True,,84.04,,,,
4531764,690068,340,109.0,66.0,75.0,18.0,False,True,,70.142857,,,,
23117880,3131270,58,141.0,55.0,77.0,2705.0,False,True,,85.777778,,,,
3686753,519526,317,115.0,84.0,91.0,10660.0,False,True,,86.761905,,,,
15049444,2289432,344,133.0,79.0,101.0,57.0,False,True,,92.608696,,,,
5065715,787147,190,99.0,74.0,80.0,5075.0,False,True,,84.407407,,,,
412791,178663,177,84.0,48.0,61.0,117.0,False,True,,64.333333,,,,
20189642,2922143,184,100.0,60.0,71.0,25.0,False,True,,69.157895,,,,
9427424,1348518,157,90.0,57.0,64.0,17666.0,False,True,,60.518519,,,,
4930263,760891,126,132.0,97.0,105.0,1854.0,False,True,,81.117647,,,,


In [112]:
vitals2 = vitals2.groupby(vitals2['patientunitstayid'].values).transform('max')

In [113]:
vitals2['noninv_chg'] = vitals2['noninv_mean_w1'] - vitals2['noninv_mean_bl']

In [114]:
vitals2.drop_duplicates(inplace=True)

In [115]:
vitals2_columns = ['patientunitstayid', 'noninv_mean_w1', 'noninv_mean_bl', 'noninv_chg',
       'noninv_syst_max_wl', 'noninv_syst_min_wl', 'noninv_dias_max_wl',
       'noninv_dias_min_wl']
vitals2 = vitals2[vitals2_columns]
vitals2.columns

Index(['patientunitstayid', 'noninv_mean_w1', 'noninv_mean_bl', 'noninv_chg',
       'noninv_syst_max_wl', 'noninv_syst_min_wl', 'noninv_dias_max_wl',
       'noninv_dias_min_wl'],
      dtype='object')

## Patient Data

In [116]:
# Unit type categorical: unittype, admitsource, ethnicity, gender
features3 = pd.read_csv(data_path + 'patient.csv')

In [117]:
check_col_complete(features3);

      patientunitstayid: COMPLETE
patienthealthsystemstayid: COMPLETE
                 gender: MISSING 134 VALUES (0.1%)
                    age: MISSING 95 VALUES (0.0%)
              ethnicity: MISSING 2290 VALUES (1.1%)
             hospitalid: COMPLETE
                 wardid: COMPLETE
      apacheadmissiondx: MISSING 22996 VALUES (11.4%)
        admissionheight: MISSING 4215 VALUES (2.1%)
    hospitaladmittime24: COMPLETE
    hospitaladmitoffset: COMPLETE
    hospitaladmitsource: MISSING 49464 VALUES (24.6%)
  hospitaldischargeyear: COMPLETE
hospitaldischargetime24: COMPLETE
hospitaldischargeoffset: COMPLETE
hospitaldischargelocation: MISSING 2033 VALUES (1.0%)
hospitaldischargestatus: MISSING 1751 VALUES (0.9%)
               unittype: COMPLETE
        unitadmittime24: COMPLETE
        unitadmitsource: MISSING 1090 VALUES (0.5%)
        unitvisitnumber: COMPLETE
           unitstaytype: COMPLETE
        admissionweight: MISSING 16718 VALUES (8.3%)
        dischargeweight: MISSING

In [118]:
# # Reduce the size of the dataframe.
features3 = features3[(features3['patientunitstayid'].isin(dx_events.patientunitstayid)) |
                                       (features3['patientunitstayid'].isin(dx_nonevents.patientunitstayid))]

features3 = features3[['patientunitstayid', 'gender', 'age', 'admissionweight', 'unittype', 'unitadmitsource', 'ethnicity']]

In [119]:
features3['age'] = features3.age.apply(lambda x: np.float(89) if x=='> 89' else np.float(x));

In [120]:
# from pandas.plotting import scatter_matrix
# import matplotlib.pyplot as plt

# attributes = ['age', 'admissionweight']
# scatter_matrix(features3[attributes], figsize=(12, 8));

# Merge All Features

In [121]:
len(dx_nonevents)

8392

In [122]:
vitals1 = vitals1.set_index('patientunitstayid')
vitals2 = vitals2.set_index('patientunitstayid')
features3 = features3.set_index('patientunitstayid')
labs = labs.set_index('patientunitstayid')

In [123]:
vitals1 = vitals1.loc[~vitals1.index.duplicated(keep='first')]
vitals2 = vitals2.loc[~vitals2.index.duplicated(keep='first')]
features3 = features3.loc[~features3.index.duplicated(keep='first')]
labs = labs.loc[~labs.index.duplicated(keep='first')]

features = pd.concat([vitals1, vitals2], axis=1)
features = pd.concat([features, features3], axis=1)
features = pd.concat([features, labs], axis=1)

In [124]:
df = df.set_index('patientunitstayid')
df = df.loc[~df.index.duplicated(keep='first')]

In [125]:
np.array(list(df.dx_seq_tok))

array([[ 74, 509,   1, ..., 905, 905, 905],
       [ 50,   2,  50, ..., 905, 905, 905],
       [273, 273, 273, ..., 905, 905, 905],
       ...,
       [ 37,   2,  37, ..., 905, 905, 905],
       [  2,   5,  17, ..., 905, 905, 905],
       [  2,  86,   9, ..., 905, 905, 905]])

In [126]:
df_merged = pd.concat([df, features], axis=1)

In [127]:
df_merged['label'] = df_merged.index.isin(dx_events.patientunitstayid)

In [128]:
df_merged.groupby('label').count()

Unnamed: 0_level_0,dx_seq,dx_seq_tok,hr_min_w1,hr_max_w1,hr_mean_bl,resp_min_w1,resp_max_w1,resp_mean_bl,sao2_min_w1,hr_chg_bl,...,tbili_max_w2,usg_max_w2,fio2_mean_w2,plt_min_w2,rbc_max_w2,bun_crt_rat,crt_chg,hgb_chg,k_chg,pao2_fio2_rat
label,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
False,8392,8392,7666,7666,8135,6916,6916,7471,7410,7613,...,3737,2157,1442,5504,5482,1475,1488,2060,2363,1319
True,8392,8392,7905,7905,8184,7267,7267,7696,7577,7839,...,4076,2303,2545,5491,5500,1717,1717,2764,3158,2370


In [129]:
# df_merged = pd.read_csv('../data/saved_data/death8392_8416240_merged_data.csv')

In [130]:
df_merged.head()

Unnamed: 0_level_0,dx_seq,dx_seq_tok,hr_min_w1,hr_max_w1,hr_mean_bl,resp_min_w1,resp_max_w1,resp_mean_bl,sao2_min_w1,hr_chg_bl,...,tbili_max_w2,usg_max_w2,fio2_mean_w2,plt_min_w2,rbc_max_w2,bun_crt_rat,crt_chg,hgb_chg,k_chg,pao2_fio2_rat
patientunitstayid,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
141168,AS;428;401;AICD;586;rheumatoid arthritis;427;4...,"[253, 6, 2, 255, 86, 294, 1, 24, 86, 1, 6, 253...",,,134.44898,,,,,,...,,,,,,,,,,
141297,s/p renal transplant;>= 20 mg prednisone per d...,"[524, 657, 86, 8, 1, 0, 11, 20, 9, 905, 905, 9...",134.0,138.0,119.152778,32.0,48.0,29.880597,98.0,18.847222,...,,,50.0,,,,,,,1.18
141314,785;441;441;344;518;441;441;441;441;427;441;44...,"[7, 217, 217, 378, 0, 217, 217, 217, 217, 1, 2...",94.0,120.0,109.41791,,,7.217391,67.0,10.58209,...,,,,,,,,,,
141360,427;436,"[1, 184, 905, 905, 905, 905, 905, 905, 905, 90...",96.0,96.0,87.53125,21.0,28.0,18.9375,99.0,8.46875,...,0.9,,,239.0,5.55,9.923664,1.0,,,
141448,585;colon;453;276;451;790;780;584,"[9, 340, 191, 4, 203, 19, 10, 5, 905, 905, 905...",94.0,94.0,75.685714,19.0,19.0,14.342857,100.0,18.314286,...,0.6,1.019,,320.0,4.07,,,,,


# Clean and Preprocess Features

In [131]:
# df_merged = pd.read_csv('/Users/tobymanders/Documents/insight_project/data/saved_data/death10000_10000_merged_data.csv')

In [132]:
len(df_merged)

16784

In [133]:
df_merged = df_merged.dropna(thresh=37)

In [134]:
len(df_merged)

11389

In [135]:
##### IMPORTANT: FITTING TEMPORARILY OFF

from sklearn.pipeline import Pipeline
from sklearn.preprocessing import StandardScaler
from sklearn.impute import SimpleImputer
from sklearn.compose import ColumnTransformer
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import OneHotEncoder

cat_encoder = OneHotEncoder()

cat_attribs = ['gender', 'unittype', 'unitadmitsource', 'ethnicity']
num_attribs = ['hr_min_w1', 'hr_max_w1', 'hr_mean_bl',
       'resp_min_w1', 'resp_max_w1', 'resp_mean_bl', 'sao2_min_w1',
       'hr_chg_bl', 'resp_chg_bl', 'noninv_mean_w1', 'noninv_mean_bl',
       'noninv_chg', 'noninv_syst_max_wl', 'noninv_syst_min_wl',
       'noninv_dias_max_wl', 'noninv_dias_min_wl', 'age', 'admissionweight', 
       'gluc_mean_bl', 'gluc_max_w1', 'gluc_min_w1', 'k_mean_bl', 'k_min_w2',
       'k_max_w2', 'na_mean_bl', 'na_min_w2', 'na_max_w2', 'hgb_mean_bl',
       'hgb_min_w2', 'cl_mean_bl', 'cl_min_w2', 'cl_max_w2', 'hct_min_w2',
       'crt_min_w2', 'crt_mean_w2', 'crt_max_w1', 'bun_mean_w1', 'ca_min_w2',
       'ca_max_w2', 'bicarb_mean_w1', 'wbc_max_w2', 'angap_max_w1',
       'angap_max_w2', 'hco3_min_w2', 'hco3_max_w2', 'pao2_min_w2',
       'paco2_max_w2', 'ph_min_w2', 'ph_max_w2', 'inr_max_w2', 'lymphs_max_w2',
       'lact_max_w2', 'alb_min_w2', 'tbili_max_w2', 'usg_max_w2',
       'fio2_mean_w2', 'plt_min_w2', 'rbc_max_w2', 'bun_crt_rat', 'crt_chg',
       'hgb_chg', 'k_chg', 'pao2_fio2_rat']

num_pipeline = Pipeline([
        ('imputer', SimpleImputer(strategy="median")),
        ('std_scaler', StandardScaler()),
    ])

full_pipeline = ColumnTransformer([
        ("cat", cat_encoder, cat_attribs),
        ("num", num_pipeline, num_attribs)],
        remainder='passthrough'
    )

# full_pipeline = pkl.load(open('../models/full_pipeline.pkl', 'rb'))

df_merged['patientunitstayid'] = df_merged.index
final_cols = ['dx_seq_tok'] + cat_attribs + num_attribs + ['label'] + ['patientunitstayid']
df_merged = df_merged[final_cols]
df_merged.replace([np.inf, -np.inf], np.nan, inplace=True)
df_merged.dropna(subset=cat_attribs, inplace=True)
full_pipeline.fit(df_merged)
df_merged_prepared = full_pipeline.transform(df_merged,)

In [136]:
df_merged_prepared.shape

(11305, 96)

In [137]:
# Save pipeline
import pickle as pkl
pkl.dump(full_pipeline, open('../models/full_pipeline.pkl', 'wb'))

# ML

In [138]:
df_merged.to_csv(data_path + 'saved_data/'+ diagnosis + str(len(dx_events)) + '_' + 
                 str(len(dx_nonevents)) + '_' + str(parameters['window'][0]) + 
                 '_merged_data.csv', index=False, header=True)

In [139]:
y = df_merged_prepared[:, -2].astype(int)

In [140]:
np.mean(y)

0.5092436974789916

In [141]:
# Stratified split

from sklearn.model_selection import StratifiedShuffleSplit

sss = StratifiedShuffleSplit(n_splits=1, test_size=0.2)
for train_index, test_index in sss.split(df_merged_prepared[:, :-2], y):
    X_train, X_test = df_merged_prepared[:, :-2][train_index], df_merged_prepared[:, :-2][test_index]
    y_train, y_test = y[train_index], y[test_index]
    
    
split = StratifiedShuffleSplit(n_splits=1, test_size=0.1)
for train_index, test_index in sss.split(X_train, y_train):
    X_train, X_val = X_train[train_index], X_train[test_index]
    y_train, y_val = y_train[train_index], y_train[test_index]

In [142]:
X_train.shape

(7235, 94)

In [143]:
X_train_A, X_train_B = np.array(list(X_train[:,-1])), X_train[:,:-1].astype(float)
X_test_A, X_test_B = np.array(list(X_test[:,-1])), X_test[:,:-1].astype(float)
X_val_A, X_val_B = np.array(list(X_val[:,-1])), X_val[:,:-1].astype(float)

In [144]:
X_train_B

array([[ 1.        ,  0.        ,  0.        , ..., -0.04895947,
        -0.0976859 , -0.15923073],
       [ 1.        ,  0.        ,  0.        , ..., -0.86805727,
        -4.3726781 , -0.10215299],
       [ 0.        ,  1.        ,  0.        , ..., -0.04895947,
        -0.0976859 , -0.10215299],
       ...,
       [ 0.        ,  1.        ,  0.        , ..., -0.04895947,
        -0.0976859 , -0.10215299],
       [ 0.        ,  1.        ,  0.        , ..., -0.04895947,
         1.54654187, -0.10215299],
       [ 1.        ,  0.        ,  0.        , ..., -0.04895947,
        -0.0976859 , -0.10215299]])

In [145]:
np.mean(y_test)

0.5090667846085802

## Model 1: DNN

In [146]:
# tf.keras.backend.clear_session()

# model_1 = tf.keras.Sequential([
#     tf.keras.layers.Dense(128, activation='relu'),
#     tf.keras.layers.Dropout(rate=0.5),
#     tf.keras.layers.Dense(128, activation='relu'),
#     tf.keras.layers.Dropout(rate=0.5),
#     tf.keras.layers.Dense(128, activation='relu'),
#     tf.keras.layers.Dropout(rate=0.5),
#     tf.keras.layers.Dense(128, activation='relu'),
#     tf.keras.layers.Dropout(rate=0.5),
#     tf.keras.layers.Dense(128, activation='relu'),
#     tf.keras.layers.Dropout(rate=0.5),
#     tf.keras.layers.Dense(128, activation='relu'),
    
#     tf.keras.layers.Dense(1, activation='sigmoid')
# ])

# model_1.compile(loss='binary_crossentropy',
#               optimizer='adam',
#               metrics=[tf.keras.metrics.AUC()])

# es1 = EarlyStopping(monitor='val_loss', mode='min', patience=20, restore_best_weights=True)

# # Numcat only
# history1 = model_1.fit(X_train_B,
#                     y_train,
#                     epochs=200,
#                     batch_size=512,
#                     validation_data=(X_val_B, y_val),
#                     callbacks=[es1], verbose=0)

In [147]:
# results = model_1.evaluate(X_test_B, y_test, verbose=0)
# print(results)

## Model 2: RNN

In [148]:
# model = tf.keras.Sequential([
#     tf.keras.layers.Dense(32),
#     tf.keras.layers.Dense(32, activation='relu'),
#     tf.keras.layers.Dropout(rate=0.4),
#     tf.keras.layers.Dense(1, activation='sigmoid')
# ])

In [149]:
# model = tf.keras.Sequential([
#     tf.keras.layers.Embedding(parameters['vocab_size'], 32),
#     tf.keras.layers.Bidirectional(tf.keras.layers.LSTM(32)),
#     tf.keras.layers.Dense(32, activation='relu'),
#     tf.keras.layers.Dropout(rate=0.4),
#     tf.keras.layers.Dense(1, activation='sigmoid')
# ])

In [150]:
# parameters['vocab_size'] = 1000

In [151]:

# tf.keras.backend.clear_session()

# input_A = tf.keras.layers.Input(shape=(parameters['seq_length'],), name='seq_input')
# embedding = tf.keras.layers.Embedding(parameters['vocab_size'], 32)(input_A)
# hidden1_A = tf.keras.layers.Bidirectional(tf.keras.layers.LSTM(16))(embedding)
# hidden_drop = tf.keras.layers.Dropout(rate=0.3)(hidden1_A)
# hidden2_A = tf.keras.layers.Dense(32, activation='relu')(hidden_drop)

# input_B = tf.keras.layers.Input(shape=(X_train_B.shape[1],), name='feature_input')
# hidden1_B = tf.keras.layers.Dense(128, kernel_regularizer=regularizers.l2(0.01),
#                 activity_regularizer=regularizers.l1(0.01))(input_B)
# hidden_drop1B = tf.keras.layers.Dropout(rate=0.3)(hidden1_B)
# hidden2_B = tf.keras.layers.Dense(128, kernel_regularizer=regularizers.l2(0.01),
#                 activity_regularizer=regularizers.l1(0.01))(hidden_drop1B)

# concat = tf.keras.layers.Concatenate()([hidden2_A, hidden2_B])
# common = tf.keras.layers.Dense(64, activation='relu')(concat)
# drop = tf.keras.layers.Dropout(rate=0.3)(common)
# output = tf.keras.layers.Dense(1, name="output", activation='sigmoid')(drop)
# aux_output = tf.keras.layers.Dense(1, activation='sigmoid', name="aux_output")(hidden1_B)

# model = tf.keras.Model(inputs=[input_A, input_B], outputs=[output, aux_output])

In [152]:
# model.compile(loss=['binary_crossentropy','binary_crossentropy'],
#               optimizer='adam',
#               loss_weights=[0.9, 0.1],
#               metrics=[tf.keras.metrics.AUC(), 'accuracy'])

# model.compile(loss='binary_crossentropy',
#               optimizer='adam',
#               metrics=[tf.keras.metrics.AUC()])

# model.compile(loss='binary_crossentropy',
#               optimizer='adam',
#               metrics=['accuracy'])

In [153]:
# model.summary()

In [154]:
# es = EarlyStopping(monitor='val_loss', mode='min', patience=5, restore_best_weights=True)

In [155]:
# # Full model
# history = model.fit([X_train_A, X_train_B],
#                     [y_train, y_train],
#                     epochs=30,
#                     batch_size=64,
#                     validation_data=([X_val_A, X_val_B], [y_val, y_val]),
#                     callbacks=[es], verbose=1)

In [156]:
# # Full model
# results = model.evaluate([X_test_A, X_test_B], [y_test, y_test], verbose=0)
# print(results)

In [157]:
# import matplotlib.pyplot as plt

# def plot_graphs(history, string):
#     plt.plot(history.history[string])
#     plt.plot(history.history['val_'+string])
#     plt.xlabel("Epochs")
#     plt.ylabel(string)
#     plt.legend([string, 'val_'+string])
#     plt.title(f'Epoch vs. {string}')

In [158]:
# auc = list(history.history.keys())[3]
# history.history.keys()

In [159]:
# plt.figure(figsize=(10,4))
# plt.subplot(1,2,1)
# plot_graphs(history, auc)

# plt.subplot(1,2,2)
# plot_graphs(history, 'loss')

# fig_name = diagnosis + str(len(dx_events)) + '_' + str(len(dx_nonevents)) + 'epochs_vs_accuracy'
# save_fig(fig_name)

In [160]:
# from sklearn.metrics import roc_curve
# from sklearn.metrics import auc

# y_pred = model.predict([X_test_A, X_test_B], batch_size=5000)[0].ravel()
# fpr_keras, tpr_keras, thresholds_keras = roc_curve(y_test, y_pred)
# auc_keras = auc(fpr_keras, tpr_keras)

# plt.style.use('seaborn')
# plt.figure(figsize=(8,8))
# plt.plot([0, 1], [0, 1], 'k--')
# plt.plot(fpr_keras, tpr_keras, label='AUC = {:.3f}'.format(auc_keras))
# plt.legend(loc='best')
# plt.title('ROC Curve')
# plt.xlabel('False positive rate')
# plt.ylabel('True positive rate')

# fig_name = diagnosis + str(len(dx_events)) + '_' + str(len(dx_nonevents)) + 'AUC' + '{:.3f}'.format(auc_keras)
# save_fig(fig_name)

In [161]:
# model_2 = tf.keras.Sequential([
#     tf.keras.layers.Embedding(parameters['vocab_size'], 64),
#     tf.keras.layers.Bidirectional(tf.keras.layers.LSTM(16)),
#     tf.keras.layers.Dense(128, activation='relu'),
#     tf.keras.layers.Dropout(rate=0.5),
#     tf.keras.layers.Dense(128, activation='relu'),
#     tf.keras.layers.Dropout(rate=0.1),
#     tf.keras.layers.Dense(1, activation='sigmoid')
# ])

# model_2.compile(loss='binary_crossentropy',
#               optimizer='adam',
#               metrics=[tf.keras.metrics.AUC()])


# es2 = EarlyStopping(monitor='val_loss', mode='min', patience=10, restore_best_weights=True)


# # Sequential only
# history2 = model_2.fit(X_train_A,
#                     y_train,
#                     epochs=50,
#                     batch_size=512,
#                     validation_data=(X_val_A, y_val),
#                     callbacks=[es2], verbose=1)

In [162]:
# results = model_2.evaluate(X_test_A, y_test, verbose=0)
# print(results)

In [163]:
# # # Numcat only
# y_pred2 = model_2.predict(X_test_A, batch_size=5000).ravel()

In [164]:
# # Remove last layer of models

# model_1.pop()
# model_2.pop()

## Model 3: DNN + RNN

In [165]:
# input_A = tf.keras.layers.Input(shape=(parameters['seq_length'],), name='seq_input')
# seq_model = model_2(input_A)

# input_B = tf.keras.layers.Input(shape=(X_train_B.shape[1],), name='feature_input')
# catnum_model = model_1(input_B)

# concat = tf.keras.layers.Concatenate()([seq_model, catnum_model])
# hidden1 = tf.keras.layers.Dense(128, activation='relu')(concat)
# drop1 = tf.keras.layers.Dropout(rate=0.5)(hidden1)

# output = tf.keras.layers.Dense(1, name="output", activation='sigmoid')(drop1)

# model_3 = tf.keras.Model(inputs=[input_A, input_B], outputs=[output])

In [166]:
# model_3.compile(loss='binary_crossentropy',
#               optimizer='adam',
#               metrics=[tf.keras.metrics.AUC(), 'accuracy'])

In [167]:
# es3 = EarlyStopping(monitor='val_loss', mode='min', patience=10, restore_best_weights=True)

# history3 = model_3.fit([X_train_A, X_train_B],
#                     y_train,
#                     epochs=30,
#                     batch_size=1024,
#                     validation_data=([X_val_A, X_val_B],  y_val),
#                     callbacks=[es3])

In [168]:
# results = model_3.evaluate([X_test_A, X_test_B], y_test, verbose=0)
# print(results)

In [169]:
# y_pred3 = model_3.predict([X_test_A, X_test_B], batch_size=5000).ravel()

In [170]:
# from sklearn.metrics import accuracy_score

# print(accuracy_score(y_pred4, y_test))

In [171]:
params = {
    'hidden_layers' : 1,
    'n_units' : 64,
    'learning_rate' : 0.0005,
    'rate' :  0.5,
    'embed_size' : 32,
    'LSTM_units' : 16,
    'rnn_hlayers' : 2,
    'n_units_rnn_dnn' : 16,
    'batch_size' : 64, #test
    'comb_hlayers' : 2,
    'comb_width' : 3,
    'activation' : 'elu'
}


model_1 = tf.keras.Sequential()

for i in range(params['hidden_layers']):
    model_1.add(tf.keras.layers.Dense(params['n_units'], activation='elu'))
    model_1.add(tf.keras.layers.Dropout(rate=params['rate']))

model_1.add(tf.keras.layers.Dense(1, activation='sigmoid'))

model_1.compile(loss='binary_crossentropy',
              optimizer=tf.keras.optimizers.Adam(learning_rate=params['learning_rate']),
              metrics=[tf.keras.metrics.AUC()])

es1 = EarlyStopping(monitor='val_loss', mode='min', patience=20, restore_best_weights=True)

# Numcat only
history1 = model_1.fit(X_train_B,
                    y_train,
                    epochs=200,
                    batch_size=params['batch_size'],
                    validation_data=(X_val_B, y_val),
                    callbacks=[es1], verbose=0)

## Model 2: RNN

model_2 = tf.keras.Sequential([
    tf.keras.layers.Embedding(1000, params['embed_size']),
    tf.keras.layers.Bidirectional(tf.keras.layers.LSTM(params['LSTM_units'])),

])

for j in range(params['rnn_hlayers']):
    model_2.add(tf.keras.layers.Dense(params['n_units_rnn_dnn'], activation=params['activation']))
    model_2.add(tf.keras.layers.Dropout(rate=params['rate']))

model_2.add(tf.keras.layers.Dense(1, activation='sigmoid'))

model_2.compile(loss='binary_crossentropy',
              optimizer='adam',
              metrics=[tf.keras.metrics.AUC()])


es2 = EarlyStopping(monitor='val_loss', mode='min', patience=10, restore_best_weights=True)


# Sequential only
history2 = model_2.fit(X_train_A,
                    y_train,
                    epochs=50,
                    batch_size=params['batch_size'],
                    validation_data=(X_val_A, y_val),
                    callbacks=[es2], verbose=0)

# Remove last layer of models

model_1.pop()
model_2.pop()

## Model 3: DNN + RNN

input_A = tf.keras.layers.Input(shape=(100,), name='seq_input')
seq_model = model_2(input_A)

input_B = tf.keras.layers.Input(shape=(X_train_B.shape[1],), name='feature_input')
catnum_model = model_1(input_B)

concat = tf.keras.layers.Concatenate()([seq_model, catnum_model])

hlayers = tf.keras.Sequential()
for k in range(params['comb_hlayers']):
    hlayers.add(tf.keras.layers.Dense(params['comb_width'], activation=params['activation']))
    hlayers.add(tf.keras.layers.Dropout(rate=params['rate']))

hidden1 = hlayers(concat)

output = tf.keras.layers.Dense(1, name="output", activation='sigmoid')(hidden1)
aux_output = tf.keras.layers.Dense(1, activation='sigmoid', name="aux_output")(catnum_model)

model_3 = tf.keras.Model(inputs=[input_A, input_B], outputs=[output, aux_output])

model_3.compile(loss=['binary_crossentropy','binary_crossentropy'],
              optimizer='adam',
              loss_weights=[0.9, 0.1],
              metrics=[tf.keras.metrics.AUC(), 'accuracy'])

es3 = EarlyStopping(monitor='val_loss', mode='min', patience=10, restore_best_weights=True)

history3 = model_3.fit([X_train_A, X_train_B],
                    [y_train, y_train],
                    epochs=30,
                    batch_size=params['batch_size'],
                    validation_data=([X_val_A, X_val_B],  [y_val, y_val]),
                    callbacks=[es3], verbose=0)

In [172]:
y_pred3 = model_3.predict([X_test_A, X_test_B])

In [173]:
# from sklearn.metrics import roc_curve
# from sklearn.metrics import auc
# import matplotlib.pyplot as plt
# %matplotlib inline

# fpr_keras, tpr_keras, thresholds_keras = roc_curve(y_test, y_pred3)
# auc_keras = auc(fpr_keras, tpr_keras)

# plt.style.use('seaborn')
# plt.figure(figsize=(8,8))
# plt.plot([0, 1], [0, 1], 'k--')
# plt.plot(fpr_keras, tpr_keras, label='AUC = {:.2f}'.format(auc_keras))
# plt.legend(loc='upper left', fontsize='xx-large')
# plt.title('Receiver Operating Characteristic (ROC) Curve')
# plt.xlabel('False positive rate')
# plt.ylabel('True positive rate')

# fig_name = diagnosis + str(len(dx_events)) + '_' + str(len(dx_nonevents)) + 'AUC' + '{:.3f}'.format(auc_keras)
# save_fig(fig_name)

In [174]:
model_3.save(f'../models/{diagnosis}_model.h5')

# Look at Embedding Layer

In [175]:
embeddings = model_2.layers[0].get_weights()[0]

In [176]:
words_vis_list = [item[0] for item in counter2[:100]]

In [177]:
word_inds = [word_to_ID[wrd] for wrd in words_vis_list]

In [178]:
embeddings_words = embeddings[word_inds,:]

In [179]:
from sklearn.decomposition import PCA

pca = PCA(n_components=2)
pca.fit(embeddings_words)
embed2d = pca.transform(embeddings_words)

In [180]:
icd_to_dx = {v: k for k, v in dx_to_icd.items()}

In [181]:
import seaborn as sns

sns.set_style('white')

x_2d = embed2d[:, 0]
y_2d = embed2d[:, 1]

# colors = ['C0', 'C2', 'black']

fig, ax = plt.subplots(figsize=(14,14))

p1=sns.regplot(x=x_2d, y=y_2d, fit_reg=False, marker="o", color="skyblue", scatter_kws={'s':0})
 
for i, item in enumerate(words_vis_list):
    if item in icd_to_dx:
        txt = icd_to_dx[item]
    else:
        txt = item
    p1.text(x_2d[i], y_2d[i], txt, fontsize=8, weight='ultralight') 

p1.set(xticks=[], yticks=[]);
save_fig('word_pca_vis')

NameError: name 'plt' is not defined

In [None]:
def specificity(y_true, y_pred):
    """
    param:
    y_pred - Predicted labels
    y_true - True labels 
    Returns:
    Specificity score
    """
    neg_y_true = 1 - y_true
    neg_y_pred = 1 - y_pred
    fp = np.sum(neg_y_true * y_pred)
    tn = np.sum(neg_y_true * neg_y_pred)
    specificity = tn / (tn + fp + tf.keras.backend.epsilon())
    return specificity

In [None]:
specificity(y_test, y_pred4)

In [None]:
from sklearn.metrics import confusion_matrix, classification_report

threshold = 0.5

y_pred4 = y_pred3 > threshold
print(classification_report(y_test, y_pred4))

In [None]:
new_preds = model_3.predict([X_train_A, X_train_B])
new_preds = [pred[0] for pred in new_preds]
time_preds_slice = pd.DataFrame(zip(y_train, new_preds), columns=['expired', 'pred'])

In [None]:
time_preds_slice['timepoint'] = -int(parameters['window'][0]/60)


time_preds_slice.sample(10)

In [None]:
import seaborn as sns

sns.boxplot(data=time_preds_slice, x='expired', y='pred',  width=0.2)

In [None]:
old_sample = pd.read_csv('../data/saved_data/sample_preds.csv')

In [None]:
np.mean(time_preds_slice[time_preds_slice.expired==0]['pred'])

In [None]:
new_sample = pd.concat([old_sample, time_preds_slice], ignore_index=True)

In [None]:
# new_sample.to_csv('../data/saved_data/sample_preds.csv', index=False, header=True)