In [None]:
import time
from datetime import timedelta
a = time.time()

import os
import json
import pickle
import numpy as np
import pandas as pd
import urllib.request
from comorbidipy import comorbidity
from scipy.optimize import linear_sum_assignment
from sklearn.metrics import pairwise_distances
from sklearn.linear_model import LogisticRegression
from tqdm import tqdm
import gc

%load_ext Cython


The Cython extension is already loaded. To reload it, use:
  %reload_ext Cython


In [None]:
gc.collect()

In [3]:

def load_pickle(file, folder):
    df = pickle.load(open(f'{folder}/{file}', 'rb'))
    return df

In [4]:
def get_site_data(site):
    
    demo_site = load_pickle(f'{site[:2]}demo_all.pkl', 'EHR')
    print('Load demo data for site ', site, demo_site.shape)
    dx_site = load_pickle(f'{site[:2]}dx_all.pkl', 'EHR')
    print('Load dx data for site ', site, dx_site.shape)
    rx_site = load_pickle(f'{site[:2]}rx_all.pkl', 'EHR')
    print('Load rx data for site ', site, rx_site.shape)
    lab_site = load_pickle(f'{site[:2]}lab_all.pkl', 'EHR')
    print('Load lab data for site ', site, lab_site.shape)
    vital_site = load_pickle(f'{site[:2]}vital_all.pkl', 'EHR')
    print('Load vital data for site ', site, vital_site.shape)

    person = set(demo_site.patid.unique().tolist()  +  dx_site.patid.unique().tolist() +  rx_site.patid.unique().tolist() +  lab_site.patid.unique().tolist() + vital_site.patid.unique().tolist())
    enc  =set( demo_site.encounterid.unique().tolist()  +  dx_site.encounterid.unique().tolist() +  rx_site.encounterid.unique().tolist() +  lab_site.encounterid.unique().tolist() + vital_site.encounterid.unique().tolist())



    return len(person), len(enc)

#### 1. Load in all of the data, already done

In [None]:
import gc
def get_site_data(site):
    
    demo_site = load_pickle(f'{site[:2]}demo_all.pkl', 'EHR')
    print('Load demo data for site ', site, demo_site.shape)
    dx_site = load_pickle(f'{site[:2]}dx_all.pkl', 'EHR')
    print('Load dx data for site ', site, dx_site.shape)
    rx_site = load_pickle(f'{site[:2]}rx_all.pkl', 'EHR')
    print('Load rx data for site ', site, rx_site.shape)
    lab_site = load_pickle(f'{site[:2]}lab_all.pkl', 'EHR')
    print('Load lab data for site ', site, lab_site.shape)
    vital_site = load_pickle(f'{site[:2]}vital_all.pkl', 'EHR')
    print('Load vital data for site ', site, vital_site.shape)

    dx_site.drop('encounterid', axis=1, inplace=True)
    rx_site.drop('encounterid', axis=1, inplace=True)
    lab_site.drop('encounterid', axis=1, inplace=True)
    vital_site.drop('encounterid', axis=1, inplace=True)

    return demo_site, dx_site, rx_site, lab_site, vital_site


def get_domains_for_40_persons(demo_site, dx_site, rx_site, lab_site, vital_site):
    comparison_date = pd.Timestamp('1975-01-01') 
    demo_site = demo_site[demo_site['birth_date']<=comparison_date]
    demo_40 = demo_site.patid.unique().tolist()
    print('Find demo <= comparison date', comparison_date, len(demo_40))

    print('Pick >=40 persons rows from demographics', demo_site.shape)

    dx_site = dx_site[dx_site.patid.isin(demo_40)]
    print('Pick >=40 persons rows from dx', dx_site.shape)

    rx_site = rx_site[rx_site.patid.isin(demo_40)]
    print('Pick >=40 persons rows from rx', rx_site.shape)

    lab_site = lab_site[lab_site.patid.isin(demo_40)]
    print('Pick >=40 persons rows from lab', lab_site.shape)

    vital_site = vital_site[vital_site.patid.isin(demo_40)]
    print('Pick >=40 persons rows from vital', vital_site.shape)
    gc.collect()
    return demo_site, dx_site, rx_site, lab_site, vital_site


def get_all_dates(dx_site, rx_site, lab_site, vital_site):

    file_dfs = [ dx_site, rx_site, lab_site, vital_site]
    date_names = ['dx_date_fill', 'rx_start_date', 'specimen_date', 'measure_date']  


    merged_ages_list = []

    for i in range(len(date_names)):
        print('df shape', file_dfs[i].shape)
        if file_dfs[i].shape[0] > 0:
            merged_ages_list.append(file_dfs[i][['patid', date_names[i]]].rename(columns={date_names[i]: "ALL_AGES"}))

    merged_ages = pd.concat(merged_ages_list, ignore_index = True).drop_duplicates()

    print('--Merge age rows', merged_ages.shape)

    merged_ages = merged_ages.groupby('patid')['ALL_AGES'].apply(list).reset_index(name='ALL_AGES')

    print(f"--There are {len(merged_ages['patid'])} unique patients across all the files.")

    merged_ages = merged_ages.iloc[
    ((merged_ages['ALL_AGES'].apply(lambda x: max(x) - min(x)).dt.days) >= 1).values, :]

    all_patients = merged_ages['patid'].unique().tolist()
    
    print(f"--There are {len(all_patients)} patients with at least 1 year of data available across all files.")
    return merged_ages

def get_domains_for_1year_EHR_persons(ehr_1_year_persons,demo_site , dx_site, rx_site, lab_site, vital_site):

    demo_site = demo_site[demo_site.patid.isin(ehr_1_year_persons)]
    print('Pick >=1 year EHR  persons rows from demo', demo_site.shape)

    dx_site = dx_site[dx_site.patid.isin(ehr_1_year_persons)]
    print('Pick >=1 year EHR  persons rows from dx', dx_site.shape)

    rx_site = rx_site[rx_site.patid.isin(ehr_1_year_persons)]
    print('Pick >=1 year EHR  persons rows from rx', rx_site.shape)

    lab_site = lab_site[lab_site.patid.isin(ehr_1_year_persons)]
    print('Pick >=1 year EHR  persons rows from lab', lab_site.shape)

    vital_site = vital_site[vital_site.patid.isin(ehr_1_year_persons)]
    print('Pick >=1 year EHR  persons rows from vital', vital_site.shape)
    gc.collect()
    return demo_site, dx_site, rx_site, lab_site, vital_site


def get_domains_for_dx_persons(demo_site, dx_site, rx_site, lab_site, vital_site):
    dx_persons = dx_site.patid.unique().tolist()
    print('Find dx_persons', len(dx_persons))

    print('Pick dx persons rows in dx', dx_site.shape)

    demo_site = demo_site[demo_site.patid.isin(dx_persons)]
    print('Pick dx persons rows from demographics', demo_site.shape)

    rx_site = rx_site[rx_site.patid.isin(dx_persons)]
    print('Pick dx persons rows from rx', rx_site.shape)

    lab_site = lab_site[lab_site.patid.isin(dx_persons)]
    print('Pick dx persons rows from lab', lab_site.shape)

    vital_site = vital_site[vital_site.patid.isin(dx_persons)]
    print('Pick dx persons rows from vital', vital_site.shape)
    gc.collect()
    return demo_site, dx_site, rx_site, lab_site, vital_site



# load in the ADRD diagnosis and medication codes from the Florida EHR paper
ADRD_dx_med_codes = pd.read_csv("./ADRD_dx_med_codes.csv")
# display(ADRD_dx_med_codes)
# get the ADRD diagnosis codes for cases
ADRD_STRINGS = ["Alzheimer's disease", "Vascular dementia", "Frontotemporal dementia", "Lewy Body Dementia"]
ADRD_STRINGS = '|'.join(ADRD_STRINGS)

ADRD_ICD9 = ADRD_dx_med_codes[(ADRD_dx_med_codes['Code_type'] == 'ICD-9') & ADRD_dx_med_codes['Concept'].str.contains(ADRD_STRINGS)].Code.reset_index(drop=True)
ADRD_ICD10 = ADRD_dx_med_codes[(ADRD_dx_med_codes['Code_type'] == 'ICD-10') & ADRD_dx_med_codes['Concept'].str.contains(ADRD_STRINGS)].Code.reset_index(drop=True)
display('ADRD ICD9 and ICD10 codes:', ' '.join(ADRD_ICD9.tolist()), ' '.join(ADRD_ICD10.tolist()))

# get the ADRD and other conditions to use as what a control patient should not have
ADRD_AND_OTHER_CONDITIONS = ADRD_STRINGS +'|'+ '|'.join(["Dementia", "Conditions cause dementia"])

ADRD_AND_OTHER_ICD9 = ADRD_dx_med_codes[(ADRD_dx_med_codes['Code_type'] == 'ICD-9') & ADRD_dx_med_codes['Concept'].str.contains(ADRD_AND_OTHER_CONDITIONS)].Code.reset_index(drop=True)
ADRD_AND_OTHER_ICD10 = ADRD_dx_med_codes[(ADRD_dx_med_codes['Code_type'] == 'ICD-10') & ADRD_dx_med_codes['Concept'].str.contains(ADRD_AND_OTHER_CONDITIONS)].Code.reset_index(drop=True)
display('ADRD and other dementia ICD9 and ICD10 codes:',' '.join(ADRD_AND_OTHER_ICD9.tolist()), ' '.join(ADRD_AND_OTHER_ICD10.tolist()))


ANTI_DEMENTIA_RXCUI = ADRD_dx_med_codes[(ADRD_dx_med_codes['Concept'] == "Anti-dementia medications") & (ADRD_dx_med_codes['Code_type'] == 'RXCUI')].Code.reset_index(drop=True)

ANTI_DEMENTIA_RXCUI_list = ANTI_DEMENTIA_RXCUI.tolist()
print('ANTI_DEMENTIA_RXCUI_list:', ANTI_DEMENTIA_RXCUI_list)



def get_dx_case_control(dx_site, rx_site, ADRD_ICD9, ADRD_ICD10, ADRD_AND_OTHER_ICD9, ADRD_AND_OTHER_ICD10, ANTI_DEMENTIA_RXCUI_list ):
   
    ############### CASES ##################
    # get all of the rows from dx_enc with an ICD9_CODE or ICD10_CODE that matches the ADRD diagnosis codes
    dx_site9 = dx_site[dx_site['dx_type']==9]
    dx_site10 = dx_site[dx_site['dx_type']==10]
    print('--dx_site9', dx_site9.shape)
    print('--dx_site10', dx_site10.shape)

    dx_site_adrd9 = dx_site9[dx_site9['dx'].isin(ADRD_ICD9) ]
    dx_site_adrd10 = dx_site10[dx_site10['dx'].isin(ADRD_ICD10) ]
    print('--dx_site_adrd9', dx_site_adrd9.shape)
    print('--dx_site_adrd9', dx_site_adrd10.shape)

    dx_site_adrd = pd.concat([dx_site_adrd9, dx_site_adrd10], axis=0)
    print('--dx_site_adrd', dx_site_adrd.shape)

    dx_cases = dx_site_adrd.patid.unique().tolist()

    ############### CONTROLS ##################
    dx_site_adrd_large9 = dx_site9[dx_site9['dx'].isin(ADRD_AND_OTHER_ICD9) ]
    dx_site_adrd_large10 = dx_site10[dx_site10['dx'].isin(ADRD_AND_OTHER_ICD10) ]
    print('--dx_site_adrd_large9', dx_site_adrd_large9.shape)
    print('--dx_site_adrd_large10', dx_site_adrd_large10.shape)

    dx_site_adrd_large = pd.concat([dx_site_adrd_large9, dx_site_adrd_large10], axis=0)
    print('--dx_site_adrd_large', dx_site_adrd_large.shape)

    dx_cases_large = dx_site_adrd_large.patid.unique().tolist()
    dx_controls = list(set(dx_site['patid'].tolist()) - set(dx_cases_large)) 
    print(f"Number of cases: {len(set(dx_cases))}")  
    print(f"--Number of controls: {len(set(dx_controls))}") 

    rx_site_adrd = rx_site[rx_site['rxnorm_cui'].isin(ANTI_DEMENTIA_RXCUI )]

    med_cases = set(rx_site_adrd['patid'].tolist())  

    dx_rx_controls = set(dx_controls) - med_cases #  controls based on dx patients who are no dx case or rx case

    print(f"Number of rx cases: {len(set(med_cases))}")  
    print('\t--Intersection of rx and dx cases', len(set(dx_cases) & set(med_cases)))
    print('\t--Intersection of rx and dx large cases', len(set(dx_cases_large) & set(med_cases)))

    print(f"Number of dx rx controls: {len(set(dx_rx_controls))}") 

    return dx_cases, dx_rx_controls, dx_site_adrd, rx_site_adrd


def get_domains_for_cases_controls(demo_site, dx_site, rx_site, lab_site, vital_site, cases, controls):
    cases_controls = set(cases) | set(controls)
    demo_site = demo_site[demo_site.patid.isin(cases_controls)]
    print('Pick cases_controls persons rows from demographics', demo_site.shape)

    dx_site = dx_site[dx_site.patid.isin(cases_controls)]
    print('Pick cases_controls persons rows from demographics', dx_site.shape)

    rx_site = rx_site[rx_site.patid.isin(cases_controls)]
    print('Pick cases_controls persons rows from rx', rx_site.shape)

    lab_site = lab_site[lab_site.patid.isin(cases_controls)]
    print('Pick cases_controls persons rows from lab', lab_site.shape)

    vital_site = vital_site[vital_site.patid.isin(cases_controls)]
    print('Pick cases_controls persons rows from vital', vital_site.shape)
    gc.collect()
    return demo_site, dx_site, rx_site, lab_site, vital_site


def get_case_index_and_compare(dx_site_adrd, rx_site_adrd, all_dates, dx_cases):

    cases_first_ADRD_df = dx_site_adrd.loc[dx_site_adrd.groupby('patid')['dx_date_fill'].idxmin()]
    rx_site_adrd_in_dx = rx_site_adrd[rx_site_adrd.patid.isin(dx_cases)]
    med_cases_first_ADRD_df = rx_site_adrd_in_dx.loc[rx_site_adrd_in_dx.groupby('patid')['rx_start_date'].idxmin()]

    cases_first_ADRD_df = cases_first_ADRD_df.rename(columns={'dx_date_fill':'first_date'})
    med_cases_first_ADRD_df = med_cases_first_ADRD_df.rename(columns={'rx_start_date':'first_date'})

    merged_cases_index = pd.concat([cases_first_ADRD_df, med_cases_first_ADRD_df], axis=0)
    merged_cases_index = merged_cases_index.groupby('patid')['first_date'].min().reset_index()
    merged_cases_index = merged_cases_index.rename(columns={'first_date':'INDEX_DATE'})
    print('--First occur dx rows', cases_first_ADRD_df.shape)
    print('--First occur rx rows', med_cases_first_ADRD_df.shape)
    print('--Index date for unique cases: ', merged_cases_index.patid.nunique())

    merged_ages_cases = all_dates[all_dates['patid'].isin(dx_cases)] # have >1r ehr

    merged_ages_cases['EARLIEST_DATE'] = merged_ages_cases['ALL_AGES'].apply(min)
    merged_ages_cases['LAST_DATE'] = merged_ages_cases['ALL_AGES'].apply(max)

    merged_cases_compare = pd.merge(merged_ages_cases, merged_cases_index, on="patid", how = "inner")
    print('Assert same:', merged_cases_compare.patid.nunique(), merged_cases_index.patid.nunique(), merged_ages_cases.patid.nunique())
    return merged_cases_compare


def get_control_dates(all_dates, dx_rx_controls ):

    merged_controls_dates = all_dates[all_dates['patid'].isin(set(dx_rx_controls))].reset_index(drop=True)
    print('--all date rows for controls:', merged_controls_dates.shape)
    print('--all dates for controls unique :', merged_controls_dates.patid.nunique())

    merged_controls_dates['EARLIEST_DATE'] = merged_controls_dates['ALL_AGES'].apply(min)
    merged_controls_dates['LAST_DATE'] = merged_controls_dates['ALL_AGES'].apply(max)

    assert (merged_controls_dates['LAST_DATE'] > merged_controls_dates['EARLIEST_DATE']).all(), \
        "Error: All rows have no LAST_DATE greater than FIRST_DATE"

    merged_controls_dates['Last_EHR_minus_one_INDEX_DATE'] = merged_controls_dates['LAST_DATE'] - pd.DateOffset(years=1)

    print('--control number', merged_controls_dates.patid.nunique())
    return merged_controls_dates


def get_cci_with_ICD(_input_dx, col, _caselist, icd, temp=True):

    # for the cases, we are calculating the comorbidities before the first diagnosis of ADRD
    _comorbidities_df = comorbidity(df = _input_dx.copy(), id = 'patid', code = col, age = None, icd=icd)

    if temp: 
        # for the cases with no diagnoses before the first ADRD diagnosis, we will assign them a comorbidity score of 0 (and a value of 0 for each comorbidity)
        _no_dx_before_first_ADRD = np.setdiff1d(_caselist, _input_dx['patid'].unique())

        temp_df = pd.DataFrame(columns=_comorbidities_df.columns)
        temp_df['patid'] = _no_dx_before_first_ADRD
        temp_df.iloc[:, 1:] = 0.0

        # concatenate the two dataframes
        _comorbidities_df = pd.concat([_comorbidities_df, temp_df])
    return _comorbidities_df


def get_merged_cci(cci1, cc2): # we compute cci based on ICD10 and ICD9 respectively, here combine the results from two cci results. 
    cci = cci1.set_index('patid').combine(cc2.set_index('patid'), func=lambda s1, s2: s1.combine(s2, max),  fill_value=0)
    weights = dict(
                ami=0,
                chf=2,
                pvd=0,
                cevd=0,
                dementia=2,
                copd=1,
                rheumd=1,
                pud=0,
                mld=2,
                diab=0,
                diabwc=1,
                hp=2,
                rend=1,
                canc=2,
                msld=4,
                metacanc=6,
                aids=2,
            )

    cci_ = cci.iloc[:, :-1]
    cci_['comorbidity_score'] = cci_.multiply(weights).sum(axis=1)

    cci_['comorbidity_score'] = cci_['comorbidity_score'] - 2 * cci_['dementia']
    cci_ = cci_.reset_index()
    
    return cci_

def get_all_cci(dx_site, merged_cases_compare, merged_controls_dates, dx_cases, dx_rx_controls ):

    dx_site9 = dx_site[dx_site['dx_type']==9]
    dx_site10 = dx_site[dx_site['dx_type']==10]
    print('--dx_site9 rows:', dx_site9.shape)
    print('--dx_site10 rows:', dx_site10.shape)

    cases_before_first_dx9 = pd.merge(dx_site9, merged_cases_compare[['patid', 'INDEX_DATE']], on='patid', how='right')
    print('--merge case index and dx_site9 for unique cases:', cases_before_first_dx9.patid.nunique())
    cases_before_first_dx10 = pd.merge(dx_site10, merged_cases_compare[['patid', 'INDEX_DATE']], on='patid', how='right')
    print('--merge case index and dx_site10 for unique cases:', cases_before_first_dx10.patid.nunique())

    cases_before_first_dx9 = cases_before_first_dx9[cases_before_first_dx9.apply(lambda x: x['dx_date_fill'] < x['INDEX_DATE'], axis=1)]
    print('--dx 9 rows before case index:', cases_before_first_dx9.patid.nunique())
    cases_before_first_dx10 = cases_before_first_dx10[cases_before_first_dx10.apply(lambda x: x['dx_date_fill'] < x['INDEX_DATE'], axis=1)]
    print('--dx 10 rows before case index:', cases_before_first_dx10.patid.nunique())

    cci_df_icd9 = get_cci_with_ICD(cases_before_first_dx9, 'dx', dx_cases, 'icd9')
    cci_df_icd10 = get_cci_with_ICD(cases_before_first_dx10, 'dx', dx_cases, 'icd10')

    cases_dx_comorbidities_df = get_merged_cci(cci_df_icd10, cci_df_icd9)

    cases_dx_comorbidities_df['target'] = 1


    controls_before_index9 = pd.merge(dx_site9, merged_controls_dates[['patid', 'Last_EHR_minus_one_INDEX_DATE']], on='patid', how='right')
    controls_before_index10 = pd.merge(dx_site10, merged_controls_dates[['patid', 'Last_EHR_minus_one_INDEX_DATE']], on='patid', how='right')
    print('\n--merge control index and dx_site9 for unique controls:', controls_before_index9.patid.nunique())
    print('--merge control index and dx_site10 for unique controls:', controls_before_index10.patid.nunique())

    control_cci_df_icd9 = get_cci_with_ICD(controls_before_index9, 'dx', dx_rx_controls, 'icd9', temp=False) 
    control_cci_df_icd10 = get_cci_with_ICD(controls_before_index10, 'dx', dx_rx_controls, 'icd10', temp=False) 

    controls_dx_comorbidities_df = get_merged_cci(control_cci_df_icd9, control_cci_df_icd10)

    controls_dx_comorbidities_df['target'] = 0

    comorbidities = pd.concat([cases_dx_comorbidities_df, controls_dx_comorbidities_df])

    comorbidities['comorbidity_score'] = comorbidities['comorbidity_score'].astype('float16')
    print('Overall cci :', comorbidities.patid.nunique())

    comorbidities = comorbidities[['patid', 'comorbidity_score', 'target']]
    return comorbidities

In [6]:
def each_site_process(site):
    demo_site, dx_site, rx_site, lab_site, vital_site = get_site_data(site)
    rx_site['rxnorm_cui'] = rx_site['rxnorm_cui'].astype('str')
    # filter <=1884, >=1year EHR, in dx_site 
    demo_site, dx_site, rx_site, lab_site, vital_site = get_domains_for_40_persons(demo_site, dx_site, rx_site, lab_site, vital_site)
    all_dates = get_all_dates(dx_site, rx_site, lab_site, vital_site)
    year1_ehr_persons = all_dates.patid.unique().tolist()
    demo_site , dx_site, rx_site, lab_site, vital_site = get_domains_for_1year_EHR_persons(year1_ehr_persons, demo_site , dx_site, rx_site, lab_site, vital_site)
    demo_site, dx_site, rx_site, lab_site, vital_site = get_domains_for_dx_persons(demo_site, dx_site, rx_site, lab_site, vital_site)

    # define case/control
    dx_cases, dx_rx_controls, dx_site_adrd, rx_site_adrd = get_dx_case_control(dx_site, rx_site, ADRD_ICD9, ADRD_ICD10, ADRD_AND_OTHER_ICD9, ADRD_AND_OTHER_ICD10, ANTI_DEMENTIA_RXCUI_list)
    demo_site, dx_site, rx_site, lab_site, vital_site = get_domains_for_cases_controls(demo_site, dx_site, rx_site, lab_site, vital_site, dx_cases, dx_rx_controls)
    
    # process date
    merged_cases_compare = get_case_index_and_compare(dx_site_adrd, rx_site_adrd, all_dates, dx_cases)
    merged_controls_dates = get_control_dates(all_dates, dx_rx_controls)
    comorbidity_scores = get_all_cci(dx_site, merged_cases_compare, merged_controls_dates, dx_cases, dx_rx_controls)
    return dx_cases, dx_rx_controls, demo_site, dx_site, rx_site, lab_site, vital_site,merged_cases_compare, merged_controls_dates, comorbidity_scores # list of persons, domain data of cases and controls, index of cases and controls, comorbidity


### extract ehrs for checkpoint


In [None]:
for site in [ 'wcm', 'columbia', 'montefiore', 'mshs', 'nyu']:
    print('Process ', site)
    dx_cases, dx_rx_controls, demo_site, dx_site, rx_site, lab_site, vital_site,merged_cases_compare, merged_controls_dates, comorbidity_scores = each_site_process(site)
    site_resources = [dx_cases, dx_rx_controls, demo_site, dx_site, rx_site, lab_site, vital_site,merged_cases_compare, merged_controls_dates, comorbidity_scores]
    # pickle.dump(site_resources, open( f'./Middle/{site}_site_resources.pkl', 'wb')) 

In [None]:
import gc

hold_out_portion = 0.5
dx_cases_all = []
dx_rx_controls_all = []

demo_site_all = []
dx_site_all = []
rx_site_all = []
lab_site_all = []
vital_site_all = []
merged_cases_compare_all = []
merged_controls_dates_all = []
comorbidity_scores_all = []
gc.collect()
for site in ['wcm', 'columbia', 'nyu', 'mshs', 'montefiore']:
    print('Load from site', site)
    dx_cases, dx_rx_controls, demo_site, dx_site, rx_site, lab_site,\
        vital_site, merged_cases_compare, merged_controls_dates, \
            comorbidity_scores = pickle.load(open(f'./Middle/{site}_site_resources.pkl', 'rb'))
    print(f'--Include cases from site {site}', len(dx_cases))  
    print(f'--Include controls from site {site}', len(dx_rx_controls))    
  
    dx_cases_all.append(dx_cases)
    dx_rx_controls_all.append(dx_rx_controls)
    demo_site_all.append(demo_site) 
    dx_site_all.append(dx_site) 
    # rx_site_all.append(rx_site) 
    # lab_site_all.append(lab_site) 
    # vital_site_all.append(vital_site) 
    merged_cases_compare_all.append(merged_cases_compare) 
    merged_controls_dates_all.append(merged_controls_dates) 
    comorbidity_scores_all.append(comorbidity_scores) 
    del dx_cases, dx_rx_controls, demo_site, dx_site, rx_site, lab_site,\
        vital_site, merged_cases_compare, merged_controls_dates, \
            comorbidity_scores
    print('Collect gabbage')
    gcsum =  gc.collect()
    print('--gc', gcsum)

cases_all = [i for idlist in dx_cases_all for i in idlist]
controls_all = [i for idlist in dx_rx_controls_all for i in idlist]

print('All cases ', len(cases_all), len(set(cases_all)))
print('All control ', len(controls_all), len(set(controls_all)))



In [None]:
pickle.dump(cases_all, open('./Middle/cases_all.pkl', 'wb'))
pickle.dump(controls_all, open('./Middle/controls_all.pkl', 'wb'))

merged_cases_compare_all_df = pd.concat(merged_cases_compare_all, axis=0)
merged_controls_dates_all_df = pd.concat(merged_controls_dates_all, axis=0)
print("--Merge cases compare shape: ", merged_cases_compare_all_df.shape)
print("--Merge controls dates shape: ", merged_controls_dates_all_df.shape)

pickle.dump(merged_cases_compare_all_df, open('./Middle/merged_cases_compare_all_df.pkl', 'wb'))
pickle.dump(merged_controls_dates_all_df, open('./Middle/merged_controls_dates_all_df.pkl', 'wb'))

# comorbidity_scores_all
comorbidity_scores_all_df = pd.concat(comorbidity_scores_all, axis=0)
pickle.dump(comorbidity_scores_all_df, open('./Middle/comorbidity_scores_all_df.pkl', 'wb'))


In [9]:

merged_cases_compare_all_df= pickle.load(open('./Middle/merged_cases_compare_all_df.pkl', 'rb'))
merged_controls_dates_all_df = pickle.load( open('./Middle/merged_controls_dates_all_df.pkl', 'rb'))

comorbidity_scores_all_df = pickle.load( open('./Middle/comorbidity_scores_all_df.pkl', 'rb'))

cases_all = pickle.load( open('./Middle/cases_all.pkl', 'rb'))
controls_all = pickle.load( open('./Middle/controls_all.pkl', 'rb'))

demo_site_all_df = pickle.load( open('./MiddleFeatures/demo_site_all_df.pkl', 'rb'))


### 2 random split for hold-out testing at 50%


In [None]:

import random
np.random.seed(52)
random.seed(52)
from sklearn.utils import check_random_state
random_state = check_random_state(52)

hold_out_control = np.random.choice(controls_all, int(hold_out_portion * (len(controls_all))), replace=False)
hold_out_case = np.random.choice(cases_all, int(hold_out_portion * (len(cases_all))), replace=False)

test_control = list(set(controls_all) - set(hold_out_control) )
test_case = list(set(cases_all) - set(hold_out_case) )

print('Hold out control', len(hold_out_control), 'Hold out case', len(hold_out_case))
print('Hold out example', hold_out_control[-2:], hold_out_case[-2:])
print('Test control', len(test_control), 'Test case', len(test_case))
print( 'Test example', test_control[-2:], test_case[-2:])


splits = (hold_out_case, hold_out_control, test_case, test_control)

with open(f"./Middle/splits/data_splits_portion_{str(hold_out_portion).split('.')[-1]}.pkl", 'wb') as f:
    pickle.dump(splits, f)

In [None]:
hold_out_portion=0.5
with open(f"./Middle/splits/data_splits_portion_{str(hold_out_portion).split('.')[-1]}.pkl", 'rb') as f:
    splits = pickle.load(f)

hold_out_case, hold_out_control, test_case, test_control = splits

In [None]:


matchcv_merged_cases_compare = merged_cases_compare_all_df[merged_cases_compare_all_df['patid'].isin(set(hold_out_case))]
matchcv_merged_controls_ages = merged_controls_dates_all_df[merged_controls_dates_all_df['patid'].isin(set(hold_out_control))]
print('Split date information of matchcv for: \n\tcase', matchcv_merged_cases_compare.shape, '\n\tcontrol', matchcv_merged_controls_ages.shape)

test_merged_cases_compare = merged_cases_compare_all_df[merged_cases_compare_all_df['patid'].isin(set(test_case))]
test_merged_controls_ages = merged_controls_dates_all_df[merged_controls_dates_all_df['patid'].isin(set(test_control))]
print('Split date information of test data for: \n\tcase:', test_merged_cases_compare.shape, '\n\tcontrol', test_merged_controls_ages.shape)

matchcv_demo_site_all_cases = demo_site_all_df[demo_site_all_df['patid'].isin(set(hold_out_case))]
matchcv_demo_site_all_controls = demo_site_all_df[demo_site_all_df['patid'].isin(set(hold_out_control))]
print('Split demo information of matchcv for: \n\tcase', matchcv_demo_site_all_cases.shape, '\n\tcontrol', matchcv_demo_site_all_controls.shape)

comorbidity_scores_all_df = pd.concat(comorbidity_scores_all, axis=0)

matchcv_comorbidity_scores_all_cases = comorbidity_scores_all_df[comorbidity_scores_all_df['patid'].isin(set(hold_out_case))]
matchcv_comorbidity_scores_all_controls = comorbidity_scores_all_df[comorbidity_scores_all_df['patid'].isin(set(hold_out_control))]
print('Split comorbidity information of matchcv for: \n\tcase', matchcv_comorbidity_scores_all_cases.shape, '\n\tcontrol', matchcv_comorbidity_scores_all_controls.shape)



#### 2.3 get initial match for hold-out cases from hold-out controls

In [None]:
from joblib import Parallel, delayed
import multiprocessing
from itertools import chain
import itertools
import multiprocessing
import pandas as pd
from datetime import datetime
from tqdm import tqdm

def initial_match_and_closer_date(caselist, case_index_date_dict, case_birth_date_dict, pre_sliced_data):
    index_limit = 6
    birth_limit = 1


    def get_possible_controls_and_nearest_date(caseid, col_all_control_info):

        filtered_df = col_all_control_info[
            (col_all_control_info["birth_date"] >= precomputed_ranges[caseid][2]) &
            (col_all_control_info["birth_date"] <=  precomputed_ranges[caseid][3]) &
             (col_all_control_info['ALL_AGES'] >=  precomputed_ranges[caseid][0]) & 
             (col_all_control_info['ALL_AGES'] <=  precomputed_ranges[caseid][1])
    ]

        filtered_df = filtered_df.reset_index(drop=True)

        filtered_df['abs_diff'] = (filtered_df['ALL_AGES'] -  precomputed_ranges[caseid][4]).abs()
        nearest_dates_df = filtered_df.loc[
            filtered_df.groupby('patid')['abs_diff'].idxmin()
        ]
        nearest_dates_df['caseid'] = caseid
        nearest_dates_df = nearest_dates_df[['caseid','patid', 'ALL_AGES']].rename(columns={'patid':'controlid'})
        del  filtered_df
        return list(nearest_dates_df.itertuples(index=False, name=None))


    precomputed_ranges = {
        caseid: (
            case_index_date_dict[caseid] - pd.DateOffset(months=index_limit),
            case_index_date_dict[caseid] + pd.DateOffset(months=index_limit),
            case_birth_date_dict[caseid] - pd.DateOffset(years=birth_limit),
            case_birth_date_dict[caseid] + pd.DateOffset(years=birth_limit),
            case_index_date_dict[caseid]
        )
        for caseid in tqdm(caselist, total=len(caselist))
    }
    print('Finish case range preparation!')

    all_results = []
    for caseid in  tqdm(caselist, total=len(caselist)):

        # results = get_possible_controls_and_nearest_date(caseid, random.choice(pre_sliced_data) )
        results = get_possible_controls_and_nearest_date(caseid, pre_sliced_data)

        all_results.append(results)


    return all_results

In [None]:
case_index_date_dict = dict(matchcv_merged_cases_compare[['patid', 'INDEX_DATE']].values)
case_birth_date_dict = dict(matchcv_demo_site_all_cases[['patid', 'birth_date']].values)

birth_demo = matchcv_demo_site_all_controls[['patid', 'birth_date']]
print('--control date for unique controls: ', matchcv_merged_controls_ages.patid.nunique())
print('--control demo for unique controls: ', birth_demo.patid.nunique())

all_control_info = birth_demo.merge(matchcv_merged_controls_ages, on='patid', how='inner')
print('--merge control date with control demo for unique controls: ', all_control_info.patid.nunique())
all_control_info = all_control_info[['patid', 'EARLIEST_DATE', 'birth_date', 'ALL_AGES']]
all_control_info_explode =  all_control_info.explode('ALL_AGES')
print('--exploded to rows: ', all_control_info_explode.shape)


#### 2.3.1 call the function above to do initial match (case and control have similar birth date and nearby visit date)


In [None]:

print('Finish control dates sampling preparation!')
results = initial_match_and_closer_date(hold_out_case, case_index_date_dict, case_birth_date_dict, all_control_info_explode)
print('average inital match control per case: ', sum([len(result) for result in results])/len(hold_out_case))



In [None]:
pickle.dump(results, open('./Middle/results.pkl', 'wb'))

In [None]:

results_melt = [tri for result in results for tri in result]
case_control_initial = pd.DataFrame(results_melt)
case_control_initial.columns=['caseid', 'controlid', 'nearest_date']

In [None]:
pickle.dump(case_control_initial, open('./Middle/case_control_initial.pkl', 'wb'))
print(case_control_initial.head())

In [None]:
case_control_initial = pickle.load( open('./Middle/case_control_initial.pkl', 'rb'))

### 3 prepare for PS matching

##### 3.1 function and class

In [None]:

from sklearn.neighbors import NearestNeighbors

from random import sample 
from sklearn.preprocessing import StandardScaler, OneHotEncoder
from sklearn.linear_model import LogisticRegression
from sklearn.compose import ColumnTransformer

def filter_ehr_enough_and_elder_control_with_closer_index(caseid_list, input_case_control_initial, all_control_info, window_length):
    print('\t--case-control-initial shape ', input_case_control_initial.shape)

    case_control_initial = input_case_control_initial[input_case_control_initial['caseid'].isin(set(caseid_list))]
    control_ehr = case_control_initial.merge(all_control_info[['patid', 'EARLIEST_DATE', 'birth_date']], left_on = 'controlid', right_on='patid', how='inner').drop('patid', axis=1)
    print('\t--merge case-control-initial with all_control_info shape ', control_ehr.shape,  '| unqiue controls:' ,control_ehr.controlid.nunique())

    # filter >= length for control
    control_ehr_length = control_ehr[control_ehr['EARLIEST_DATE'] +  pd.DateOffset(years=window_length + 1) <= control_ehr['nearest_date']] 
    print('\t--filter with length of yeas:', window_length + 1, 'Resulting in',  control_ehr_length.shape,  '| unique control:', control_ehr_length.controlid.nunique())
    # display(control_ehr_length.sort_values(by=['birth_date']))

    # filter >=50 at index for control 
    control_ehr_elder = control_ehr_length[control_ehr_length['birth_date'] + pd.DateOffset(years=50) <= control_ehr_length['nearest_date']] 
    print('\t--filter with elder 50, Resulting in',  control_ehr_elder.shape, '| unique control:',  control_ehr_elder.controlid.nunique())

    # control_ehr_elder = control_ehr_elder[['caseid', 'controlid']].groupby('caseid')['controlid'].apply(list).reset_index()
    # case_control_dict = dict(zip(control_ehr_elder['caseid'].values, control_ehr_elder[ 'controlid'].values))
    # display(control_ehr_elder)
    del control_ehr_length, control_ehr, case_control_initial
    print('collect gc:',gc.collect())
    return control_ehr_elder 


class Matcher_PS(object):

    def __init__(self,  subfolder='middle_files'):
        self.subfolder = subfolder

    def get_psm_input(self, corm_score_case, corm_score_control, transformed_demo, initial_match_df):
    
        def prepare_psm_case_input( yearcase):
            
            casedemos = transformed_demo[transformed_demo['patid'].isin(yearcase)]
 
            casecorms = corm_score_case[['patid', 'comorbidity_score']]

            case_demo_corm = casedemos.merge(casecorms, on='patid', how='left')
            assert case_demo_corm.patid.nunique() == casedemos.patid.nunique() 

            case_demo_corm['is_cases'] = 1
            return case_demo_corm
        

        def prepare_psm_control_input( yearcontrol):
            control_demos = transformed_demo[transformed_demo['patid'].isin(yearcontrol)]

            control_cormdf = corm_score_control[['patid', 'comorbidity_score']]
                
            control_corm_demo = control_demos.merge(control_cormdf, on='patid', how='left')
            assert control_demos.patid.nunique() == control_corm_demo.patid.nunique() 

            control_corm_demo['is_cases'] = 0
            return control_corm_demo
        
        year_case = initial_match_df.caseid.unique().tolist()
        year_control = initial_match_df.controlid.unique().tolist()

        case_psminput = prepare_psm_case_input(year_case)
        control_psminput = prepare_psm_control_input( year_control)
        
        cat_psminput = pd.concat([case_psminput, control_psminput], axis=0)
        del case_psminput, control_psminput
        return cat_psminput
    
    
    def propensity_score(self, covariate_df, cols):
        mod = LogisticRegression(class_weight="balanced")
        mod.fit(covariate_df[cols], covariate_df["is_cases"])
        covariate_df["propensity_score"] = mod.predict_proba(covariate_df[cols])[:, 1]
        return covariate_df
    
    
    def match_case_to_controls_on_propensity_score(self, propensity_df, case_match_control_dict, max_matches, case_index_date_df, prior_cases):

        cases_propensity_df = propensity_df[propensity_df["is_cases"]==1].reset_index(drop=True)
        controls_propensity_df = propensity_df[propensity_df["is_cases"]==0].reset_index(drop=True)
        
        num_cases = cases_propensity_df.shape[0]

        # case_str_ids = cases_propensity_df['patid'].astype(str).values

        controls_id_to_index = dict(zip(controls_propensity_df["patid"], controls_propensity_df.index.values)) # order is right
        # map control int to index int

        psm_match_dict = {}
        marked_as_used = set()
        
        if prior_cases is not None:
            sorted_keys = [key for key, value in sorted(case_match_control_dict.items() , key=lambda item: len(item[1])) if key not in prior_cases]
            prior_ids = sorted(prior_cases, key=lambda item: len(case_match_control_dict[item]))

            sorted_keys = prior_ids + sorted_keys
        else:
            sorted_keys = [key for key, value in sorted(case_match_control_dict.items() , key=lambda item: len(item[1]))]
        print('sorted keys!')


        controls_ps_values = controls_propensity_df['propensity_score'].values

        nn = NearestNeighbors(n_neighbors=max_matches)


        for counti,  case_str_id in tqdm(enumerate(sorted_keys), total=len(sorted_keys)):
            potential_controls = case_match_control_dict[case_str_id] # read id
            
            if len(potential_controls) <= max_matches: 
                # print('--- case:', case_str_id)
                print('ori potential < max:', case_str_id)


            potential_controls = list( set(potential_controls) - marked_as_used)
            # [pc for pc in potential_controls if pc not in marked_as_used]

            if len(potential_controls) > 0:
                potential_control_indices = np.vectorize(controls_id_to_index.get)(potential_controls)
                # potential control id to control index, ordered by control propensity df 
                
                if potential_control_indices.shape[0] > max_matches:
                    # control_df_potential = controls_propensity_df.iloc[potential_control_indices]  # get by index
                    control_df_potential = controls_ps_values[potential_control_indices]  # get by index
                    # print(control_df_potential.shape)
                    nn.fit(control_df_potential.reshape(control_df_potential.shape[0], -1))
                    # nn.fit(control_df_potential['propensity_score'].values.reshape(control_df_potential.shape[0], -1))

                    _, matched_controls_indices = nn.kneighbors(cases_propensity_df[cases_propensity_df['patid']== case_str_id]['propensity_score'].values.reshape( -1, 1))

                    matched_potential_indices = potential_control_indices[matched_controls_indices[0]] # the position in index

                    matched_potential_ids = controls_propensity_df.iloc[matched_potential_indices]["patid"].astype(str).values

                    marked_as_used.update(matched_potential_ids)
                
                else: 
                    # print('--- case', case_str_id)
                    # print('after removing the used: ', potential_control_indices.shape[0],'| used: ', len(marked_as_used))
                    matched_potential_ids = potential_controls
                    marked_as_used.update(matched_potential_ids)
                psm_match_dict[case_str_id] = matched_potential_ids

            else:
                 psm_match_dict[case_str_id] = []

            if counti % 5000 ==0:
                gcsum = gc.collect()
                print('Colelct', gcsum)


        psm_match_df = pd.DataFrame.from_dict(psm_match_dict, orient="index", columns=[f'psm_control_{str(i)}' for i in range(1, max_matches + 1)])
        psm_match_df.index.name = "case_id"
        psm_match_df.sort_index(inplace=True)
        
        index_dates = case_index_date_df[case_index_date_df["patid"].isin(psm_match_df.index.values.astype(str))]
        index_dates.sort_values("patid", inplace=True)

        psm_match_df.insert(0, "case_index_date", index_dates["INDEX_DATE"].values)

        return psm_match_df

       
    def get_matched_case_control_demo(self, corm_case, corm_control, transformed_demo, initial_match_dict, max_matches, index_case,  prior_cases=None):
        # the input should be specific to the year 
        psm_input = self.get_psm_input(corm_case, corm_control, transformed_demo, initial_match_dict)

        colsinput = [ i for i in psm_input.columns if i not in ['is_cases', 'patid'] ]

        propensity_pred =  self.propensity_score(psm_input, cols=colsinput)

        simple_match_dict = initial_match_dict.groupby(initial_match_dict['caseid'].map(str))['controlid'].apply(list).to_dict()
        
        propensity_match = self.match_case_to_controls_on_propensity_score\
        (propensity_pred, simple_match_dict, max_matches=max_matches, case_index_date_df=index_case, prior_cases=prior_cases)
        

        return propensity_match

#### 3.2 perform matching at 1:10

In [None]:
def get_psm_match(input_case_corm, input_control_corm, democase, democontrol, ratio, initial_match, caseindex, all_control_dates, yearlist=[]):
    
    def prepare_demo(democase, democontrol):
        democ = pd.concat([democase, democontrol], axis=0)
        demo = democ.drop_duplicates(inplace=False)
        # display(demo)
        demo['birth_year'] = demo['birth_date'].dt.year
        demo['sex_at_birth'] = demo['sex'].apply(lambda x: x if x in ['F', 'M'] else 'Other/unknown')
        demo['race'] = demo['race'].apply\
        (lambda x: x if x in ['03', '02',   '05', '07', '06',   '04', '01'] else 'Other/unknown')
        demo['eth'] = demo['hispanic'].apply\
        (lambda x: x if x in ['N', 'Y'] else 'Other/unknown')

        demo_persons = demo['patid'].values.tolist()

        preprocessor = ColumnTransformer(
        transformers=[
            ('num', StandardScaler(), ['birth_year']),
            ('cat', OneHotEncoder(drop='first'), ['sex_at_birth', 'race', 'eth'])
        ])

        ft_demo =  preprocessor.fit_transform(demo)
                    
        num_features = ['birth_year']
        cat_features = preprocessor.named_transformers_['cat'].get_feature_names_out(['sex_at_birth', 'race', 'eth'])
        all_features = num_features + list(cat_features)
        
        if hasattr(ft_demo, "toarray"):
            ft_demo = ft_demo.toarray()
            
        df_transformed = pd.DataFrame(ft_demo, columns=all_features)
        
        new_index = pd.Series(demo_persons, name='patid')

        df_transformed['patid'] = new_index
        
        return  df_transformed

    
    controls_with_cci =  input_control_corm.patid.unique().tolist()
    print('--Controls with cci', len(controls_with_cci))
    df_transformed = prepare_demo(democase, democontrol) ## all demo of all hold-out individuals
    case_with_demo = caseindex.merge(democase, on='patid', how='inner') # all hold-out case with all case index 
    print('Input case index, ', caseindex.shape, '| case demo', democase.shape, '| case with demo', case_with_demo.shape)
    
    obs_years_list = yearlist
    ps_matcher = Matcher_PS(subfolder='Middle')
    psm_match_collection = {}
    match_control_index = {}
    for year in reversed(obs_years_list):   
        print('-------------- Year', year)
        print('\nStart with cases: ', case_with_demo.patid.nunique())

        # select proper cases for the window
        if year == 0: 
            case_with_demo_ehr = case_with_demo[case_with_demo['EARLIEST_DATE'] + pd.DateOffset(days=1 )<= case_with_demo['INDEX_DATE']]
            print('--Case_with ehr length day >= ',  1, '| unique: ', case_with_demo_ehr.patid.nunique() )
        else:
            case_with_demo_ehr = case_with_demo[case_with_demo['EARLIEST_DATE'] + pd.DateOffset(years=year +1 )<= case_with_demo['INDEX_DATE']]
            print('--Case_with ehr length year >= ', year + 1, '| unique: ', case_with_demo_ehr.patid.nunique() )

        case_with_demo_elder = case_with_demo_ehr[case_with_demo_ehr['birth_date'] + pd.DateOffset(years=50)<= case_with_demo_ehr['INDEX_DATE']]
        case_year =  case_with_demo_elder['patid'].unique().tolist()       
        print('--Case_with index age >= 50', '| unique: ', len(case_year))


        print('\nStart with initial match rows', initial_match.shape )
        initial_match_1 = initial_match[initial_match['caseid'].isin(set(case_year))]
        print('--pick initial match rows within cases of the year', initial_match_1.shape)

        filter_initial_match = filter_ehr_enough_and_elder_control_with_closer_index(case_year, initial_match_1, all_control_dates, year)
        print('--pick initial match rows within control having enough EHR and elder, drop ', initial_match_1.shape[0]-filter_initial_match.shape[0])

        cci_initial_match_filter = filter_initial_match[filter_initial_match['controlid'].isin(controls_with_cci)]
        print('--pick initial match rows within control having cci, drop ', filter_initial_match.shape[0] - cci_initial_match_filter.shape[0])

        match_control_index['match_year{}'.format(year)] = cci_initial_match_filter[['caseid',	'controlid',	'nearest_date']]


        control_year = cci_initial_match_filter.controlid.unique().tolist()
        case_year = cci_initial_match_filter.caseid.unique().tolist()

        corm_case_year = input_case_corm[input_case_corm['patid'].isin(case_year)]
        corm_control_year = input_control_corm[input_control_corm['patid'].isin(control_year)]
        print('\nPick corm rows for cases and controls of the year', corm_case_year.shape ,corm_control_year.shape )

        print('\nCase of the year: ', len(case_year), '| Control of the year', len(control_year))

        _psm_control_case_df = ps_matcher.get_matched_case_control_demo\
        (corm_case_year, corm_control_year, df_transformed, cci_initial_match_filter, max_matches=ratio,\
         index_case=caseindex)
        
        psm_match_collection['match_year{}'.format(year)] = _psm_control_case_df

        display(_psm_control_case_df.isnull().sum())
    
    return psm_match_collection, match_control_index



In [None]:

psm1, match_control_index1 = get_psm_match(matchcv_comorbidity_scores_all_cases, matchcv_comorbidity_scores_all_controls, matchcv_demo_site_all_cases,\
                     matchcv_demo_site_all_controls, 10, case_control_initial, matchcv_merged_cases_compare[['patid', 'EARLIEST_DATE','INDEX_DATE']], \
                        all_control_info[['patid'	,'EARLIEST_DATE',	'birth_date']], yearlist=[0, 1, 2, 5, 10])


In [None]:
pickle.dump(psm1 , open( './PSM_results/years_psm_ratio10.pkl', 'wb'))
pickle.dump(match_control_index1 , open( './PSM_results/years_match_control_index.pkl', 'wb'))

Above is Done for PSM matching!

below compute control index date, save to psm_match dataframe

In [3]:
match_control_index1 = pickle.load(  open( './PSM_results/years_match_control_index.pkl', 'rb'))

psm1 = pickle.load( open( './PSM_results/years_psm_ratio10.pkl', 'rb'))


In [None]:


def get_control_index(row, control_col, _nearest_age_dict):
    key = (row.name, row[control_col])
    value = _nearest_age_dict.get(key, None)
    # print(key)
    return value
psm_match_years_with_control_index = {}

for year, psm_match in psm1.items():
    print(year)
    psm_match_copy = psm_match.copy()

    ori_cols =  psm_match_copy.columns

    controlindexyear = match_control_index1[year]

    dict_date = dict(zip(zip(controlindexyear['caseid'], controlindexyear['controlid']), controlindexyear['nearest_date']))
    print('dict obtained: ', len(dict_date))

    
    for control_col in [c  for c in ori_cols if 'psm_control_' in c]: 
        print('--', control_col)
        decide_control_index = psm_match_copy.apply(lambda row: get_control_index(row, control_col, dict_date), axis=1)

        psm_match_copy[control_col + '_index'] = decide_control_index

    psm_match_years_with_control_index[year] = psm_match_copy

In [25]:
pickle.dump(psm_match_years_with_control_index, open( './PSM_results/years_psm_ratio10_index_date.pkl', 'wb'))
