# CS598 DL4H Final Project Demo Notebook
#### by Rui Zou and Yueming Pang, May 2023

In this notebook, we demonstrate the key steps involved in reproducing the research paper titled "GAMENet: Graph Augmented MEmory Networks for Recommending Medication Combination". Our demo covers various stages, including data preprocessing, dataset splitting, and running different base models along with the GAMENet model to obtain results.

#### Data Preprocessing
To prepare the dataset for our project, we perform several data preprocessing steps such as cleaning, filtering, and feature engineering. We also explore the dataset to understand the distribution of features and print out the statistics of the dataset.

#### Dataset Splitting
We discuss the dataset splitting method used in the paper and explore its effectiveness in our experiments. We then propose the additional step to make the splitting method more robust. We split the dataset into training, validation, and test sets, and will train the model using training and validation datasets, and compare the performance of the models using the test dataset.

#### Base Models and GAMENet
We run several base models, including nearest visit, LR and several neural network-based models, to establish a baseline for comparison. We then implement the GAMENet model and compare its performance with the baseline models.

## Part 1: Data Preprocessing
In this part, we process the patient EHR data, including medication, diagnosis, and procedure data, so that they can be fed into the models later. At the end of this section, we print the statistics of the processed data.

We cleaned the original code in the authors' repo, removed redundancies, added functions, and added doc strings and comments.

### Load Libraries

In [1]:
import pandas as pd
import functools
import dill
import numpy as np
from collections import defaultdict
import warnings
warnings.filterwarnings("ignore")

### Read from MIMIC csv files

In [2]:
# Files can be downloaded from https://mimic.physionet.org/gettingstarted/dbsetup/
med_file = 'PRESCRIPTIONS.csv'
diag_file = 'DIAGNOSES_ICD.csv'
procedure_file = 'PROCEDURES_ICD.csv'

# drug code mapping files (already in ./data/)
ndc2atc_file = 'ndc2atc_level4.csv' 
cid_atc = 'drug-atc.csv'
ndc_rxnorm_file = 'ndc2rxnorm_mapping.txt'

# drug-drug interactions can be down https://www.dropbox.com/s/8os4pd2zmp2jemd/drug-DDI.csv?dl=0
ddi_file = 'drug-DDI.csv'

### Preprocessing MIMIC csv files, filter the medication used in the first 24 hours, and print the statistics

In [3]:
def convert_to_list(x):
    """
    This function takes an input `x` and converts it into a list.
    
    Args:
        x (str or iterable): The input to be converted into a list.
    
    Returns:
        list: The input `x` as a list. If `x` is already a list, it is returned as is.
    """
    if isinstance(x, str):
        return [x]
    else:
        return list(x)

def process_procedure(procedure_file):
    """
    Read and process the procedure CSV file.
    
    Args:
        procedure_file: str, the path to the CSV file containing procedure data
    
    Returns:
        pro_pd: pandas DataFrame, the processed procedure data
    """
    # Read the CSV file and set data types
    pro_pd = pd.read_csv(procedure_file, dtype={'ICD9_CODE':'category'})

    # Drop unnecessary columns and duplicates
    pro_pd = pro_pd.drop(columns=['ROW_ID', 'SEQ_NUM']).drop_duplicates()

    # Sort the data by SUBJECT_ID, HADM_ID, and ICD9_CODE
    pro_pd = pro_pd.sort_values(by=['SUBJECT_ID', 'HADM_ID', 'ICD9_CODE'])

    # Reset the index
    pro_pd = pro_pd.reset_index(drop=True)

    return pro_pd

def process_med(med_file):
    """
    Processes medication data from a CSV file and returns a cleaned and filtered pandas DataFrame.

    Args:
        med_file (str): A string specifying the path and file name of the CSV file containing medication data.

    Returns:
        pandas DataFrame: A cleaned and filtered pandas DataFrame containing medication data.
    """
    med_pd = pd.read_csv(med_file, dtype={'NDC': 'category'})
    
    # filter and clean data
    med_pd.drop(columns=['ROW_ID','DRUG_TYPE','DRUG_NAME_POE','DRUG_NAME_GENERIC',
                          'FORMULARY_DRUG_CD','GSN','PROD_STRENGTH','DOSE_VAL_RX',
                          'DOSE_UNIT_RX','FORM_VAL_DISP','FORM_UNIT_DISP',
                          'ROUTE','ENDDATE','DRUG'], inplace=True)
    med_pd.drop(med_pd[med_pd['NDC'] == '0'].index, axis=0, inplace=True)
    med_pd.fillna(method='pad', inplace=True)
    med_pd.dropna(inplace=True)
    med_pd['ICUSTAY_ID'] = med_pd['ICUSTAY_ID'].astype('int64')
    med_pd['STARTDATE'] = pd.to_datetime(med_pd['STARTDATE'], format='%Y-%m-%d %H:%M:%S')    
    med_pd.drop_duplicates(inplace=True)
    
    # sort by columns
    med_pd.sort_values(by=['SUBJECT_ID', 'HADM_ID', 'ICUSTAY_ID', 'STARTDATE'], inplace=True)
    med_pd.reset_index(drop=True, inplace=True)
    
    def filter_first24hour_med(med_pd):
        """
        Filter medication data to keep only the first ICU stay for each patient and merge with original dataframe to keep the NDC code.

        Args:
            med_pd: pandas.DataFrame
                The medication data to be processed.

        Returns:
            pandas.DataFrame
                The processed medication data with only the first ICU stay for each patient and the NDC code.
        """
        # Keep only the first ICU stay for each patient
        med_pd_new = med_pd.drop(columns=['NDC'])
        med_pd_new = med_pd_new.drop_duplicates(subset=['SUBJECT_ID','HADM_ID','ICUSTAY_ID'], keep='first')

        # Merge with original dataframe to keep the NDC code
        med_pd_new = pd.merge(med_pd_new, med_pd, on=['SUBJECT_ID','HADM_ID','ICUSTAY_ID','STARTDATE'])
        med_pd_new = med_pd_new.drop(columns=['STARTDATE'])

        return med_pd_new

    med_pd = filter_first24hour_med(med_pd)

    # Drop the 'ICUSTAY_ID' column from med_pd
    med_pd = med_pd.drop(columns=['ICUSTAY_ID'])

    # Drop duplicates from med_pd
    med_pd = med_pd.drop_duplicates().reset_index(drop=True)
    
    # visit > 2
    def process_visit_lg2(med_pd):
        """
        Filters `med_pd` dataframe to include only the first ICU stay for each patient with more than one visit.
        
        Args:
            med_pd : pandas.DataFrame
                The input DataFrame containing medication data.
            
        Returns:
            pandas.DataFrame
                A new DataFrame with the same columns as `med_pd`, but only including rows for the first ICU stay of each
                patient with more than one visit. If a patient has only one ICU stay, all rows for that patient are included.
                Additionally, a new column 'HADM_ID_Len' is added which contains the number of HADM_IDs per SUBJECT_ID.
        """
        # Get unique HADM_IDs per SUBJECT_ID
        unique_hadm_ids = med_pd[['SUBJECT_ID', 'HADM_ID']].drop_duplicates()
        subject_counts = unique_hadm_ids['SUBJECT_ID'].value_counts()

        # Filter SUBJECT_IDs with visit counts > 1
        subjects_with_multiple_visits = subject_counts[subject_counts > 1].index
        visits_lg2 = unique_hadm_ids[unique_hadm_ids['SUBJECT_ID'].isin(subjects_with_multiple_visits)]

        # Add HADM_ID count as a new column
        visits_lg2['HADM_ID_Len'] = visits_lg2.groupby('SUBJECT_ID')['HADM_ID'].transform('count')

        # Return result
        return visits_lg2

    med_pd_lg2 = process_visit_lg2(med_pd).reset_index(drop=True)    
    med_pd = med_pd.merge(med_pd_lg2[['SUBJECT_ID']], on='SUBJECT_ID', how='inner')    
    
    return med_pd.reset_index(drop=True)

def process_diag(diag_file):
    """
    Reads in the diagnosis data from the provided file, removes any rows with missing data, unnecessary columns,
    and duplicate rows. The resulting DataFrame is sorted by subject ID and hospital admission ID and the index is reset.
    
    Args:
        diag_file: string, path to the diagnosis data file
    
    Returns:
        diag_pd: pandas DataFrame, diagnosis data processed and sorted by subject ID and hospital admission ID
    """
    # Read in the diagnosis data from the provided file
    diag_pd = pd.read_csv(diag_file)
    
    # Drop any rows with missing data
    diag_pd.dropna(inplace=True)
    
    # Remove unnecessary columns
    diag_pd.drop(columns=['SEQ_NUM','ROW_ID'], inplace=True)
    
    # Drop any duplicate rows
    diag_pd.drop_duplicates(inplace=True)
    
    # Sort the data by subject ID and hospital admission ID
    diag_pd.sort_values(by=['SUBJECT_ID','HADM_ID'], inplace=True)
    
    # Reset the index of the DataFrame and return it
    return diag_pd.reset_index(drop=True)

def ndc2atc4(med_pd):
    """
    Converts the NDC codes in the medication DataFrame `med_pd` to ATC4 codes using external data files.

    Args:
        med_pd (pandas DataFrame): DataFrame containing medication data with NDC codes.

    Returns:
        pandas DataFrame: DataFrame containing medication data with ATC4 codes.
    """
    with open(ndc_rxnorm_file, 'r') as f:
        ndc2rxnorm = eval(f.read())
    
    med_pd['RXCUI'] = med_pd['NDC'].map(ndc2rxnorm)
    med_pd.dropna(subset=['RXCUI'], inplace=True)

    rxnorm2atc = pd.read_csv(ndc2atc_file, usecols=['RXCUI', 'ATC4'], squeeze=True)
    rxnorm2atc.drop_duplicates(subset=['RXCUI'], inplace=True)

    med_pd = med_pd[~med_pd['RXCUI'].isin([''])]
    med_pd['RXCUI'] = med_pd['RXCUI'].astype('int64')

    med_pd = med_pd.merge(rxnorm2atc, on=['RXCUI'])
    med_pd.drop(columns=['NDC', 'RXCUI'], inplace=True)

    med_pd['NDC'] = med_pd['ATC4'].str[:4]
    med_pd.drop_duplicates(inplace=True)
    med_pd.reset_index(drop=True, inplace=True)

    return med_pd

def filter_2000_most_diag(diag_pd):
    """
    Filter the diagnosis DataFrame to keep only the rows with the top 2000 most frequent ICD9 codes.
    
    Args:
        diag_pd : pandas.DataFrame
            DataFrame containing diagnosis data with 'ICD9_CODE' column.
    
    Returns:
        pandas.DataFrame
            DataFrame with only the rows containing ICD9 codes in the top 2000 most frequent list.
    """
    # Get the top 2000 most frequent ICD9 codes
    diag_count = diag_pd['ICD9_CODE'].value_counts().reset_index().rename(columns={'index': 'ICD9_CODE', 'ICD9_CODE': 'count'})
    top_2000_icd9_codes = diag_count.loc[:1999, 'ICD9_CODE'].tolist()
    
    # Filter the DataFrame to keep only the rows with ICD9 codes in the top 2000
    diag_pd = diag_pd[diag_pd['ICD9_CODE'].isin(top_2000_icd9_codes)]
    
    return diag_pd.reset_index(drop=True)

def process_all():
    """
    This function processes medication, diagnosis, and procedure data and returns a pandas DataFrame with the processed data.
    
    Returns:
        data (pandas DataFrame): A DataFrame containing processed medication, diagnosis, and procedure data, merged on unique SUBJECT_ID and HADM_ID. The DataFrame has the following columns:
            - SUBJECT_ID: unique identifier for each patient
            - HADM_ID: unique identifier for each hospital admission
            - ICD9_CODE: list of diagnosis codes associated with the admission
            - NDC: list of medication NDC codes associated with the admission
            - PRO_CODE: list of procedure codes associated with the admission
            - NDC_Len: number of unique NDC codes associated with the admission
    """
    # get med and diag (visit>=2)
    medication_pd = process_med(med_file)
    medication_pd = ndc2atc4(medication_pd)

    diagnosis_pd = process_diag(diag_file)
    diagnosis_pd = filter_2000_most_diag(diagnosis_pd)

    procedure_pd = process_procedure(procedure_file)

    medication_pd_key = medication_pd[['SUBJECT_ID', 'HADM_ID']].drop_duplicates()
    diagnosis_pd_key = diagnosis_pd[['SUBJECT_ID', 'HADM_ID']].drop_duplicates()
    procedure_pd_key = procedure_pd[['SUBJECT_ID', 'HADM_ID']].drop_duplicates()

    combined_key = functools.reduce(lambda x, y: pd.merge(x, y, on=['SUBJECT_ID', 'HADM_ID'], how='inner'), 
                                    [medication_pd_key, diagnosis_pd_key, procedure_pd_key])
    
    diagnosis_pd = diagnosis_pd.merge(combined_key, on=['SUBJECT_ID', 'HADM_ID'], how='inner')
    medication_pd = medication_pd.merge(combined_key, on=['SUBJECT_ID', 'HADM_ID'], how='inner')
    procedure_pd = procedure_pd.merge(combined_key, on=['SUBJECT_ID', 'HADM_ID'], how='inner')

    # flatten and merge
    diagnosis_pd = diagnosis_pd.groupby(by=['SUBJECT_ID','HADM_ID'])['ICD9_CODE'].unique().reset_index()  
    medication_pd = medication_pd.groupby(by=['SUBJECT_ID', 'HADM_ID'])['NDC'].unique().reset_index()
    procedure_pd = procedure_pd.groupby(by=['SUBJECT_ID','HADM_ID'])['ICD9_CODE'].unique().reset_index().rename(columns={'ICD9_CODE':'PRO_CODE'})  
    medication_pd['NDC'] = medication_pd['NDC'].map(convert_to_list)
    procedure_pd['PRO_CODE'] = procedure_pd['PRO_CODE'].map(convert_to_list)
    
    data = functools.reduce(lambda x, y: pd.merge(x, y, on=['SUBJECT_ID', 'HADM_ID'], how='inner'), 
                            [diagnosis_pd, medication_pd, procedure_pd])
    
    data['NDC_Len'] = data['NDC'].map(len)
    return data

def statistics():
    """
    This function prints various statistics related to the processed medical data. These statistics include the number of unique patients and clinical events, the number of unique diagnoses, medications, and procedures, and various averages and maxes calculated over patient visits.

        Prints:
        - #patients: number of unique patients
        - #clinical events: total number of clinical events
        - #diagnosis: number of unique diagnosis codes
        - #med: number of unique medication NDC codes
        - #procedure: number of unique procedure codes
        - #avg of diagnoses: average number of diagnoses per clinical event
        - #avg of medicines: average number of medications per clinical event
        - #avg of procedures: average number of procedures per clinical event
        - #avg of vists: average number of visits per patient
        - #max of diagnoses: maximum number of diagnoses associated with a single visit
        - #max of medicines: maximum number of medications associated with a single visit
        - #max of procedures: maximum number of procedures associated with a single visit
        - #max of visit: maximum number of visits associated with a single patient
    """
    print('#patients ', data['SUBJECT_ID'].nunique())
    print('#clinical events ', len(data))
    
    unique_diag = data['ICD9_CODE'].explode().nunique()
    unique_med = data['NDC'].explode().nunique()
    unique_pro = data['PRO_CODE'].explode().nunique()
    
    print('#diagnosis ', unique_diag)
    print('#med ', unique_med)
    print('#procedure', unique_pro)
    
    avg_diag = 0
    avg_med = 0
    avg_pro = 0
    max_diag = 0
    max_med = 0
    max_pro = 0
    cnt = 0
    max_visit = 0
    avg_visit = 0

    for subject_id in data['SUBJECT_ID'].unique():
        item_data = data[data['SUBJECT_ID'] == subject_id]
        x = []
        y = []
        z = []
        visit_cnt = 0
        for index, row in item_data.iterrows():
            visit_cnt += 1
            cnt += 1
            x.extend(list(row['ICD9_CODE']))
            y.extend(list(row['NDC']))
            z.extend(list(row['PRO_CODE']))
        x = set(x)
        y = set(y)
        z = set(z)
        avg_diag += len(x)
        avg_med += len(y)
        avg_pro += len(z)
        avg_visit += visit_cnt
        if len(x) > max_diag:
            max_diag = len(x)
        if len(y) > max_med:
            max_med = len(y) 
        if len(z) > max_pro:
            max_pro = len(z)
        if visit_cnt > max_visit:
            max_visit = visit_cnt
        
    print('#avg of diagnoses ', avg_diag / cnt)
    print('#avg of medicines ', avg_med / cnt)
    print('#avg of procedures ', avg_pro / cnt)
    print('#avg of vists ', avg_visit / len(data['SUBJECT_ID'].unique()))

    print('#max of diagnoses ', max_diag)
    print('#max of medicines ', max_med)
    print('#max of procedures ', max_pro)
    print('#max of visit ', max_visit)
    
    
data = process_all()
statistics()
data.to_pickle('data_final.pkl')

#patients  6350
#clinical events  15016
#diagnosis  1958
#med  145
#procedure 1426
#avg of diagnoses  10.514318060735215
#avg of medicines  8.80420884389984
#avg of procedures  3.8445657964837507
#avg of vists  2.3647244094488187
#max of diagnoses  128
#max of medicines  55
#max of procedures  50
#max of visit  29


### Create vocabulary for medical codes and save patient record in pickle form

In [7]:
class Voc:
    """
    A vocabulary class that maps words to indices and vice versa.
    """
    
    def __init__(self):
        """
        Initialize the vocabulary with empty word-to-index and index-to-word mappings.
        """
        self.idx2word = {}
        self.word2idx = {}

    def add_sentence(self, sentence):
        """
        Add a sentence to the vocabulary.
        
        Args:
            sentence (list): A list of words in the sentence.
        """
        for word in sentence:
            if word not in self.word2idx:
                self.idx2word[len(self.word2idx)] = word
                self.word2idx[word] = len(self.word2idx)
                
def create_str_token_mapping(df):
    """
    Create string to token mappings for diagnosis, medication, and procedure codes.
    
    Args:
        df (pandas.DataFrame): A DataFrame containing the medical records data.
        
    Returns:
        Tuple[Voc, Voc, Voc]: A tuple of three Voc objects, one for each code type.
    """
    diag_voc = Voc()
    med_voc = Voc()
    pro_voc = Voc()
    
    for index, row in df.iterrows():
        diag_voc.add_sentence(row['ICD9_CODE'])
        med_voc.add_sentence(row['NDC'])
        pro_voc.add_sentence(row['PRO_CODE'])
    
    with open('voc_final.pkl', 'wb') as f:
        dill.dump({'diag_voc': diag_voc, 'med_voc': med_voc, 'pro_voc': pro_voc}, f)
        
    return diag_voc, med_voc, pro_voc

def create_patient_record(df, diag_voc, med_voc, pro_voc):
    """
    Create a patient record data structure from the medical records data.
    
    Args:
        df (pandas.DataFrame): A DataFrame containing the medical records data.
        diag_voc (Voc): A Voc object for diagnosis codes.
        med_voc (Voc): A Voc object for medication codes.
        pro_voc (Voc): A Voc object for procedure codes.
        
    Returns:
        list: A list of patient records, where each record is a list of admissions, and each admission is a
        list of three lists of code indices (one for each code type).
    """
    records = [] # (patient, code_kind:3, codes)  code_kind:diag, proc, med
    for subject_id in df['SUBJECT_ID'].unique():
        item_df = df[df['SUBJECT_ID'] == subject_id]
        patient = []
        for index, row in item_df.iterrows():
            admission = []
            admission.append([diag_voc.word2idx[i] for i in row['ICD9_CODE']])
            admission.append([pro_voc.word2idx[i] for i in row['PRO_CODE']])
            admission.append([med_voc.word2idx[i] for i in row['NDC']])
            patient.append(admission)
        records.append(patient) 
    
    with open('records_final.pkl', 'wb') as f:
        dill.dump(records, f)
        
    return records
        
    
path='data_final.pkl'
df = pd.read_pickle(path)
diag_voc, med_voc, pro_voc = create_str_token_mapping(df)
records = create_patient_record(df, diag_voc, med_voc, pro_voc)
print(len(diag_voc.idx2word), len(med_voc.idx2word), len(pro_voc.idx2word))

1958 145 1426


### Construct DDI, EHR Adj and DDI Adj data

In [8]:
# atc -> cid
ddi_file = 'drug-DDI.csv'
cid_atc = 'drug-atc.csv'
voc_file = 'voc_final.pkl'
data_path = 'records_final.pkl'
TOPK = 40 # topk drug-drug interaction

records =  dill.load(open(data_path, 'rb'))
cid2atc_dic = defaultdict(set)
med_voc = dill.load(open(voc_file, 'rb'))['med_voc']
med_voc_size = len(med_voc.idx2word)
med_unique_word = [med_voc.idx2word[i] for i in range(med_voc_size)]
atc3_atc4_dic = defaultdict(set)

for item in med_unique_word:
    atc3_atc4_dic[item[:4]].add(item)
    
with open(cid_atc, 'r') as f:
    for line in f:
        line_ls = line[:-1].split(',')
        cid = line_ls[0]
        atcs = line_ls[1:]
        for atc in atcs:
            if len(atc3_atc4_dic[atc[:4]]) != 0:
                cid2atc_dic[cid].add(atc[:4])
 
# ddi load
ddi_df = pd.read_csv(ddi_file)
ddi_topk_pd = (
    ddi_df.groupby(['Polypharmacy Side Effect', 'Side Effect Name'])
    .size()
    .reset_index(name='count')
    .sort_values('count', ascending=False)
    .tail(TOPK)
)
ddi_df = (
    ddi_df[ddi_df['Side Effect Name'].isin(ddi_topk_pd['Side Effect Name'].tolist())]
    .drop_duplicates(subset=['STITCH 1', 'STITCH 2'])
    .reset_index(drop=True)
)

# weighted ehr adj 
ehr_adj = np.zeros((med_voc_size, med_voc_size))
for patient in records:
    for adm in patient:
        med_set = adm[2]
        for i, med_i in enumerate(med_set):
            for j, med_j in enumerate(med_set):
                if j<=i:
                    continue
                ehr_adj[med_i, med_j] = 1
                ehr_adj[med_j, med_i] = 1
dill.dump(ehr_adj, open('ehr_adj_final.pkl', 'wb'))  

# ddi adj
ddi_adj = np.zeros((med_voc_size,med_voc_size))
for index, row in ddi_df.iterrows():
    # ddi
    cid1 = row['STITCH 1']
    cid2 = row['STITCH 2']
    
    # cid -> atc_level3
    for atc_i in cid2atc_dic[cid1]:
        for atc_j in cid2atc_dic[cid2]:
            
            # atc_level3 -> atc_level4
            for i in atc3_atc4_dic[atc_i]:
                for j in atc3_atc4_dic[atc_j]:
                    if med_voc.word2idx[i] != med_voc.word2idx[j]:
                        ddi_adj[med_voc.word2idx[i], med_voc.word2idx[j]] = 1
                        ddi_adj[med_voc.word2idx[j], med_voc.word2idx[i]] = 1
dill.dump(ddi_adj, open('ddi_A_final.pkl', 'wb')) 
                        
print('Complete!')

Complete!


## Part 2: Explore the splitting method's impact on DDI rates
We had initial difficulties reproducing the paper's results, and we checked the splitting method in the authors' repo. By using some customized code, we calculated the DDI rates in each dataset, including all data, train data, test data, and eval data. If we directly split data using the indices, there are big differences in the various datasets' DDI rates.

In [19]:
data_path = '../data/records_final.pkl'
data = dill.load(open(data_path, 'rb'))   

split_point = int(len(data) * 2 / 3)                  # 4233
data_train = data[:split_point]                       # 4233 use 67% data for training

eval_len = int(len(data[split_point:]) / 2)           # 1058 divide the rest 33% into eval and test
data_test = data[split_point:split_point + eval_len]  # 1058
data_eval = data[split_point+eval_len:]               # 1059

In [20]:
all_med = []
for patient in data:
    patient_visit = []
    for visit in patient:
        patient_visit.append(visit[2])        
    all_med.append(patient_visit)

train_med = []
for patient in data_train:
    patient_visit = []
    for visit in patient:
        patient_visit.append(visit[2])        
    train_med.append(patient_visit)

test_med = []
for patient in data_test:
    patient_visit = []
    for visit in patient:
        patient_visit.append(visit[2])        
    test_med.append(patient_visit)
    
val_med = []
for patient in data_eval:
    patient_visit = []
    for visit in patient:
        patient_visit.append(visit[2])        
    val_med.append(patient_visit)

In [26]:
def ddi_rate_score(record, path='../data/ddi_A_final.pkl'):
    # ddi rate
    ddi_A = dill.load(open(path, 'rb'))
    all_cnt = 0
    dd_cnt = 0
    for patient in record:
        for adm in patient:
            med_code_set = adm
            for i, med_i in enumerate(med_code_set):
                for j, med_j in enumerate(med_code_set):
                    if j <= i:
                        continue
                    all_cnt += 1
                    if ddi_A[med_i, med_j] == 1 or ddi_A[med_j, med_i] == 1:
                        dd_cnt += 1
    if all_cnt == 0:
        return 0
    return dd_cnt / all_cnt

In [27]:
ddi_rate_score(all_med), ddi_rate_score(train_med), ddi_rate_score(test_med), ddi_rate_score(val_med)

(0.08640852474666356,
 0.09050506339269059,
 0.07768553459119497,
 0.07869765923715379)

In order to address this uneven dataset issue, we shuffled the dataset and did the indices split. Now the DDI rates among all the datasets are relatively even now. We used the shuffled datasets to train and compared models in the next steps.

In [28]:
import random
random.seed(1203)

random.shuffle(data)  
data_train = data[:split_point]    
data_test = data[split_point:split_point + eval_len]  
data_eval = data[split_point+eval_len:]               

In [30]:
all_med = []
for patient in data:
    patient_visit = []
    for visit in patient:
        patient_visit.append(visit[2])        
    all_med.append(patient_visit)

train_med = []
for patient in data_train:
    patient_visit = []
    for visit in patient:
        patient_visit.append(visit[2])        
    train_med.append(patient_visit)

test_med = []
for patient in data_test:
    patient_visit = []
    for visit in patient:
        patient_visit.append(visit[2])        
    test_med.append(patient_visit)
    
val_med = []
for patient in data_eval:
    patient_visit = []
    for visit in patient:
        patient_visit.append(visit[2])        
    val_med.append(patient_visit)

ddi_rate_score(all_med), ddi_rate_score(train_med), ddi_rate_score(test_med), ddi_rate_score(val_med)

(0.08640852474666356,
 0.08665412003769873,
 0.08676206150376398,
 0.08512859907741303)

## Part 3: Run base models and the GAMENet model
We directly run the scripts from the command line in order to make the notebook clean. All the codes are in our GitHub repo. For the two GAMENet runs, we used RMSprop as the optimizer, as mentioned in our final report, which showed better performances than the Adam optimizer used in the authors' original repo. Ground truth DDI rate is printed in the run results, so we can calculate DDI rate reduction without ambiguity.

In [1]:
%cd ../code

/Users/rachelzou/Documents/Courses/MCS/DLH/Final-Project/RZ-GameNet/code


In [12]:
%run baseline_near.py

	DDI Rate: 0.0885, Jaccard: 0.3722, PRAUC: 0.3570, AVG_PRC: 0.5471, AVG_RECALL: 0.5527, AVG_F1: 0.5256

avg med 14.205563093622795


In [13]:
%run train_LR.py

	DDI Rate: 0.0803, Jaccard: 0.4116, PRAUC: 0.6687, AVG_PRC: 0.6623, AVG_RECALL: 0.5283, AVG_F1: 0.5682

avg med 11.088862559241706


In [15]:
%run train_Leap.py --model_name Leap

parameters 436884
Eval--Epoch: 0, Step: 1058/10593	GT DDI Rate: 0.0851, DDI Rate: 0.1364, Jaccard: 0.3379,  PRAUC: 0.5487, AVG_PRC: 0.4900, AVG_RECALL: 0.5415, AVG_F1: 0.4956
avg med 14.923879443585781
	Epoch: 0, Loss1: 3.2181, One Epoch Time: 2.88m, Appro Left Time: 1.87h

Eval--Epoch: 1, Step: 1058/10593	GT DDI Rate: 0.0851, DDI Rate: 0.1236, Jaccard: 0.3393,  PRAUC: 0.5353, AVG_PRC: 0.4766, AVG_RECALL: 0.5667, AVG_F1: 0.4970
avg med 16.11321483771252
	Epoch: 1, Loss1: 2.9221, One Epoch Time: 2.90m, Appro Left Time: 1.84h

Eval--Epoch: 2, Step: 1058/10593	GT DDI Rate: 0.0851, DDI Rate: 0.1164, Jaccard: 0.3560,  PRAUC: 0.5470, AVG_PRC: 0.5062, AVG_RECALL: 0.5719, AVG_F1: 0.5140
avg med 15.328438948995363
	Epoch: 2, Loss1: 2.8184, One Epoch Time: 2.89m, Appro Left Time: 1.78h

Eval--Epoch: 3, Step: 1058/10593	GT DDI Rate: 0.0851, DDI Rate: 0.1118, Jaccard: 0.3573,  PRAUC: 0.5425, AVG_PRC: 0.5001, AVG_RECALL: 0.5845, AVG_F1: 0.5153
avg med 15.86707882534776
	Epoch: 3, Loss1: 2.7524, One


Eval--Epoch: 32, Step: 1058/10593	GT DDI Rate: 0.0851, DDI Rate: 0.0809, Jaccard: 0.3817,  PRAUC: 0.5702, AVG_PRC: 0.5406, AVG_RECALL: 0.5850, AVG_F1: 0.5373
avg med 15.253091190108192
	Epoch: 32, Loss1: 2.3019, One Epoch Time: 2.66m, Appro Left Time: 0.31h

Eval--Epoch: 33, Step: 1058/10593	GT DDI Rate: 0.0851, DDI Rate: 0.0845, Jaccard: 0.3827,  PRAUC: 0.5662, AVG_PRC: 0.5403, AVG_RECALL: 0.5877, AVG_F1: 0.5385
avg med 15.326506955177743
	Epoch: 33, Loss1: 2.2991, One Epoch Time: 2.66m, Appro Left Time: 0.27h

Eval--Epoch: 34, Step: 1058/10593	GT DDI Rate: 0.0851, DDI Rate: 0.0783, Jaccard: 0.3827,  PRAUC: 0.5661, AVG_PRC: 0.5413, AVG_RECALL: 0.5848, AVG_F1: 0.5376
avg med 15.23338485316847
	Epoch: 34, Loss1: 2.2924, One Epoch Time: 2.65m, Appro Left Time: 0.22h

Eval--Epoch: 35, Step: 1058/10593	GT DDI Rate: 0.0851, DDI Rate: 0.0809, Jaccard: 0.3825,  PRAUC: 0.5670, AVG_PRC: 0.5423, AVG_RECALL: 0.5847, AVG_F1: 0.5379
avg med 15.198608964451314
	Epoch: 35, Loss1: 2.2877, One Epoch T

In [17]:
%run train_Leap.py --model_name Leap --resume_path Epoch_31_JA_0.3868_DDI_0.0816.model --eval

parameters 436884
Eval--Epoch: 0, Step: 1057/1058	GT DDI Rate: 0.0868, DDI Rate: 0.0851, Jaccard: 0.3809,  PRAUC: 0.5595, AVG_PRC: 0.5345, AVG_RECALL: 0.5872, AVG_F1: 0.5363
avg med 15.21958925750395


In [21]:
%run train_Retain.py --model_name Retain

parameters 289490
Train--Epoch: 0, Step: 4232/4233
Eval--Epoch: 0, Step: 1058/1059	GT DDI Rate: 0.0876, DDI Rate: 0.0850, Jaccard: 0.3115,  PRAUC: 0.5393, AVG_PRC: 0.3994, AVG_RECALL: 0.6615, AVG_F1: 0.4648
avg med 22.365598430346633
	Epoch: 0, Loss1: 0.8704, One Epoch Time: 0.44m, Appro Left Time: 0.57h

Train--Epoch: 1, Step: 4232/4233
Eval--Epoch: 1, Step: 1058/1059	GT DDI Rate: 0.0876, DDI Rate: 0.0933, Jaccard: 0.3486,  PRAUC: 0.5939, AVG_PRC: 0.4529, AVG_RECALL: 0.6584, AVG_F1: 0.5063
avg med 19.69718770438195
	Epoch: 1, Loss1: 0.5137, One Epoch Time: 0.42m, Appro Left Time: 0.54h

Train--Epoch: 2, Step: 4232/4233
Eval--Epoch: 2, Step: 1058/1059	GT DDI Rate: 0.0876, DDI Rate: 0.0950, Jaccard: 0.3606,  PRAUC: 0.6068, AVG_PRC: 0.4641, AVG_RECALL: 0.6672, AVG_F1: 0.5191
avg med 19.346631785480707
	Epoch: 2, Loss1: 0.4601, One Epoch Time: 0.42m, Appro Left Time: 0.54h

Train--Epoch: 3, Step: 4232/4233
Eval--Epoch: 3, Step: 1058/1059	GT DDI Rate: 0.0876, DDI Rate: 0.0925, Jaccard: 0.3

Eval--Epoch: 28, Step: 1058/1059	GT DDI Rate: 0.0876, DDI Rate: 0.0922, Jaccard: 0.4078,  PRAUC: 0.6475, AVG_PRC: 0.5669, AVG_RECALL: 0.6168, AVG_F1: 0.5681
avg med 15.089601046435579
	Epoch: 28, Loss1: 0.3415, One Epoch Time: 0.40m, Appro Left Time: 0.34h

Train--Epoch: 29, Step: 4232/4233
Eval--Epoch: 29, Step: 1058/1059	GT DDI Rate: 0.0876, DDI Rate: 0.0954, Jaccard: 0.4101,  PRAUC: 0.6478, AVG_PRC: 0.5632, AVG_RECALL: 0.6259, AVG_F1: 0.5703
avg med 15.487246566383257
	Epoch: 29, Loss1: 0.3356, One Epoch Time: 0.40m, Appro Left Time: 0.33h

Train--Epoch: 30, Step: 4232/4233
Eval--Epoch: 30, Step: 1058/1059	GT DDI Rate: 0.0876, DDI Rate: 0.0953, Jaccard: 0.4095,  PRAUC: 0.6489, AVG_PRC: 0.5742, AVG_RECALL: 0.6131, AVG_F1: 0.5699
avg med 14.734466971877044
	Epoch: 30, Loss1: 0.3327, One Epoch Time: 0.40m, Appro Left Time: 0.32h

Train--Epoch: 31, Step: 4232/4233
Eval--Epoch: 31, Step: 1058/1059	GT DDI Rate: 0.0876, DDI Rate: 0.0940, Jaccard: 0.4106,  PRAUC: 0.6478, AVG_PRC: 0.5641, AV

Eval--Epoch: 56, Step: 1058/1059	GT DDI Rate: 0.0876, DDI Rate: 0.0901, Jaccard: 0.4145,  PRAUC: 0.6543, AVG_PRC: 0.5739, AVG_RECALL: 0.6238, AVG_F1: 0.5744
avg med 15.224329627207325
	Epoch: 56, Loss1: 0.3158, One Epoch Time: 0.40m, Appro Left Time: 0.15h

Train--Epoch: 57, Step: 4232/4233
Eval--Epoch: 57, Step: 1058/1059	GT DDI Rate: 0.0876, DDI Rate: 0.0932, Jaccard: 0.4172,  PRAUC: 0.6527, AVG_PRC: 0.5664, AVG_RECALL: 0.6380, AVG_F1: 0.5768
avg med 15.710922171353825
	Epoch: 57, Loss1: 0.3189, One Epoch Time: 0.40m, Appro Left Time: 0.15h

Train--Epoch: 58, Step: 4232/4233
Eval--Epoch: 58, Step: 1058/1059	GT DDI Rate: 0.0876, DDI Rate: 0.0890, Jaccard: 0.4161,  PRAUC: 0.6537, AVG_PRC: 0.5728, AVG_RECALL: 0.6270, AVG_F1: 0.5758
avg med 15.424460431654676
	Epoch: 58, Loss1: 0.3169, One Epoch Time: 0.40m, Appro Left Time: 0.14h

Train--Epoch: 59, Step: 4232/4233
Eval--Epoch: 59, Step: 1058/1059	GT DDI Rate: 0.0876, DDI Rate: 0.0911, Jaccard: 0.4165,  PRAUC: 0.6539, AVG_PRC: 0.5713, AV

In [22]:
%run train_Retain.py --model_name Retain --resume_path Epoch_67_JA_0.4176_DDI_0.0881.model --eval

parameters 289490

Eval--Epoch: 0, Step: 1057/1058	GT DDI Rate: 0.0877, DDI Rate: 0.0914, Jaccard: 0.4132,  PRAUC: 0.6482, AVG_PRC: 0.5588, AVG_RECALL: 0.6345, AVG_F1: 0.5724
avg med 15.614654002713705


In [24]:
%run train_GAMENet.py --model_name GAMENet --ddi

parameters 536534
Eval--Epoch: 0, Step: 1058/10593, L_p cnt: 6754, L_neg cnt: 3142	GT DDI Rate: 0.0851, DDI Rate: 0.0398, Jaccard: 0.3418,  PRAUC: 0.6060, AVG_PRC: 0.6724, AVG_RECALL: 0.4240, AVG_F1: 0.5005
avg med 8.311823802163833
	Epoch: 0, Loss: 0.1759, One Epoch Time: 1.74m, Appro Left Time: 1.13h

Eval--Epoch: 1, Step: 1058/10593, L_p cnt: 8652, L_neg cnt: 1244	GT DDI Rate: 0.0851, DDI Rate: 0.0460, Jaccard: 0.3637,  PRAUC: 0.6157, AVG_PRC: 0.6471, AVG_RECALL: 0.4688, AVG_F1: 0.5237
avg med 9.688948995363214
	Epoch: 1, Loss: 0.2205, One Epoch Time: 1.72m, Appro Left Time: 1.09h

Eval--Epoch: 2, Step: 1058/10593, L_p cnt: 7925, L_neg cnt: 1971	GT DDI Rate: 0.0851, DDI Rate: 0.0400, Jaccard: 0.3726,  PRAUC: 0.6245, AVG_PRC: 0.6508, AVG_RECALL: 0.4792, AVG_F1: 0.5329
avg med 9.744976816074189
	Epoch: 2, Loss: 0.1973, One Epoch Time: 1.94m, Appro Left Time: 1.20h

Eval--Epoch: 3, Step: 1058/10593, L_p cnt: 7619, L_neg cnt: 2277	GT DDI Rate: 0.0851, DDI Rate: 0.0445, Jaccard: 0.3756, 

avg med 11.652627511591962
	Epoch: 28, Loss: 0.2070, One Epoch Time: 1.70m, Appro Left Time: 0.31h

Eval--Epoch: 29, Step: 1058/10593, L_p cnt: 9468, L_neg cnt: 428	GT DDI Rate: 0.0851, DDI Rate: 0.0826, Jaccard: 0.4302,  PRAUC: 0.6767, AVG_PRC: 0.6533, AVG_RECALL: 0.5652, AVG_F1: 0.5871
avg med 11.639103554868624
	Epoch: 29, Loss: 0.2076, One Epoch Time: 1.67m, Appro Left Time: 0.28h

Eval--Epoch: 30, Step: 1058/10593, L_p cnt: 9539, L_neg cnt: 357	GT DDI Rate: 0.0851, DDI Rate: 0.0872, Jaccard: 0.4325,  PRAUC: 0.6764, AVG_PRC: 0.6472, AVG_RECALL: 0.5722, AVG_F1: 0.5892
avg med 11.906105100463678
	Epoch: 30, Loss: 0.2089, One Epoch Time: 1.84m, Appro Left Time: 0.28h

Eval--Epoch: 31, Step: 1058/10593, L_p cnt: 9598, L_neg cnt: 298	GT DDI Rate: 0.0851, DDI Rate: 0.0839, Jaccard: 0.4323,  PRAUC: 0.6769, AVG_PRC: 0.6474, AVG_RECALL: 0.5720, AVG_F1: 0.5892
avg med 11.877511591962906
	Epoch: 31, Loss: 0.2095, One Epoch Time: 1.68m, Appro Left Time: 0.22h

Eval--Epoch: 32, Step: 1058/10593

In [25]:
%run train_GAMENet.py --model_name GAMENet --ddi --resume_path Epoch_33_JA_0.4328_DDI_0.0826.model --eval

parameters 536534
Eval--Epoch: 0, Step: 1057/1058	GT DDI Rate: 0.0868, DDI Rate: 0.0861, Jaccard: 0.4255,  PRAUC: 0.6709, AVG_PRC: 0.6451, AVG_RECALL: 0.5643, AVG_F1: 0.5826
avg med 11.66785150078989


In [5]:
%run train_GAMENet.py --model_name GAMENet

parameters 536534
Eval--Epoch: 0, Step: 1058/10593, L_p cnt: 0, L_neg cnt: 0	GT DDI Rate: 0.0851, DDI Rate: 0.0644, Jaccard: 0.3464,  PRAUC: 0.6053, AVG_PRC: 0.6527, AVG_RECALL: 0.4402, AVG_F1: 0.5054
avg med 9.093508500772797
	Epoch: 0, Loss: 0.2581, One Epoch Time: 1.63m, Appro Left Time: 1.06h

Eval--Epoch: 1, Step: 1058/10593, L_p cnt: 0, L_neg cnt: 0	GT DDI Rate: 0.0851, DDI Rate: 0.0507, Jaccard: 0.3809,  PRAUC: 0.6263, AVG_PRC: 0.6445, AVG_RECALL: 0.4941, AVG_F1: 0.5389
avg med 10.234157650695519
	Epoch: 1, Loss: 0.2508, One Epoch Time: 1.67m, Appro Left Time: 1.06h

Eval--Epoch: 2, Step: 1058/10593, L_p cnt: 0, L_neg cnt: 0	GT DDI Rate: 0.0851, DDI Rate: 0.0620, Jaccard: 0.3898,  PRAUC: 0.6378, AVG_PRC: 0.6517, AVG_RECALL: 0.5016, AVG_F1: 0.5476
avg med 10.221792890262751
	Epoch: 2, Loss: 0.2418, One Epoch Time: 1.60m, Appro Left Time: 0.99h

Eval--Epoch: 3, Step: 1058/10593, L_p cnt: 0, L_neg cnt: 0	GT DDI Rate: 0.0851, DDI Rate: 0.0630, Jaccard: 0.3951,  PRAUC: 0.6471, AVG_PR


Eval--Epoch: 29, Step: 1058/10593, L_p cnt: 0, L_neg cnt: 0	GT DDI Rate: 0.0851, DDI Rate: 0.0809, Jaccard: 0.4328,  PRAUC: 0.6810, AVG_PRC: 0.6496, AVG_RECALL: 0.5703, AVG_F1: 0.5892
avg med 11.818778979907265
	Epoch: 29, Loss: 0.2119, One Epoch Time: 1.48m, Appro Left Time: 0.25h

Eval--Epoch: 30, Step: 1058/10593, L_p cnt: 0, L_neg cnt: 0	GT DDI Rate: 0.0851, DDI Rate: 0.0847, Jaccard: 0.4334,  PRAUC: 0.6805, AVG_PRC: 0.6496, AVG_RECALL: 0.5707, AVG_F1: 0.5898
avg med 11.857032457496135
	Epoch: 30, Loss: 0.2117, One Epoch Time: 1.53m, Appro Left Time: 0.23h

Eval--Epoch: 31, Step: 1058/10593, L_p cnt: 0, L_neg cnt: 0	GT DDI Rate: 0.0851, DDI Rate: 0.0886, Jaccard: 0.4349,  PRAUC: 0.6792, AVG_PRC: 0.6467, AVG_RECALL: 0.5760, AVG_F1: 0.5914
avg med 12.006568778979908
	Epoch: 31, Loss: 0.2112, One Epoch Time: 1.58m, Appro Left Time: 0.21h

Eval--Epoch: 32, Step: 1058/10593, L_p cnt: 0, L_neg cnt: 0	GT DDI Rate: 0.0851, DDI Rate: 0.0828, Jaccard: 0.4355,  PRAUC: 0.6812, AVG_PRC: 0.6485

In [6]:
%run train_GAMENet.py --model_name GAMENet --resume_path Epoch_32_JA_0.4355_DDI_0.0828.model --eval

parameters 536534
Eval--Epoch: 0, Step: 1057/1058	GT DDI Rate: 0.0868, DDI Rate: 0.0844, Jaccard: 0.4294,  PRAUC: 0.6747, AVG_PRC: 0.6422, AVG_RECALL: 0.5711, AVG_F1: 0.5860
avg med 11.873222748815166


## Conclusion
Our demo notebook showcases the essential steps involved in reproducing the research paper "GAMENet: Graph Augmented MEmory Networks for Recommending Medication Combination". The results and additional things we have tried are discussed in more detail in our final report. We hope that this demo will help others to understand the research process and inspire further research in this field.