In [1]:
import pandas as pd
from tqdm import tqdm
import os
from typing import List, Dict

In [35]:
# Same ordering as paper
task_2_name: Dict[str, str] = {
    # Operational outcomes
    'guo_los': 'Long LOS',
    'guo_readmission': '30-Day Readmission',
    'guo_icu': 'ICU Admission',
    # Anticipating lab test results
    'lab_thrombocytopenia': 'Thrombocytopenia',
    'lab_hyperkalemia': 'Hyperkalemia',
    'lab_hypoglycemia': 'Hypoglycemia',
    'lab_hyponatremia': 'Hyponatremia',
    'lab_anemia': 'Anemia',
    # Assignment of new diagnoses
    'new_hypertension': 'Hypertension',
    'new_hyperlipidemia': 'Hyperlipidemia',
    'new_pancan': 'Pancreatic Cancer',
    'new_celiac': 'Celiac',
    'new_lupus': 'Lupus',
    'new_acutemi' : 'Acute MI',
    # Anticipating chest x-ray findings
    'chexpert' : 'Chest X-Ray',
}

task_2_value_type: Dict[str, str] = {
    'new_pancan': 'boolean',
    'new_celiac': 'boolean',
    'new_lupus': 'boolean',
    'new_acutemi' : 'boolean',
    'new_hypertension': 'boolean',
    'new_hyperlipidemia': 'boolean',
    'guo_los': 'boolean',
    'guo_readmission': 'boolean',
    'guo_icu': 'boolean',
    'lab_thrombocytopenia': 'multiclass',
    'lab_hyperkalemia': 'multiclass',
    'lab_hypoglycemia': 'multiclass',
    'lab_hyponatremia': 'multiclass',
    'lab_anemia': 'multiclass',
    'chexpert' : 'multilabel',
}

In [36]:
path_to_data_csv = '../EHRSHOT_ASSETS/data/ehrshot.csv'
path_to_labels_dir = '../EHRSHOT_ASSETS/benchmark/'
path_to_splits_csv = '../EHRSHOT_ASSETS/splits/person_id_map.csv'

# Overall Stats

In [10]:
df_dataset = pd.read_csv(path_to_data_csv)
df_split = pd.read_csv(path_to_splits_csv)

  df_dataset = pd.read_csv(path_to_data_csv)


In [13]:
print("# of events:", df_dataset.shape[0])
print("# of patients:", df_dataset['patient_id'].nunique())
print("# of visits:", df_dataset['visit_id'].nunique())
print("# of train patients", df_split[df_split['split'] == 'train']['omop_person_id'].nunique())
print("# of val patients", df_split[df_split['split'] == 'val']['omop_person_id'].nunique())
print("# of test patients", df_split[df_split['split'] == 'test']['omop_person_id'].nunique())

# of events: 41661637
# of patients: 6739
# of visits: 921499
# of train patients 2295
# of val patients 2232
# of test patients 2212


# Label Stats

In [None]:
df_labels = pd.read_csv(os.path.join(path_to_labels_dir, 'all_labels.csv'))

In [17]:
print("# of labels:", df_labels.shape[0])

# of labels: 406379


In [37]:
results = {
    'train' : [],
    'test' : [],
    'val' : [],
    'all' : []
}
for task, task_name in tqdm(task_2_name.items()):
    path_to_task_csv: str = f"{path_to_labels_dir}{task}/labeled_patients.csv"
    if not os.path.exists(path_to_task_csv):
        print(f"Skipping {task_name} @ {path_to_task_csv}")
        continue
    df = pd.read_csv(path_to_task_csv)
    value_type = task_2_value_type[task]
    if task_2_value_type[task] == "boolean":
        df['is_positive_label'] = df["value"]
    elif task_2_value_type[task] == "multiclass":
        df['is_positive_label'] = df["value"] > 0
    elif task_2_value_type[task] == "multilabel":
        df['is_positive_label'] = df["value"] != 8192
    else:
        print(f"Skipping {task_name}")
        continue
    
    # Splits
    for split in ['train', 'test', 'val']:
        df_ = df[df['patient_id'].isin(df_split[df_split['split'] == split]['omop_person_id'])]
        results[split].append({
            'task' : task,
            'task_name' : task_name,
            'n_patients' : df_['patient_id'].nunique(),
            'n_positive_patients' : df_.groupby('patient_id')['is_positive_label'].max().sum(),
            'n_labels' : df_.shape[0],
            'n_positive_labels' : df_['is_positive_label'].sum(),
        })
    
    # All
    results['all'].append({
        'task' : task,
        'task_name' : task_name,
        'n_patients' : df['patient_id'].nunique(),
        'n_positive_patients' : df.groupby('patient_id')['is_positive_label'].max().sum(),
        'n_labels' : df.shape[0],
        'n_positive_labels' : df['is_positive_label'].sum(),
    })

for key in results.keys():
    results[key] = pd.DataFrame(results[key])
    results[key]['n_negative_labels'] = results[key]['n_labels'] - results[key]['n_positive_labels']
    results[key]['n_negative_patients'] = results[key]['n_patients'] - results[key]['n_positive_patients']
    results[key]['label_prevalence'] = results[key]['n_positive_labels'] / results[key]['n_labels']

100%|██████████| 15/15 [00:00<00:00, 19.30it/s]


In [38]:
results['all']

Unnamed: 0,task,task_name,n_patients,n_positive_patients,n_labels,n_positive_labels,n_negative_labels,n_negative_patients,label_prevalence
0,guo_los,Long LOS,3855,1271,6995,1767,5228,2584,0.252609
1,guo_readmission,30-Day Readmission,3718,474,7003,911,6092,3244,0.130087
2,guo_icu,ICU Admission,3617,266,6491,290,6201,3351,0.044677
3,lab_thrombocytopenia,Thrombocytopenia,6063,2566,179618,59718,119900,3497,0.332472
4,lab_hyperkalemia,Hyperkalemia,5931,1289,200170,4769,195401,4642,0.023825
5,lab_hypoglycemia,Hypoglycemia,5974,1379,318164,4721,313443,4595,0.014838
6,lab_hyponatremia,Hyponatremia,5921,3692,212837,60708,152129,2229,0.285232
7,lab_anemia,Anemia,6086,4271,184880,127496,57384,1815,0.689615
8,new_hypertension,Hypertension,2328,386,3764,516,3248,1942,0.137088
9,new_hyperlipidemia,Hyperlipidemia,2650,410,4442,566,3876,2240,0.12742


In [39]:
# Train
results['train']

Unnamed: 0,task,task_name,n_patients,n_positive_patients,n_labels,n_positive_labels,n_negative_labels,n_negative_patients,label_prevalence
0,guo_los,Long LOS,1377,464,2569,681,1888,913,0.265084
1,guo_readmission,30-Day Readmission,1337,164,2608,370,2238,1173,0.141871
2,guo_icu,ICU Admission,1306,107,2402,113,2289,1199,0.047044
3,lab_thrombocytopenia,Thrombocytopenia,2084,906,68776,22714,46062,1178,0.330261
4,lab_hyperkalemia,Hyperkalemia,2038,456,76349,1829,74520,1582,0.023956
5,lab_hypoglycemia,Hypoglycemia,2054,511,122108,1904,120204,1543,0.015593
6,lab_hyponatremia,Hyponatremia,2035,1294,81336,23877,57459,741,0.29356
7,lab_anemia,Anemia,2092,1484,70501,49028,21473,608,0.695423
8,new_hypertension,Hypertension,792,129,1259,182,1077,663,0.144559
9,new_hyperlipidemia,Hyperlipidemia,923,137,1684,205,1479,786,0.121734


In [40]:
# Val
results['val']

Unnamed: 0,task,task_name,n_patients,n_positive_patients,n_labels,n_positive_labels,n_negative_labels,n_negative_patients,label_prevalence
0,guo_los,Long LOS,1240,395,2231,534,1697,845,0.239355
1,guo_readmission,30-Day Readmission,1191,159,2206,281,1925,1032,0.12738
2,guo_icu,ICU Admission,1157,84,2052,92,1960,1073,0.044834
3,lab_thrombocytopenia,Thrombocytopenia,1981,807,54504,17867,36637,1174,0.327811
4,lab_hyperkalemia,Hyperkalemia,1935,428,60168,1386,58782,1507,0.023036
5,lab_hypoglycemia,Hypoglycemia,1950,433,95488,1449,94039,1517,0.015175
6,lab_hyponatremia,Hyponatremia,1930,1174,64473,17557,46916,756,0.272316
7,lab_anemia,Anemia,1992,1379,56224,38498,17726,613,0.684725
8,new_hypertension,Hypertension,781,128,1247,175,1072,653,0.140337
9,new_hyperlipidemia,Hyperlipidemia,863,140,1441,189,1252,723,0.131159


In [41]:
# Test
results['test']

Unnamed: 0,task,task_name,n_patients,n_positive_patients,n_labels,n_positive_labels,n_negative_labels,n_negative_patients,label_prevalence
0,guo_los,Long LOS,1238,412,2195,552,1643,826,0.251481
1,guo_readmission,30-Day Readmission,1190,151,2189,260,1929,1039,0.118776
2,guo_icu,ICU Admission,1154,75,2037,85,1952,1079,0.041728
3,lab_thrombocytopenia,Thrombocytopenia,1998,853,56338,19137,37201,1145,0.339682
4,lab_hyperkalemia,Hyperkalemia,1958,405,63653,1554,62099,1553,0.024414
5,lab_hypoglycemia,Hypoglycemia,1970,435,100568,1368,99200,1535,0.013603
6,lab_hyponatremia,Hyponatremia,1956,1224,67028,19274,47754,732,0.287551
7,lab_anemia,Anemia,2002,1408,58155,39970,18185,594,0.687301
8,new_hypertension,Hypertension,755,129,1258,159,1099,626,0.126391
9,new_hyperlipidemia,Hyperlipidemia,864,133,1317,172,1145,731,0.1306
