In [1]:
import pandas as pd
import numpy as np

import os
import argparse
import pickle
import joblib
import pdb
import warnings

In [9]:
cohort_path = "/local-scratch/nigam/projects/jlemmon/transfer_learning/experiments/data/"
tasks=['hospital_mortality','sepsis','LOS_7','readmission_30','hyperkalemia_lab_mild_label','hyperkalemia_lab_moderate_label','hyperkalemia_lab_severe_label','hyperkalemia_lab_abnormal_label','hypoglycemia_lab_mild_label','hypoglycemia_lab_moderate_label','hypoglycemia_lab_severe_label','hypoglycemia_lab_abnormal_label','neutropenia_lab_mild_label','neutropenia_lab_moderate_label','neutropenia_lab_severe_label','hyponatremia_lab_mild_label','hyponatremia_lab_moderate_label','hyponatremia_lab_severe_label','hyponatremia_lab_abnormal_label','aki_lab_aki1_label','aki_lab_aki2_label','aki_lab_aki3_label','aki_lab_abnormal_label','anemia_lab_mild_label','anemia_lab_moderate_label','anemia_lab_severe_label','anemia_lab_abnormal_label','thrombocytopenia_lab_mild_label','thrombocytopenia_lab_moderate_label','thrombocytopenia_lab_severe_label','thrombocytopenia_lab_abnormal_label']

In [3]:
def read_file(filename, columns=None, **kwargs):
    print(filename)
    load_extension = os.path.splitext(filename)[-1]
    if load_extension == ".parquet":
        return pd.read_parquet(filename, columns=columns,**kwargs)
    elif load_extension == ".csv":
        return pd.read_csv(filename, usecols=columns, **kwargs)

In [4]:
cohort = read_file(
    os.path.join(
        cohort_path,
        "cohort/cohort_split_no_nb.parquet"
    ),
    engine='pyarrow'
)

/local-scratch/nigam/projects/jlemmon/transfer_learning/experiments/data/cohort/cohort_split_no_nb.parquet


In [5]:
ad_cohort = cohort.query("adult_at_admission==1")
ped_cohort = cohort.query("adult_at_admission==0")

In [13]:
ped_cohort.query("admission_year>=2018")[["person_id", "fold_id"]].groupby(["fold_id"]).count()

Unnamed: 0,person_id,fold_id
51,29939372,0
68,29941136,val
71,29941317,test
78,29941910,0
504,30044852,val
...,...,...
232284,85698351,test
232285,85699500,test
232286,85700562,test
232287,86280289,test


In [15]:
ad_cohort.query("admission_year<=2019")[["person_id", "fold_id"]]#.groupby(["fold_id"]).count()

Unnamed: 0,person_id,fold_id
0,29936887,0
1,29936888,0
2,29936900,0
3,29936906,0
4,29936914,0
...,...,...
274191,43705036,0
274196,43705412,0
274216,43742869,0
274219,43743212,0


In [10]:
prev_df = pd.DataFrame()
for task in tasks:
    for fold in ["0", "val", "test"]:
        c_df = ad_cohort[["person_id",f"{task}", f"{task}_fold_id"]].query(f"{task}_fold_id==@fold")
        s = c_df.query(f"{task}==1")[task].sum()
        df = pd.DataFrame()
        df_dict = {"task":[task],
                   "fold":["train"] if fold == "0" else [fold],
                   "prevalence": [s/len(c_df)*100],
                   "total_pos": [s],
                   "total_pats": [len(c_df)]
                  }
        df = pd.DataFrame(df_dict)
        prev_df = pd.concat((prev_df,df))
prev_df = prev_df.reset_index(drop=True)
print(prev_df) 
        

                                   task   fold  prevalence  total_pos  \
0                    hospital_mortality  train    2.135311       3907   
1                    hospital_mortality    val    2.169691        458   
2                    hospital_mortality   test    2.009877        757   
3                                sepsis  train    2.927282       4734   
4                                sepsis    val    2.302214        417   
..                                  ...    ...         ...        ...   
88    thrombocytopenia_lab_severe_label    val    2.186002        456   
89    thrombocytopenia_lab_severe_label   test    2.219475        826   
90  thrombocytopenia_lab_abnormal_label  train   17.123963      27423   
91  thrombocytopenia_lab_abnormal_label    val   15.033719       2742   
92  thrombocytopenia_lab_abnormal_label   test   14.883495       4912   

    total_pats  
0       182971  
1        21109  
2        37664  
3       161720  
4        18113  
..         ...  
88  

In [41]:
age_prev_df = pd.DataFrame()
for task in tasks:
    for aaa in [0, 1]:
        for fold in ["0", "val", "test"]:
            c_df = cohort[["person_id",f"{task}", f"{task}_fold_id", "adult_at_admission"]].query(f"{task}_fold_id==@fold and adult_at_admission==@aaa")
            s = c_df.query(f"{task}==1")[task].sum()
            df = pd.DataFrame()
            df_dict = {"task":[task],
                       "age_group": ['pediatric'] if aaa==0 else ['adult'],
                       "fold":["train"] if fold == "0" else [fold],
                       "prevalence": [s/len(c_df)*100],
                       "total_pos": [s],
                       "total_pats": [len(c_df)]
                      }
            df = pd.DataFrame(df_dict)
            age_prev_df = pd.concat((age_prev_df,df))
age_prev_df = age_prev_df.reset_index(drop=True)
print(age_prev_df) 
        

                  task  age_group   fold  prevalence  total_pos  total_pats
0   hospital_mortality  pediatric  train    0.794933        379       47677
1   hospital_mortality  pediatric    val    0.677442         57        8414
2   hospital_mortality  pediatric   test    0.652540         97       14865
3   hospital_mortality      adult  train    2.151511       3733      173506
4   hospital_mortality      adult    val    2.067116        632       30574
5   hospital_mortality      adult   test    2.009877        757       37664
6               sepsis  pediatric  train    1.564386        700       44746
7               sepsis  pediatric    val    1.395408        110        7883
8               sepsis  pediatric   test    1.402962        198       14113
9               sepsis      adult  train    2.839274       4340      152856
10              sepsis      adult    val    3.006265        811       26977
11              sepsis      adult   test    2.147734        692       32220
12          

In [44]:
cohort.columns

Index(['person_id', 'admit_date', 'discharge_date', 'admit_date_midnight',
       'discharge_date_midnight', 'hospital_mortality', 'death_date',
       'month_mortality', 'LOS_days', 'LOS_7', 'readmission_30',
       'readmission_window', 'icu_admission', 'icu_start_datetime',
       'age_in_years', 'age_group', 'aki_base_creatinine',
       'aki_max_creatinine', 'aki1_creatinine', 'aki1_creatinine_time',
       'aki1_label', 'aki2_creatinine', 'aki2_creatinine_time', 'aki2_label',
       'hg_min_glucose', 'hg_glucose', 'hg_glucose_time', 'hg_label',
       'np_min_neutrophils', 'np_500_neutrophils', 'np_500_neutrophils_time',
       'np_500_label', 'np_1000_neutrophils', 'np_1000_neutrophils_time',
       'np_1000_label', 'race_eth', 'gender_concept_name', 'race_eth_raw',
       'race_eth_gender', 'race_eth_age_group', 'race_eth_gender_age_group',
       'race_eth_raw_gender', 'race_eth_raw_age_group',
       'race_eth_raw_gender_age_group', 'prediction_id', 'fold_id',
       'adult_a

In [49]:
cohort[['person_id', 'gender_concept_name']].groupby(['gender_concept_name']).count()

Unnamed: 0_level_0,person_id
gender_concept_name,Unnamed: 1_level_1
FEMALE,173885
MALE,142295
No matching concept,17


In [50]:
cohort[['person_id', 'age_group']].groupby(['age_group']).count()

Unnamed: 0_level_0,person_id
age_group,Unnamed: 1_level_1
<18,71386
[18-30),28054
[30-45),56152
[45-55),32216
[55-65),42577
[65-75),44639
[75-91),41173


In [51]:
cohort[['person_id', 'race_eth_raw']].groupby(['race_eth_raw']).count()

Unnamed: 0_level_0,person_id
race_eth_raw,Unnamed: 1_level_1
American Indian or Alaska Native,768
Asian,56660
Black or African American,11085
Hispanic or Latino,66849
Native Hawaiian or Other Pacific Islander,3843
Other,28536
White,148456


In [None]:
# Check list of newborn ids and crosscheck with clmbr train ids