In [1]:
import pandas as pd
import stride_ml
import os
from IPython.display import display

In [2]:
project_dir = '/labs/shahlab/spfohl/fairness_MLHC'
label_path = os.path.join(project_dir, 'labels')

In [3]:
label_df = pd.read_csv(os.path.join(label_path, 'labels.csv'))

In [4]:
table_path = './latex_tables'
os.makedirs(table_path, exist_ok = True)

In [5]:
print(label_df.shape)
label_df.head()

(129633, 10)


Unnamed: 0.1,Unnamed: 0,patient_id,day_index,los,mortality,age,gender,race_eth,label,split
0,0,9235026,72,False,False,"[30, 45)",Female,Other,"[30, 45)",train
1,1,9235064,35,False,False,"[30, 45)",Female,Other,"[30, 45)",train
2,2,9235136,552,False,False,"[45, 65)",Female,White,"[45, 65)",train
3,3,9235148,691,False,False,"[45, 65)",Female,Hispanic,"[45, 65)",train
4,4,9235161,17,False,False,"[65, 89)",Female,Asian,"[65, 89)",train


In [6]:
sensitive_variables = ['race_eth', 'gender', 'age']

In [7]:
event_rates = pd.concat({sensitive_variable: label_df. \
                         groupby(sensitive_variable)[['los', 'mortality']]. \
                         apply(lambda x: x.mean()) 
 for sensitive_variable in sensitive_variables}, keys = sensitive_variables)
event_rates_all = label_df[['los', 'mortality']].apply(lambda x: pd.Series(x.mean()))
display(event_rates)

Unnamed: 0,Unnamed: 1,los,mortality
race_eth,Asian,0.186888,0.025308
race_eth,Black,0.238947,0.020185
race_eth,Hispanic,0.195741,0.018837
race_eth,Other,0.199564,0.021719
race_eth,Unknown,0.201225,0.072363
race_eth,White,0.204202,0.021139
gender,Female,0.166933,0.017766
gender,Male,0.245112,0.029067
gender,Other,0.0,0.0
age,"[18, 30)",0.179714,0.007128


In [8]:
patient_counts = pd.concat({sensitive_variable: label_df. \
                         groupby(sensitive_variable). \
                         apply(lambda x: pd.Series(len(x['patient_id'].unique())))
     for sensitive_variable in sensitive_variables}, keys = sensitive_variables)
patient_counts.columns = ['count']
patient_counts_all = label_df[['patient_id']].apply(lambda x: len(x.unique()))
print(patient_counts_all)
display(patient_counts)

patient_id    129633
dtype: int64


Unnamed: 0,Unnamed: 1,count
race_eth,Asian,17465
race_eth,Black,5202
race_eth,Hispanic,21978
race_eth,Other,11004
race_eth,Unknown,3593
race_eth,White,70391
gender,Female,72556
gender,Male,57076
gender,Other,1
age,"[18, 30)",15291


In [9]:
# display(event_rates_all)
# display(patient_counts_all)

patient_counts_all = pd.DataFrame(patient_counts_all).reset_index(drop=True)
patient_counts_all.columns = ['count']
summary_df_all = pd.concat([event_rates_all, patient_counts_all, pd.DataFrame({'Group': 'All'}, index = [0])], axis = 1)

display(summary_df_all)

Unnamed: 0,los,mortality,count,Group
0,0.201353,0.022741,129633,All


In [10]:
summary_df = pd.merge(event_rates, patient_counts, left_index = True, right_index = True)
summary_df = summary_df.rename_axis(['sensitive', 'Group']).reset_index()
summary_df = pd.concat([summary_df, summary_df_all], ignore_index = True, sort = False)
summary_df = summary_df.loc[~((summary_df.sensitive == 'gender') & (summary_df.Group == 'Other'))]
summary_df = summary_df.drop(columns = 'sensitive')

display(summary_df)

Unnamed: 0,Group,los,mortality,count
0,Asian,0.186888,0.025308,17465
1,Black,0.238947,0.020185,5202
2,Hispanic,0.195741,0.018837,21978
3,Other,0.199564,0.021719,11004
4,Unknown,0.201225,0.072363,3593
5,White,0.204202,0.021139,70391
6,Female,0.166933,0.017766,72556
7,Male,0.245112,0.029067,57076
9,"[18, 30)",0.179714,0.007128,15291
10,"[30, 45)",0.13979,0.006776,27155


In [11]:
summary_df.columns = ['Group', 'LOS > 7 days', 'Hospital Mortality', 'Count']

In [12]:
summary_df

Unnamed: 0,Group,LOS > 7 days,Hospital Mortality,Count
0,Asian,0.186888,0.025308,17465
1,Black,0.238947,0.020185,5202
2,Hispanic,0.195741,0.018837,21978
3,Other,0.199564,0.021719,11004
4,Unknown,0.201225,0.072363,3593
5,White,0.204202,0.021139,70391
6,Female,0.166933,0.017766,72556
7,Male,0.245112,0.029067,57076
9,"[18, 30)",0.179714,0.007128,15291
10,"[30, 45)",0.13979,0.006776,27155


In [15]:
with open(os.path.join(table_path, 'summary.txt'), 'w') as fp:
    summary_df.to_latex(fp, 
                        columns= ['Group', 'Count', 'LOS > 7 days', 'Hospital Mortality'], 
                        float_format = '%.3f', index = False)