In [None]:
import os
import re
import json
import gc
import dill
import pickle
import warnings
import urllib.request
from functools import reduce
from collections import defaultdict

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

from tqdm import tqdm

from sklearn.model_selection import train_test_split, StratifiedKFold, RandomizedSearchCV, GridSearchCV
from sklearn.impute import SimpleImputer
from sklearn.metrics import roc_auc_score
from sklearn.preprocessing import StandardScaler
from sklearn.linear_model import LogisticRegression

warnings.filterwarnings('ignore')


### 1.1 load data

In [None]:
hold_out_portion = 0.5
ratio = 10
psm_match_years = pickle.load( open(f"PSM_results/years_psm_ratio10.pkl", 'rb'))
psm_match_years_with_dates = pickle.load( open(f"PSM_results/years_psm_ratio10_index_date.pkl", 'rb'))

data_splits = pickle.load(open(f'Middle/splits/data_splits_portion_{str(hold_out_portion).split('.')[-1]}.pkl', 'rb'))

hold_out_case, hold_out_control, test_case, test_control = data_splits

print('Hold_out_case:', hold_out_case.shape, hold_out_case[:2])
print('Hold_out_control:', hold_out_control.shape, hold_out_control[:2])    
print('Test_case:', len(test_case), test_case[:2])
print('Test_control:', len(test_control), test_control[:2])



In [None]:
cases_ages_index = pickle.load(open('Middle/merged_cases_compare_all_df.pkl', 'rb'))
control_ages = pickle.load(open('Middle/merged_controls_dates_all_df.pkl', 'rb'))

### 1.2 load feature

In [None]:
dxdata = pickle.load(open('./MiddleFeatures/processed_dx_enc_phe.pkl', 'rb'))
rxdata = pickle.load(open('./MiddleFeatures/processed_rx_ing.pkl', 'rb'))
labdata = pickle.load(open('./MiddleFeatures/part_processed_lab_flag.pkl', 'rb'))
vitaldata = pickle.load(open('./MiddleFeatures/processed_vital_continuous.pkl', 'rb'))


In [None]:
def convert_codelist():      # maybe useless
    ADRD_codes = pd.read_csv(f"ADRD_dx_med_codes.csv")
    ADRD_codes.loc[ADRD_codes["Code"] == "33111", "Code"] = "331.11"
    ADRD_and_dementia_strings = ["Alzheimer's disease", "Vascular dementia", "Frontotemporal dementia", "Lewy Body Dementia", "Dementia", "Conditions cause dementia"]
    ADRD_and_dementia_strings = '|'.join(ADRD_and_dementia_strings)

    ADRD_and_dementia_ICD9 = ADRD_codes[(ADRD_codes['Code_type'] == 'ICD-9') & ADRD_codes['Concept'].str.contains(ADRD_and_dementia_strings)] 
    ADRD_and_dementia_ICD10 = ADRD_codes[(ADRD_codes['Code_type'] == 'ICD-10') & ADRD_codes['Concept'].str.contains(ADRD_and_dementia_strings)] 
    print('ADRD_and_dementia_strings related ICD9 and ICD10 codes', ADRD_and_dementia_ICD10.shape[0] +   ADRD_and_dementia_ICD9.shape[0])
    pd.set_option('display.max_rows', None)  # Show all rows
    pd.set_option('display.max_columns', None)  # Show all columns

    ICD_to_Phewas = pd.read_csv(f"Phecode_map_v1_2_icd9_icd10cm_09_30_2024.csv", dtype={'ICD': str, 'Phecode': str})
    ICD9_to_Phewas = ICD_to_Phewas[ICD_to_Phewas['Flag']==9]
    ICD10_to_Phewas = ICD_to_Phewas[ICD_to_Phewas['Flag']==10]
    
    ADRD_and_dementia_ICD9_phecodes_ = ADRD_and_dementia_ICD9.merge(ICD9_to_Phewas[["ICD", "Phecode"]], left_on='Code', right_on='ICD', how='left')
    ADRD_and_dementia_ICD10_phecodes_ = ADRD_and_dementia_ICD10.merge(ICD10_to_Phewas[["ICD", "Phecode"]], left_on='Code', right_on='ICD', how='left')

    # display(ADRD_and_dementia_ICD9_phecodes_)
    # display(ADRD_and_dementia_ICD10_phecodes_)

    ADRD_and_dementia_ICD9_phecodes = ADRD_and_dementia_ICD9_phecodes_.Phecode.unique().tolist()
    ADRD_and_dementia_ICD10_phecodes = ADRD_and_dementia_ICD10_phecodes_.Phecode.unique().tolist()
 
    ICD9_codes = ADRD_and_dementia_ICD9_phecodes_.ICD.tolist()
    ICD10_codes = ADRD_and_dementia_ICD10_phecodes_.ICD.tolist()
    
    pd.set_option('display.max_rows', 100)  # Show all rows

    return ICD9_codes, ICD10_codes, ADRD_and_dementia_ICD9_phecodes, ADRD_and_dementia_ICD10_phecodes

code9, code10, phelist9, phelist10 = convert_codelist()
print(len(set(phelist9 + phelist10)), set(phelist9 + phelist10))

ADRD_and_dementia_strings related ICD9 and ICD10 codes 118
26 {'433.5', '290.3', '401.3', '433', '295.3', '290.2', '349', '290.1', '348', '331.1', '290.16', '317.1', '332', '291.4', '292.2', '433.2', '290.12', '433.32', '433.12', '316', '290.13', '290.11', '433.3', '290', '331.9', '331'}


In [None]:
phecode_9_10 = set(phelist9 + phelist10)

dxdata_noadrd = dxdata[~dxdata['phecode'].str.startswith(tuple(phecode_9_10))]
print('--Before :', dxdata.shape)
print('--Drop ADRD diagnosis phecodes from dxdata:', dxdata_noadrd.shape)
print('--Drop ratio:', 1- dxdata_noadrd.shape[0] / dxdata.shape[0])
del dxdata
gc.collect()

In [None]:
ADRD_dx_med_codes = pd.read_csv("./ADRD_dx_med_codes.csv")

ANTI_DEMENTIA_RXCUI = ADRD_dx_med_codes[(ADRD_dx_med_codes['Concept'] == "Anti-dementia medications") & (ADRD_dx_med_codes['Code_type'] == 'RXCUI')].Code.reset_index(drop=True)
print('Before', rxdata.shape)

ANTI_DEMENTIA_RXCUI_list = ANTI_DEMENTIA_RXCUI.unique().tolist()
ANTI_DEMENTIA_RXCUI_codes = [int(i) for i in ANTI_DEMENTIA_RXCUI_list]

rxdata['rxcui_ing'] = pd.to_numeric(rxdata['rxcui_ing'], errors='coerce')
rxdata['rxcui_ing'] = rxdata['rxcui_ing'].astype(int)

rxdata_noadrd = rxdata[~rxdata.rxcui_ing.isin(ANTI_DEMENTIA_RXCUI_codes)]

print('--Drop Anti-dementia medications from rxdata ingredient:', rxdata_noadrd.shape)
print('--Drop ratio:', 1 - rxdata_noadrd.shape[0] / rxdata.shape[0])
del rxdata
gc.collect()

In [None]:
print('--all noduplicate lab', labdata.shape )
labdata_common = labdata[~labdata['Class'].isin(['free t3, serum', 'urine urea nitrogen'])]

print('--Removing uncommon lab', labdata_common.shape )
print('--Drop ratio:', 1- labdata_common.shape[0] / labdata.shape[0])

del labdata
gc.collect()

In [None]:
import numpy as np
vitaldata = vitaldata[(vitaldata[['ht', 'wt', 'diastolic', 'systolic']]>0).any(axis=1)]

In [None]:
def get_record_span(df): 
    df["record_date"] = pd.to_datetime(df["record_date"])
    min_date = df["record_date"].dt.date.min()
    max_date = df["record_date"].dt.date.max()
    return ((max_date - min_date).days)/365


def get_early_date(df, col='record_date'): 
    df[col] = pd.to_datetime(df[col])

    min_date = df[col].dt.date.min()
    return min_date


def process_features(input_cases, input_controls, years, input_dxdata, input_rxdata, input_labdata, input_vitaldata, common):

    def _process_dx(diagnosis, all_patient_ids, person_id_to_index_date_map, date_offset, common):
        dx_tmp = diagnosis[diagnosis["patid"].isin(all_patient_ids)].reset_index(drop=True)
        print("\n#### Processing the diganosis data ...", dx_tmp.shape)
    

        dx_tmp = dx_tmp[dx_tmp["dx_date_fill"] <= ( dx_tmp["patid"].map(person_id_to_index_date_map) - date_offset)].reset_index(drop=True)
        print('\t--Dropped EHR after offset, now rows and unique people: ', dx_tmp.shape, dx_tmp['patid'].nunique() )
        if common:

            dx_tmp_common = dx_tmp.groupby('phecode')['patid'].nunique().reset_index(name='count')
            min_limit = int(len(set(all_patient_ids)) * 0.001)
            if min_limit <2:
                min_limit = 2
                print('\t--DXdata set limit to 2')

            dx_tmp_common = dx_tmp_common[dx_tmp_common['count'] >= min_limit]
            phe_keep = dx_tmp_common['phecode'].to_list()
            
            dx_tmp = dx_tmp[dx_tmp['phecode'].isin(phe_keep)]
            print('\t-- Only keep {} common phecode with occurrence >'.format(len(phe_keep)), min_limit, ', remains rows:' , dx_tmp.shape)
        print('\tGet first date...')
        dx_first = dx_tmp.rename(columns={"dx_date_fill": "record_date"})
        dx_first = dx_first.groupby("patid")['record_date'].apply(min).reset_index(name="early")
                
        return dx_first, dx_tmp


    def _process_rx(medication, all_patient_ids, person_id_to_index_date_map, date_offset, common):
        rx_tmp = medication[medication["patid"].isin(all_patient_ids)].reset_index(drop=True)
        print("\n#### Processing the medication data ...", rx_tmp.shape)

        rx_tmp = rx_tmp[rx_tmp["rx_start_date"] <= ((rx_tmp["patid"].map(person_id_to_index_date_map)) - date_offset)].reset_index(drop=True)
        print('\t--Dropped EHR after offset, now ', rx_tmp.shape, rx_tmp['patid'].nunique())
        if common:
            rx_tmp_common = rx_tmp.groupby('rxcui_ing')['patid'].nunique().reset_index(name='count')
            min_limit = int(len(set(all_patient_ids)) * 0.001)
            if min_limit < 2:
                min_limit = 2
                print('\t--RXdata set limit to 2')
            rx_tmp_common = rx_tmp_common[rx_tmp_common['count'] >= min_limit]
            rxc_keep = rx_tmp_common['rxcui_ing'].to_list()
            
            rx_tmp = rx_tmp[rx_tmp['rxcui_ing'].isin(rxc_keep)]
            print('\t-- Only keep {} common rxcui_ing with occurrence >='.format(len(rxc_keep)), min_limit, ', remains rows:' , rx_tmp.shape)

        print('\tGet first date...')
        rx_first = rx_tmp.rename(columns={"rx_start_date": "record_date"})
        rx_first = rx_first.groupby("patid")['record_date'].apply(min).reset_index(name="early")
        return rx_first, rx_tmp

    
    def _process_ms_common(measurement, all_patient_ids, person_id_to_index_date_map, date_offset, common):

        measurement_tmp = measurement
        measurement_tmp = measurement_tmp[measurement_tmp["patid"].isin(all_patient_ids)].reset_index(drop=True)
        print("\n#### Processing the measurement data ...", measurement_tmp.shape)

        measurement_tmp = measurement_tmp[measurement_tmp["specimen_date"] <= ((measurement_tmp["patid"].map(person_id_to_index_date_map)) - date_offset)].reset_index(drop=True)
        print('\t--Dropped EHR after offset, now ', measurement_tmp.shape,  measurement_tmp['patid'].nunique())

        if common:
            measurement_tmp_common = measurement_tmp.groupby('Class')['patid'].nunique().reset_index(name='count')
            
            min_limit = int(len(set(all_patient_ids)) * 0.001)
            measurement_tmp_common = measurement_tmp_common[measurement_tmp_common['count'] >= min_limit]

            lab_keep = measurement_tmp_common['Class'].to_list()
            measurement_tmp = measurement_tmp[measurement_tmp['Class'].isin(lab_keep)]
            print('\t-- Only keep {} common lab classes with occurrence >'.format(len(lab_keep)), min_limit, ', remains rows:' , measurement_tmp.shape)

        print('\tGetting recent measurements......')
        measurement_tmp = measurement_tmp.sort_values(by=["patid", "Class", "specimen_date"], ascending=[True, True, False])
        measurement_tmp = measurement_tmp.groupby(["patid", "Class"]).first().reset_index()
        print('\t-- Recent measurements', measurement_tmp.shape)
        measurement_tmp_format = measurement_tmp[['patid', 'specimen_date', 'Class', 'flag']]

        return measurement_tmp_format
    

    def _process_vital_common(vitals, all_patient_ids, person_id_to_index_date_map, date_offset, common):
         
        range_dict = {
            'sbp': {'low':None, 'high':  120},
            'dbp': {'low':None, 'high': 80},
            'bmi':{'low':18.5, 'high':25}
        }        

        vitals = vitals[vitals["patid"].isin(all_patient_ids)].reset_index(drop=True)
        print("\n#### Processing the vital data ...", vitals.shape)

        vitals = vitals[vitals["measure_date"] <= ((vitals["patid"].map(person_id_to_index_date_map)) - date_offset)].reset_index(drop=True)
        print('\t--Dropped EHR after offset, now ', vitals.shape,  vitals['patid'].nunique())
        
        vital_cols = []
        for col in ['ht', 'wt' , 'systolic', 'diastolic']:
            # print('--process each vital, ', col)
            vital_col = vitals[~vitals[col].isna()][['patid', 	'measure_date'	, col]]
            # print('----each vital, ' , col, vital_col.shape)

            vital_col = vital_col.sort_values(by=["patid", "measure_date"], ascending=[True, False])
            # print('----sort by date')

            vital_col = vital_col.groupby(["patid"]).first().reset_index()    
            print('----keep last date', col, vital_col.shape)

            vital_cols.append(vital_col)


        # bmi computation
        ht_dict = dict(zip(vital_cols[0]['patid'], vital_cols[0]['ht']))

        wt_dict = dict(zip(vital_cols[1]['patid'], vital_cols[1]['wt']))
        wtdate_dict = dict(zip(zip(vital_cols[1]['patid'], vital_cols[1]['wt']), vital_cols[1]['measure_date']))

        bmi_list = []

        print('--computing bmi')
        for idx in tqdm(all_patient_ids):
            ht_idx = ht_dict.get(idx, None)
            wt_idx = wt_dict.get(idx, None)

            if (ht_idx is not None) and (wt_idx is not None):
                if ht_idx > 0 and wt_idx > 0:
                    record_date_ = wtdate_dict.get((idx, wt_idx), None) 
                    bmiidx = (wt_idx * 0.4536 if wt_idx>77 else wt_idx) /( ( ht_idx  * 2.54 / 100 ) **2)
                    bmi_list.append([idx, record_date_, bmiidx ])
        vital_bmi = pd.DataFrame(bmi_list, columns=['patid', 'measure_date',  'value'])
        vital_bmi['Class'] = 'bmi'
        vital_bmi['flag'] = vital_bmi['value'].apply(lambda x: 'ablow' if x < range_dict['bmi']['low'] else ('abhigh' if x> range_dict['bmi']['high'] else 'nor'))
        
        vitals_bmi_format = vital_bmi[['patid', 'measure_date', 'Class', 'flag']].rename(columns={'measure_date':'specimen_date'})
        print('-- build bmi df', vitals_bmi_format.shape)


        def determine_flag(row, _vital_range, cat ):
            ranges = _vital_range.get(cat, None)

            if ranges and pd.notnull(ranges['low']) and pd.notnull(ranges['high']):
                if ranges['low'] <= row['value_as_number'] <= ranges['high']:
                    return 'nor'
                elif row['value_as_number'] < ranges['low']:
                    return 'ablow'
                elif row['value_as_number'] > ranges['high']:
                    return 'abhigh'
            if ranges and pd.isnull(ranges['low']):
                if row['value_as_number'] <= ranges['high']:
                    return 'nor'
                else:
                    return 'abhigh'
                
            if ranges and pd.isnull(ranges['high']):
                if row['value_as_number'] >= ranges['low']:
                    return 'nor'
                else:
                    return 'ablow'    

        # sbp
        vitals_sbp = vital_cols[2].rename(columns={'systolic':'value_as_number'})
        vitals_sbp['flag'] = vitals_sbp.apply(lambda row: determine_flag(row, range_dict, cat ='sbp'), axis=1)
        vitals_sbp['Class'] = 'sbp'

        vitals_dbp = vital_cols[3].rename(columns={'diastolic':'value_as_number'})
        vitals_dbp['flag'] = vitals_dbp.apply(lambda row: determine_flag(row, range_dict, cat ='dbp'), axis=1)
        vitals_dbp['Class'] = 'dbp'

        vitals_bp = pd.concat([vitals_sbp, vitals_dbp], axis=0)
        print('-- build bp df', vitals_bp.shape)

        # display(vitals_bp)
        vitals_bp_format = vitals_bp[['patid', 'measure_date', 'Class',  'flag']].rename(columns={'measure_date':'specimen_date'})

        return vitals_bmi_format, vitals_bp_format
        

    def _process_vital_ms_common(measurement,vitals, all_patient_ids, person_id_to_index_date_map, date_offset, common):
        
        mea_tmp = _process_ms_common(measurement, all_patient_ids, person_id_to_index_date_map, date_offset, common)
        
        vital_bmi, vital_bp = _process_vital_common(vitals, all_patient_ids, person_id_to_index_date_map, date_offset, common)

        all_tmp = pd.concat([mea_tmp, vital_bmi, vital_bp], axis=0, ignore_index=True)
        
        print('\tGet first date...') # this only computes the abnormal as signs of meas_first, 
        meas_first = all_tmp.rename(columns={"specimen_date": "record_date"})

        meas_first = meas_first.groupby("patid").apply(lambda x: get_early_date(x)).reset_index(name="early")
        return meas_first, all_tmp
    
    def _construct_feature_matrix4(cases, controls, diagnosis, medication, measurement=None, vital_data=None, demographics=None, prediction_window=None, CP_num=None, further_lblist=None, common=False):
        print(f"----****Constructing the feature matrix for computable phenotype {CP_num} for a {prediction_window}-year prediction window ...")

        if isinstance(cases, pd.DataFrame) and isinstance(controls, pd.DataFrame): # the case index dataframe | the control index dataframe 
            cases_tmp = cases 
            controls_tmp = controls
            all_patient_ids =set(cases['patid'].tolist() + controls_tmp['patid'].tolist() )

            cases_to_index_date_map = cases_tmp.set_index("patid").to_dict()["INDEX_DATE"]  
            controls_to_index_date_map = controls_tmp.set_index("patid").to_dict()["INDEX_DATE"] 
    

        print('\tall patients: ', len(all_patient_ids))   
        person_id_to_index_date_map = cases_to_index_date_map | controls_to_index_date_map
       
        date_offset = pd.DateOffset(years = prediction_window) if prediction_window > 0 else  pd.DateOffset(days = 1)
        date_offset_early = pd.DateOffset(years = prediction_window+1)


        print(f"\nThere are a total of {cases_tmp.shape[0] + controls_tmp.shape[0]} samples, with {cases_tmp.shape[0]} cases and {controls_tmp.shape[0]} controls.")
        print('\t-- define date offset', date_offset)
        print('\t-- define date offset_early', date_offset_early, '\n')
        
        dx_first, dx_tmp = _process_dx(diagnosis, all_patient_ids, person_id_to_index_date_map, date_offset, common)
        rx_first, rx_tmp = _process_rx(medication, all_patient_ids, person_id_to_index_date_map, date_offset, common)

        ms_first, ms_tmp = _process_vital_ms_common(measurement, vital_data, all_patient_ids, person_id_to_index_date_map, date_offset, common)
        # display(dx_first, rx_first, ms_first)
        # display(dx_tmp, rx_tmp, ms_tmp)
        early_all = pd.concat([dx_first, rx_first, ms_first])

        early_all['early'] =  pd.to_datetime(early_all['early'])
        early_all = early_all.groupby('patid', as_index=False)['early'].min()

        early_all_index  = early_all[early_all["early"] <= ((early_all["patid"].map(person_id_to_index_date_map)) - date_offset_early)] # .reset_index(drop=True)
        early_ids = early_all_index['patid'].unique().tolist()
        print('|| Early patients: ', early_all_index['patid'].nunique())
        del early_all, dx_first, rx_first, ms_first, early_all_index
        gc.collect()
        return early_ids, dx_tmp, rx_tmp, ms_tmp
    
    
    def _pivot_feature_matrix4(early_ids, cases, controls, dx_tmp, rx_tmp, measurement_tmp=None, demographics=None, prediction_window=None, CP_num=1):
    
        if isinstance(cases, pd.DataFrame) and isinstance(controls, pd.DataFrame): # the case index dataframe | the control index dataframe 
            cases_tmp = cases
            controls_tmp = controls
            all_patient_ids =set(cases_tmp['patid'].tolist() + controls_tmp['patid'].tolist() )
            
            cases_to_index_date_map = cases_tmp.set_index("patid").to_dict()["INDEX_DATE"] # modified Nov2
            controls_to_index_date_map = controls_tmp.set_index("patid").to_dict()["INDEX_DATE"] 
    
        all_patient_ids = list(set(all_patient_ids) & set(early_ids)) #[p for p in all_patient_ids if p in early_ids] # here use the early_ids obtained from the _construct_feature_matrix4 function
        print('\nRestrict all patients to only include Early patients: ', len(all_patient_ids)) # remove patients that dont have more than 1 year observation window  

        person_id_to_index_date_map = cases_to_index_date_map | controls_to_index_date_map

        # date_offset = prediction_window 
        date_offset = pd.DateOffset(years = prediction_window) if prediction_window > 0 else  pd.DateOffset(days = 1)
 
        print('\t-- date offset', date_offset)

        dx_tmp = dx_tmp[dx_tmp['patid'].isin(early_ids)]
        rx_tmp = rx_tmp[rx_tmp['patid'].isin(early_ids)]
        measurement_tmp = measurement_tmp[measurement_tmp['patid'].isin(early_ids)] # no lab data in ohsu
        print('|| After early patients selection:', dx_tmp.shape, rx_tmp.shape, measurement_tmp.shape, '\n')

        dx_tmp = dx_tmp[["patid", "phecode"]]
        dx_tmp = dx_tmp.pivot_table(index="patid", columns="phecode", values="phecode", aggfunc={"phecode": "first"}).reset_index().rename_axis(None, axis=1)
        print('#### Pivot to dx feature dataframe', dx_tmp.shape)

        dx_tmp = pd.merge(dx_tmp, pd.DataFrame({"patid":list(set(all_patient_ids))}), on="patid", how="outer")
        print('\t-- outer merge with patients with no dx records', dx_tmp.shape) # if no diagnosis for one id, fill the row with 0
        dx_tmp[dx_tmp.columns[1:]] = dx_tmp[dx_tmp.columns[1:]].notna().astype(int) 
        print('\t-- dx_tmp column examples', dx_tmp.columns[1:10])
        dx_tmp = dx_tmp.rename(columns={col: str(col)+ "_dx" for col in dx_tmp.columns[1:]}) # add the _dx as the suffix for diagnosis columns
        print(f"-- The dimension of the processed diagnosis data is {dx_tmp.shape}.\n")
        gc.collect()
        rx_tmp = rx_tmp[["patid", "rxcui_ing"]]
        rx_tmp = rx_tmp.pivot_table(index="patid", columns="rxcui_ing", values="rxcui_ing", aggfunc={"rxcui_ing": "first"}).reset_index().rename_axis(None, axis=1)
        print('#### Pivot to rx feature dataframe', rx_tmp.shape)

        rx_tmp = pd.merge(rx_tmp, pd.DataFrame({"patid":list(set(all_patient_ids))}), on="patid", how="outer")
        print('\t-- outer merge with patients with no rx records', rx_tmp.shape)
        rx_tmp[rx_tmp.columns[1:]] = rx_tmp[rx_tmp.columns[1:]].notna().astype(int)
        print('\t-- rx_tmp column examples', rx_tmp.columns[1:10])
        rx_tmp = rx_tmp.rename(columns={col: str(col)+ "_rx" for col in rx_tmp.columns[1:]}) # # add the _rx as the suffix for medication columns
        print(f"-- The dimension of the processed medication data is {rx_tmp.shape}.\n")
        gc.collect()

        measurement_tmp = measurement_tmp[["patid", "Class", 'flag']]
        measurement_tmp['Class_flag'] = measurement_tmp['Class'] + '_' + measurement_tmp['flag']

        measurement_tmp['value'] = 1   
#         measurement_tmp = measurement_tmp.pivot_table(index='person_id', columns='Class_flag', values='value', fill_value=0)
        measurement_tmp = measurement_tmp.pivot_table(index="patid", columns="Class_flag", values="value", aggfunc={"Class_flag": "first"}).reset_index().rename_axis(None, axis=1)
        print('#### Pivot to ms feature dataframe', measurement_tmp.shape)
        measurement_tmp = pd.merge(measurement_tmp, pd.DataFrame({"patid":list(set(all_patient_ids))}), on="patid", how="outer")
        print('\t-- outer merge with patients with no measurement_tmp records', measurement_tmp.shape)
        measurement_tmp[measurement_tmp.columns[1:]] = measurement_tmp[measurement_tmp.columns[1:]].notna().astype(int)
        print('\t-- measurement_tmp column examples', measurement_tmp.columns[1:10])
        measurement_tmp = measurement_tmp.rename(columns={col: str(col)+"_measurement" for col in measurement_tmp.columns[1:]})
        print(f"-- The dimension of the processed measurement data is {measurement_tmp.shape}.\n")
        gc.collect()

        print("Generating the final feature matrix ...")
        all_dfs = [dx_tmp, rx_tmp, measurement_tmp]

        features_df = reduce(lambda left, right: pd.merge(left, right, on=['patid'], how='outer'), all_dfs) 

        del all_dfs
        gc.collect()
        print(f"The dimension of the feature matrix is {features_df.shape}.\n")
        # compute age_at_prediction_window column as the feature, 
        # besides this age_at_prediction, there is no demographics in ohsu, so here the _px is not combined with demographics 

        features_df["age_at_prediction_window"] = features_df["patid"].map(person_id_to_index_date_map) - date_offset # index age - prediction_window = age_at_prediction
             
        targets = np.where(features_df["patid"].isin(cases_tmp["patid"]), 1, 0) # set the label for case and control
        print('\tCases: ', targets.sum(), 'Controls: ', len(targets) - targets.sum(), 'Ratio: ', (len(targets) - targets.sum())/targets.sum())
        object_cols = [col for col in features_df.dtypes[features_df.dtypes == 'object'].index if col != "patid"]
        for col in object_cols:
            features_df[col] = pd.to_numeric(features_df[col], errors="coerce")
        gc.collect()
        print("--------------------------- Done -------------------------------------\n")
        return features_df, targets

    
    _early_patient_ids, _dx_tmp, _rx_tmp, _ms_tmp = _construct_feature_matrix4\
    (input_cases, input_controls, input_dxdata, input_rxdata, input_labdata, input_vitaldata, prediction_window=years[0], CP_num=1, common=common)
    # pivot feature to matrix
    X, y = _pivot_feature_matrix4\
    (_early_patient_ids, input_cases, input_controls,\
        _dx_tmp, _rx_tmp, _ms_tmp, None,\
        years[0], 1)
    return X, y, _early_patient_ids


def further_filter_controls(_feature, _target, _early, input_psm, years):
    _feature_drop = {}
    _target_drop = {}
    _early_drop = {}
    
    for prediction_window in reversed(years):
        print('Dealing with prediction window: ', prediction_window)
        px = _feature[prediction_window]
        tx = _target[prediction_window]
        ex = _early[prediction_window] 

        print('\tBefore the processing, total ids:',  len(tx), '| cases:', tx.sum(), '| control:',len(tx)-tx.sum(), '| Control/Case ratio:',(len(tx)-tx.sum())/tx.sum())

        psm = input_psm[f'match_year{prediction_window}']

        ori_case = psm.index.tolist()
        # compare the original cases and the early_ids, to find the cases that are removed 
        # case_to_remove =  [i for i in ori_case if i not in set(ex)]
        case_to_remove = list(set(ori_case) - set(ex))

        print('\tRemoved cases:', case_to_remove)

        if len(case_to_remove) > 0:
            loc_case = psm.loc[case_to_remove]

            psmcol = [c for c in loc_case.columns if c.startswith('psm_control_') and 'index' not in c] # select the matched control columns
            loc_control = loc_case.loc[:, psmcol].values # search the columns to find controls in the psm data
            control_to_remove_ = [item for sublist in loc_control for item in sublist] # fine the controls of the cases (that will be removed)
            control_to_remove  = [i for i in control_to_remove_ if i is not None]
            print('\t\tWill remove controls:', len(control_to_remove))
            
            idx_control_to_remove = px[px['patid'].isin(control_to_remove)] # find the row index of the controls 
            idx_control_to_remove = idx_control_to_remove.index.tolist()

            px_drop = px.drop(idx_control_to_remove, inplace=False, axis=0)
            tx_drop = np.array([ ti for i , ti in enumerate(tx) if i not in idx_control_to_remove])
            ex_drop = [ ei for i , ei in enumerate(ex) if ei not in control_to_remove]
            print('\tAfter the processing: total ids:', px_drop.shape, '| cases:', tx_drop.sum(), '| control:',len(tx_drop)-tx_drop.sum(), '| Control/Case ratio:',(len(tx_drop)-tx_drop.sum())/tx_drop.sum())

        else:
            # if no case is removed afte early_ids computation, this function is designed to have no use. 
            px_drop = px
            tx_drop = tx
            
            ex_drop = ex

        _target_drop[prediction_window] = tx_drop
        _early_drop[prediction_window] = ex_drop
        _feature_drop[prediction_window] = px_drop
        # there is no demographics in ohsu, so here the _px is not combined with demographics 
        print('------------\n')
    return _feature_drop, _target_drop, _early_drop
    

In [None]:
matched_f = {}
matched_t = {}
matched_e = {}


for year in reversed( [0, 1, 2, 5, 10]):
# for year in [10]:
    psm_match = psm_match_years_with_dates[f'match_year{year}']

    case_id_year = psm_match.index.tolist()  
    case_date_info = cases_ages_index[cases_ages_index['patid'].isin(case_id_year)]  
    # get the case index date in the psm file 

    control_columns = [col for col in psm_match.columns if 'psm_control_' in col]
    control_and_index_list = []
    for match_number in range(10):
        control_and_index = psm_match.loc[:, ['psm_control_' + str(match_number+1), 'psm_control_' + str(match_number+1) + '_index']].reset_index(drop=True)
        control_and_index.columns = ['patid', 'INDEX_DATE']

        control_and_index_list.append(control_and_index)


    control_index_info = pd.concat(control_and_index_list, axis=0)

    control_index_info = control_index_info.dropna(subset=['patid'])
    print('--control_date_info shape', control_index_info.shape)
    control_id_year = control_index_info.patid.unique().tolist()

    print('Case for this year:', len(case_id_year), '| Control for this year:', len(control_id_year), '| Total ids:', len(set(case_id_year + control_id_year)))

    ## common is false
    f, t, e = process_features(case_date_info, control_index_info, [year], dxdata_noadrd, rxdata_noadrd, labdata_common, vitaldata, common=False)
    matched_f[year] = f 
    matched_t[year] = t 
    matched_e[year] = e 


In [None]:
matchf_drop, matcht_drop, matche_drop = further_filter_controls(matched_f, matched_t, matched_e, psm_match_years_with_dates, [10, 5, 2, 1, 0 ])


In [None]:
pickle.dump(matchf_drop, open(f'./MiddleFeatures/matched_f_drop_portion_{str(hold_out_portion).split('.')[-1]}_ratio_{str(ratio)}.pkl', 'wb'))
pickle.dump(matcht_drop, open(f'./MiddleFeatures/matched_t_drop_portion_{str(hold_out_portion).split('.')[-1]}_ratio_{str(ratio)}.pkl', 'wb'))
pickle.dump(matche_drop, open(f'./MiddleFeatures/matched_e_drop_portion_{str(hold_out_portion).split('.')[-1]}_ratio_{str(ratio)}.pkl', 'wb'))

In [None]:
hold_out_portion =0.5
ratio = 10
pickle.dump(matched_f, open(f'./MiddleFeatures/matched_f_portion_{str(hold_out_portion).split('.')[-1]}_ratio_{str(ratio)}.pkl', 'wb'))

pickle.dump(matched_t, open(f'./MiddleFeatures/matched_t_portion_{str(hold_out_portion).split('.')[-1]}_ratio_{str(ratio)}.pkl', 'wb'))
pickle.dump(matched_e, open(f'./MiddleFeatures/matched_e_portion_{str(hold_out_portion).split('.')[-1]}_ratio_{str(ratio)}.pkl', 'wb'))

In [None]:
unmatched_f = {}
unmatched_t = {}
unmatched_e = {}

for year in [0, 1, 2 ,5, 10]:

    test_case_date = cases_ages_index[cases_ages_index['patid'].isin(test_case)]  # select from 50% testing data
    required_date_offset =  year + 1
    test_case_date = test_case_date[test_case_date['EARLIEST_DATE'] + pd.DateOffset(years =  required_date_offset) <= test_case_date['INDEX_DATE']] # the test case should have at least 1 year in observation window
    test_case_id = test_case_date['patid'].unique().tolist()

    test_control_date = control_ages[control_ages['patid'].isin(test_control)] 
    test_control_date = test_control_date[test_control_date['EARLIEST_DATE'] +  pd.DateOffset(years =  required_date_offset) <= test_control_date['Last_EHR_minus_one_INDEX_DATE']] # for testing data, control is not matched so the index date is Last_EHR_minus_one_INDEX_DATE
    test_control_date = test_control_date.rename(columns={'Last_EHR_minus_one_INDEX_DATE':'INDEX_DATE'})
    test_control_id = test_control_date['patid'].unique().tolist()

    print('Case for this year:', len(test_case_id), '| Control for this year:', len(test_control_id), '| Total ids:', len(set(test_case_id + test_control_id)))

    f, t, e = process_features(test_case_date, test_control_date, [year], dxdata_noadrd, rxdata_noadrd, labdata_common, vitaldata, common=False)
    unmatched_f[year] = f 
    unmatched_t[year] = t 
    unmatched_e[year] = e 
    

In [None]:
del labdata_common, dxdata_noadrd, rxdata_noadrd, vitaldata
gc.collect()

0

In [None]:
hold_out_portion =0.5
ratio = 10
pickle.dump(unmatched_f, open(f'./MiddleFeatures/test_f_portion_{str(hold_out_portion).split('.')[-1]}.pkl', 'wb'))


In [None]:
pickle.dump(unmatched_t, open(f'./MiddleFeatures/test_t_portion_{str(hold_out_portion).split('.')[-1]}.pkl', 'wb'))
pickle.dump(unmatched_e, open(f'./MiddleFeatures/test_e_portion_{str(hold_out_portion).split('.')[-1]}.pkl', 'wb'))