<a href="https://colab.research.google.com/github/russpv/SafeDrug/blob/main/process_data.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# **Data Preprocessing**

## Load things

In [111]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [112]:
#!pip install dnc
!pip install rdkit-pypi



In [113]:
import sklearn, dill
'''
pandas: 1.3.0
dill: 0.3.4
torch: 1.8.0+cu111
rdkit: 2021.03.4
scikit-learn: 0.24.2
numpy: 1.21.1'''

import os
import random
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
import csv
from rdkit import Chem
from rdkit.Chem import BRICS
from collections import defaultdict

# set seed
seed = 1234
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
os.environ["PYTHONHASHSEED"] = str(seed)

# define data path
DATA_PATH = "drive/MyDrive/DL4H/Project/SAFEDRUG_lib/data/"
WORKING_PATH = "drive/MyDrive/DL4H/Project/SAFEDRUG/"

## Process data

### Process MIMIC

In [114]:
class Vocabulary(object):
  # Want to process the MIMIC sets, map code formats, and also token index the codes
    def __init__(self, tape=None):
        self.word2index = {}
        self.index2word = {}
        self.vocab = sorted(set(tape))
        self.size = len(self.vocab)
        for index, word in enumerate(self.vocab):
            self.word2index[word] = index
            self.index2word[index] = word

    def __getitem__(self, index):
          return self.index2word[index]

    def addword(self, word):
        self.word2index[word] = len(self.word2index)
        self.index2word[len(self.word2index)] = word

    def printStats(self):
        print('Num. words:', len(self.word2index))

In [115]:
# Merge the 3 MIMIC tables on SUBJECT_ID and HADM_ID 

In [116]:
def process_prescriptions(file, filtertopn_drug=None, min_visits=None, rxnorm_RXCUI=None, RXCUI_ATC=None):
    '''
    1 forward fill missing values
    2 remove patients with only one visit
    3 take topN prescribed drugs
    4 sort RECENT LAST
    5 transform drug CID into ATC3
    '''
    df = pd.read_csv(file, parse_dates=['STARTDATE', 'ENDDATE'], 
                                infer_datetime_format=True, dtype={'NDC': "category", 
                                                              'ICUSTAY_ID': "object",
                                                              'HADM_ID': "int64",
                                                              'SUBJECT_ID': "int64"} )\
        [['SUBJECT_ID','HADM_ID', 'ICUSTAY_ID','STARTDATE', 'ENDDATE', 'DRUG', 'NDC']]
    df = df[df.NDC != '0']  # filter out the zero drug code
    df.sort_values(by=['SUBJECT_ID','HADM_ID', 'ICUSTAY_ID','STARTDATE','ENDDATE'], inplace=True, ascending=True)
    df.drop(columns=['ENDDATE', 'ICUSTAY_ID'], inplace=True) 
    df.reset_index(inplace=True, drop=True)

    df.fillna(method='ffill', inplace=True) 
    df.dropna(inplace=True) 
    df.drop_duplicates(inplace=True)
    df.reset_index(inplace=True, drop=True)

    print(f'base prescriptions df shape: {df.shape}')

    if filtertopn_drug is not None:
        topn_drugs = df.groupby(['NDC'], as_index=False).agg(\
                      count_col=pd.NamedAgg(column="NDC", aggfunc="count"))\
                      .nlargest(filtertopn_drug, 'count_col')
        df = pd.merge(left=df,
              right=topn_drugs, 
              on ='NDC', 
              how ='inner').drop(columns=['count_col']).reset_index(drop=True)
        print(f'topn prescriptions df shape: {df.shape}')

    if min_visits is not None:
        admissions = df.groupby(['SUBJECT_ID', 'HADM_ID'], as_index=False).count()
        visit_counts = admissions.groupby(['SUBJECT_ID'], as_index=False).agg(\
                      visit_count=pd.NamedAgg(column="SUBJECT_ID", aggfunc="count"))
        filtered_patients = visit_counts[visit_counts['visit_count'] > 1]
        df = pd.merge(left=df,
              right=filtered_patients, 
              on ='SUBJECT_ID', 
              how ='inner').drop(columns=['visit_count']).reset_index(drop=True)
        print(f'min_visits prescriptions df shape: {df.shape}')

    if rxnorm_RXCUI is not None:
        with open(rxnorm_RXCUI, 'r') as f:
            s = f.read()
            table = ast.literal_eval(s)
        df['RXCUI'] = df['NDC'].map(table)
        df = df[df.RXCUI != ''] # about 2thou get dropped
        df['RXCUI'] = df['RXCUI'].astype('int64')
        print(f'rxnorm prescriptions df shape: {df.shape}')
        
    if RXCUI_ATC is not None:
        map = pd.read_csv(map_RXCUI_atc4_file, parse_dates={'date': ['YEAR', 'MONTH']},\
                          dtype={'RXCUI': 'int64'})
        map.sort_values(by=['RXCUI','date'], inplace=True, ascending=True) #318257 
        map.drop(columns="NDC", inplace=True)
        map.dropna(inplace=True)
        map.drop_duplicates(subset='RXCUI', keep='last', inplace=True) #11478 
        
        df = pd.merge(df, map, on='RXCUI').drop(columns=['date']).reset_index(drop=True) # about 200thou get dropped
        myarray = df['ATC4'].to_numpy().astype('U4')
        df['ATC4'] = myarray
        df = df.rename(columns={'ATC4':'ATC3'})
        df.drop_duplicates(inplace=True)    
        df.reset_index(drop=True, inplace=True)
        print(f'RXCUI_ATC prescriptions df shape: {df.shape}')

    return df

In [117]:
def process_MIMIC_file(file, filtertopn_ICD9=None):
    '''
    1 take topN ICD9s
    2 remove duplicates and na values
    3 sort RECENT LAST
    '''
    df = pd.read_csv(file, dtype={'ICD9_CODE': 'object',
                                            'SUBJECT_ID': 'int64',
                                            'HADM_ID': 'int64'} )\
        [['SUBJECT_ID','HADM_ID', 'SEQ_NUM', 'ICD9_CODE']] 
    df.dropna(inplace=True)
    df = df[df.ICD9_CODE != '0']
    df.sort_values(by=['SUBJECT_ID','HADM_ID', 'SEQ_NUM', 'ICD9_CODE'], inplace=True, ascending=True)
    df.drop(columns=['SEQ_NUM'], inplace=True)
    df.drop_duplicates(inplace=True)
    df.reset_index(inplace=True, drop=True)

    if filtertopn_ICD9 is not None:
        topn_ICD9_df = df.groupby(['ICD9_CODE'], as_index=False).agg(\
                      count_col=pd.NamedAgg(column="ICD9_CODE", aggfunc="count"))\
                      .nlargest(filtertopn_ICD9, 'count_col')
        df = pd.merge(left=df,
              right=topn_ICD9_df, 
              on ='ICD9_CODE', 
              how ='inner').drop(columns=['count_col']).reset_index(drop=True)

    return df

In [118]:
def get_mutualfilter(medications_df, diagnoses_df, procedures_df):
    
    # inner join 3 times to get keys
    medications_keys = medications_df[['SUBJECT_ID','HADM_ID']].drop_duplicates() #17329
    diagnoses_keys = diagnoses_df[['SUBJECT_ID','HADM_ID']].drop_duplicates() #58857
    procedures_keys = procedures_df[['SUBJECT_ID','HADM_ID']].drop_duplicates() #52125

    mutual_keys = pd.merge(medications_keys, diagnoses_keys, how='inner') 
    mutual_keys = pd.merge(mutual_keys, procedures_keys, how='inner') #14975 
    
    # filter tables by the mutual keys
    medications_df = medications_df.merge(mutual_keys, how='inner')
    diagnoses_df = diagnoses_df.merge(mutual_keys, how='inner') 
    procedures_df = procedures_df.merge(mutual_keys, how='inner') 
    print(f'mutual filtered df shapes - med: {medications_df.shape} diag: {diagnoses_df.shape} proc: {procedures_df.shape}')

    # groupby mutual key and combine codes into list
    medications_textcombine = medications_df.groupby(by=['SUBJECT_ID','HADM_ID'])['ATC3'].unique().reset_index()
    diagnoses_textcombine = diagnoses_df.groupby(by=['SUBJECT_ID','HADM_ID'])['ICD9_CODE'].unique().reset_index()
    procedures_textcombine = procedures_df.groupby(by=['SUBJECT_ID','HADM_ID'])['ICD9_CODE'].unique().reset_index()
    print(f'groupedby df shapes - med: {medications_textcombine.shape} diag: {diagnoses_textcombine.shape} proc: {procedures_textcombine.shape}')

    #assert len(medications_textcombine) == len(procedures_textcombine) == len(diagnoses_textcombine) == 14975
    return medications_textcombine, diagnoses_textcombine, procedures_textcombine, mutual_keys, medications_df, diagnoses_df, procedures_df

In [119]:
def make_flattable(med_df, diag_df, proc_df):
  
    alldata = med_df.merge(diag_df, how='inner')
    proc_df.rename(columns={"ICD9_CODE": "PROD_ICD9_CODE"}, inplace=True)
    alldata = alldata.merge(proc_df, on=['SUBJECT_ID','HADM_ID'], how='inner')

    return alldata

In [120]:
if __name__ == "__main__":
    import ast

    data_medications_file = DATA_PATH + 'PRESCRIPTIONS.csv'
    data_diagnoses_file = DATA_PATH + 'DIAGNOSES_ICD.csv'
    data_procedures_file = DATA_PATH + 'PROCEDURES_ICD.csv'  

    map_CID_RXCUI_file = DATA_PATH + 'ndc2rxnorm_mapping.txt'
    map_RXCUI_atc4_file = DATA_PATH + 'ndc2atc_level4.csv'

    # processed output
    data_processed_ehr_file = DATA_PATH + 'processed/ehr.pkl'
    vocabs_file  = DATA_PATH + 'processed/vocabs.pkl'
    matrix_ddi_graph_file = DATA_PATH + 'processed/ddiadj.pkl'
    matrix_ehr_graph_file = DATA_PATH + 'processed/ehradj.pkl' # for GAMENet
    map_ATC_SMILES_file = DATA_PATH +  'processed/atc2SMILES.pkl'
    matrix_h_mask_file = DATA_PATH + 'processed/hmask.pkl'

    # do initial ETL
    medications_df = process_prescriptions(data_medications_file, filtertopn_drug=300, min_visits=2, rxnorm_RXCUI=map_CID_RXCUI_file, RXCUI_ATC=map_RXCUI_atc4_file)
    diagnoses_df = process_MIMIC_file(data_diagnoses_file, 2000)
    procedures_df = process_MIMIC_file(data_procedures_file)
    '''
    assert len(medications_df) == 647333 # reduce 2mil
    assert len(diagnoses_df) == 625434 # reduce a few 10thou
    assert len(procedures_df) == 228679 # reduce a few thou
    '''



base prescriptions df shape: (2970017, 5)
topn prescriptions df shape: (2207681, 5)
min_visits prescriptions df shape: (830116, 5)
rxnorm prescriptions df shape: (828676, 6)
RXCUI_ATC prescriptions df shape: (647333, 7)


In [121]:
def get_statistics(data):
    print('#patients ', data['SUBJECT_ID'].unique().shape)
    print('#clinical events ', len(data))
    
    diag = data['ICD9_CODE'].values
    med = data['ATC3'].values
    pro = data['PROD_ICD9_CODE'].values
    
    unique_diag = set([j for i in diag for j in list(i)])
    unique_med = set([j for i in med for j in list(i)])
    unique_pro = set([j for i in pro for j in list(i)])
    
    print('#diagnosis ', len(unique_diag))
    print('#med ', len(unique_med))
    print('#procedure', len(unique_pro))
    
    avg_diag, avg_med, avg_pro, max_diag, max_med, max_pro, cnt, max_visit, avg_visit = [0 for i in range(9)]

    for subject_id in data['SUBJECT_ID'].unique():
        item_data = data[data['SUBJECT_ID'] == subject_id]
        x, y, z = [], [], []
        visit_cnt = 0
        for index, row in item_data.iterrows():
            visit_cnt += 1
            cnt += 1
            x.extend(list(row['ICD9_CODE']))
            y.extend(list(row['ATC3']))
            z.extend(list(row['PROD_ICD9_CODE']))
        x, y, z = set(x), set(y), set(z)
        avg_diag += len(x)
        avg_med += len(y)
        avg_pro += len(z)
        avg_visit += visit_cnt
        if len(x) > max_diag:
            max_diag = len(x)
        if len(y) > max_med:
            max_med = len(y) 
        if len(z) > max_pro:
            max_pro = len(z)
        if visit_cnt > max_visit:
            max_visit = visit_cnt
        
    print('#avg of diagnoses ', avg_diag/ cnt)
    print('#avg of medicines ', avg_med/ cnt)
    print('#avg of procedures ', avg_pro/ cnt)
    print('#avg of visits ', avg_visit/ len(data['SUBJECT_ID'].unique()))
    
    print('#max of diagnoses ', max_diag)
    print('#max of medicines ', max_med)
    print('#max of procedures ', max_pro)
    print('#max of visit ', max_visit)

In [122]:
if __name__ == "__main__":
    # Mutually filter all three MIMIC tables for 100% overlap, combine into one table.k,
    filtered_meds, filtered_diags, filtered_procs, mutual_keys, med_df, diag_df, proc_df = get_mutualfilter(medications_df, diagnoses_df, procedures_df)
    alldata = make_flattable(filtered_meds, filtered_diags, filtered_procs)

    # send the post-filter columns to the vocab creator
    med_vocab = Vocabulary(tape=med_df['ATC3'])
    diagnoses_vocab = Vocabulary(tape=diag_df['ICD9_CODE'])
    procedures_vocab = Vocabulary(tape=proc_df['ICD9_CODE'])
    patient_vocab = Vocabulary(tape=mutual_keys['SUBJECT_ID'])
    hospadm_vocab = Vocabulary(tape=mutual_keys['HADM_ID'])

    # convert flat table codes to index
    alldata['ATC3'] = [[med_vocab.word2index[token] for token in row] for row in alldata.ATC3]
    alldata['ICD9_CODE'] = [[diagnoses_vocab.word2index[token] for token in row] for row in alldata.ICD9_CODE]
    alldata['PROD_ICD9_CODE'] = [[procedures_vocab.word2index[token] for token in row] for row in alldata.PROD_ICD9_CODE]
    alldata['SUBJECT_ID'] = alldata.SUBJECT_ID.map(patient_vocab.word2index)
    alldata['HADM_ID'] = alldata.HADM_ID.map(hospadm_vocab.word2index)

    print(f'Unique drugs: {med_vocab.size}')
    print(f'Unique diagnoses: {diagnoses_vocab.size}')
    print(f'Unique procedures: {procedures_vocab.size}')
    print(f'Unique patients: {patient_vocab.size}')
    print(f'Unique admissions: {hospadm_vocab.size}')

    # combine series into series of series (diag, proc, med)
    alldata['value'] = alldata[['ATC3', 'ICD9_CODE', 'PROD_ICD9_CODE']].apply(lambda x: [x[1], x[2], x[0]], axis=1)

    # Rollup into data object and serialize
    grouped_data = alldata.groupby(by=['SUBJECT_ID']).agg({"value": lambda x: list(x) })
    nested_list = grouped_data.to_numpy().tolist() 
    finallist = [i[0] for i in nested_list] # (patient, code_kind:3, codes)  code_kind:diag, proc, med  (cancel xtra [])

    assert finallist[0][0][0] == grouped_data.iloc[0][0][0][0] == alldata[alldata.SUBJECT_ID == 0]['value'][0][0]

    print(f'Total records: {len(finallist)}')

    dill.dump(obj=finallist, file=open(data_processed_ehr_file, 'wb'))
    dill.dump(obj={'diag_vocab':diagnoses_vocab, 'med_vocab':med_vocab ,'pro_vocab':procedures_vocab}, file=open(vocabs_file,'wb'))

mutual filtered df shapes - med: (601390, 7) diag: (204802, 3) proc: (68114, 3)
groupedby df shapes - med: (15008, 3) diag: (15008, 3) proc: (15008, 3)
Unique drugs: 57
Unique diagnoses: 1957
Unique procedures: 1425
Unique patients: 6344
Unique admissions: 15008
Total records: 6344


In [123]:
get_statistics(alldata)

#patients  (6344,)
#clinical events  15008
#diagnosis  1957
#med  57
#procedure 1425
#avg of diagnoses  10.518923240938166
#avg of medicines  8.894922707889126
#avg of procedures  3.846015458422175
#avg of visits  2.3656998738965953
#max of diagnoses  128
#max of medicines  49
#max of procedures  50
#max of visit  29


### Process DDI

In [124]:
'''
Create DDI adjacency matrix
  1 Get Top-40 severity types in TWOSIDES 
    # Note: the side effect is a string name under the Polypharmacy Side Effect code, many to one
  2 Get SMILES molecules from DrugBank, transform into DDI matrix D
  3 incorporate the DDI into the graph embeddings using BRICS
  4 Ensure ATC Third Level coding (4 chars)
'''
def make_CID_ATC_map(map_CID_ATClong_file, med_vocab):
    # Convert CID-ATC map to ATC Third Level (4 chars), sized to vocab
    map_CID_ATC = {}
    with open(map_CID_ATClong_file, 'r') as f:
        map_CID_ATC = {row[0]: set(token[:4] for token in row if token[:4] in med_vocab.word2index )\
                      for row in csv.reader(f)\
                      if set(token[:4] for token in row if token[:4] in med_vocab.word2index )} # omit ATC codes not found
    
 
    assert len(med_vocab.word2index) == 57
    assert len(map_CID_ATC) == 389
    return map_CID_ATC

def process_ddi(external_ddi_file, map_CID_ATClong_file, TOP_N_SIDES=40):
    dtypes = {
        "STITCH 1": "category",
        "STITCH 2": "category",
        "Polypharmacy Side Effect": "category",
        "Side Effect Name ": "category", # Set 'category' to save mem
    }
    ddi_df = pd.read_csv(external_ddi_file, dtype=dtypes)

    # get top n
    ddi_topsides_df = ddi_df.groupby(['Polypharmacy Side Effect', 'Side Effect Name'], as_index=False).count().nlargest(TOP_N_SIDES, 'STITCH 1')\
          .reset_index()\
          .rename(columns={'STITCH 1':'count'}) \
          .drop(columns=['STITCH 2', 'index'])

    # get drug CIDs that match the top n
    # QUESTION: Grouping by side effect code achieves same impact no?
    ddi_filter_df = pd.merge(left=ddi_df,
                  right=ddi_topsides_df, 
                  on ='Side Effect Name', 
                  how ='inner')

    ddi_topsidesdrugpairs_df = ddi_filter_df[['STITCH 1', 'STITCH 2']].drop_duplicates().reset_index(drop=True) # 818796 rows raw, 62791 after dedupe
    # NOTE: error in paper: ddi_most_pd = ddi_most_pd.iloc[-TOPN_SIDES:,:] (=bottom N) should be ddi_most_pd = ddi_most_pd.iloc[:TOPN_SIDES,:]

    return ddi_topsidesdrugpairs_df

In [125]:
def make_ddi_adj_matrix(ddi_topsidesdrugpairs_df, map_CID_ATC):
    ### construct an adjacency matrix for the med_vocab
    # create empty matrix dimensioned to med_vocab
    # use the DDI to create edges between drugs in the med_vocab
    # if ATC Third Level code is in the med_vocab, place a '1' where those nodes intersect
    # mine all potential relations: for every row in DDI table, there are many potential ATC-ATC relations, since one-to-many
    # map_CID_ATC tests for set presence

    ddi_adj_matrix = np.zeros((med_vocab.size, med_vocab.size))
    drug1_list = ddi_topsidesdrugpairs_df['STITCH 1'].to_numpy().tolist()
    drug2_list = ddi_topsidesdrugpairs_df['STITCH 2'].to_numpy().tolist()
    for i, CID1 in enumerate(drug1_list):
        if CID1 in map_CID_ATC:
            CID2 = drug2_list[i]
            if CID2 in map_CID_ATC:
                set_ATC1 = map_CID_ATC[CID1]
                set_ATC2 = map_CID_ATC[CID2]
                for ATC1 in set_ATC1:
                    for ATC2 in set_ATC2:
                        index1 = med_vocab.word2index[ATC1]
                        index2 = med_vocab.word2index[ATC2]
                        ddi_adj_matrix[index1][index1] = 1 # self loop
                        ddi_adj_matrix[index2][index2] = 1 # self loop
                        ddi_adj_matrix[index1][index2] = 1
                        ddi_adj_matrix[index2][index1] = 1 #undirected

    assert ddi_adj_matrix.shape == (57, 57)
    print(f'Shape of DDI adjacency matrix: {ddi_adj_matrix.shape}')

    dill.dump(obj=ddi_adj_matrix, file=open(matrix_ddi_graph_file, 'wb'))

    return ddi_adj_matrix

In [126]:
### construct drug adjacency for set of all patients for GAMENet
def make_ehr_adj_matrix(med_vocab, final_ehr):

    ehr_adj_matrix = np.zeros((med_vocab.size, med_vocab.size))
    for patient in final_ehr:
        for admission in patient:
            for drug1 in admission[2]:
                for drug2 in admission[2]:
                    ehr_adj_matrix[drug1][drug2] = 1 # self loop

    assert ehr_adj_matrix.shape == (57, 57)
    print(f'Shape of ehr adjacency matrix: {ehr_adj_matrix.shape}')

    dill.dump(ehr_adj_matrix, open(matrix_ehr_graph_file, 'wb'))  

    return ehr_adj_matrix

## Process molecules

In [127]:
### Goal1: make the mask H of size substructures vs drugs
### Goal2: make a dict of molecules for MPNN 
def get_ATC_Drugname_map(med_df):
    # gives set() of text names in records per ATC key
    #test = 'B02B'
    result_map = {}
    for ATC, drugname in med_df[['ATC3', 'DRUG']].values:
        if ATC in result_map:
            result_map[ATC].add(drugname)
        else:
            result_map[ATC] = set([drugname, drugname]) # prevent treating same string as iterator and returning chars
    #assert test in result_map
    return result_map

def get_ATC_SMILES_map(drugbank, med_df):
    # append up to 3 SMILES strings per drug text name
    # NOTE: THIS WILL CAPTURE DUPLICATES, test 'N06A'
    map_ATC_drugname = get_ATC_Drugname_map(med_df)
    db_df = pd.read_csv(drugbank, dtype={'name': 'category', 'moldb_smiles': 'category'})
    
    map_drugbank_SMILES = defaultdict()
    result_map = {}

    for drug_altname, drugname, smiles in db_df[['title', 'name', 'moldb_smiles']].values:
        if type(smiles) == type('a'):
            map_drugbank_SMILES[drugname] = smiles  # drugname to smiles is 1:1
            if drug_altname not in map_drugbank_SMILES: # add altnames; there are some MIMIC names in 'title' but not 'name' columns
                map_drugbank_SMILES[drug_altname] = smiles 

    for ATC, drugnames in map_ATC_drugname.items(): # DATA ATCS
        group_upto_3 = [map_drugbank_SMILES.get(name)  for name in drugnames if name in map_drugbank_SMILES][:3]
        if len(group_upto_3) > 0:
            result_map[ATC] = group_upto_3 # if MIMIC contains both name and altname for same ATC, duplicates result

    return result_map, map_drugbank_SMILES

In [128]:
def get_chemical_subgroups(ATC_SMILES_map):
    def chemset(arg):
        return BRICS.BRICSDecompose(Chem.MolFromSmiles(arg)) # returns set

    subgroup_set = set()
    for ATC, smiles_list in ATC_SMILES_map.items():
        for smiles in smiles_list:
            subgroup_set.update(chemset(smiles))

    map_ATC_BRIC = {}
    # map_ATC_BRIC = {ATC: func(smiles) for ATC, smiles_list in ATC_SMILES_map.items() for smiles in smiles_list}
    for ATC, smiles_list in ATC_SMILES_map.items():
        for smiles in smiles_list:
            map_ATC_BRIC[ATC] = chemset(smiles)

    return subgroup_set, map_ATC_BRIC

In [129]:
### Get H mask, S x M size
def make_H_mask(med_vocab, subgroup_set):
    # returns map of M*3 SMILES (many SMILES to ATC) to S subgroups 
    # 1 index the set of all subgroups as S
    # - unpack ATCs in medvocab, use the map to get SMILES, run SMILES through BRIC to get subgroups, take set
    # 2 make dict of ATC to set of subgroups
    # 3 make empty matrix and iterate over all ATCs at mark 1 at subgroup index
    H_mask = np.zeros((med_vocab.size, len(subgroup_set))) # should have reversed this to avoid transpose() later

    BRIC_list = list(subgroup_set)

    for ATC in med_vocab.vocab:
        i1 = med_vocab.word2index[ATC]
        for subgroup in map_ATC_BRIC[ATC]:
            i2 = BRIC_list.index(subgroup)
            H_mask[i1][i2] = 1
    
    print(f'Shape of H mask matrix: {H_mask.shape}')
    dill.dump(H_mask, open(matrix_h_mask_file, 'wb'))  

    return H_mask

In [130]:
if __name__ == "__main__":
    TOP_N_SIDES = 40 # define top n side effects to filter DDIs
    map_CID_ATClong_file =  DATA_PATH + 'drug-atc.csv'
    external_ddi_file = DATA_PATH + 'drug-DDI.csv' 
    external_drugbank_file = DATA_PATH + 'drugbank_drugs_info.csv'

    ddi_topsides_df = process_ddi(external_ddi_file, map_CID_ATClong_file, TOP_N_SIDES) #63,000 CID matches

    map_CID_ATC = make_CID_ATC_map(map_CID_ATClong_file, med_vocab)
    map_ATC_SMILES, map_drugbank_SMILES = get_ATC_SMILES_map(external_drugbank_file, med_df)
    subgroup_set, map_ATC_BRIC = get_chemical_subgroups(map_ATC_SMILES)
    
    matrix_ddi_adj = make_ddi_adj_matrix(ddi_topsides_df, map_CID_ATC)
    matrix_ehr_adj = make_ehr_adj_matrix(med_vocab, finallist)
    matrix_H_mask = make_H_mask(med_vocab, subgroup_set)

    dill.dump(map_ATC_SMILES, open(map_ATC_SMILES_file, 'wb')) 
    print('Matrices made.')

  # Remove the CWD from sys.path while we load stuff.


Shape of DDI adjacency matrix: (57, 57)
Shape of ehr adjacency matrix: (57, 57)
Shape of H mask matrix: (57, 177)
Matrices made.


In [131]:
if __name__ == "__main__":
    test = 'B02B'
    assert test in med_vocab.vocab
    assert test in med_df[['ATC3', 'DRUG']].values
    assert test in map_ATC_SMILES
    assert test in map_ATC_BRIC