## Read from MIMIC csv files

In [23]:
import pandas as pd

In [20]:
# files can be downloaded from https://mimic.physionet.org/gettingstarted/dbsetup/
MED_FILE = 'PRESCRIPTIONS.csv'
DIAG_ICD_FILE = 'DIAGNOSES_ICD.csv' 
PROCEDURES_ICD_FILE = 'PROCEDURES_ICD.csv'

# drug code mapping files (already in ./data/)
ndc2atc_file = 'ndc2atc_level4.csv' 
cid_atc = 'drug-atc.csv'
ndc2rxnorm_file = 'ndc2rxnorm_mapping.txt'

# drug-drug interactions can be down https://www.dropbox.com/s/8os4pd2zmp2jemd/drug-DDI.csv?dl=0
ddi_file = 'drug-DDI.csv'

In [24]:
# 'PROCEDURES_ICD.csv'
def process_procedure():
    pro_pd = pd.read_csv(PROCEDURES_ICD_FILE, dtype={'ICD9_CODE':'category'})
    pro_pd.drop(columns=['ROW_ID'], inplace=True)
#     pro_pd = pro_pd[pro_pd['SEQ_NUM']<5]
#     def icd9_tree(x):
#         if x[0]=='E':
#             return x[:4] 
#         return x[:3]
#     pro_pd['ICD9_CODE'] = pro_pd['ICD9_CODE'].map(icd9_tree)
    pro_pd.drop_duplicates(inplace=True)
    pro_pd.sort_values(by=['SUBJECT_ID', 'HADM_ID', 'SEQ_NUM'], inplace=True)
    pro_pd.drop(columns=['SEQ_NUM'], inplace=True)
    pro_pd.drop_duplicates(inplace=True)
    pro_pd.reset_index(drop=True, inplace=True)
    
    return pro_pd

process_procedure()

Unnamed: 0,SUBJECT_ID,HADM_ID,ICD9_CODE
0,2,163353,9955
1,3,145834,9604
2,3,145834,9962
3,3,145834,8964
4,3,145834,9672
...,...,...,...
228674,99999,113369,8108
228675,99999,113369,8051
228676,99999,113369,8162
228677,99999,113369,9979


In [109]:
'PRESCRIPTIONS.csv'
def process_med():
    med_pd = pd.read_csv(MED_FILE)
    # filter
    med_pd.drop(columns=['ROW_ID','DRUG_TYPE','DRUG_NAME_POE','DRUG_NAME_GENERIC',
                     'FORMULARY_DRUG_CD','GSN','PROD_STRENGTH','DOSE_VAL_RX',
                     'DOSE_UNIT_RX','FORM_VAL_DISP','FORM_UNIT_DISP','FORM_UNIT_DISP',
                      'ROUTE','ENDDATE','DRUG'], axis=1, inplace=True)
    med_pd.drop(index = med_pd[med_pd['NDC'] == '0'].index, axis=0, inplace=True)
    med_pd.fillna(method='pad', inplace=True)
    med_pd.dropna(inplace=True)
    med_pd.drop_duplicates(inplace=True)
    med_pd['NDC'] = med_pd['NDC'].astype('int64').astype('category')
    med_pd['ICUSTAY_ID'] = med_pd['ICUSTAY_ID'].astype('int64')
    med_pd['STARTDATE'] = pd.to_datetime(med_pd['STARTDATE'], format='%Y-%m-%d %H:%M:%S')    
    med_pd.sort_values(by=['SUBJECT_ID', 'HADM_ID', 'ICUSTAY_ID', 'STARTDATE'], inplace=True)
    med_pd = med_pd.reset_index(drop=True)

    ##
    med_pd = filter_first24hour_med(med_pd)
    #     med_pd = med_pd.drop(columns=['STARTDATE'])
    
    med_pd = med_pd.drop(columns=['ICUSTAY_ID'])
    med_pd = med_pd.drop_duplicates()
    med_pd = med_pd.reset_index(drop=True)
    
    med_pd_lg2 = process_visit_lg2(med_pd).reset_index(drop=True)    
    med_pd = med_pd.merge(med_pd_lg2[['SUBJECT_ID']], on='SUBJECT_ID', how='inner')    
    
    med_pd = med_pd.reset_index(drop=True)

    return med_pd
    
def filter_first24hour_med(med_pd):
    med_pd_new = med_pd.drop(columns=['NDC'])
    med_pd_new = med_pd_new.groupby(by=['SUBJECT_ID','HADM_ID','ICUSTAY_ID']).head(1).reset_index(drop=True)
    med_pd_new = pd.merge(med_pd_new, med_pd, on=['SUBJECT_ID','HADM_ID','ICUSTAY_ID','STARTDATE'])
    med_pd_new = med_pd_new.drop(columns=['STARTDATE'])
    return med_pd_new

# visit > 2
def process_visit_lg2(med_pd):
    a = med_pd[['SUBJECT_ID', 'HADM_ID']].groupby(by='SUBJECT_ID')['HADM_ID'].unique().reset_index()
    a['HADM_ID_Len'] = a['HADM_ID'].map(lambda x:len(x))
    a = a[a['HADM_ID_Len'] > 1]
    return a 

In [110]:
med_pd = process_med()

  med_pd = pd.read_csv(MED_FILE)


In [111]:
med_pd.dtypes

SUBJECT_ID       int64
HADM_ID          int64
NDC           category
dtype: object

In [112]:
# among 'PRESCRIPTIONS.csv', return records within first 24 hours & more than 2 visits. 
med_pd

Unnamed: 0,SUBJECT_ID,HADM_ID,NDC
0,17,161087,713016550
1,17,161087,904770418
2,17,161087,904404073
3,17,161087,904526161
4,17,161087,121075210
...,...,...,...
436126,99982,183791,51991045757
436127,99982,183791,409490234
436128,99982,183791,904404073
436129,99982,183791,63323026201


In [51]:
def process_diag():
    diag_pd = pd.read_csv(DIAG_ICD_FILE)
    diag_pd.dropna(inplace=True)
#     def icd9_tree(x):
#         if x[0]=='E':
#             return x[:4] 
#         return x[:3]
#     diag_pd['ICD9_CODE'] = diag_pd['ICD9_CODE'].map(icd9_tree)
#     diag_pd = diag_pd[diag_pd['SEQ_NUM'] < 5]
    diag_pd.drop(columns=['SEQ_NUM','ROW_ID'],inplace=True)
    diag_pd.drop_duplicates(inplace=True)
    diag_pd.sort_values(by=['SUBJECT_ID','HADM_ID'], inplace=True)
    return diag_pd.reset_index(drop=True)

process_diag()

Unnamed: 0,SUBJECT_ID,HADM_ID,ICD9_CODE
0,2,163353,V3001
1,2,163353,V053
2,2,163353,V290
3,3,145834,0389
4,3,145834,78559
...,...,...,...
650935,99999,113369,75612
650936,99999,113369,7861
650937,99999,113369,4019
650938,99999,113369,25000


In [113]:
med_pd_tmp = med_pd.copy()

In [119]:
med_pd_tmp['NDC']

0           713016550
1           904770418
2           904404073
3           904526161
4           121075210
             ...     
436126    51991045757
436127      409490234
436128      904404073
436129    63323026201
436130    55390000401
Name: NDC, Length: 436131, dtype: category
Categories (4204, int64): [0, 2050101, 2140701, 2144401, ..., 79511050204, 87701071218, 87701083336, 87701089415]

In [134]:
med_pd_tmp[med_pd_tmp['NDC'] == 64100133]
# med_pd_tmp['NDC']
# 00338500341

Unnamed: 0,SUBJECT_ID,HADM_ID,NDC
48731,5696,114502,64100133
151182,17564,153774,64100133
174464,20140,119827,64100133
222165,25658,188188,64100133


In [143]:
def ndc2atc4(med_pd):
    with open(ndc2rxnorm_file, 'r') as f:
        ndc2rxnorm = eval(f.read())
    '''
    {'00338500341': u'1665046', '00064100133': u'545106', '00069077038': u'616287', '00085043104': u'746189', ...}
    '''

    ndc2rxnorm.pop('idx')
    # for k, v in ndc2rxnorm.items():
    #     print(k, v)
    #     print(int(k), v)

    ndc2rxnorm_ = {int(k): v for k, v in ndc2rxnorm.items()}
    
    med_pd['RXCUI'] = med_pd['NDC'].map(ndc2rxnorm_)
    med_pd.dropna(inplace=True)

    # print(med_pd)

    rxnorm2atc = pd.read_csv(ndc2atc_file)
    rxnorm2atc = rxnorm2atc.drop(columns=['YEAR','MONTH','NDC'])
    rxnorm2atc.drop_duplicates(subset=['RXCUI'], inplace=True)
    med_pd.drop(index = med_pd[med_pd['RXCUI'].isin([''])].index, axis=0, inplace=True)
    
    med_pd['RXCUI'] = med_pd['RXCUI'].astype('int64')
    med_pd = med_pd.reset_index(drop=True)
    med_pd = med_pd.merge(rxnorm2atc, on=['RXCUI'])
    med_pd.drop(columns=['NDC', 'RXCUI'], inplace=True)
    med_pd = med_pd.rename(columns={'ATC4':'NDC'})
    med_pd['NDC'] = med_pd['NDC'].map(lambda x: x[:4])
    med_pd = med_pd.drop_duplicates()    
    med_pd = med_pd.reset_index(drop=True)
    return med_pd

ndc2atc4_ = ndc2atc4(med_pd_tmp)

In [144]:
ndc2atc4_

Unnamed: 0,SUBJECT_ID,HADM_ID,NDC
0,17,161087,N02B
1,17,194023,N02B
2,23,152223,N02B
3,36,182104,N02B
4,103,130744,N02B
...,...,...,...
237373,97547,112445,N05A
237374,97547,127852,N05A
237375,97547,194938,N05A
237376,98920,115857,L02B


In [57]:
def filter_1000_most_pro(pro_pd):
    pro_count = pro_pd.groupby(by=['ICD9_CODE']).size().reset_index().rename(columns={0:'count'}).sort_values(by=['count'],ascending=False).reset_index(drop=True)
    pro_pd = pro_pd[pro_pd['ICD9_CODE'].isin(pro_count.loc[:1000, 'ICD9_CODE'])]
    
    return pro_pd.reset_index(drop=True)    

def filter_2000_most_diag(diag_pd):
    diag_count = diag_pd.groupby(by=['ICD9_CODE']).size().reset_index().rename(columns={0:'count'}).sort_values(by=['count'],ascending=False).reset_index(drop=True)
    diag_pd = diag_pd[diag_pd['ICD9_CODE'].isin(diag_count.loc[:1999, 'ICD9_CODE'])]
    
    return diag_pd.reset_index(drop=True)

def filter_300_most_med(med_pd):
    med_count = med_pd.groupby(by=['NDC']).size().reset_index().rename(columns={0:'count'}).sort_values(by=['count'],ascending=False).reset_index(drop=True)
    med_pd = med_pd[med_pd['NDC'].isin(med_count.loc[:299, 'NDC'])]
    
    return med_pd.reset_index(drop=True)

In [149]:
# get med and diag (visit>=2)
med_pd = process_med()

  med_pd = pd.read_csv(MED_FILE)


In [151]:
med_pd

Unnamed: 0,SUBJECT_ID,HADM_ID,NDC
0,17,161087,713016550
1,17,161087,904770418
2,17,161087,904404073
3,17,161087,904526161
4,17,161087,121075210
...,...,...,...
436126,99982,183791,51991045757
436127,99982,183791,409490234
436128,99982,183791,904404073
436129,99982,183791,63323026201


In [152]:
med_pd = ndc2atc4(med_pd)
med_pd

Unnamed: 0,SUBJECT_ID,HADM_ID,NDC
0,17,161087,N02B
1,17,194023,N02B
2,23,152223,N02B
3,36,182104,N02B
4,103,130744,N02B
...,...,...,...
237373,97547,112445,N05A
237374,97547,127852,N05A
237375,97547,194938,N05A
237376,98920,115857,L02B


In [154]:
med_pd = filter_300_most_med(med_pd)
med_pd

Unnamed: 0,SUBJECT_ID,HADM_ID,NDC
0,17,161087,N02B
1,17,194023,N02B
2,23,152223,N02B
3,36,182104,N02B
4,103,130744,N02B
...,...,...,...
237373,97547,112445,N05A
237374,97547,127852,N05A
237375,97547,194938,N05A
237376,98920,115857,L02B


In [155]:
diag_pd = process_diag()
diag_pd = filter_2000_most_diag(diag_pd)


In [156]:
diag_pd

Unnamed: 0,SUBJECT_ID,HADM_ID,ICD9_CODE
0,2,163353,V3001
1,2,163353,V053
2,2,163353,V290
3,3,145834,0389
4,3,145834,78559
...,...,...,...
625429,99995,137810,41401
625430,99999,113369,7861
625431,99999,113369,4019
625432,99999,113369,25000


In [157]:
pro_pd = process_procedure()
#     pro_pd = filter_1000_most_pro(pro_pd)
pro_pd

Unnamed: 0,SUBJECT_ID,HADM_ID,ICD9_CODE
0,2,163353,9955
1,3,145834,9604
2,3,145834,9962
3,3,145834,8964
4,3,145834,9672
...,...,...,...
228674,99999,113369,8108
228675,99999,113369,8051
228676,99999,113369,8162
228677,99999,113369,9979


In [158]:
med_pd_key = med_pd[['SUBJECT_ID', 'HADM_ID']].drop_duplicates()
diag_pd_key = diag_pd[['SUBJECT_ID', 'HADM_ID']].drop_duplicates()
pro_pd_key = pro_pd[['SUBJECT_ID', 'HADM_ID']].drop_duplicates()

In [161]:
pro_pd_key

Unnamed: 0,SUBJECT_ID,HADM_ID
0,2,163353
1,3,145834
7,4,185777
10,5,178980
11,6,107064
...,...,...
228657,99985,176670
228662,99991,151118
228669,99992,197084
228671,99995,137810


In [162]:
combined_key = med_pd_key.merge(diag_pd_key, on=['SUBJECT_ID', 'HADM_ID'], how='inner')
combined_key = combined_key.merge(pro_pd_key, on=['SUBJECT_ID', 'HADM_ID'], how='inner')

In [163]:
combined_key

Unnamed: 0,SUBJECT_ID,HADM_ID
0,17,161087
1,17,194023
2,23,152223
3,36,182104
4,103,130744
...,...,...
15011,27714,179593
15012,59948,186194
15013,9911,152237
15014,16634,148327


In [164]:
diag_pd = diag_pd.merge(combined_key, on=['SUBJECT_ID', 'HADM_ID'], how='inner')
med_pd = med_pd.merge(combined_key, on=['SUBJECT_ID', 'HADM_ID'], how='inner')
pro_pd = pro_pd.merge(combined_key, on=['SUBJECT_ID', 'HADM_ID'], how='inner')

In [167]:
pro_pd

Unnamed: 0,SUBJECT_ID,HADM_ID,ICD9_CODE
0,17,161087,3731
1,17,161087,8872
2,17,161087,3893
3,17,194023,3571
4,17,194023,3961
...,...,...,...
68118,99982,151454,3527
68119,99982,151454,3961
68120,99982,183791,3721
68121,99982,183791,3897


In [168]:
# flatten and merge
diag_pd = diag_pd.groupby(by=['SUBJECT_ID','HADM_ID'])['ICD9_CODE'].unique().reset_index()  
med_pd = med_pd.groupby(by=['SUBJECT_ID', 'HADM_ID'])['NDC'].unique().reset_index()
pro_pd = pro_pd.groupby(by=['SUBJECT_ID','HADM_ID'])['ICD9_CODE'].unique().reset_index().rename(columns={'ICD9_CODE':'PRO_CODE'})  
med_pd['NDC'] = med_pd['NDC'].map(lambda x: list(x))
pro_pd['PRO_CODE'] = pro_pd['PRO_CODE'].map(lambda x: list(x))
data = diag_pd.merge(med_pd, on=['SUBJECT_ID', 'HADM_ID'], how='inner')
data = data.merge(pro_pd, on=['SUBJECT_ID', 'HADM_ID'], how='inner')
#     data['ICD9_CODE_Len'] = data['ICD9_CODE'].map(lambda x: len(x))
data['NDC_Len'] = data['NDC'].map(lambda x: len(x))

In [169]:
data

Unnamed: 0,SUBJECT_ID,HADM_ID,ICD9_CODE,NDC,PRO_CODE,NDC_Len
0,17,161087,"[4239, 5119, 78551, 4589, 311, 7220, 71946, 2724]","[N02B, A01A, A02B, A06A, B05C, A12A, A12C, C01...","[3731, 8872, 3893]",14
1,17,194023,"[7455, 45829, V1259, 2724]","[N02B, A01A, A02B, A06A, A12A, B05C, A12C, C01...","[3571, 3961, 8872]",15
2,21,109451,"[41071, 78551, 5781, 5849, 40391, 4280, 4592, ...","[A06A, C07A, A12A, A02A, J01M, C02A, B05C, B01...","[0066, 3761, 3950, 3606, 0042, 0047, 3895, 399...",17
3,21,111970,"[0388, 78552, 40391, 42731, 70709, 5119, 6823,...","[A06A, B05C, A12C, A07A, N02B, B01A, N06A, A01...","[3995, 8961, 0014]",17
4,23,124321,"[2252, 3485, 78039, 4241, 4019, 2720, 2724, V4...","[C07A, N02B, A02B, H03A, N03A, A01A, N05A, C09...",[0151],11
...,...,...,...,...,...,...
15011,99923,164914,"[45829, 4532, 2761, 5723, 4561, 45621, 5849, 7...","[N02B, B01A, A06A, J01M, A07A]","[5491, 4513]",5
15012,99923,192053,"[5712, 5856, 5724, 40391, 9974, 5601, 30393, V...","[A06A, A12A, A12C, N01A, C07A, B01A, A02B, A04...","[5059, 504, 5569, 0093]",18
15013,99982,112748,"[4280, 42823, 5849, 4254, 2763, 42731, 78729, ...","[A01A, C03C, A06A, A02B, A12C, B05C, C01A, B01...",[3721],11
15014,99982,151454,"[42823, 4254, 2875, 42731, 3970, 5303, 4280, V...","[C03C, A02B, A06A, C07A, C09C, A12B]","[3527, 3961]",6


In [148]:
med_pd

Unnamed: 0,SUBJECT_ID,HADM_ID,NDC


In [None]:
# def process_all():
    
#     # get med and diag (visit>=2)
#     med_pd = process_med()
#     med_pd = ndc2atc4(med_pd)
# #     med_pd = filter_300_most_med(med_pd)
    
#     diag_pd = process_diag()
#     diag_pd = filter_2000_most_diag(diag_pd)
    
#     pro_pd = process_procedure()
# #     pro_pd = filter_1000_most_pro(pro_pd)
    
#     med_pd_key = med_pd[['SUBJECT_ID', 'HADM_ID']].drop_duplicates()
#     diag_pd_key = diag_pd[['SUBJECT_ID', 'HADM_ID']].drop_duplicates()
#     pro_pd_key = pro_pd[['SUBJECT_ID', 'HADM_ID']].drop_duplicates()
    
#     combined_key = med_pd_key.merge(diag_pd_key, on=['SUBJECT_ID', 'HADM_ID'], how='inner')
#     combined_key = combined_key.merge(pro_pd_key, on=['SUBJECT_ID', 'HADM_ID'], how='inner')
    
#     diag_pd = diag_pd.merge(combined_key, on=['SUBJECT_ID', 'HADM_ID'], how='inner')
#     med_pd = med_pd.merge(combined_key, on=['SUBJECT_ID', 'HADM_ID'], how='inner')
#     pro_pd = pro_pd.merge(combined_key, on=['SUBJECT_ID', 'HADM_ID'], how='inner')

#     # flatten and merge
#     diag_pd = diag_pd.groupby(by=['SUBJECT_ID','HADM_ID'])['ICD9_CODE'].unique().reset_index()  
#     med_pd = med_pd.groupby(by=['SUBJECT_ID', 'HADM_ID'])['NDC'].unique().reset_index()
#     pro_pd = pro_pd.groupby(by=['SUBJECT_ID','HADM_ID'])['ICD9_CODE'].unique().reset_index().rename(columns={'ICD9_CODE':'PRO_CODE'})  
#     med_pd['NDC'] = med_pd['NDC'].map(lambda x: list(x))
#     pro_pd['PRO_CODE'] = pro_pd['PRO_CODE'].map(lambda x: list(x))
#     data = diag_pd.merge(med_pd, on=['SUBJECT_ID', 'HADM_ID'], how='inner')
#     data = data.merge(pro_pd, on=['SUBJECT_ID', 'HADM_ID'], how='inner')
# #     data['ICD9_CODE_Len'] = data['ICD9_CODE'].map(lambda x: len(x))
#     data['NDC_Len'] = data['NDC'].map(lambda x: len(x))
#     return data

In [170]:
def statistics():
    print('#patients ', data['SUBJECT_ID'].unique().shape)
    print('#clinical events ', len(data))
    
    diag = data['ICD9_CODE'].values
    med = data['NDC'].values
    pro = data['PRO_CODE'].values
    
    unique_diag = set([j for i in diag for j in list(i)])
    unique_med = set([j for i in med for j in list(i)])
    unique_pro = set([j for i in pro for j in list(i)])
    
    print('#diagnosis ', len(unique_diag))
    print('#med ', len(unique_med))
    print('#procedure', len(unique_pro))
    
    avg_diag = 0
    avg_med = 0
    avg_pro = 0
    max_diag = 0
    max_med = 0
    max_pro = 0
    cnt = 0
    max_visit = 0
    avg_visit = 0

    for subject_id in data['SUBJECT_ID'].unique():
        item_data = data[data['SUBJECT_ID'] == subject_id]
        x = []
        y = []
        z = []
        visit_cnt = 0
        for index, row in item_data.iterrows():
            visit_cnt += 1
            cnt += 1
            x.extend(list(row['ICD9_CODE']))
            y.extend(list(row['NDC']))
            z.extend(list(row['PRO_CODE']))
        x = set(x)
        y = set(y)
        z = set(z)
        avg_diag += len(x)
        avg_med += len(y)
        avg_pro += len(z)
        avg_visit += visit_cnt
        if len(x) > max_diag:
            max_diag = len(x)
        if len(y) > max_med:
            max_med = len(y) 
        if len(z) > max_pro:
            max_pro = len(z)
        if visit_cnt > max_visit:
            max_visit = visit_cnt
        

        
    print('#avg of diagnoses ', avg_diag/ cnt)
    print('#avg of medicines ', avg_med/ cnt)
    print('#avg of procedures ', avg_pro/ cnt)
    print('#avg of vists ', avg_visit/ len(data['SUBJECT_ID'].unique()))
    

    print('#max of diagnoses ', max_diag)
    print('#max of medicines ', max_med)
    print('#max of procedures ', max_pro)
    print('#max of visit ', max_visit)

In [171]:
# data = process_all()
statistics()
data.to_pickle('data_final.pkl')
data.head()

#patients  (6350,)
#clinical events  15016
#diagnosis  1958
#med  145
#procedure 1426
#avg of diagnoses  10.514717634523175
#avg of medicines  8.793819925412892
#avg of procedures  3.8445657964837507
#avg of vists  2.3647244094488187
#max of diagnoses  128
#max of medicines  55
#max of procedures  50
#max of visit  29


Unnamed: 0,SUBJECT_ID,HADM_ID,ICD9_CODE,NDC,PRO_CODE,NDC_Len
0,17,161087,"[4239, 5119, 78551, 4589, 311, 7220, 71946, 2724]","[N02B, A01A, A02B, A06A, B05C, A12A, A12C, C01...","[3731, 8872, 3893]",14
1,17,194023,"[7455, 45829, V1259, 2724]","[N02B, A01A, A02B, A06A, A12A, B05C, A12C, C01...","[3571, 3961, 8872]",15
2,21,109451,"[41071, 78551, 5781, 5849, 40391, 4280, 4592, ...","[A06A, C07A, A12A, A02A, J01M, C02A, B05C, B01...","[0066, 3761, 3950, 3606, 0042, 0047, 3895, 399...",17
3,21,111970,"[0388, 78552, 40391, 42731, 70709, 5119, 6823,...","[A06A, B05C, A12C, A07A, N02B, B01A, N06A, A01...","[3995, 8961, 0014]",17
4,23,124321,"[2252, 3485, 78039, 4241, 4019, 2720, 2724, V4...","[C07A, N02B, A02B, H03A, N03A, A01A, N05A, C09...",[0151],11


## Create Vocaboray for Medical Codes & Save Patient Record in pickle form

In [173]:
import dill
import pandas as pd
class Voc(object):
    def __init__(self):
        self.idx2word = {}
        self.word2idx = {}

    def add_sentence(self, sentence):
        for word in sentence:
            if word not in self.word2idx:
                self.idx2word[len(self.word2idx)] = word
                self.word2idx[word] = len(self.word2idx)
                
def create_str_token_mapping(df):
    diag_voc = Voc()
    med_voc = Voc()
    pro_voc = Voc()
    ## only for DMNC
#     diag_voc.add_sentence(['seperator', 'decoder_point'])
#     med_voc.add_sentence(['seperator', 'decoder_point'])
#     pro_voc.add_sentence(['seperator', 'decoder_point'])
    
    for index, row in df.iterrows():
        diag_voc.add_sentence(row['ICD9_CODE'])
        med_voc.add_sentence(row['NDC'])
        pro_voc.add_sentence(row['PRO_CODE'])
    
    dill.dump(obj={'diag_voc':diag_voc, 'med_voc':med_voc ,'pro_voc':pro_voc}, file=open('voc_final.pkl','wb'))
    return diag_voc, med_voc, pro_voc

def create_patient_record(df, diag_voc, med_voc, pro_voc):
    records = [] # (patient, code_kind:3, codes)  code_kind:diag, proc, med
    for subject_id in df['SUBJECT_ID'].unique():
        item_df = df[df['SUBJECT_ID'] == subject_id]
        patient = []
        for index, row in item_df.iterrows():
            admission = []
            admission.append([diag_voc.word2idx[i] for i in row['ICD9_CODE']])
            admission.append([pro_voc.word2idx[i] for i in row['PRO_CODE']])
            admission.append([med_voc.word2idx[i] for i in row['NDC']])
            patient.append(admission)
        records.append(patient) 
    dill.dump(obj=records, file=open('records_final.pkl', 'wb'))
    return records
        
    
path='data_final.pkl'
df = pd.read_pickle(path)
diag_voc, med_voc, pro_voc = create_str_token_mapping(df)
records = create_patient_record(df, diag_voc, med_voc, pro_voc)
len(diag_voc.idx2word), len(med_voc.idx2word), len(pro_voc.idx2word)

(1958, 145, 1426)

## DDI & Construct EHR Adj and DDI Adj

In [175]:
import pandas as pd
import numpy as np
from collections import defaultdict
import dill

# atc -> cid
ddi_file = 'drug-DDI.csv'
cid_atc = 'drug-atc.csv'
voc_file = 'voc_final.pkl'
data_path = 'records_final.pkl'
TOPK = 40 # topk drug-drug interaction

records =  dill.load(open(data_path, 'rb'))
cid2atc_dic = defaultdict(set)
med_voc = dill.load(open(voc_file, 'rb'))['med_voc']
med_voc_size = len(med_voc.idx2word)
med_unique_word = [med_voc.idx2word[i] for i in range(med_voc_size)]
atc3_atc4_dic = defaultdict(set)
for item in med_unique_word:
    atc3_atc4_dic[item[:4]].add(item)
    

with open(cid_atc, 'r') as f:
    for line in f:
        line_ls = line[:-1].split(',')
        cid = line_ls[0]
        atcs = line_ls[1:]
        for atc in atcs:
            if len(atc3_atc4_dic[atc[:4]]) != 0:
                cid2atc_dic[cid].add(atc[:4])
            
# ddi load
ddi_df = pd.read_csv(ddi_file)
# fliter sever side effect 
ddi_most_pd = ddi_df.groupby(by=['Polypharmacy Side Effect', 'Side Effect Name']).size().reset_index().rename(columns={0:'count'}).sort_values(by=['count'],ascending=False).reset_index(drop=True)
ddi_most_pd = ddi_most_pd.iloc[-TOPK:,:]
# ddi_most_pd = pd.DataFrame(columns=['Side Effect Name'], data=['as','asd','as'])
fliter_ddi_df = ddi_df.merge(ddi_most_pd[['Side Effect Name']], how='inner', on=['Side Effect Name'])
ddi_df = fliter_ddi_df[['STITCH 1','STITCH 2']].drop_duplicates().reset_index(drop=True)

# weighted ehr adj 
ehr_adj = np.zeros((med_voc_size, med_voc_size))
for patient in records:
    for adm in patient:
        med_set = adm[2]
        for i, med_i in enumerate(med_set):
            for j, med_j in enumerate(med_set):
                if j<=i:
                    continue
                ehr_adj[med_i, med_j] = 1
                ehr_adj[med_j, med_i] = 1
dill.dump(ehr_adj, open('ehr_adj_final.pkl', 'wb'))  



# ddi adj
ddi_adj = np.zeros((med_voc_size,med_voc_size))
for index, row in ddi_df.iterrows():
    # ddi
    cid1 = row['STITCH 1']
    cid2 = row['STITCH 2']
    
    # cid -> atc_level3
    for atc_i in cid2atc_dic[cid1]:
        for atc_j in cid2atc_dic[cid2]:
            
            # atc_level3 -> atc_level4
            for i in atc3_atc4_dic[atc_i]:
                for j in atc3_atc4_dic[atc_j]:
                    if med_voc.word2idx[i] != med_voc.word2idx[j]:
                        ddi_adj[med_voc.word2idx[i], med_voc.word2idx[j]] = 1
                        ddi_adj[med_voc.word2idx[j], med_voc.word2idx[i]] = 1
dill.dump(ddi_adj, open('ddi_A_final.pkl', 'wb')) 
                        
print('complete!')

complete!
