In [36]:
import os
import pandas as pd

In [2]:
dir_cohort = '/local-scratch/nigam/projects/jlemmon/transfer_learning/experiments/data/cohort/cohort_split.parquet'

In [24]:
df = pd.read_parquet(dir_cohort)
df = df.assign(age_in_days=(df['admit_date']-df['birth_datetime']).dt.days)

# each row = unique patient
assert(df.shape[0]==df['person_id'].nunique())

In [25]:
print(f"n = {df.shape[0]}; min admission:{df.admit_date.min()}; max admission:{df.admit_date.max()}")

n = 316197; min admission:2008-01-01 06:01:00; max admission:2022-08-28 16:37:00


In [33]:
adults = df.query("age_in_years>=18")
peds = df.query("age_in_years<18 and age_in_days>=28")

In [34]:
adults.fold_id.value_counts(), peds.fold_id.value_counts()

(0       185370
 test     38035
 val      21406
 Name: fold_id, dtype: int64,
 0       18621
 test     4959
 val      2851
 Name: fold_id, dtype: int64)

#### CLMBR pretraining 

In [44]:
dir_pretrain_cohort = "/local-scratch/nigam/projects/jlemmon/transfer_learning/experiments/data/pretrain_cohort"

def read_file(path):
    with open(path,"r") as f:
        return [x for x in f.read().split('\n') if x!='']

In [48]:
adults_train = read_file(os.path.join(dir_pretrain_cohort,'train_patient_ids_ad.txt'))
adults_val = read_file(os.path.join(dir_pretrain_cohort,'val_patient_ids_ad.txt'))
peds_train = read_file(os.path.join(dir_pretrain_cohort,'train_patient_ids_ped.txt'))
peds_val = read_file(os.path.join(dir_pretrain_cohort,'val_patient_ids_ped.txt'))

[len(x) for x in [adults_train, adults_val, peds_train, peds_val]]

[185370, 185370, 20858, 20858]

In [46]:
len(x)

185370

#### Final task N for each split

In [50]:
dir_cohort = '/local-scratch/nigam/projects/jlemmon/transfer_learning/experiments/data/cohort/cohort_split_no_nb.parquet'
df = pd.read_parquet(dir_cohort)
df = df.assign(age_in_days=(df['admit_date']-df['birth_datetime']).dt.days)

# each row = unique patient
assert(df.shape[0]==df['person_id'].nunique())

In [None]:
filters = {
    'adults':"age_in_years>=18",
    'peds':"age_in_years<18 and age_in_days>=28"
}

tasks = [
    'hospital_mortality', 'LOS_7', 'readmission_30', 'sepsis', 
    'aki_lab_aki3_label', 'hyperkalemia_lab_severe_label', 
    'hypoglycemia_lab_severe_label','hyponatremia_lab_severe_label',
    'anemia_lab_severe_label','neutropenia_lab_severe_label',
    'thrombocytopenia_lab_severe_label'
]

titles = {
    'hospital_mortality':'Hospital Mortality', 
    'sepsis':'Sepsis', 
    'LOS_7':'Long LOS', 
    'readmission_30':'30-day Readmission', 
    'aki_lab_aki1_label':'Acute Kidney Injury',
    'aki_lab_aki3_label':'Acute Kidney Injury',
    'hyperkalemia_lab_mild_label':'Hyperkalemia',
    'hyperkalemia_lab_severe_label':'Hyperkalemia',
    'hypoglycemia_lab_mild_label': 'Hypoglycemia',
    'hypoglycemia_lab_severe_label':'Hypoglycemia',
    'hyponatremia_lab_mild_label':'Hyponatremia',
    'hyponatremia_lab_severe_label':'Hyponatremia',
    'neutropenia_lab_mild_label':'Neutropenia',
    'neutropenia_lab_severe_label':'Neutropenia',
    'anemia_lab_mild_label':'Anemia',
    'anemia_lab_severe_label':'Anemia',
    'thrombocytopenia_lab_mild_label':'Thrombocytopenia',
    'thrombocytopenia_lab_severe_label':'Thrombocytopenia'
}

tasks_renamed = [titles[x] for x in tasks]

table = pd.DataFrame()
for group,f in filters.items():
    for task in tasks:
        fold_id = f"{task}_fold_id"
        
        table = pd.concat((
            table,
            (
                pd.DataFrame(df.query(f)[fold_id].value_counts())
                .reset_index()
                .rename(columns={'index':'Split',fold_id:'N_admissions'})
                .assign(Task=titles[task],Group=group)
                .replace({'Split':{'0':'Training','val':'Validation','test':'Test'}})
                .query("Split==['Training','Validation','Test']")
            )
        ))  
        
for group in ['adults','peds']:
    for split in ['Training','Validation','Test']:
        table.query("Group==@group and Split==@split")[['Task','N_admissions']].to_csv(f"tables/splits_{group}_{split}.csv",index=False)

In [120]:
filters = {
    'adults':"age_in_years>=18",
    'peds':"age_in_years<18 and age_in_days>=28"
}

tasks = [
    'hospital_mortality', 'LOS_7', 'readmission_30', 'sepsis', 
    'aki_lab_aki3_label', 'hyperkalemia_lab_severe_label', 
    'hypoglycemia_lab_severe_label','hyponatremia_lab_severe_label',
    'anemia_lab_severe_label','neutropenia_lab_severe_label',
    'thrombocytopenia_lab_severe_label'
]

titles = {
    'hospital_mortality':'Hospital Mortality', 
    'sepsis':'Sepsis', 
    'LOS_7':'Long LOS', 
    'readmission_30':'30-day Readmission', 
    'aki_lab_aki1_label':'Acute Kidney Injury',
    'aki_lab_aki3_label':'Acute Kidney Injury',
    'hyperkalemia_lab_mild_label':'Hyperkalemia',
    'hyperkalemia_lab_severe_label':'Hyperkalemia',
    'hypoglycemia_lab_mild_label': 'Hypoglycemia',
    'hypoglycemia_lab_severe_label':'Hypoglycemia',
    'hyponatremia_lab_mild_label':'Hyponatremia',
    'hyponatremia_lab_severe_label':'Hyponatremia',
    'neutropenia_lab_mild_label':'Neutropenia',
    'neutropenia_lab_severe_label':'Neutropenia',
    'anemia_lab_mild_label':'Anemia',
    'anemia_lab_severe_label':'Anemia',
    'thrombocytopenia_lab_mild_label':'Thrombocytopenia',
    'thrombocytopenia_lab_severe_label':'Thrombocytopenia'
}

tasks_renamed = [titles[x] for x in tasks]

table = pd.DataFrame()
for group,f in filters.items():
    
    n_death_discharge = (df.query(f)['hospital_mortality_fold_id']=='ignore').sum()
    
    for task in tasks:
        fold_id = f"{task}_fold_id"
        n_excluded = (df.query(f)[fold_id]=='ignore').sum()
        print(f"{task}:{n_excluded-n_death_discharge},}")
        
#for group in ['adults','peds']:
#    for split in ['Training','Validation','Test']:
#        table.query("Group==@group and Split==@split")[['Task','N_admissions']].to_csv(f"tables/splits_{group}_{split}.csv",index=False)

hospital_mortality:0
LOS_7:0
readmission_30:6382
sepsis:29691
aki_lab_aki3_label:3687
hyperkalemia_lab_severe_label:706
hypoglycemia_lab_severe_label:671
hyponatremia_lab_severe_label:2317
anemia_lab_severe_label:3036
neutropenia_lab_severe_label:445
thrombocytopenia_lab_severe_label:2920
hospital_mortality:0
LOS_7:0
readmission_30:255
sepsis:2231
aki_lab_aki3_label:174
hyperkalemia_lab_severe_label:108
hypoglycemia_lab_severe_label:127
hyponatremia_lab_severe_label:41
anemia_lab_severe_label:486
neutropenia_lab_severe_label:27
thrombocytopenia_lab_severe_label:221
