In [None]:
import polars as pl
import pandas as pd
import datetime
import random
import copy
import numpy as np
import time
import traceback
from xgbse import (
    XGBSEDebiasedBCE,
)
from xgbse.metrics import concordance_index
from xgbse.converters import (
    convert_data_to_xgb_format,
    convert_to_structured
)

from ai_knowledge.ai_knowledge.base.convert import *
import yaml
from sklearn.model_selection import train_test_split

from PA_GE_test.rule_execution import rules_check
from sklearn.inspection import PartialDependenceDisplay

today_date = datetime.date.today()

In [None]:
with open('./ai_knowledge/ai_knowledge/base/keys_labs.yaml','r') as file:
    lab_kb = yaml.safe_load(file)
    
lab_kb = {c['key']:c for c in lab_kb}

with open('./ai_knowledge/ai_knowledge/base/units.yaml','r') as file:
    unite_conv_map = yaml.safe_load(file)
    
unite_conv_map = {c['base']:c for c in unite_conv_map}

In [None]:
def get_stats(feature_name, values):
     return pl.DataFrame({'feature':feature_name,'min':values.min(),\
                  'max':values.max(),'mean':values.mean(),'median':values.median(),\
                  'mode':values.mode(),'std':values.std()})
    
stat_df_list = []

In [None]:
def warn(*args, **kwargs):
    pass
import warnings
warnings.warn = warn

In [None]:
start = time.time()

In [None]:
#DB Connection Details
host = "c0-eurekarwd-web.crdektqhzze9.us-east-1.redshift.amazonaws.com"
port = "5439"
dbname = "c0_eurekarwd"
user = "dparmar@concertai.com"
password = "Sql@1996"

In [None]:
schema = 'dts_cdm_master'
# schema = 'qcca_cdm_master'

In [None]:
db_conn_uri = "redshift://{user}:{pwd}@{server}:{port}/{db}".format(user=user.replace('@','%40'), pwd=password.replace('@','%40'), server=host, port=port, db=dbname)
chai_patid_query = "select chai_patient_id,diagnosis_date,diagnosis_code_standard_code,diagnosis_code_standard_name from {schema}.condition where (lower(diagnosis_code_standard_name) like '%multiple%myeloma%' or lower(diagnosis_code_standard_code) like 'C90.0%') and extract(year from diagnosis_date) > 2019".format(schema=schema)
pat_id_df = pl.read_database(chai_patid_query,db_conn_uri)

In [None]:
pat_id_df.select(['chai_patient_id']).unique().shape

In [None]:
pat_ids = pat_id_df.select(['chai_patient_id']).unique().to_series().to_list()
pat_ids = ','.join(["'"+str(x)+"'" for x in pat_ids])

In [None]:
lab_dict = {'m_protein_in_serum':['33358-3','51435-6','35559-4','94400-9','33647-9','50796-2','56766-9','44932-2','50792-1'],
            'm_protein_in_urine':['42482-0','40661-1','35560-2'],
            'ca':['17861-6','49765-1'],
            'serum_free_light':['36916-5','33944-0','11051-0','11050-2'],
            'hemoglobin_in_blood':['718-7','20509-6','30313-1','48725-6'],
            'neutrophils_count':['751-8','26499-4','768-2','30451-9','753-4'],
            'lymphocytes_count':['26474-7','732-8','731-0'],
            'platelets':['777-3','26515-7','53800-9','49497-1','778-1'],
            'na':['2951-2','2955-3'],
            'mg':['21377-7','19123-9'],
            'cl':['2075-0'],
            'phos' : ['2777-1'],
            'hr' : ['8867-4'],
            'dbp' : ['8462-4'],
            'ecog' : ['89262-0'],
            'k' : ['2823-3','2828-2']
           }

In [None]:
pat_serum_test_query = "select chai_patient_id,test_date,test_name_standard_code,test_name_standard_name,test_value_numeric,test_unit_standard_name from {schema}.patient_test where (chai_patient_id in ({pat_ids})) and (test_name_standard_code in ('33358-3','51435-6','35559-4','94400-9','33647-9','50796-2','33647-9','56766-9','44932-2','50792-1')) and (extract(year from test_date) > 2019)".format(pat_ids=pat_ids,schema=schema)
pat_serum_test_df = pl.read_database(pat_serum_test_query,db_conn_uri)

In [None]:
pat_urine_test_query = "select chai_patient_id,test_date,test_name_standard_code,test_name_standard_name,test_value_numeric,test_unit_standard_name from {schema}.patient_test where (chai_patient_id in ({pat_ids})) and (test_name_standard_code in ('42482-0','40661-1','35560-2')) and (extract(year from test_date) > 2019)".format(pat_ids=pat_ids,schema=schema)
pat_urine_test_df = pl.read_database(pat_urine_test_query,db_conn_uri)

In [None]:
pat_serum_test_df_cast = pat_serum_test_df.filter(
        pl.any_horizontal(
            pl.col('test_value_numeric').is_not_null() & pl.col('test_unit_standard_name').is_not_null() 
        ))\
        .with_columns(pl.col('test_unit_standard_name').apply(lambda x:x.lower()))

In [None]:
pat_urine_test_df_cast = pat_urine_test_df.filter(
        pl.any_horizontal(
            pl.col('test_value_numeric').is_not_null() & pl.col('test_unit_standard_name').is_not_null() 
        ))\
        .with_columns(pl.col('test_unit_standard_name').apply(lambda x:x.lower()))

In [None]:
pat_urine_test_df_cast.shape

In [None]:
pat_serum_test_df_cast = pat_serum_test_df_cast.with_columns(
    pl.when(pl.col('test_name_standard_code').is_in(lab_dict['m_protein_in_serum']))
    .then('m_protein_in_serum')
    .alias('key')
)

pat_urine_test_df_cast = pat_urine_test_df_cast.with_columns(
    pl.when(pl.col('test_name_standard_code').is_in(lab_dict['m_protein_in_urine']))
    .then('m_protein_in_urine')
    .alias('key')
)

In [None]:
print(pat_urine_test_df_cast.shape)
print(pat_serum_test_df_cast.shape)

In [None]:
def convert_unit(df):
    lab_test = df['key'].unique().drop_nulls()
    for key in lab_test:
        if key not in lab_kb.keys():
            raise KeyError(f'Lab test {key} not registered')
        std_unit = lab_kb[key]['attributes']['units']
        unit_keys = list(unite_conv_map[std_unit]['convert'].keys())
        df = df.with_columns(
                pl.struct(['test_value_numeric','test_unit_standard_name','key'])
                          .apply(lambda x: eval(unite_conv_map[std_unit]['convert'][x['test_unit_standard_name']].split('.')[-1])(x['test_value_numeric'])\
                                 if (x['key']==key) & (x['test_unit_standard_name'] in (unit_keys))\
                                 else x['test_value_numeric']                                                      
                                ).alias('test_value_numeric')
        )
        df = df.with_columns(
                        pl.when((pl.col('key')==key) & ((pl.col('test_unit_standard_name')==std_unit) | (pl.col('test_unit_standard_name').is_in(unite_conv_map[std_unit]['convert'].keys()))))
                .then(pl.lit(std_unit))
                .when(pl.col('key')==key)
                .then(None)
                .otherwise(pl.col('test_unit_standard_name'))
                .alias('test_unit_standard_name'))\
        .drop_nulls('test_unit_standard_name')
    return df

In [None]:
pat_serum_test_df_cast = pat_serum_test_df_cast.select(['chai_patient_id','test_date','test_value_numeric','test_name_standard_code','test_unit_standard_name'])\
    .with_columns(pl.col('test_date').apply(lambda x:x.date()))\
#     .sort(['chai_patient_id','test_date','test_value_numeric'],descending=False)\
#     .unique(subset=['chai_patient_id','test_date','test_name_standard_code'],keep='first')

In [None]:
pat_urine_test_df_cast = pat_urine_test_df_cast.select(['chai_patient_id','test_date','test_value_numeric','test_name_standard_code','test_unit_standard_name'])\
    .with_columns(pl.col('test_date').apply(lambda x:x.date()))\
#     .sort(['chai_patient_id','test_date','test_value_numeric'],descending=False)\
#     .unique(subset=['chai_patient_id','test_date','test_name_standard_code'],keep='first')

In [None]:
print(pat_urine_test_df_cast.shape)
print(pat_serum_test_df_cast.shape)

## Calcium

In [None]:
pat_cal_test_query = "select chai_patient_id,source_labid,test_date,test_name_standard_code,test_name_standard_name,test_value_numeric,test_unit_standard_name from {schema}.patient_test where (chai_patient_id in ({pat_ids})) and (test_name_standard_code in ('17861-6','49765-1')) and (extract(year from test_date) > 2019)".format(pat_ids=pat_ids,schema=schema)
pat_cal_test_df = pl.read_database(pat_cal_test_query,db_conn_uri)

In [None]:
pat_cal_test_df = pat_cal_test_df.select(['chai_patient_id','test_date','test_value_numeric','test_name_standard_code','test_unit_standard_name'])\
                        .with_columns(pl.col('test_date').apply(lambda x:x.date()))\
#                         .sort(['chai_patient_id','test_date','test_value_numeric'],descending=False)\
#                         .unique(subset=['chai_patient_id','test_date','test_name_standard_code'],keep='first')

In [None]:
pat_cal_test_df.shape

## Serum Free light chain

In [None]:
pat_free_light_query = "select chai_patient_id,source_labid,test_date,test_name_standard_code,test_name_standard_name,test_value_numeric,test_unit_standard_name from {schema}.patient_test where (chai_patient_id in ({pat_ids})) and (test_name_standard_code in ('36916-5','33944-0','11051-0','11050-2')) and (extract(year from test_date) > 2019)".format(pat_ids=pat_ids,schema=schema)
pat_free_light_df = pl.read_database(pat_free_light_query,db_conn_uri)

In [None]:
pat_free_light_df.shape

In [None]:
pat_free_light_df = pat_free_light_df.select(['chai_patient_id','test_date','test_value_numeric','test_name_standard_code','test_unit_standard_name'])\
    .with_columns(pl.col('test_date').apply(lambda x:x.date()))\
#     .sort(['chai_patient_id','test_date','test_value_numeric'],descending=False)\
#     .unique(subset=['chai_patient_id','test_date','test_name_standard_code'],keep='first')

In [None]:
pat_free_light_df.shape

In [None]:
lab_test_df = pl.concat([pat_urine_test_df_cast,pat_serum_test_df_cast,pat_free_light_df,pat_cal_test_df])

In [None]:
print(pat_urine_test_df_cast.shape)
print(pat_serum_test_df_cast.shape)
print(pat_free_light_df.shape)
print(pat_cal_test_df.shape)

In [None]:
lab_test_df = lab_test_df.with_columns(
    pl.when(pl.col('test_name_standard_code').is_in(lab_dict['m_protein_in_serum']))
    .then('m_protein_in_serum')
    .when(pl.col('test_name_standard_code').is_in(lab_dict['m_protein_in_urine']))
    .then('m_protein_in_urine')
    .when(pl.col('test_name_standard_code').is_in(lab_dict['ca']))
    .then('ca')
    .when(pl.col('test_name_standard_code').is_in(lab_dict['serum_free_light']))
    .then('serum_free_light')
    .when(pl.col('test_name_standard_code').is_in(lab_dict['hemoglobin_in_blood']))
    .then('hemoglobin_in_blood')
    .when(pl.col('test_name_standard_code').is_in(lab_dict['neutrophils_count']))
    .then('neutrophils_count')
    .when(pl.col('test_name_standard_code').is_in(lab_dict['lymphocytes_count']))
    .then('lymphocytes_count')
    .when(pl.col('test_name_standard_code').is_in(lab_dict['platelets']))
    .then('platelets')
    .when(pl.col('test_name_standard_code').is_in(lab_dict['na']))
    .then('na')
    .when(pl.col('test_name_standard_code').is_in(lab_dict['mg']))
    .then('mg')
    .when(pl.col('test_name_standard_code').is_in(lab_dict['cl']))
    .then('cl')
    .when(pl.col('test_name_standard_code').is_in(lab_dict['phos']))
    .then('phos')
    .when(pl.col('test_name_standard_code').is_in(lab_dict['k']))
    .then('k')
    .alias('key'),
    pl.col('test_unit_standard_name').apply(lambda x:x.lower()),
    pl.col('test_value_numeric').str.replace('<','').str.replace('>','').str.replace(',','.')
    .cast(pl.Float64, strict=False)
).drop_nulls(subset='test_value_numeric')\
.with_columns(
    pl.when((pl.col('test_name_standard_code').is_in(lab_dict['hr']+lab_dict['dbp']+lab_dict['ecog'])) & (pl.col('test_unit_standard_name').is_null()))
    .then('valid')
    .otherwise(pl.col('test_unit_standard_name'))
    .alias('test_unit_standard_name')
)
lab_test_df.shape

In [None]:
lab_test_df = convert_unit(lab_test_df)
lab_test_df.shape

In [None]:
lab_test_df = lab_test_df.sort(['chai_patient_id','test_date','test_value_numeric'],descending=False)\
    .unique(subset=['chai_patient_id','test_date','test_name_standard_code'],keep='first')

In [None]:
lab_dict.update({'serum_free_light_kappa':['36916-5','11050-2'],'serum_free_light_lambda':['33944-0','11051-0']})
lab_class_map =  {0:'m_protein_in_serum',1:'m_protein_in_urine',2:'ca',3:'serum_free_light_kappa',4:'serum_free_light_lambda'}

lab_test_df = lab_test_df.with_columns(
                                    pl.when(pl.col('test_name_standard_code').is_in(lab_dict['m_protein_in_serum']))
                                    .then(pl.lit(0))
                                    .when(pl.col('test_name_standard_code').is_in(lab_dict['m_protein_in_urine']))
                                    .then(pl.lit(1))
                                    .when(pl.col('test_name_standard_code').is_in(lab_dict['ca']))
                                    .then(pl.lit(2))
                                    .when(pl.col('test_name_standard_code').is_in(lab_dict['serum_free_light_kappa']))
                                    .then(pl.lit(3))
                                    .when(pl.col('test_name_standard_code').is_in(lab_dict['serum_free_light_lambda']))
                                    .then(pl.lit(4))
                                    .alias('test_name_standard_code')
)

In [None]:
lab_test_df['chai_patient_id'].unique().shape

In [None]:
lab_test_df.shape

## Criteria

In [None]:
criteria_df = pl.read_csv('../Dhaval/data_backup/criteria_2023-11-10.csv')

In [None]:
criteria_df['chai_patient_id'].unique().shape

In [None]:
def download_dod(source_schema, patients_ids):
    """
    :param source_schema: Source schema name
    :param patients_ids: Patient IDs for whom max date is required
    :return: A DataFrame with patient ID and max date
    """
    dod_query = '''Select chai_patient_id, date_of_death from {}.patient where chai_patient_id in ({}) 
        and (date_of_death is not null and date_part(year, date_of_death) >= 1950)'''.format(
            source_schema, ', '.join(["'" + str(pat) + "'" for pat in patients_ids]))
    dod_df = pl.read_database(dod_query,db_conn_uri)
    
#     dod_df["date_of_death"] = pd.to_datetime(dod_df["date_of_death"], errors='coerce')
    return dod_df

def download_last_activity_date_dod(source_schema, patients_ids):
    """
    This function returns max date information

    :param dtu: DTU object for downloading tables from redshift
    :param source_schema: Source schema name
    :param patients_ids: Patient IDs for whom max date is required
    :return: A DataFrame with patient ID and max date
    """
    dod_df = download_dod(source_schema, patients_ids)
    
    date_fields = {'medication': 'med_start_date', 'patient_test': 'test_date', 'tumor_exam': 'exam_date',
                   'care_goal': 'treatment_plan_start_date', 'surgery': 'surgery_date', 'radiation': 'rad_start_date',
                   'condition': 'diagnosis_date', 'adverse_event': 'adverse_event_date', 'encounter': 'encounter_date',
                   'disease_status': 'assessment_date', 'staging': 'stage_date'}
    visits = pl.DataFrame()
    for key, value in date_fields.items():
        last_date_query = '''Select distinct chai_patient_id, max({column_name}) as max_date
                                                       from {source_schema}.{table_name} where {column_name} is not null and 
                                                       chai_patient_id in ({pats_list}) group by chai_patient_id '''.format(
            column_name=value,
            table_name=key,
            source_schema=source_schema,
            pats_list=', '.join(["'" + str(pat) + "'" for pat in patients_ids]))
        last_date_df = pl.read_database(last_date_query,db_conn_uri)
        
        visits = pl.concat([visits,last_date_df])
   
    last_medical_activity = visits.sort(['chai_patient_id', 'max_date'], descending=True)\
                                        .unique(['chai_patient_id']).join(dod_df, on="chai_patient_id", how="left")
    last_medical_activity = last_medical_activity.with_columns(pl.min_horizontal('date_of_death','max_date')\
                                                               .alias('last_activity_date')).drop(['date_of_death', 'max_date'])

    last_medical_activity = last_medical_activity.with_columns(pl.col('last_activity_date').apply(lambda x:x.date()))
    return last_medical_activity

In [None]:
last_date_df = download_last_activity_date_dod(schema,set(list(criteria_df['chai_patient_id'])))

In [None]:
last_date_df.shape

In [None]:
criteria_df = criteria_df.join(last_date_df,on='chai_patient_id', how='left')
criteria_df = criteria_df.with_columns(pl.col('min_date').apply(lambda x:datetime.datetime.strptime(x, '%Y-%m-%d').date() if isinstance(x,str) else x))
criteria_df = criteria_df.with_columns(pl.struct(['last_activity_date','min_date','final_selection']).apply(lambda x: x['min_date'] if (x['final_selection']==True) else x['last_activity_date']))
criteria_df = criteria_df.with_columns(pl.col('last_activity_date').apply(lambda x:datetime.date.today() if (x>datetime.date.today()) else x))

In [None]:
pat_id_df_updated = pat_id_df.select(['chai_patient_id','diagnosis_date']).sort(by =['chai_patient_id','diagnosis_date'],descending=False).unique(subset=['chai_patient_id'],keep='first').drop_nulls()

In [None]:
pat_id_df_updated['chai_patient_id'].unique().shape

In [None]:
criteria_df = criteria_df.join(pat_id_df_updated,on='chai_patient_id',how='inner')
criteria_df['chai_patient_id'].unique().shape

In [None]:
criteria_df = criteria_df.with_columns(pl.col('diagnosis_date').apply(lambda x:datetime.datetime.strptime(x, '%Y-%m-%d %H:%M:%S').date() if isinstance(x,str) else x.date()))

In [None]:
criteria_df = criteria_df.with_columns(pl.struct(['last_activity_date','diagnosis_date']).apply(lambda x:(x['last_activity_date']-x['diagnosis_date']).days).alias('diff'))\
                                        .filter(pl.col('diff')>=0)

In [None]:
random.seed(123)
criteria_df = criteria_df.with_columns(pl.col('diff').apply(lambda x:random.randint(0,x)).alias('random_point'))
criteria_df = criteria_df.with_columns(pl.struct(['diff','random_point']).apply(lambda x:x['diff']-x['random_point']).alias('label'))

In [None]:
criteria_df['chai_patient_id'].unique().shape

In [None]:
lab_test_df = lab_test_df.join(criteria_df.select(['chai_patient_id','diagnosis_date','random_point']), on ='chai_patient_id',how='inner')
lab_test_df['chai_patient_id'].unique().shape

In [None]:
lab_test_df = lab_test_df.with_columns(pl.struct(['test_date','diagnosis_date']).apply(lambda x:(x['test_date']-x['diagnosis_date']).days).alias('test_diff'))\
                            .sort(by=['chai_patient_id','test_name_standard_code','test_date'],descending=False)

print(lab_test_df['test_name_standard_code'].unique().shape)
lab_test_df.shape

In [None]:
lab_test_df = lab_test_df.filter(pl.col('test_diff')<=pl.col('random_point'))
lab_test_df = lab_test_df.sort(by=['chai_patient_id','test_name_standard_code','test_date'],descending=True)\
                        .unique(subset=['chai_patient_id','test_name_standard_code'],keep='first')\
                        .select(['chai_patient_id','test_name_standard_code','test_value_numeric'])

In [None]:
test_code_map = {'33358-3' : 'Protein.monoclonal [Mass/volume] in Serum or Plasma by Electrophoresis',
'51435-6' : 'Protein.monoclonal band 1 [Mass/volume] in Serum or Plasma by Electrophoresis',
'35559-4' : 'Protein.monoclonal band 2 [Mass/volume] in Serum or Plasma by Electrophoresis',
'94400-9' : 'Protein.monoclonal [Presence] in Serum or Plasma',
'33647-9' : 'protein.monoclonal/protein.total in serum or plasma by electrophoresis',
'50796-2' : 'Protein.monoclonal band 3 [Mass/volume] in Serum or Plasma by Electrophoresis',
'33647-9' : 'Protein.monoclonal/Protein.total in Serum or Plasma by Electrophoresis',
'56766-9' : 'protein.monoclonal band 1/protein.total in serum or plasma by electrophoresis',
'44932-2' : 'Protein.monoclonal band 2/Protein.total in Serum or Plasma by Electrophoresis',
'50792-1' : 'Protein.monoclonal band 3/Protein.total in Serum or Plasma by Electrophoresis',
'42482-0' : 'Protein.monoclonal [Mass/time] in 24 hour Urine by Electrophoresis',
'40661-1' : 'Protein.monoclonal [Mass/volume] in Urine by Electrophoresis',
'35560-2' : 'Protein.monoclonal [Mass/volume] in Urine',
'36916-5' : 'Kappa light chains.free [Mass/volume] in Serum',
'33944-0' : 'lambda light chains.free [mass/volume] in serum or plasma',
'33944-0' : 'Lambda light chains.free [Mass/volume] in Serum or Plasma',
'11051-0' : 'Lambda light chains [Mass/volume] in Serum or Plasma',
'11050-2' : 'Kappa light chains [Mass/volume] in Serum or Plasma',
'17861-6' : 'Calcium [Mass/volume] in Serum or Plasma',
'49765-1' : 'Calcium [Mass/volume] in Blood',
'8867-4'  :'heart rate',
'8462-4'  :'diastolic blood pressure',
'718-7'   :'hemoglobin [mass/volume] in blood',
'26515-7' :'platelets [#/volume] in blood',
'89262-0' :'ecog performance status [interpretation]',
'20509-6' : 'Hemoglobin [Mass/volume] in Blood by calculation',
'30313-1' : 'Hemoglobin [Mass/volume] in Arterial blood',
'48725-6' : 'Hemoglobin [Mass/volume] in Blood --pre therapeutic phlebotomy',
'751-8' : 'Neutrophils [#/volume] in Blood by Automated count',
'26499-4' :'Neutrophils [#/volume] in Blood', 
'768-2' : 'Segmented neutrophils [#/volume] in Blood by Manual count',
'30451-9' : 'Segmented neutrophils [#/volume] in Blood',
'753-4' : 'Neutrophils [#/volume] in Blood by Manual count',
'26474-7' : 'Lymphocytes [#/volume] in Blood',
'732-8' : 'Lymphocytes [#/volume] in Blood by Manual count',
'731-0' : 'Lymphocytes [#/volume] in Blood by Automated count',
'777-3' : 'Platelets [#/volume] in Blood by Automated count',
'53800-9' : 'Platelets panel - Blood by Automated count',
'49497-1' : 'Platelets [#/volume] in Blood by Estimate',
'778-1' : 'Platelets [#/volume] in Blood by Manual count',
'2955-3' : 'sodium [moles/volume] in urine',
'2951-2' : 'sodium [moles/volume] in serum or plasma',
'2823-3' : 'potassium [moles/volume] in serum or plasma',
'2828-2' : 'potassium [moles/volume] in urine',
'2075-0' : 'chloride [moles/volume] in serum or plasma',
'21377-7': 'magnesium [mass/volume] in blood',
'19123-9': 'magnesium [mass/volume] in serum or plasma',
'2777-1' : 'phosphate [mass/volume] in serum or plasma'}

In [None]:
        feat_lab_list = list(lab_class_map.keys())
final_stat_df = pl.DataFrame({'chai_patient_id': list(set(lab_test_df['chai_patient_id']))})

for lab in feat_lab_list:
    lab_name = lab_class_map[lab]
    temp_df = lab_test_df.filter(pl.col('test_name_standard_code')==lab)\
                        .select(['chai_patient_id','test_value_numeric'])\
                        .rename({'test_value_numeric':lab_name})
    if temp_df.shape[0]>0:
        final_stat_df = final_stat_df.join(temp_df,on='chai_patient_id',how='left')
    else:
        final_stat_df = final_stat_df.with_columns(pl.lit(None).alias(lab_name))

In [None]:
final_stat_df.shape

In [None]:
final_stat_df = final_stat_df.join(criteria_df[['chai_patient_id','label','final_selection']],on='chai_patient_id',how='inner')

In [None]:
final_stat_df.shape

In [None]:
def to_pandas(df):
    """
    Method to convert Polars DataFrame Object to Pandas Data Frame Object
    
    ARGUMENTS:
    ----------
    df: Polars.DataFrame: Polars Data Frame object to convert
    
    RETURNS:
    --------
    pandas_df: Pandas.DataFrame
    """
    
    pandas_df = df.to_pandas()
    
    #Polars boolean dtype is converted to object type through `to_pandas()`, recasting to boolean type
    bool_cols = df.select(pl.col(pl.Boolean)).columns
    for col in bool_cols:
        pandas_df[col] = pandas_df[col].astype('bool')
        
    float_cols = df.select(pl.col(pl.Float64)).columns
    float_cols+= df.select(pl.col(pl.Null)).columns
    for col in float_cols:
        pandas_df[col] = pandas_df[col].astype('float')
    
    return pandas_df

final_stat_df = to_pandas(final_stat_df)

In [None]:
final_stat_df.shape

In [None]:
x = final_stat_df.drop(['chai_patient_id','final_selection','label'],axis=1)
y = convert_to_structured((final_stat_df['label']),final_stat_df['final_selection'])

In [None]:
for col in x.columns:
    if '[#/volume]' in col:
        x = x.rename(columns={col:col.replace('[#/volume]','').replace('  ',' ')})
    if '[Mass/volume]' in col:
        x = x.rename(columns={col:col.replace('[Mass/volume]','').replace('  ',' ')})
    if '[mass/volume]' in col:
        x = x.rename(columns={col:col.replace('[mass/volume]','').replace('  ',' ')})    
    if '[Mass/time]' in col:
        x = x.rename(columns={col:col.replace('[Mass/time]','').replace('  ',' ')})
    if '[Enzymatic activity/volume]' in col:
        x = x.rename(columns={col:col.replace('[Enzymatic activity/volume]','').replace('  ',' ')})
    if '[Mass Ratio]' in col:
        x = x.rename(columns={col:col.replace('[Mass Ratio]','').replace('  ',' ')})
    if '[Presence]' in col:
        x = x.rename(columns={col:col.replace('[Presence]','').replace('  ',' ')})
    if '[Mass/time in 24 hour Urine by Electrophoresis_0]' in col:
        x = x.rename(columns={col:col.replace('[Mass/time in 24 hour Urine by Electrophoresis_0]','').replace('  ',' ')})
    if '[moles/volume]' in col :
        x = x.rename(columns={col:col.replace('[moles/volume]','').replace('  ',' ')})
    if '[interpretation]' in col:
        x = x.rename(columns={col:col.replace('[interpretation]','').replace('  ',' ')})
    if ('[' in col) or (']' in col):
        x = x.rename(columns={col:col.replace('[','').replace(']','')})

In [None]:
x_train, x_test, y_train, y_test = train_test_split(x,y, test_size=0.2,random_state=123)

In [None]:
x_train_list = []
x_test_list = []
x_train_list.append(x_train.shape)
x_test_list.append(x_test.shape)
print(x_train_list)
print(x_test_list)

In [None]:
## pre selected params for models ##

PARAMS_XGB_AFT = {
    'objective': 'survival:aft',
    'eval_metric': 'aft-nloglik',
    'aft_loss_distribution': 'normal',
    'aft_loss_distribution_scale': 1.0,
    'tree_method': 'hist',
    'learning_rate': 0.005,
    'max_depth': 16,
    'booster':'dart',
    'subsample': 0.8,
    'min_child_weight': 30,
    'colsample_bynode':0.8
}

xgbse_model = XGBSEDebiasedBCE(PARAMS_XGB_AFT)
xgbse_model.fit(x_train,y_train,time_bins=np.array([30,60,90,120]))

In [None]:
pred = []
pred.append(xgbse_model.predict(x_test))
for col in pred[0].columns:
    stat_df_list.append(get_stats(f'prob_{col}',pred[0][col]))
    
event_prob_df = 1-pred[0]
event_prob_df.to_csv(f'data_backup/all_feature_prob_{today_date}.csv')
slelected_pt_list = []
for col in event_prob_df.columns:
    selected_df = event_prob_df[event_prob_df[col]>=0.8]
    main_index = list(x_test.reset_index().iloc[list(selected_df.index)]['index'])
    slelected_pt_list.append(list(final_stat_df.iloc[main_index]['chai_patient_id']))
    print(f'{col} : {selected_df.shape}')

In [None]:
# from sklearn.metrics import PrecisionRecallDisplay,ConfusionMatrixDisplay,confusion_matrix
# import matplotlib.pyplot as plt
# # combined_test_info_df = event_prob_df.join(x_test).merge(final_stat_df[['index','chai_patient_id','label','final_selection']],on='index',how='inner')

# combined_test_info_df = event_prob_df.join(x_test.reset_index()).merge(final_stat_df.reset_index()[['index','chai_patient_id','label','final_selection']],on='index',how='inner')
# for i in [30,60,90,120]:
#     combined_test_info_df[f'final_selection_{i}'] = combined_test_info_df.apply(lambda x: False if x['label']>i else x['final_selection'],axis=1)
#     print(combined_test_info_df[f'final_selection_{i}'].value_counts())
#     PrecisionRecallDisplay.from_predictions(combined_test_info_df[f'final_selection_{i}'], combined_test_info_df[i], name="PA_model", plot_chance_level=True)
#     plt.title(f'window : {i}')
#     plt.plot()

In [None]:
from sklearn.metrics import ConfusionMatrixDisplay,confusion_matrix
import matplotlib.pyplot as plt
# combined_test_info_df = event_prob_df.join(x_test.reset_index()).merge(final_stat_df.reset_index()[['index','chai_patient_id','label','final_selection']],on='index',how='inner')

for thresold in [0.3,0.2]:
    combined_test_info_df[f'pred_{thresold}'] = combined_test_info_df[30].apply(lambda x:x>thresold)
    cm = confusion_matrix(combined_test_info_df['final_selection'],combined_test_info_df[f'pred_{thresold}'])
    disp = ConfusionMatrixDisplay(cm)
    disp.plot()
    plt.title(f'Thresold : {thresold}, Window : {30}')
    plt.show()

In [None]:
train_cindex_list = []
test_cindex_list = []
train_cindex_list.append(concordance_index(y_train, xgbse_model.predict(x_train)))
test_cindex_list.append(concordance_index(y_test, xgbse_model.predict(x_test)))

In [None]:
print(f'Train c-index list : {train_cindex_list}')
print(f'Test c-index list: {test_cindex_list}')
print(f'Train c-index average : {np.mean(train_cindex_list)}')
print(f'Test c-index average: {np.mean(test_cindex_list)}')

In [None]:
omop_feature_df = pd.read_csv('../Jaina/dts_patient_availability_3/dts_patient_availability/src/lab_featuer_group_label.csv')

In [None]:
omop_feature_df.head()

In [None]:
test_pat_ids = list(final_stat_df.iloc[list(x_test.index)]['chai_patient_id'])
omop_feature_df_test = omop_feature_df[omop_feature_df['person_id'].apply(lambda x:'CH'+str(x) in test_pat_ids)]

In [None]:
omop_y_test = convert_to_structured((omop_feature_df_test['label']),omop_feature_df_test['final_selection'])

In [None]:
pred_omop_y_test = xgbse_model.predict(omop_feature_df_test[['m_protein_in_serum','m_protein_in_urine','ca','serum_free_light_kappa','serum_free_light_lambda']])

concordance_index(omop_y_test,pred_omop_y_test)

In [None]:
(1-pred_omop_y_test[30]).hist(bins=100)

In [None]:
evennt_prob_omop_test = 1- pred_omop_y_test
for col in evennt_prob_omop_test.columns:
    selected_df = evennt_prob_omop_test[evennt_prob_omop_test[col]>=0.8]
    print(selected_df.shape)

In [None]:
combined_omop_test_df = evennt_prob_omop_test.join(omop_feature_df_test.reset_index())

In [None]:
for thresold in [0.2,0.3]:
    combined_omop_test_df[f'pred_{thresold}'] = combined_omop_test_df[30].apply(lambda x:x>thresold)
    cm = confusion_matrix(combined_omop_test_df['final_selection'],combined_omop_test_df[f'pred_{thresold}'])
    disp = ConfusionMatrixDisplay(cm)
    disp.plot()
    plt.title(f'Thresold : {thresold}, Window : {30}')
    plt.show()

In [None]:
final_stat_df.head()

In [None]:
feature_importance_df = pd.DataFrame(xgbse_model.feature_importances_.items(),columns=['feature','importance'])
feature_importance_df = feature_importance_df.sort_values('importance',ascending=False)
# feature_importance_df.iloc[0:50]

In [None]:
feature_importance_df_list = []
feature_importance_df_list.append(feature_importance_df)

In [None]:
feature_importance_df =feature_importance_df.sort_values(by='importance',ascending=False)
combine_feature_importance_df_subset = feature_importance_df[:20]
combine_feature_importance_df_subset = combine_feature_importance_df_subset.sort_values(by='importance',ascending=True)

In [None]:
import matplotlib.pyplot as plt

In [None]:
feature_importance_df.shape

In [None]:
x.describe()

In [None]:
x.shape

In [None]:
x.describe().to_csv('data_backup/describe_notebook_11_10_2023_v2.csv',index=False)

In [None]:
final_stat_df[['chai_patient_id','final_selection']].to_csv('data_backup/final_selection_11_10_2023_v1.csv',index=False)

In [None]:
pd.DataFrame(y_train)['c2'].describe()

In [None]:
f, ax = plt.subplots(figsize=(20,10))
plt.barh(combine_feature_importance_df_subset['feature'],combine_feature_importance_df_subset['importance'])
# plt.xticks(range(len(combine_feature_importance_df_subset['feature'])), combine_feature_importance_df_subset['average'],rotation='vertical')
plt.show()

In [None]:
x_train.shape

In [None]:
# combine_prob_dist_df = pred[0].reset_index().merge(pred[1].reset_index(),on='index',how='outer')
# combine_prob_dist_df = combine_prob_dist_df.merge(pred[2].reset_index(),on='index',how='outer')

In [None]:
# combine_prob_dist_df.head()

In [None]:
# combine_prob_dist_df['avg_30'] = combine_prob_dist_df.apply(lambda x:(x['30_x']+x['30_y']+x[30])/3,axis=1)
# combine_prob_dist_df['avg_60'] = combine_prob_dist_df.apply(lambda x:(x['60_x']+x['60_y']+x[60])/3,axis=1)
# combine_prob_dist_df['avg_90'] = combine_prob_dist_df.apply(lambda x:(x['90_x']+x['90_y']+x[90])/3,axis=1)
# combine_prob_dist_df['avg_120'] = combine_prob_dist_df.apply(lambda x:(x['120_x']+x['120_y']+x[120])/3,axis=1)

In [None]:
# combine_prob_dist_df['avg_120'].hist(bins=100)

In [None]:
(1-pred[0])[120].hist(bins=100)

In [None]:
import shap

In [None]:
explainer = shap.Explainer(xgbse_model.predict, x_train)
shap_values = explainer(x_train)

In [None]:
shap.plots.beeswarm(shap_values[list(selected_df.index),:,3],max_display=20)