In [None]:
import numpy as np
import pandas as pd

import matplotlib.pyplot as plt

In [None]:
def create_pivot_lab_table(df):    
    lab_pivot = pd.pivot_table(df, 
                               values='valuenum', index=['icustay_id','charttime'],
                               columns=['item_name'], aggfunc='mean').reset_index()
    
    lab_pivot = lab_pivot.sort_values(['icustay_id','charttime'])
    
    return lab_pivot


def handle_outliers(df):
    df.loc[(df['Potassium']<1)|(df['Potassium']>15), 'Potassium'] = np.nan
    df.loc[(df['Sodium']<95)|(df['Sodium']>178), 'Sodium'] = np.nan
    df.loc[(df['Chloride']<70)|(df['Chloride']>150), 'Chloride'] = np.nan
    df.loc[df['Creatinine']>150, 'Creatinine'] = np.nan
    df.loc[df['Magnesium']>10, 'Magnesium'] = np.nan
    df.loc[df['Calcium']>20, 'Calcium'] = np.nan
    df.loc[df['Hgb']>20, 'Hgb'] = np.nan
    df.loc[df['WBC']>500, 'WBC'] = np.nan
    df.loc[df['Platelet']>2000, 'Platelet'] = np.nan
    
    return df


def resample_table(df, time_interval='1H'):
    df_resample = df.set_index('charttime').groupby('icustay_id').resample('1H').agg(np.mean)
    df_resample.drop(columns='icustay_id', inplace=True)
    df_resample = df_resample.dropna(axis=0, how='all')
    
    return df_resample



def cal_time_interval(df):
    # Indicate the time interval between current and next lab test
    df.set_index('icustay_id', inplace=True)
    df['Initial Event Time'] = df.groupby('icustay_id')['charttime'].min()
    df['time_offset'] = df.groupby('icustay_id').apply(lambda x: x['charttime']-x['Initial Event Time']).values/np.timedelta64(1,'h')
    df['time_offset'] = df.groupby('icustay_id')['time_offset'].ffill()
    df['time_delta'] = df.groupby('icustay_id')['time_offset'].diff().shift(periods=-1)
    df.reset_index(inplace=True)
    
    return df


def visualize_time_length(df):
    order_times_list = df.groupby('icustay_id')['charttime'].count()
    count = np.quantile(order_times_list.values, 0.95)
    print(f'95% patients have at most {count} test orderings')
    
    plt.figure(figsize=(8,6))
    plt.hist(order_times_list[order_times_list<200].values, bins=50)
    plt.xlabel('Test Ordering Counts')
    plt.show()
    
    
def cal_test_delta(df):
    lab_value = df.set_index('icustay_id')['Hgb'].copy()
#     lab_nnan_mask = non_nan_mask.set_index('icustay_id')['Hgb'].copy()
#     lab_value[lab_value==0] = np.nan
    non_nan_delta = lab_value[~lab_value.isna()].groupby('icustay_id').apply(lambda x: x.diff().abs())
    lab_delta = lab_value.copy()
    lab_delta[~lab_delta.isna()] = non_nan_delta.values
    lab_delta[lab_delta.isna()] = 0
    df['Hgb_delta'] = lab_delta.values
    
    return df


def truncate_table(df):
    # Remove patients with sbp < 90
    valid_id_set = df['icustay_id'].unique()
    low_sbp_ids = df[df['sysbp']<90]['icustay_id'].unique()
    df_trunc = df[~df['icustay_id'].isin(low_sbp_ids)]
    display(df_trunc)
    
    # End with the last Hgb visit
    df_trunc = df_trunc.loc[df.groupby('icustay_id')[['Hgb']].
                            apply(lambda x: x.loc[:x.last_valid_index()]).index.get_level_values(1)]

    # Drop icustay_id with <=1 Hgb test
    df_test_count = df_trunc.groupby('icustay_id').count()[['Hgb']]
    valid_test_icustay_id = df_test_count[df_test_count.values>1].index
    df_trunc = df_trunc[df_trunc['icustay_id'].isin(valid_test_icustay_id)]
    
    return df_trunc


def fix_time_length(df, max_len=5):
    # Fix time length for all patients
    low_density_ids = []
    def expand_group(x, max_len=max_len):
        # Remove initial observations in which timestamps have low density
        first_valid_index = x[x['time_delta']<=120].head(1).index.values
        if len(first_valid_index) == 0:
            low_density_ids.append(x['icustay_id'].values[0])
            return x
        
        result = x.loc[first_valid_index[0]:]    
        
        result = result.head(max_len)
        x_id = result['icustay_id'].iloc[0]
        time = result['charttime'].iloc[0]

        last_row = pd.Series({'icustay_id': x_id,
                              'charttime': time})
        
        # When actual length is short, extend it to max_len
        result = pd.concat([result, pd.DataFrame([last_row]*(max_len-result['charttime'].count()))], 
                           ignore_index=True)

        return result

    df_expand = df.groupby('icustay_id').apply(expand_group)
    df_expand = df_expand.droplevel(axis=0,level=0)
    df_expand = df_expand[~df_expand['icustay_id'].isin(low_density_ids)]
    
    return df_expand



def get_normal_mask(df):
    def set_hgb_normal_low(x):
        if x < 5:
            return 11
        elif 5 <= x <= 11:
            return 11.5
        else:
            return 12.0 # Do not consider gender, default to be female


    def set_hgb_normal_high(x):
            if x < 5:
                return 14.0
            else:
                return 16.0 # Do not consider gender, default to be female
            
    df['Hgb_normal_low'] = df['admission_age'].apply(set_hgb_normal_low)
    df['Hgb_normal_high'] = df['admission_age'].apply(set_hgb_normal_high)
    df.loc[(df['admission_age']>14)&(df['gender']=='M'), 'Hgb_normal_low'] = 13.0
    
    # Normal mask (label)
    lab_normal_mask = (df['Hgb']>=df['Hgb_normal_low']).astype(int)
    lab_normal_mask[df['Hgb'].isna()] = -1
    df['Hgb_normal_mask'] = lab_normal_mask
    
    # Stable mask (label)
    lab_normal_mask = df.set_index('icustay_id')['Hgb_normal_mask'].copy()
    lab_normal_mask[lab_normal_mask==-1] = np.nan
    non_nan_stable = lab_normal_mask[~lab_normal_mask.isna()].groupby('icustay_id').apply(lambda x: x-x.shift())
    lab_stable = lab_normal_mask.copy()
    lab_stable[~lab_stable.isna()] = non_nan_stable.values
    lab_stable_mask = lab_stable.copy()
    lab_stable_mask[lab_stable==0] = 1
    lab_stable_mask[(lab_stable<0)|(lab_stable>0)] = 0 # transition: Normal->Abnormal, Abnormal -> Normal
    lab_stable_mask[lab_stable.isna()] = -1
    
    df['Hgb_stable_mask'] = lab_stable_mask.values
    
    
    return df


def fill_missing_values(df):
    new_df = df.copy()
    
    for col in lab_names+vital_names+['time_delta']:
        new_df[col] = new_df[col].fillna(0)
    for col in ['race_code', 'gender_code','admission_age']:
        new_df[col] = new_df[col].fillna(-1)
    
    return new_df


def get_non_nan_mask(df):
    non_nan_mask = df[['icustay_id','charttime']+lab_names+vital_names].copy()
    non_nan_mask[non_nan_mask.columns[2:]] = (~non_nan_mask[non_nan_mask.columns[2:]].isna()).astype('float')
    
    return non_nan_mask

In [None]:
input_dir = 'mimic_data/'

demog = pd.read_csv(input_dir + 'demog.csv')
vitals = pd.read_csv(input_dir + 'vitals.csv')
labs = pd.read_csv(input_dir + 'labs.csv')

In [None]:
lab_names = ['BUN','Calcium','Chloride','Creatinine','HCO3','Hgb','Magnesium',
            'Phosphate','Platelet','Potassium','Sodium','WBC']
vital_names = ['heartrate', 'sysbp', 'diasbp', 'resprate', 'spo2']
demog_names = ['admission_age','gender_code','race_code']

In [None]:
labs = labs[labs['item_name'].isin(lab_names)]
vitals = vitals[['icustay_id','charttime']+vital_names]

labs['charttime'] = pd.to_datetime(labs['charttime'])
vitals['charttime'] = pd.to_datetime(vitals['charttime'])

In [None]:
lab_pivot = create_pivot_lab_table(labs)
lab_pivot = handle_outliers(lab_pivot)

In [None]:
lab_pivot_resample = resample_table(lab_pivot)
vital_resample = resample_table(vitals)

df_comb = pd.merge(lab_pivot_resample,
                   vital_resample,
                   on=['icustay_id', 'charttime'], how='left')
df_comb = df_comb.reset_index()
df_comb = df_comb.sort_values(['icustay_id','charttime'])

df_comb = cal_time_interval(df_comb)

In [None]:
df_comb_trunc = truncate_table(df_comb)
# visualize_time_length(df_comb_trunc)
df_comb_fix = fix_time_length(df_comb_trunc)

In [None]:
demog['gender_code'] = pd.Categorical(demog['gender']).codes
demog['race_code'] = pd.Categorical(demog['ethnicity_grouped']).codes
df_comb_all = pd.merge(df_comb_fix, demog, on='icustay_id', how='left')

In [None]:
df_comb_all = cal_test_delta(df_comb_all)
df_comb_all = get_normal_mask(df_comb_all)
df_final = fill_missing_values(df_comb_all)

In [None]:
non_nan_mask = get_non_nan_mask(df_comb_all)

In [None]:
df_final.to_csv(input_dir + 'lab_test_data.csv', index=False)
non_nan_mask.to_csv(input_dir + 'non_nan_mask.csv', index=False)