In [28]:
import pandas as pd
from tqdm import tqdm
import os
from typing import List, Dict
import datasets
import pickle

In [29]:
# Same ordering as paper
task_2_name: Dict[str, str] = {
    # Operational outcomes
    'guo_los': 'Long LOS',
    'guo_readmit': '30-Day Readmission',
    'guo_icu': 'ICU Admission',
    # Anticipating lab test results
    'lab_thrombocytopenia': 'Thrombocytopenia',
    'lab_hyperkalemia': 'Hyperkalemia',
    'lab_hypoglycemia': 'Hypoglycemia',
    'lab_hyponatremia': 'Hyponatremia',
    'lab_anemia': 'Anemia',
    # Assignment of new diagnoses
    'new_hypertension': 'Hypertension',
    'new_hyperlipidemia': 'Hyperlipidemia',
    'new_pancan': 'Pancreatic Cancer',
    'new_celiac': 'Celiac',
    'new_lupus': 'Lupus',
    'new_acutemi' : 'Acute MI',
    # Anticipating chest x-ray findings
    'chexpert' : 'Chest X-Ray',
}

task_2_value_type: Dict[str, str] = {
    'new_pancan': 'boolean',
    'new_celiac': 'boolean',
    'new_lupus': 'boolean',
    'new_acutemi' : 'boolean',
    'new_hypertension': 'boolean',
    'new_hyperlipidemia': 'boolean',
    'guo_los': 'boolean',
    'guo_readmit': 'boolean',
    'guo_icu': 'boolean',
    'lab_thrombocytopenia': 'multiclass',
    'lab_hyperkalemia': 'multiclass',
    'lab_hypoglycemia': 'multiclass',
    'lab_hyponatremia': 'multiclass',
    'lab_anemia': 'multiclass',
    'chexpert' : 'multilabel',
}

In [30]:
dataset = datasets.Dataset.from_parquet('/Users/mwornow/Downloads/ehrshot-meds-standard-stanford/data/*.parquet')
ontology = pickle.load(open('/Users/mwornow/Desktop/ehrshot-benchmark/assets/ontology_standard.pkl', 'rb'))

In [31]:
path_to_splits = '/Users/mwornow/Downloads/som-nero-nigam-starr.starr_omop_cdm5_confidential_filtered_2024_02_12_ehrshot_release_dua/person_id_map/merged.csv'
df_splits = pd.read_csv(path_to_splits)
df_splits.shape

(6732, 2)

# EHRSHOT

Do split by train/test/val

In [32]:
results = {
    'train' : [],
    'test' : [],
    'val' : [],
    'all' : []
}
for task, task_name in tqdm(task_2_name.items()):
    path_to_task_csv: str = f"../../assets/labels/{task}_labels.csv"
    if not os.path.exists(path_to_task_csv):
        print(f"Skipping {task_name}")
        continue
    try:
        df = pd.read_csv(path_to_task_csv)
        df['boolean_value'] = df['boolean_value'] if 'boolean_value' in df.columns else ''
        df['integer_value'] = df['integer_value'] if 'integer_value' in df.columns else ''
        df['categorical_value'] = df['categorical_value'] if 'categorical_value' in df.columns else ''
        df['float_value'] = df['float_value'] if 'float_value' in df.columns else ''
        df.to_csv(path_to_task_csv, index=False)
        if task_2_value_type[task] == "boolean":
            df['is_positive_label'] = df["boolean_value"]      
        elif task_2_value_type[task] == "multiclass":
            df['is_positive_label'] = df["integer_value"] > 0
        else:
            print(f"Skipping {task_name}")
            continue
    except Exception as e:
        print(f"Skipping {task_name}")
        print(e)
        continue
    
    # Splits
    for split in ['train', 'test', 'val']:
        df_split = df[df['patient_id'].isin(df_splits[df_splits['split'] == split]['omop_person_id'])]
        results[split].append({
            'task' : task,
            'task_name' : task_name,
            'n_patients' : df_split['patient_id'].nunique(),
            'n_positive_patients' : df_split.groupby('patient_id')['is_positive_label'].max().sum(),
            'n_labels' : df_split.shape[0],
            'n_positive_labels' : df_split['is_positive_label'].sum(),
        })
    
    # All
    results['all'].append({
        'task' : task,
        'task_name' : task_name,
        'n_patients' : df['patient_id'].nunique(),
        'n_positive_patients' : df.groupby('patient_id')['is_positive_label'].max().sum(),
        'n_labels' : df.shape[0],
        'n_positive_labels' : df['is_positive_label'].sum(),
    })

for key in results.keys():
    results[key] = pd.DataFrame(results[key])
    results[key]['n_negative_labels'] = results[key]['n_labels'] - results[key]['n_positive_labels']
    results[key]['n_negative_patients'] = results[key]['n_patients'] - results[key]['n_positive_patients']
    results[key]['label_prevalence'] = results[key]['n_positive_labels'] / results[key]['n_labels']

100%|██████████| 15/15 [00:03<00:00,  4.22it/s]

Skipping Chest X-Ray





In [38]:
# All splits
results['all']

Unnamed: 0,task,task_name,n_patients,n_positive_patients,n_labels,n_positive_labels,n_negative_labels,n_negative_patients,label_prevalence
0,guo_los,Long LOS,4659,2332,14671,4938,9733,2327,0.336582
1,guo_readmit,30-Day Readmission,4514,1301,15545,3760,11785,3213,0.241878
2,guo_icu,ICU Admission,4508,756,14151,971,13180,3752,0.068617
3,lab_thrombocytopenia,Thrombocytopenia,6076,2605,187666,61513,126153,3471,0.327779
4,lab_hyperkalemia,Hyperkalemia,5949,1310,210417,4977,205440,4639,0.023653
5,lab_hypoglycemia,Hypoglycemia,5994,1422,334495,5072,329423,4572,0.015163
6,lab_hyponatremia,Hyponatremia,5940,3735,222539,62883,159656,2205,0.282571
7,lab_anemia,Anemia,6098,4308,193204,132557,60647,1790,0.686099
8,new_hypertension,Hypertension,2451,479,6046,865,5181,1972,0.14307
9,new_hyperlipidemia,Hyperlipidemia,2855,519,7916,979,6937,2336,0.123674


In [39]:
# Train
results['train']

Unnamed: 0,task,task_name,n_patients,n_positive_patients,n_labels,n_positive_labels,n_negative_labels,n_negative_patients,label_prevalence
0,guo_los,Long LOS,1625,855,5543,1941,3602,770,0.350171
1,guo_readmit,30-Day Readmission,1588,486,5888,1556,4332,1102,0.264266
2,guo_icu,ICU Admission,1584,287,5380,360,5020,1297,0.066914
3,lab_thrombocytopenia,Thrombocytopenia,2090,922,72029,23363,48666,1168,0.324355
4,lab_hyperkalemia,Hyperkalemia,2047,469,80394,1933,78461,1578,0.024044
5,lab_hypoglycemia,Hypoglycemia,2063,528,127586,2005,125581,1535,0.015715
6,lab_hyponatremia,Hyponatremia,2044,1310,85212,24723,60489,734,0.290135
7,lab_anemia,Anemia,2097,1495,73830,51098,22732,602,0.692103
8,new_hypertension,Hypertension,831,162,2092,333,1759,669,0.159178
9,new_hyperlipidemia,Hyperlipidemia,991,183,3002,382,2620,808,0.127249


In [40]:
# Val
results['val']

Unnamed: 0,task,task_name,n_patients,n_positive_patients,n_labels,n_positive_labels,n_negative_labels,n_negative_patients,label_prevalence
0,guo_los,Long LOS,1520,732,4437,1440,2997,788,0.324544
1,guo_readmit,30-Day Readmission,1468,422,4693,1043,3650,1046,0.222246
2,guo_icu,ICU Admission,1462,234,4258,304,3954,1228,0.071395
3,lab_thrombocytopenia,Thrombocytopenia,1983,814,56951,18371,38580,1169,0.322576
4,lab_hyperkalemia,Hyperkalemia,1940,430,63377,1443,61934,1510,0.022769
5,lab_hypoglycemia,Hypoglycemia,1957,437,101462,1618,99844,1520,0.015947
6,lab_hyponatremia,Hyponatremia,1936,1186,67513,18284,49229,750,0.270822
7,lab_anemia,Anemia,1994,1394,58791,40015,18776,600,0.680631
8,new_hypertension,Hypertension,816,158,1950,275,1675,658,0.141026
9,new_hyperlipidemia,Hyperlipidemia,924,170,2476,274,2202,754,0.110662


In [41]:
# Test
results['test']

Unnamed: 0,task,task_name,n_patients,n_positive_patients,n_labels,n_positive_labels,n_negative_labels,n_negative_patients,label_prevalence
0,guo_los,Long LOS,1514,745,4691,1557,3134,769,0.331912
1,guo_readmit,30-Day Readmission,1458,393,4964,1161,3803,1065,0.233884
2,guo_icu,ICU Admission,1462,235,4513,307,4206,1227,0.068026
3,lab_thrombocytopenia,Thrombocytopenia,2003,869,58686,19779,38907,1134,0.337031
4,lab_hyperkalemia,Hyperkalemia,1962,411,66646,1601,65045,1551,0.024022
5,lab_hypoglycemia,Hypoglycemia,1974,457,105447,1449,103998,1517,0.013742
6,lab_hyponatremia,Hyponatremia,1960,1239,69814,19876,49938,721,0.284699
7,lab_anemia,Anemia,2007,1419,60583,41444,19139,588,0.684086
8,new_hypertension,Hypertension,804,159,2004,257,1747,645,0.128244
9,new_hyperlipidemia,Hyperlipidemia,940,166,2438,323,2115,774,0.132486
