In [None]:
import polars as pl
import pandas as pd
import datetime
from datetime import timedelta
import random
import copy
import numpy as np
import time
import traceback
import shap
from xgbse import (
    XGBSEDebiasedBCE,
#     XGBSEStackedWeibull,
#     XGBSEKaplanNeighbors,
#     XGBSEKaplanTree
)
from xgbse.metrics import concordance_index
from xgbse.converters import (
    convert_data_to_xgb_format,
    convert_to_structured
)

from ai_knowledge.ai_knowledge.base.convert import *
import yaml
from sklearn.model_selection import train_test_split

# from PA_GE_test.rule_execution import rules_check
from sklearn.inspection import PartialDependenceDisplay
import matplotlib.pyplot as plt

today_date = datetime.date.today()
clip_year = 2019

In [None]:
import os
output_dir = f'data_backup/output_feature_nadir_multiprogression_2024-05-02/'
# output_dir = f'data_backup/output_feature_nadir_multiprogression_{today_date}/'
output_dir_infer  = f'{output_dir}inference/'

if 'inference' not in os.listdir(output_dir):
    os.makedirs(output_dir_infer)
text_file = open(output_dir_infer+'output_infer.txt','w')

In [None]:
with open('./ai_knowledge/ai_knowledge/base/keys_labs.yaml','r') as file:
    lab_kb = yaml.safe_load(file)
    
lab_kb = {c['key']:c for c in lab_kb}

with open('./ai_knowledge/ai_knowledge/base/keys_vitals.yaml','r') as file:
    vital_kb = yaml.safe_load(file)
    
vital_kb = {c['key']:c for c in vital_kb}

with open('./ai_knowledge/ai_knowledge/base/units.yaml','r') as file:
    unite_conv_map = yaml.safe_load(file)
    
unite_conv_map = {c['base']:c for c in unite_conv_map}

In [None]:
def get_stats(feature_name, values):
     return pl.DataFrame({'feature':feature_name,'min':values.min(),\
                  'max':values.max(),'mean':values.mean(),'median':values.median(),\
                  'mode':values.mode(),'std':values.std()})
    
stat_df_list = []

In [None]:
def warn(*args, **kwargs):
    pass
import warnings
warnings.warn = warn

In [None]:
start = time.time()

In [None]:
#DB Connection Details
host = "c0-eurekarwd-web.crdektqhzze9.us-east-1.redshift.amazonaws.com"
port = "5439"
dbname = "c0_eurekarwd"
user = ""
password = ""

In [None]:
# schema = 'dts_cdm_master'
schema = 'dts_cdm_master_20240501s'

In [None]:
db_conn_uri = "redshift://{user}:{pwd}@{server}:{port}/{db}".format(user=user.replace('@','%40'), pwd=password.replace('@','%40'), server=host, port=port, db=dbname)
chai_patid_query = "select chai_patient_id,diagnosis_date,diagnosis_code_standard_code,diagnosis_code_standard_name from {schema}.condition where (lower(diagnosis_code_standard_name) like '%multiple%myeloma%' or lower(diagnosis_code_standard_code) like 'C90.0%') and extract(year from diagnosis_date) > 2019".format(schema=schema)
pat_id_df = pl.read_database(chai_patid_query,db_conn_uri)

In [None]:
# pat_id_df = pat_id_df.with_columns(pl.col('diagnosis_date').cast(pl.Date))
# rules_check(pat_id_df,'condition','qcca_cdm_master','condition_rule_checks.xlsx')

In [None]:
pat_id_df_updated = pat_id_df.select(['chai_patient_id','diagnosis_date'])\
            .sort(by =['chai_patient_id','diagnosis_date'],descending=False)\
            .unique(subset=['chai_patient_id'],keep='first')\
            .drop_nulls()

In [None]:
pat_id_df_updated.shape

In [None]:
pat_ids = pat_id_df_updated.select(['chai_patient_id']).unique().to_series().to_list()
pat_ids = ','.join(["'"+str(x)+"'" for x in pat_ids])

In [None]:
lab_dict = {'m_protein_in_serum':['33358-3','51435-6','35559-4','94400-9','50796-2'],
            'm_protein_in_urine':['42482-0','35560-2'],
            'ca':['17861-6','49765-1'],
            'serum_free_light':['36916-5','33944-0'],
            'hemoglobin_in_blood':['718-7','20509-6','30313-1','48725-6'],
            'neutrophils_count':['751-8','26499-4','768-2','30451-9','753-4'],
            'lymphocytes_count':['26474-7','732-8','731-0'],
            'platelets':['777-3','26515-7','53800-9','49497-1','778-1'],
            'na':['2951-2','2955-3'],
            'mg':['21377-7','19123-9'],
            'cl':['2075-0'],
            'phos' : ['2777-1'],
            'hr' : ['8867-4'],
            'dbp' : ['8462-4'],
            'ecog' : ['89262-0'],
            'k' : ['2823-3','2828-2']
           }

In [None]:
## For OMOP
# patient_test_df = pl.read_csv('patient_test.csv')
# patient_test_df = patient_test_df.with_columns(pl.col('measurement_date').str.to_datetime("%Y-%m-%d").cast(pl.Date))
# patient_test_df = patient_test_df.rename({'person_id':'chai_patient_id','measurement_date':'test_date','value_as_number':'test_value_numeric',
#                                           'measurement_unit_source_name':'test_unit_standard_name','concept_code':'test_name_standard_code',
#                                           'concept_name':'test_name_standard_name'})

In [None]:
pat_serum_test_query = "select chai_patient_id,source_labid,test_date,test_name_standard_code,test_name_standard_name,test_value_numeric,test_unit_standard_name from {schema}.patient_test where (chai_patient_id in ({pat_ids})) and (test_name_standard_code in ('33358-3','51435-6','35559-4','94400-9','50796-2')) and (extract(year from test_date) > {clip_year})".format(pat_ids=pat_ids,schema=schema,clip_year=clip_year)
pat_serum_test_df = pl.read_database(pat_serum_test_query,db_conn_uri)

In [None]:
##For OMOP
# pat_serum_test_df = pd.read_csv('data_backup/serum_patient_test_202312141244.csv')
# pat_serum_test_df = pl.from_pandas(pat_serum_test_df)

In [None]:
# pat_serum_test_df = pat_serum_test_df.filter(pl.col('chai_patient_id').apply(lambda x:x in pat_ids))
pat_serum_test_df = pat_serum_test_df.join(pat_id_df_updated,on='chai_patient_id',how='inner')
# pat_serum_test_df = pat_serum_test_df.with_columns(pl.col('test_date').apply(lambda x:datetime.datetime.strptime(x,'%Y-%m-%d %H:%M:%S').date()))
pat_serum_test_df = pat_serum_test_df.filter(pl.col('test_date')>=pl.col('diagnosis_date'))
pat_serum_test_df.shape

In [None]:
pat_urine_test_query = "select chai_patient_id,source_labid,test_date,test_name_standard_code,test_name_standard_name,test_value_numeric,test_unit_standard_name from {schema}.patient_test where (chai_patient_id in ({pat_ids})) and (test_name_standard_code in ('42482-0','35560-2')) and (extract(year from test_date) > {clip_year})".format(pat_ids=pat_ids,schema=schema,clip_year=clip_year)

pat_urine_test_df = pl.read_database(pat_urine_test_query,db_conn_uri)

In [None]:
##For OMOP
# pat_urine_test_df = pd.read_csv('data_backup/urine_patient_test_202312141245.csv')
# pat_urine_test_df = pl.from_pandas(pat_urine_test_df)
# pat_urine_test_df = pat_urine_test_df.filter(pl.col('chai_patient_id').apply(lambda x:x in pat_ids))

In [None]:
pat_urine_test_df = pat_urine_test_df.join(pat_id_df_updated,on='chai_patient_id',how='inner')
# pat_urine_test_df = pat_urine_test_df.with_columns(pl.col('test_date').apply(lambda x:datetime.datetime.strptime(x,'%Y-%m-%d %H:%M:%S').date()))
pat_urine_test_df = pat_urine_test_df.filter(pl.col('test_date')>=pl.col('diagnosis_date'))
pat_urine_test_df.shape

In [None]:
pat_serum_test_df_cast = pat_serum_test_df.filter(
        pl.any_horizontal(
            pl.col('test_value_numeric').is_not_null() & pl.col('test_unit_standard_name').is_not_null() 
        ))\
        .with_columns(pl.col('test_unit_standard_name').apply(lambda x:x.lower()))

In [None]:
pat_urine_test_df_cast = pat_urine_test_df.filter(
        pl.any_horizontal(
            pl.col('test_value_numeric').is_not_null() & pl.col('test_unit_standard_name').is_not_null() 
        ))\
        .with_columns(pl.col('test_unit_standard_name').apply(lambda x:x.lower()))

In [None]:
pat_urine_test_df_cast.shape

In [None]:
pat_serum_test_df_cast = pat_serum_test_df_cast.with_columns(
    pl.when(pl.col('test_name_standard_code').is_in(lab_dict['m_protein_in_serum']))
    .then(pl.lit('m_protein_in_serum'))
    .alias('key')
)

pat_urine_test_df_cast = pat_urine_test_df_cast.with_columns(
    pl.when(pl.col('test_name_standard_code').is_in(lab_dict['m_protein_in_urine']))
    .then(pl.lit('m_protein_in_urine'))
    .alias('key')
)

In [None]:
print(pat_urine_test_df_cast.shape)
print(pat_serum_test_df_cast.shape)

In [None]:
def convert_unit(df):
    lab_test = df['key'].unique().drop_nulls()
    for key in lab_test:
        if key not in lab_kb.keys():
            raise KeyError(f'Lab test {key} not registered')
        std_unit = lab_kb[key]['attributes']['units']
        unit_keys = list(unite_conv_map[std_unit]['convert'].keys())
        df = df.with_columns(
                pl.struct(['test_value_numeric','test_unit_standard_name','key'])
                          .apply(lambda x: eval(unite_conv_map[std_unit]['convert'][x['test_unit_standard_name']].split('.')[-1])(x['test_value_numeric'])\
                                 if (x['key']==key) & (x['test_unit_standard_name'] in (unit_keys))\
                                 else x['test_value_numeric']                                                      
                                ).alias('test_value_numeric')
        )
        df = df.with_columns(
                        pl.when((pl.col('key')==key) & ((pl.col('test_unit_standard_name')==std_unit) | (pl.col('test_unit_standard_name').is_in(unite_conv_map[std_unit]['convert'].keys()))))
                .then(pl.lit(std_unit))
                .when(pl.col('key')==key)
                .then(None)
                .otherwise(pl.col('test_unit_standard_name'))
                .alias('test_unit_standard_name'))\
        .drop_nulls('test_unit_standard_name')
    return df

In [None]:
##OMOP
# pat_serum_test_df_cast = patient_test_df.filter(pl.col('key')=='m_protein_in_serum')
# pat_urine_test_df_cast = patient_test_df.filter(pl.col('key')=='m_protein_in_urine')

In [None]:
pat_serum_test_df_cast = pat_serum_test_df_cast.select(['chai_patient_id','test_date','test_value_numeric','test_name_standard_code','test_unit_standard_name'])\
    .with_columns(pl.col('test_date').apply(lambda x:x.date()))

In [None]:
pat_urine_test_df_cast = pat_urine_test_df_cast.select(['chai_patient_id','test_date','test_value_numeric','test_name_standard_code','test_unit_standard_name'])\
    .with_columns(pl.col('test_date').apply(lambda x:x.date()))

In [None]:
print(pat_urine_test_df_cast.shape)
print(pat_serum_test_df_cast.shape)

## Serum Free light chain

In [None]:
pat_free_light_query = "select chai_patient_id,source_labid,test_date,test_name_standard_code,test_name_standard_name,test_value_numeric,test_unit_standard_name from {schema}.patient_test where (chai_patient_id in ({pat_ids})) and (test_name_standard_code in ('36916-5','33944-0')) and (extract(year from test_date) > {clip_year})".format(pat_ids=pat_ids,schema=schema,clip_year=clip_year)

pat_free_light_df = pl.read_database(pat_free_light_query,db_conn_uri)

In [None]:
##OMOP
# pat_free_light_df = pd.read_csv('data_backup/FLC_patient_test_202312141247.csv')
# pat_free_light_df = pl.from_pandas(pat_free_light_df)
# pat_serum_test_df = pat_serum_test_df.filter(pl.col('chai_patient_id').apply(lambda x:x in pat_ids))

In [None]:
pat_free_light_df = pat_free_light_df.join(pat_id_df_updated,on='chai_patient_id',how='inner')
# pat_free_light_df = pat_free_light_df.with_columns(pl.col('test_date').apply(lambda x:datetime.datetime.strptime(x,'%Y-%m-%d %H:%M:%S').date()))
pat_free_light_df = pat_free_light_df.filter(pl.col('test_date')>=pl.col('diagnosis_date'))
pat_free_light_df.shape

In [None]:
pat_free_light_df.shape

In [None]:
# pat_free_light_df = patient_test_df.filter(pl.col('key')=='serum_free_light')

In [None]:
pat_free_light_df = pat_free_light_df.select(['chai_patient_id','test_date','test_value_numeric','test_name_standard_code','test_unit_standard_name'])\
    .with_columns(pl.col('test_date').apply(lambda x:x.date()))

In [None]:
pat_free_light_df.shape

In [None]:
lab_test_df = pl.concat([pat_urine_test_df_cast,pat_serum_test_df_cast,pat_free_light_df])

In [None]:
print(pat_urine_test_df_cast.shape)
print(pat_serum_test_df_cast.shape)
print(pat_free_light_df.shape)

In [None]:
lab_test_df = lab_test_df.with_columns(
    pl.when(pl.col('test_name_standard_code').is_in(lab_dict['m_protein_in_serum']))
    .then(pl.lit('m_protein_in_serum'))
    .when(pl.col('test_name_standard_code').is_in(lab_dict['m_protein_in_urine']))
    .then(pl.lit('m_protein_in_urine'))
    .when(pl.col('test_name_standard_code').is_in(lab_dict['ca']))
    .then(pl.lit('ca'))
    .when(pl.col('test_name_standard_code').is_in(lab_dict['serum_free_light']))
    .then(pl.lit('serum_free_light'))
    .when(pl.col('test_name_standard_code').is_in(lab_dict['hemoglobin_in_blood']))
    .then(pl.lit('hemoglobin_in_blood'))
    .when(pl.col('test_name_standard_code').is_in(lab_dict['neutrophils_count']))
    .then(pl.lit('neutrophils_count'))
    .when(pl.col('test_name_standard_code').is_in(lab_dict['lymphocytes_count']))
    .then(pl.lit('lymphocytes_count'))
    .when(pl.col('test_name_standard_code').is_in(lab_dict['platelets']))
    .then(pl.lit('platelets'))
    .when(pl.col('test_name_standard_code').is_in(lab_dict['na']))
    .then(pl.lit('na'))
    .when(pl.col('test_name_standard_code').is_in(lab_dict['mg']))
    .then(pl.lit('mg'))
    .when(pl.col('test_name_standard_code').is_in(lab_dict['cl']))
    .then(pl.lit('cl'))
    .when(pl.col('test_name_standard_code').is_in(lab_dict['phos']))
    .then(pl.lit('phos'))
    .when(pl.col('test_name_standard_code').is_in(lab_dict['k']))
    .then(pl.lit('k'))
    .alias('key'),
    pl.col('test_unit_standard_name').apply(lambda x:x.lower()),
    pl.col('test_value_numeric').str.replace('<','').str.replace('>','').str.replace(',','.')
    .cast(pl.Float64, strict=False)
).drop_nulls(subset='test_value_numeric')\
.with_columns(
    pl.when((pl.col('test_name_standard_code').is_in(lab_dict['hr']+lab_dict['dbp']+lab_dict['ecog'])) & (pl.col('test_unit_standard_name').is_null()))
    .then(pl.lit('valid'))
    .otherwise(pl.col('test_unit_standard_name'))
    .alias('test_unit_standard_name')
)
lab_test_df.shape

In [None]:
# lab_test_df = lab_test_df.with_columns(pl.col('test_date').cast(pl.Datetime))
# rules_check(lab_test_df,'patient_test_raw','qcca_cdm_master','patient_test_raw_rule_checks.xlsx')

In [None]:
lab_test_df = convert_unit(lab_test_df)
lab_test_df.shape

In [None]:
# lab_test_df = lab_test_df.sort(['chai_patient_id','test_date','test_value_numeric'],descending=False)\
#     .unique(subset=['chai_patient_id','test_date','test_name_standard_code'],keep='first')

In [None]:
# lab_test_df = lab_test_df.with_columns(pl.col('test_date').cast(pl.Datetime))
# rules_check(lab_test_df,'patient_test_std','qcca_cdm_master','patient_test_std_rule_checks.xlsx')

In [None]:
# lab_test_df.shape

In [None]:
for ele in lab_test_df['test_name_standard_code'].unique():
    ele_df = lab_test_df.filter(pl.col('test_name_standard_code')==ele)
    stat_df_list.append(get_stats(ele,ele_df['test_value_numeric']))

In [None]:
lab_dict.update({'serum_free_light_kappa':['36916-5'],'serum_free_light_lambda':['33944-0']})
lab_class_map =  {0:'m_protein_in_serum',1:'m_protein_in_urine',2:'ca',3:'serum_free_light_kappa',4:'serum_free_light_lambda'}

# lab_test_df = lab_test_df.with_columns(
#                                     pl.when(pl.col('test_name_standard_code').is_in(lab_dict['m_protein_in_serum']))
#                                     .then(pl.lit(0))
#                                     .when(pl.col('test_name_standard_code').is_in(lab_dict['m_protein_in_urine']))
#                                     .then(pl.lit(1))
#                                     .when(pl.col('test_name_standard_code').is_in(lab_dict['ca']))
#                                     .then(pl.lit(2))
#                                     .when(pl.col('test_name_standard_code').is_in(lab_dict['serum_free_light_kappa']))
#                                     .then(pl.lit(3))
#                                     .when(pl.col('test_name_standard_code').is_in(lab_dict['serum_free_light_lambda']))
#                                     .then(pl.lit(4))
#                                     .alias('test_name_standard_code')
# )

In [None]:
lab_test_df = lab_test_df.groupby(by=['chai_patient_id','test_date','test_name_standard_code']).agg(pl.col('test_value_numeric').median())\
                         .select(['chai_patient_id','test_date','test_value_numeric','test_name_standard_code'])

In [None]:
lab_test_df['chai_patient_id'].unique().shape

## Criteria

In [None]:
output_dir

In [None]:
criteria_df = pl.read_csv(f'{output_dir}criteria_multiprogression_label_2024-04-30.csv')
# criteria_df = pl.read_csv(f'{output_dir}criteria_multiprogression_label_{today_date}.csv')

criteria_df = criteria_df.with_columns(pl.col('min_date').apply(lambda x:datetime.datetime.strptime(x,'%Y-%m-%d').date()))

In [None]:
criteria_df.head()

In [None]:
c1_df = copy.deepcopy(criteria_df)
criteria_df = criteria_df.drop('criteria_1')

In [None]:
criteria_df.shape

In [None]:
criteria_df['chai_patient_id'].unique().shape

In [None]:
def download_dod(source_schema, patients_ids):
    """
    :param source_schema: Source schema name
    :param patients_ids: Patient IDs for whom max date is required
    :return: A DataFrame with patient ID and max date
    """
    dod_query = '''Select chai_patient_id, date_of_death from {}.patient where chai_patient_id in ({}) 
        and (date_of_death is not null and date_part(year, date_of_death) >= 1950)'''.format(
            source_schema, ', '.join(["'" + str(pat) + "'" for pat in patients_ids]))
    dod_df = pl.read_database(dod_query,db_conn_uri)
    
#     dod_df["date_of_death"] = pd.to_datetime(dod_df["date_of_death"], errors='coerce')
    return dod_df

def download_last_activity_date_dod(source_schema, patients_ids):
    """
    This function returns max date information

    :param dtu: DTU object for downloading tables from redshift
    :param source_schema: Source schema name
    :param patients_ids: Patient IDs for whom max date is required
    :return: A DataFrame with patient ID and max date
    """
#     dod_df = download_dod(source_schema, patients_ids)
    
    date_fields = {'medication': 'med_start_date', 'patient_test': 'test_date', 'tumor_exam': 'exam_date',
                   'care_goal': 'treatment_plan_start_date', 'surgery': 'surgery_date', 'radiation': 'rad_start_date',
                   'condition': 'diagnosis_date', 'adverse_event': 'adverse_event_date', 'encounter': 'encounter_date',
                   'disease_status': 'assessment_date', 'staging': 'stage_date'}
#     date_fields = {'medication': 'latest_date_medication__202312181034', 'patient_test': 'latest_date_patient_test_202312181038', 'tumor_exam': 'latest_date_tumor_exam_202312181040',
#                    'surgery': 'latest_date_surgery_202312181046', 'radiation': 'latest_date_radiation_202312181051',
#                    'condition': 'latest_date_condition_202312181053', 'adverse_event': 'latest_date_adverse_event_202312181059', 'encounter': 'latest_date_encounter_202312181101',
#                    'disease_status': 'latest_date_disease_status_202312181200', 'staging': 'latest_date_staging_202312181201'}
    
    visits = pl.DataFrame()
    for key, value in date_fields.items():
        last_date_query = '''Select distinct chai_patient_id, max({column_name}) as max_date
                                                       from {source_schema}.{table_name} where {column_name} is not null and 
                                                       chai_patient_id in ({pats_list}) group by chai_patient_id '''.format(
            column_name=value,
            table_name=key,
            source_schema=source_schema,
            pats_list=', '.join(["'" + str(pat) + "'" for pat in patients_ids]))
        last_date_df = pl.read_database(last_date_query,db_conn_uri)
#         last_date_df = pl.read_csv(f'data_backup/{value}.csv')
        
        visits = pl.concat([visits,last_date_df])
   
    last_medical_activity = visits.sort(['chai_patient_id', 'max_date'], descending=True)\
                                        .unique(['chai_patient_id']).rename({"max_date":"last_activity_date"})#.join(dod_df, on="chai_patient_id", how="left")
#     last_medical_activity = last_medical_activity.with_columns(pl.min_horizontal('date_of_death','max_date')\
#                                                                .alias('last_activity_date')).drop(['date_of_death', 'max_date'])

    last_medical_activity = last_medical_activity.with_columns(pl.col('last_activity_date').apply(lambda x:x.date()))
#     last_medical_activity = last_medical_activity.with_columns(pl.col('last_activity_date').str.to_datetime("%Y-%m-%d %H:%M:%S").cast(pl.Date))
    return last_medical_activity

In [None]:
last_date_df = download_last_activity_date_dod(schema,set(list(lab_test_df['chai_patient_id'])))
last_date_df = last_date_df.with_columns(pl.col('last_activity_date').apply(lambda x:datetime.date.today() if (x>datetime.date.today()) else x))

In [None]:
last_date_df = last_date_df.filter(pl.col('chai_patient_id').is_in(set(list(lab_test_df['chai_patient_id']))))

In [None]:
last_date_df['chai_patient_id'].unique().shape

In [None]:
last_date_df.shape

In [None]:
pat_id_df_updated = pat_id_df.select(['chai_patient_id','diagnosis_date'])\
                            .sort(by =['chai_patient_id','diagnosis_date'],descending=False)\
                            .unique(subset=['chai_patient_id'],keep='first').drop_nulls()
pat_id_df_updated = pat_id_df_updated.with_columns(pl.col('diagnosis_date').apply(lambda x:x.date()))
pat_id_df_updated = pat_id_df_updated.filter(pl.col('chai_patient_id').is_in(set(list(lab_test_df['chai_patient_id']))))

In [None]:
criteria_df = criteria_df.filter(pl.col('chai_patient_id').is_in(lab_test_df['chai_patient_id']))

In [None]:
criteria_df_merged = criteria_df.join(pat_id_df_updated,on='chai_patient_id',how='outer').join(last_date_df,on='chai_patient_id',how='outer')
criteria_df_merged.shape

In [None]:
criteria_df_merged = criteria_df_merged.filter(pl.col('diagnosis_date')<pl.col('last_activity_date'))

In [None]:
criteria_df_merged = criteria_df_merged.filter(((pl.col('last_activity_date')>pl.col('min_date')) & (pl.col('diagnosis_date')<pl.col('min_date'))) | pl.col('min_date').is_null())

In [None]:
criteria_df = pl.concat([criteria_df_merged[['chai_patient_id','min_date']].drop_nulls(),criteria_df_merged[['chai_patient_id','diagnosis_date']].rename({'diagnosis_date':'min_date'})])
criteria_df = pl.concat([criteria_df,criteria_df_merged[['chai_patient_id','last_activity_date']].rename({'last_activity_date':'min_date'})])
criteria_df = criteria_df.unique()
criteria_df.shape

In [None]:
criteria_df.unique().shape

In [None]:
def get_time_windows(date_list):
    start_date=[]
    end_date=[]
    final_selection = []
    pc = 0
    progression_count = []
    prg_cnt_2_years = []
    for i in range(len(date_list)-1):  
        start_date.append(date_list[i])
        end_date.append(date_list[i+1])
        final_selection.append(True)
        pc_2_year=0
        for j in range(1,i):
            if (date_list[j]<=date_list[i+1]) and (date_list[j]>=(date_list[i+1]-timedelta(days=730))):
                pc_2_year+=1
                
        prg_cnt_2_years.append(pc_2_year)
        progression_count.append(pc)
        pc+=1
    final_selection[-1] = False
    return({'start_date':start_date,'end_date':end_date,'final_selection':final_selection,'progression_count':progression_count,'prg_cnt_2_years':prg_cnt_2_years})

criteria_df = criteria_df.sort(by=['chai_patient_id','min_date'],descending=False)\
                        .groupby(['chai_patient_id']).agg(pl.col('min_date'))\
                        .filter(pl.col('min_date').apply(lambda x:len(x)>1))\
                        .with_columns(pl.col('min_date').apply(lambda x:get_time_windows(x)).alias('windows'))\
                        .unnest('windows')\
                        .drop('min_date')\
                        .explode(['start_date','end_date','final_selection','progression_count','prg_cnt_2_years'])

In [None]:
criteria_df = criteria_df.filter((pl.col('start_date')<datetime.date(2023,6,17)) & (pl.col('end_date')>datetime.date(2023,6,17)))

In [None]:
criteria_df.shape

In [None]:
criteria_df[['chai_patient_id']].unique().shape

In [None]:
# random.seed(123)
criteria_df = criteria_df.with_columns(pl.struct(['start_date','end_date']).apply(lambda x:(x['end_date']-x['start_date']).days).alias('diff'))
criteria_df = criteria_df.with_columns(pl.lit(datetime.date(2023,6,17)).alias('random_date'))
# criteria_df = criteria_df.with_columns(pl.struct(['diff','start_date']).apply(lambda x:x['start_date']+timedelta(days=random.randint(1,x['diff']))).alias('random_date'))

In [None]:
criteria_df = criteria_df.with_columns(pl.struct(['random_date','end_date']).apply(lambda x:(x['end_date']-x['random_date']).days).alias('label'))

In [None]:
criteria_df.head()

In [None]:
lab_test_df = lab_test_df.join(criteria_df.select(['chai_patient_id','start_date','random_date']), on ='chai_patient_id',how='inner')
lab_test_df['chai_patient_id'].unique().shape

In [None]:
lab_test_df.head()

In [None]:
lab_test_df.shape

In [None]:
lab_test_df = lab_test_df.filter((pl.col('test_date')>pl.col('start_date')) & (pl.col('test_date')<=pl.col('random_date')))

In [None]:
# lab_test_df = lab_test_df.with_columns(pl.struct(['test_date','start_date']).apply(lambda x:(x['test_date']-x['start_date']).days).alias('test_diff'))
lab_test_df = lab_test_df.sort(by=['chai_patient_id','start_date','test_name_standard_code','test_date'],descending=False)

print(lab_test_df['test_name_standard_code'].unique().shape)
lab_test_df.shape

In [None]:
# lab_test_df = lab_test_df.filter(pl.col('test_diff')<=pl.col('random_point'))
min_df = lab_test_df.groupby(['chai_patient_id','start_date','test_name_standard_code']).agg(pl.col('test_value_numeric').min().alias('nadir_value'))

In [None]:
min_df.shape

In [None]:
lab_test_df = lab_test_df.filter(pl.struct(['test_date','random_date']).apply(lambda x:True if (x['test_date']>(x['random_date']-timedelta(days=90))) else False))
print(lab_test_df['chai_patient_id'].unique().shape)
lab_test_df.shape

In [None]:
latest_lab_test_df = lab_test_df.sort(by=['chai_patient_id','start_date','test_name_standard_code','test_date'],descending=True)\
                        .unique(subset=['chai_patient_id','start_date','test_name_standard_code'],keep='first')\
                        .select(['chai_patient_id','start_date','test_name_standard_code','test_value_numeric','test_date'])
lab_test_df.shape

In [None]:
nadir_df = min_df.join(latest_lab_test_df.select(['chai_patient_id','start_date','test_name_standard_code','test_value_numeric','test_date']),on=['chai_patient_id','start_date','test_name_standard_code'],how='inner')
nadir_df = nadir_df.with_columns(pl.struct(['test_value_numeric','nadir_value'])\
                                  .apply(lambda x:x['test_value_numeric']-x['nadir_value']).alias('abs_change_from_nadir'))
nadir_df = nadir_df.with_columns(pl.struct(['test_value_numeric','nadir_value']).apply(lambda x:(x['test_value_numeric']-x['nadir_value'])/(x['nadir_value']+0.0001)).alias('perc_change_from_nadir'))

In [None]:
nadir_df.head()

In [None]:
nadir_df.shape

In [None]:
# thresold_map = {'m_protein_in_serum':0.5,'m_protein_in_urine':200,'serum_free_light_kappa':10,'serum_free_light_lambda':10}
# feat_lab_list = list(lab_class_map.keys())
# final_stat_df = pl.DataFrame({'chai_patient_id': list(set(lab_test_df['chai_patient_id']))})

# for lab in feat_lab_list:
#     if lab==2:
#         continue
#     lab_name = lab_class_map[lab]
#     temp_df = lab_test_df.filter(pl.col('test_name_standard_code')==lab)\
#                         .select(['chai_patient_id','test_value_numeric'])\
#                         .rename({'test_value_numeric':lab_name})
#     if temp_df.shape[0]>0:
# #         temp_df = temp_df.with_columns(pl.col(lab_name).apply(lambda x:x-thresold_map[lab_name] if (x>=thresold_map[lab_name]) else None).alias(f'{lab_name}_base_diff'))
#         final_stat_df = final_stat_df.join(temp_df,on='chai_patient_id',how='left')
#     else:
#         final_stat_df = final_stat_df.with_columns(pl.lit(None).alias(lab_name))
# #         final_stat_df = final_stat_df.with_columns(pl.lit(None).alias(lab_name),pl.lit(None).alias(f'{lab_name}_base_diff'))
        
#     temp_df = nadir_df.filter(pl.col('test_name_standard_code')==lab)\
#                         .select(['chai_patient_id','abs_change_from_nadir','perc_change_from_nadir'])\
#                         .rename({'abs_change_from_nadir':f'abs_change_from_nadir_{lab_name}',
#                                  'perc_change_from_nadir':f'perc_change_from_nadir_{lab_name}'})
#     if temp_df.shape[0]>0:
#         final_stat_df = final_stat_df.join(temp_df,on='chai_patient_id',how='left')
#     else:
#         final_stat_df = final_stat_df.with_columns(pl.lit(None).alias(f'abs_change_from_nadir_{lab_name}'),\
#                                                   pl.lit(None).alias(f'perc_change_from_nadir_{lab_name}'))
#################################        
        
feat_lab_list = list(lab_class_map.keys())
final_stat_df = pl.DataFrame({'chai_patient_id': list(set(lab_test_df['chai_patient_id']))})
final_stat_df = lab_test_df[['chai_patient_id','start_date']].unique()

for lab in feat_lab_list:
    if lab==2:
        continue
    lab_name = lab_class_map[lab]
    lab_codes = lab_dict[lab_name]
#     temp_df = lab_test_df.filter(pl.col('test_name_standard_code')==lab)
    temp_df = lab_test_df.filter(pl.col('test_name_standard_code').is_in(lab_codes))
    temp_df = temp_df.sort(by=['chai_patient_id','start_date','test_date'],descending=True)\
                    .unique(subset=['chai_patient_id','start_date'],keep='first')\
                    .select(['chai_patient_id','start_date','test_value_numeric'])\
                    .rename({'test_value_numeric':lab_name})
    if temp_df.shape[0]>0:
#         temp_df = temp_df.with_columns(pl.col(lab_name).apply(lambda x:x-thresold_map[lab_name] if (x>=thresold_map[lab_name]) else None).alias(f'{lab_name}_base_diff'))
        final_stat_df = final_stat_df.join(temp_df,on=['chai_patient_id','start_date'],how='left')
    else:
        final_stat_df = final_stat_df.with_columns(pl.lit(None).alias(lab_name))
#         final_stat_df = final_stat_df.with_columns(pl.lit(None).alias(lab_name),pl.lit(None).alias(f'{lab_name}_base_diff'))
        
#     temp_df = nadir_df.filter(pl.col('test_name_standard_code')==lab)\
    temp_df = nadir_df.filter(pl.col('test_name_standard_code').is_in(lab_codes))
    temp_df = temp_df.sort(by=['chai_patient_id','start_date','test_date'],descending=True)\
                    .unique(subset=['chai_patient_id','start_date'],keep='first')\
                    .select(['chai_patient_id','start_date','abs_change_from_nadir','perc_change_from_nadir'])\
                    .rename({'abs_change_from_nadir':f'abs_change_from_nadir_{lab_name}',
                                 'perc_change_from_nadir':f'perc_change_from_nadir_{lab_name}'})
    if temp_df.shape[0]>0:
        final_stat_df = final_stat_df.join(temp_df,on=['chai_patient_id','start_date'],how='left')
    else:
        final_stat_df = final_stat_df.with_columns(pl.lit(None).alias(f'abs_change_from_nadir_{lab_name}'),\
                                                  pl.lit(None).alias(f'perc_change_from_nadir_{lab_name}'))

In [None]:
final_stat_df.shape

In [None]:
final_stat_df = final_stat_df.join(criteria_df[['chai_patient_id','start_date','end_date','label','final_selection','diff','progression_count','prg_cnt_2_years']],on=['chai_patient_id','start_date'],how='inner')

In [None]:
final_stat_df = final_stat_df.join(c1_df,left_on=['chai_patient_id','end_date'],right_on=['chai_patient_id','min_date'],how='left')

In [None]:
end = time.time()
print(end-start)

In [None]:
final_stat_df = final_stat_df.fill_nan(None)

In [None]:
final_stat_df = final_stat_df.filter(~pl.all_horizontal(pl.col(['m_protein_in_serum','m_protein_in_urine','serum_free_light_kappa','serum_free_light_lambda'
                            ]).is_null()))
final_stat_df.shape

In [None]:
def to_pandas(df):
    """
    Method to convert Polars DataFrame Object to Pandas Data Frame Object
    
    ARGUMENTS:
    ----------
    df: Polars.DataFrame: Polars Data Frame object to convert
    
    RETURNS:
    --------
    pandas_df: Pandas.DataFrame
    """
    
    pandas_df = df.to_pandas()
    
    #Polars boolean dtype is converted to object type through `to_pandas()`, recasting to boolean type
    bool_cols = df.select(pl.col(pl.Boolean)).columns
    for col in bool_cols:
        pandas_df[col] = pandas_df[col].astype('bool')
        
    float_cols = df.select(pl.col(pl.Float64)).columns
    float_cols+= df.select(pl.col(pl.Null)).columns
    for col in float_cols:
        pandas_df[col] = pandas_df[col].astype('float')
    
    return pandas_df

final_stat_df = to_pandas(final_stat_df)

In [None]:
for col in final_stat_df.columns:
    if '[#/volume]' in col:
        final_stat_df = final_stat_df.rename(columns={col:col.replace('[#/volume]','').replace('  ',' ')})
    if '[Mass/volume]' in col:
        final_stat_df = final_stat_df.rename(columns={col:col.replace('[Mass/volume]','').replace('  ',' ')})
    if '[mass/volume]' in col:
        final_stat_df = final_stat_df.rename(columns={col:col.replace('[mass/volume]','').replace('  ',' ')})    
    if '[Mass/time]' in col:
        final_stat_df = final_stat_df.rename(columns={col:col.replace('[Mass/time]','').replace('  ',' ')})
    if '[Enzymatic activity/volume]' in col:
        final_stat_df = final_stat_df.rename(columns={col:col.replace('[Enzymatic activity/volume]','').replace('  ',' ')})
    if '[Mass Ratio]' in col:
        final_stat_df = final_stat_df.rename(columns={col:col.replace('[Mass Ratio]','').replace('  ',' ')})
    if '[Presence]' in col:
        final_stat_df = final_stat_df.rename(columns={col:col.replace('[Presence]','').replace('  ',' ')})
    if '[Mass/time in 24 hour Urine by Electrophoresis_0]' in col:
        final_stat_df = final_stat_df.rename(columns={col:col.replace('[Mass/time in 24 hour Urine by Electrophoresis_0]','').replace('  ',' ')})
    if '[moles/volume]' in col :
        final_stat_df = final_stat_df.rename(columns={col:col.replace('[moles/volume]','').replace('  ',' ')})
    if '[interpretation]' in col:
        final_stat_df = final_stat_df.rename(columns={col:col.replace('[interpretation]','').replace('  ',' ')})
    if ('[' in col) or (']' in col):
        final_stat_df = final_stat_df.rename(columns={col:col.replace('[','').replace(']','')})

In [None]:
train_pat = random.sample(list(final_stat_df['chai_patient_id'].unique()),int(final_stat_df['chai_patient_id'].unique().shape[0]*0.8))
final_stat_df_train = final_stat_df[final_stat_df['chai_patient_id'].isin(train_pat)]
final_stat_df_test = final_stat_df[~final_stat_df['chai_patient_id'].isin(train_pat)]

In [None]:
final_stat_df.head()

In [None]:
x = final_stat_df.drop(['chai_patient_id','start_date','end_date','final_selection','label','diff','criteria_1','progression_count'],axis=1)
y = convert_to_structured((final_stat_df['label']),final_stat_df['final_selection'])

In [None]:
x.head()

In [None]:
x.dropna(how='all').shape

In [None]:
print(x.shape)
text_file.write(f'x shape : {x.shape} \n')

In [None]:
text_file.write(f'label distribution : \n{pd.DataFrame(y)["c1"].value_counts()} \n\n')

In [None]:
text_file.write('Stats for duration between dx and event/censor date \n')
text_file.write(f"{final_stat_df['diff'].describe(percentiles=[0.1,0.25,0.50,0.75,0.9])} \n\n")

In [None]:
text_file.write('Stats for duration between random point and event/censor date \n')
text_file.write(f"{final_stat_df['label'].describe(percentiles=[0.1,0.25,0.50,0.75,0.9])} \n\n")

In [None]:
print(f'c-index : {concordance_index(y, xgbse_model.predict(x))}')
text_file.write(f'c-index : {concordance_index(y, xgbse_model.predict(x))} \n\n')

In [None]:
output_dir

In [None]:
import pickle
import joblib
date = '05_02_2024'
# # event_prob_df = pd.read_csv(f'data_backup/event_prob_df_{date}.csv')
# x_train = pd.read_csv(f'{output_dir}x_train_{date}.csv')
# x_test = pd.read_csv(f'{output_dir}x_test_{date}.csv')
# y_train = np.load(f'{output_dir}y_train_{date}.npy')
# y_test = np.load(f'{output_dir}y_test_{date}.npy')
# final_stat_df = pd.read_csv(f'{output_dir}final_stat_df_{date}.csv')

with open(f'{output_dir}model_{date}.pkl','rb') as f:
    xgbse_model = pickle.load(f) 

In [None]:
# final_stat_df = final_stat_df.drop(['start_date','end_date','final_selection','label','diff','criteria_1','progression_count'],axis=1)
prediction_df = 1-pd.DataFrame(xgbse_model.predict(x))

In [None]:
prediction_df.shape

In [None]:
result_df = pd.concat([final_stat_df,prediction_df],axis=1)
result_df.shape

In [None]:
criteria_df = criteria_df.to_pandas()
result_df = criteria_df.merge(result_df,on=['chai_patient_id'],how = 'inner')
result_df.head()

In [None]:
result_df.to_csv(f'{output_dir_infer}Inference_prediction_17_06_2023_index_{date}.csv',index=False)

In [None]:
text_file.close()

In [None]:
from sklearn.metrics import PrecisionRecallDisplay,ConfusionMatrixDisplay,confusion_matrix,roc_auc_score
import matplotlib.pyplot as plt
# combined_test_info_df = event_prob_df.join(x_test).merge(final_stat_df[['index','chai_patient_id','label','final_selection']],on='index',how='inner')

text_file.write('GT Label counts at different prediction windows \n')
combined_test_info_df = event_prob_df.join(x_test.reset_index()).merge(final_stat_df.reset_index()[['index','chai_patient_id','label','final_selection']],on='index',how='inner')
for i in [30,60,90,120,150,180]:
    combined_test_info_df[f'final_selection_{i}'] = combined_test_info_df.apply(lambda x: False if x['label']>i else x['final_selection'],axis=1)
    print(combined_test_info_df[f'final_selection_{i}'].value_counts())
    print(f"ROC AUC Score : {roc_auc_score(combined_test_info_df[f'final_selection_{i}'], combined_test_info_df[i])}")
    text_file.write(f'Label counts for window : {i} \n')
    text_file.write(f"{combined_test_info_df[f'final_selection_{i}'].value_counts()} \n\n")
    PrecisionRecallDisplay.from_predictions(combined_test_info_df[f'final_selection_{i}'], combined_test_info_df[i], name="PA_model", plot_chance_level=True)
    plt.title(f'window : {i}')
    plt.savefig(f'{output_dir}/PRC_Window_{i}.png')
    plt.plot()
    
text_file.close()

In [None]:
# text_file.close()

In [None]:
from sklearn.metrics import ConfusionMatrixDisplay,confusion_matrix
import matplotlib.pyplot as plt
# combined_test_info_df = event_prob_df.join(x_test.reset_index()).merge(final_stat_df.reset_index()[['index','chai_patient_id','label','final_selection']],on='index',how='inner')

for window in [30,60,90,120,150,180]:
    for thresold in [0.1]:
        combined_test_info_df[f'pred_{thresold}'] = combined_test_info_df[window].apply(lambda x:x>thresold)
        cm = confusion_matrix(combined_test_info_df[f'final_selection_{window}'],combined_test_info_df[f'pred_{thresold}'])
        disp = ConfusionMatrixDisplay(cm)
        disp.plot()
        plt.title(f'Thresold : {thresold}, Window : {window}')
        plt.savefig(f'{output_dir}/CF_thresold_{thresold}_Window_{window}.png')
#         plt.show()

In [None]:
from sklearn.metrics import ConfusionMatrixDisplay,confusion_matrix
# combined_test_info_df = event_prob_df.join(x_test.reset_index()).merge(final_stat_df.reset_index()[['index','chai_patient_id','label','final_selection']],on='index',how='inner')
metric_list = []
for window in [30,60,90,120,150,180]:
    window_df = pd.DataFrame({'window':[window]})
    for thresold in [0.2,0.3,0.4,0.5,0.6,0.7,0.8]:
        combined_test_info_df[f'pred_{thresold}'] = combined_test_info_df[window].apply(lambda x:x>thresold)
        cm = confusion_matrix(combined_test_info_df[f'final_selection_{i}'],combined_test_info_df[f'pred_{thresold}'])
        tn,fp,fn,tp = cm[0,0],cm[0,1],cm[1,0],cm[1,1]
        p1 = tp/(tp+fp)
        r1 = tp/(tp+fn)
        f1 = 2*p1*r1/(p1+r1)
        p0 = tn/(tn+fn)
        r0 = tn/(tn+fp)
        f0 = 2*p0*r0/(p0+r0)
        window_df = pd.concat([window_df,pd.DataFrame({
            f'tn_{thresold}' : [tn],
            f'fp_{thresold}' : [fp],
            f'fn_{thresold}' : [fn],
            f'tp_{thresold}' : [tp],
            f'p1_{thresold}' : [p1],
            f'r1_{thresold}' : [r1],
            f'f1_{thresold}' : [f1],
            f'p0_{thresold}' : [p0],
            f'r0_{thresold}' : [r0],
            f'f0_{thresold}' : [f0],
        })],axis=1)
    metric_list.append(window_df)
metric_df = pd.concat(metric_list)

In [None]:
metric_df.to_csv(f'data_backup/metric_df_{date}.csv',index=False)

In [None]:
for i in [30,60,90,120,150,180]:
    print(i)
    event_prob_df[i].hist(bins=100)
    plt.savefig(f'{output_dir}predicted_prob_distribution_{i}.png')
    plt.show()