## TODO:
1. write each label set to temp table, and then combine into target table

#### Labeling 

Here, we will call a pre-defined set of labeling functions to assign binary labels to a pre-defined cohort. A cohort is a table in which rows correspond to unique combinations of `person_id`, `window_start_field` (e.g., admit_date), and `window_end_field` (e.g., discharge_date).

In addition to the labels, the labeling function obtains relevant information such as "death date" for inhospital mortality, and the abnormal observation for lab-based labeling functions. Finally, the labeler assigns a `row_id` (e.g., prediction_id) to each row in the final table.

In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import pandas as pd
from datasets.labelers import Labeler

#### Instantiate Labeler

In [3]:
labeler = Labeler()



#### Configure Labeler
- `google_application_credentials`: location of the json file that stores the gcloud auth credentials. Default is "~/.config/gcloud/application_default_credential.json", the default location after auth setup using the command `gcloud auth application-default login`
- `glcoud_project`: gcloud project [default "som-nero-nigam-starr"]
- `dataset_project`: project in which OMOP CDM dataset is stored [default "som-nero-nigam-starr"]
- `rs_dataset_project`: project in which cohort table is stored and to which the label table will be written [default "som-nero-nigam-starr"]
- `dataset`: name of the OMOP CDM dataset [default "starr_omop_cdm5_deid_20210723"]
- `rs_dataset`: name of the dataset in which cohort table is stored and to which the label table will be written
- `cohort_name`: name of the cohort
- `target_table_name`: name of the label table to be created
- `row_id`: name of the unique identifier that is created at the end to uniquely identify each row in the final table [default "prediction id"]
- `window_start_field`: the field in the cohort table that specifies the start of the time window [default: "admit_date"]
- `window_end_field`: the field in the cohort table that specifies the end of the time window [default: "discharge_date"]


In [4]:
labeler.configure(
    rs_dataset="lguo_explore",
    cohort_name="test_refactor_admissions_rollup_filtered",
    target_table_name="test_refactor_admissions_rollup_filtered_labeled",
)

In [5]:
labeler.config

{'google_application_credentials': '/home/guolin1/.config/gcloud/application_default_credentials.json',
 'gcloud_project': 'som-nero-nigam-starr',
 'dataset_project': 'som-nero-nigam-starr',
 'rs_dataset_project': 'som-nero-nigam-starr',
 'dataset': 'starr_omop_cdm5_deid_20210723',
 'rs_dataset': 'lguo_explore',
 'cohort_name': 'test_refactor_admissions_rollup_filtered',
 'target_table_name': 'test_refactor_admissions_rollup_filtered_labeled',
 'row_id': 'prediction_id',
 'window_start_field': 'admit_date',
 'window_end_field': 'discharge_date',
 'temp_dataset': 'temp'}

#### Take a look at available labelers

In [6]:
labeler.list_queries()

{'age': 'Age group labels according to 1)pediatric age group and 2)intervals',
 'sex': 'OMOP standard concepts for sex',
 'race': 'OMOP standard concepts for race',
 'mortality': '1 if death occured within the specified time window, 0 otherwise.',
 'los_7': '1 if length of the specified time window is at least 7 days, 0 otherwise',
 'icu_admission': '1 if admitted to ICU during specified time window, 0 otherwise',
 'readmission_30': '1 if readmission occurred within 30 days from the end of the specified time window, 0 otherwise',
 'hyperkalemia_lab': 'lab-based definition for hyperkalemia using blood potassium concentration (mmol/L). Thresholds: mild(>5.5),moderate(>6),severe(>7), and abnormal range.',
 'hypoglycemia_lab': 'lab-based definition for hypoglycemia using blood glucose concentration (mmol/L). Thresholds: mild(<3), moderate(<3.5), severe(<=3.9), and abnormal range.',
 'neutropenia_lab': 'lab-based definition for neutropenia based on neutrophils count (thousands/uL). Threshol

#### Using the labeling functions to obtain labels for each patient in the cohort
- `create_label_table()` by default uses all labeling functions unless specified otherwise using `labeler_ids` or `exclude_labeler_ids`. 

In [7]:
# obtain all labels except for readmission
labeler.create_label_table(exclude_labeler_ids=['readmission_30'])

df = pd.read_gbq(
    "select * from `som-nero-nigam-starr.lguo_explore.test_refactor_admissions_rollup_filtered_labeled`",
    use_bqstorage_api=True
)

df.head(5)



Unnamed: 0,person_id,admit_date,discharge_date,age_days,pediatric_age_group,age_group,sex,race,mortality_label,death_date,...,anemia_dx_label,anemia_dx_start_datetime,hyperkalemia_dx_label,hyperkalemia_dx_start_datetime,hyponatremia_dx_label,hyponatremia_dx_start_datetime,thrombocytopenia_dx_label,thrombocytopenia_dx_start_datetime,neutropenia_dx_label,neutropenia_dx_start_datetime
0,30088240,2016-12-19 23:15:00,2016-12-22 19:05:00,1898,middle childhood,"[0,18)",MALE,Other,0,NaT,...,0,NaT,0,NaT,0,NaT,0,NaT,0,NaT
1,42129119,2019-10-16 23:16:00,2019-10-18 14:45:00,399,toddler,"[0,18)",MALE,Other,0,NaT,...,0,NaT,0,NaT,0,NaT,0,NaT,0,NaT
2,30196618,2015-02-26 19:00:00,2015-02-28 16:35:00,1,term neonatal,"[0,18)",FEMALE,Other,0,NaT,...,0,NaT,0,NaT,0,NaT,0,NaT,0,NaT
3,32177649,2019-01-09 15:08:00,2019-01-12 16:30:00,26212,non-pediatric,"[70,80)",MALE,White,0,NaT,...,0,NaT,0,NaT,0,NaT,0,NaT,0,NaT
4,44748669,2020-06-01 13:09:00,2020-06-10 15:43:00,0,term neonatal,"[0,18)",FEMALE,Other,0,NaT,...,0,NaT,0,NaT,0,NaT,0,NaT,0,NaT


In [8]:
# write table to disk
# df.to_parquet("path_to_parquet")

#### Let's take a look at outcome prevalence
- note that labelers often obtain additional information, for example mortality labeler obtains death date and lab-based labelers obtain the observed abnormal lab measurement value, the time of the observation, and the min/max of the measurement. 

In [9]:
outcomes = [
    col for col in df.columns 
    if 'label' in col
]

table = (
    df[outcomes]
    .sum()
    .reset_index()
    .rename(columns={'index':'Outcome',0:'No. Positive'})
)
table = table.assign(**{
    'Perc. Positive': (table['No. Positive']/df.shape[0]*100).round(1)
})

In [10]:
table

Unnamed: 0,Outcome,No. Positive,Perc. Positive
0,mortality_label,5434,1.9
1,los_7_label,52143,18.2
2,icu_admission_label,44716,15.6
3,hyperkalemia_lab_mild_label,17973,6.3
4,hyperkalemia_lab_moderate_label,10404,3.6
5,hyperkalemia_lab_severe_label,3987,1.4
6,hyperkalemia_lab_abnormal_range_label,16583,5.8
7,hypoglycemia_lab_mild_label,32475,11.4
8,hypoglycemia_lab_moderate_label,23683,8.3
9,hypoglycemia_lab_severe_label,16287,5.7


#### Let's take a look at how often each lab component was measured as well as the distributions

In [11]:
components = [x for x in df.columns if 'min' in x or 'max' in x]

for comp in components:
    print(f"\
    {comp.split('_')[-1][:9]}: \t\
    {((~df[comp].isnull()).sum()/df.shape[0]*100).round(1)}%; \t\
    Median = {round(df[comp].median(),1)} \
    [{round(df[comp].quantile(0.025),1)} - {round(df[comp].quantile(0.975),1)}]\
    ")

    potassium: 	    67.6%; 	    Median = 4.3     [3.5 - 6.8]    
    glucose: 	    74.6%; 	    Median = 5.2     [2.1 - 8.6]    
    neutrophi: 	    69.3%; 	    Median = 5.8     [0.8 - 15.0]    
    sodium: 	    67.6%; 	    Median = 136.0     [124.0 - 142.0]    
    creatinin: 	    68.7%; 	    Median = 79.6     [31.8 - 433.2]    
    hgb: 	    76.9%; 	    Median = 107.0     [61.0 - 156.0]    
    platelet: 	    75.6%; 	    Median = 185.0     [29.0 - 392.0]    
