In [1]:
import numpy as np
import pandas as pd
import os
import glob
import torch
import torch.nn.functional as F
import joblib
import itertools
import scipy
import matplotlib.pyplot as plt
import seaborn as sns
import statsmodels.api as sm
import warnings
import string
from sklearn.metrics import roc_auc_score, average_precision_score, brier_score_loss, recall_score, precision_score
from prediction_utils.util import df_dict_concat, yaml_read
from matplotlib.ticker import FormatStrFormatter

In [2]:
project_dir = "/share/pi/nigam/projects/spfohl/cohorts/admissions/starr_20200523"
experiment_name_baseline = 'baseline_tuning_fold_1_10'
experiment_name_fair = 'fair_tuning_fold_1_10'
tasks = ['hospital_mortality', 'LOS_7', 'readmission_30']
cohort_path = os.path.join(project_dir, 'cohort', 'cohort.parquet')
row_id_map_path = os.path.join(
    project_dir, 'merged_features_binary/features_sparse/features_row_id_map.parquet'
)

In [3]:
attributes = ['gender_concept_name', 'age_group', 'race_eth']

In [4]:
cohort = pd.read_parquet(cohort_path)
row_id_map = pd.read_parquet(row_id_map_path)
cohort = cohort.merge(row_id_map)

### Generate the cohort table

In [5]:
### Cohort table
cohort_df_long = (
    cohort
    .melt(
        id_vars = ['person_id'] + attributes,
        value_vars = tasks,
        var_name = 'task',
        value_name = 'labels'
    )
    .melt(
        id_vars = ['person_id', 'task', 'labels'],
        value_vars = attributes,
        var_name = 'attribute',
        value_name = 'group'
    )
)

In [6]:
cohort_statistics_df = (
    cohort_df_long
    .groupby(['task', 'attribute', 'group'])
    .agg(
        prevalence=('labels', 'mean'),
    )
    .reset_index()
    .groupby('attribute')
    .apply(lambda x: x.pivot_table(index = 'group', columns = 'task', values = 'prevalence'))
    .reset_index()
)

group_size_df = (
    cohort_df_long
    .groupby(['task', 'attribute', 'group'])
    .agg(
        size = ('labels', lambda x: x.shape[0])
    )
    .reset_index()
    .drop(columns = 'task')
    .drop_duplicates()
)

cohort_statistics_df = cohort_statistics_df.merge(group_size_df)
cohort_statistics_df = (
    cohort_statistics_df
    .set_index(['attribute', 'group'])
    [['size'] + tasks]
)

In [7]:
cohort_statistics_df

Unnamed: 0_level_0,Unnamed: 1_level_0,size,hospital_mortality,LOS_7,readmission_30
attribute,group,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1
age_group,[18-30),23042,0.006814,0.175072,0.04609
age_group,[30-45),43432,0.005963,0.129536,0.039648
age_group,[45-55),27394,0.017778,0.205227,0.052712
age_group,[55-65),35703,0.025096,0.227432,0.055794
age_group,[65-75),36084,0.028378,0.234204,0.054844
age_group,[75-91),32989,0.040013,0.237807,0.054533
gender_concept_name,FEMALE,112713,0.016085,0.165935,0.045177
gender_concept_name,MALE,85923,0.027117,0.244335,0.057098
gender_concept_name,No matching concept,8,0.0,0.125,0.0
race_eth,Asian,29460,0.020876,0.171317,0.053632


In [8]:
## Write to Latex
table_path = './../figures/starr_20200523'
os.makedirs(table_path, exist_ok=True)
with open(os.path.join(table_path, 'cohort_table.txt'), 'w') as fp:
    (
        cohort_statistics_df
        .reset_index().drop(columns='attribute').set_index(['group'])
        .to_latex(
            fp, 
            float_format = '%.3g', 
            index_names = False, 
            index=True
        )
    )