<a href="https://colab.research.google.com/github/russpv/SafeDrug/blob/main/LR.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Setup

In [1]:
gpu_info = !nvidia-smi
gpu_info = '\n'.join(gpu_info)
if gpu_info.find('failed') >= 0:
  print('Not connected to a GPU')
else:
  print(gpu_info)

from psutil import virtual_memory
ram_gb = virtual_memory().total / 1e9
print('Your runtime has {:.1f} gigabytes of available RAM\n'.format(ram_gb))

if ram_gb < 20:
  print('Not using a high-RAM runtime')
else:
  print('You are using a high-RAM runtime!')

! pip install memory_profiler

Not connected to a GPU
Your runtime has 13.6 gigabytes of available RAM

Not using a high-RAM runtime
Collecting memory_profiler
  Downloading memory_profiler-0.60.0.tar.gz (38 kB)
Building wheels for collected packages: memory-profiler
  Building wheel for memory-profiler (setup.py) ... [?25l[?25hdone
  Created wheel for memory-profiler: filename=memory_profiler-0.60.0-py3-none-any.whl size=31284 sha256=7479175c62cdfd7c61818e87b16726a7350e7374af34e3efc78f7edcd3b32323
  Stored in directory: /root/.cache/pip/wheels/67/2b/fb/326e30d638c538e69a5eb0aa47f4223d979f502bbdb403950f
Successfully built memory-profiler
Installing collected packages: memory-profiler
Successfully installed memory-profiler-0.60.0


# Args

In [2]:
import argparse
def arg_parser():
    """ Parse command line arguments

    Outputs:
        arguments {object} -- object containing command line arguments
    """

    # Initializer
    parser = argparse.ArgumentParser()

    # Add arguments here
    parser.add_argument('--Test', action='store_true', default=False, help="test mode")
    parser.add_argument('--model_name', type=str, default='none', help="model name")
    parser.add_argument('--resume_path', type=str, default='none', help='resume path')
    parser.add_argument('--lr', type=float, default=5e-4, help='learning rate')
    parser.add_argument('--target_ddi', type=float, default=0.06, help="target ddi")
    parser.add_argument('--dropout', type=float, default=0.5, help="dropout for embeddings")
    parser.add_argument('--cuda', type=int, default=0, help='which cuda') ###

    parser.add_argument('--smalldata', type=int, default=1, help='debug data set') ###
    parser.add_argument('--mydata', type=int, default=1, help='paper code') ###
    parser.add_argument('--Inf_time', type=int, default=0, help='inference time test') ###
 
    # Parse and return arguments
    return(parser.parse_args(args=[]))

args = arg_parser()

In [3]:
! pip install memory_profiler
%load_ext memory_profiler



In [4]:
import os
import dill
import random
import numpy as np
import pandas as pd
import sys
import time
import statistics
import datetime as dt
import logging
from sklearn.linear_model import LogisticRegression
from sklearn.ensemble import RandomForestClassifier
from sklearn.multiclass import OneVsRestClassifier
from collections import defaultdict
from sklearn.model_selection import GridSearchCV
from sklearn.metrics import jaccard_score

# set seed
seed = 1203 #1203
random.seed(seed)

# define data path
DATA_PATH = "drive/MyDrive/DL4H/Project/PaperCode/processed_orig/"
MYDATA_PATH = "drive/MyDrive/DL4H/Project/SAFEDRUG_lib/data/processed/"
WORKING_PATH = "drive/MyDrive/DL4H/Project/LR/"
TEST_PATH = "drive/MyDrive/DL4H/Project/LR/results/"

# define dataset
args.mydata = 0
args.smalldata = 0

# setting
args.model_name = 'LR'
# args.resume_path = WORKING_PATH + 'saved/' + ''

logger = logging.getLogger('')
logger.setLevel(logging.CRITICAL)

# Data Setup

In [5]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [6]:
# Data switch
if args.mydata == 1:
    data_path = MYDATA_PATH + 'ehr.pkl'
    voc_path = MYDATA_PATH + 'vocabs.pkl'

    # ehr_adj_path = MYDATA_PATH + 'ehradj.pkl'
    ddi_adj_path = MYDATA_PATH + 'ddiadj.pkl'
    # ddi_mask_path = MYDATA_PATH + 'hmask.pkl'
    # molecule_path = MYDATA_PATH + 'atc2SMILES.pkl'

    voc = dill.load(open(voc_path, 'rb'))
    diag_voc, pro_voc, med_voc = voc['diag_vocab'].index2word, voc['pro_vocab'].index2word, voc['med_vocab'].index2word

else:
    data_path = DATA_PATH + 'records_final.pkl'
    voc_path = DATA_PATH + 'voc_final.pkl'


    # ehr_adj_path = DATA_PATH + 'ehr_adj_final.pkl'
    ddi_adj_path = DATA_PATH + 'ddi_A_final.pkl'
    # ddi_mask_path = DATA_PATH + 'ddi_mask_H.pkl'
    # molecule_path = DATA_PATH + 'atc3toSMILES.pkl'
    
    voc = dill.load(open(voc_path, 'rb'))
    diag_voc, pro_voc, med_voc = voc['diag_voc'].idx2word, voc['pro_voc'].idx2word, voc['med_voc'].idx2word

# ehr_adj = dill.load(open(ehr_adj_path, 'rb'))
# ddi_adj = dill.load(open(ddi_adj_path, 'rb'))
# ddi_mask_H = dill.load(open(ddi_mask_path, 'rb'))
data = dill.load(open(data_path, 'rb'))
# molecule = dill.load(open(molecule_path, 'rb')) 

if args.smalldata == 1:
    data_train = data[:200] 
    data_test = data[200:250]
    data_eval = data[250:300]
else:
    split_point = int(len(data) * 2 / 3)
    data_train = data[:split_point]
    eval_len = int(len(data[split_point:]) / 2)
    data_test = data[split_point:split_point + eval_len]
    data_eval = data[split_point+eval_len:]


In [7]:
# concatenated diagnosis and procedure codes into multi-hot np.arrays
def process_data_LR(data, diag_voc, pro_voc, med_voc):
    X, y = [], []
    for patient in data:
        for visit in patient:
            multi_hot_input = np.zeros(len(diag_voc) + len(pro_voc))
            multi_hot_input[visit[0]] = 1
            multi_hot_input[np.array(visit[1]) + len(diag_voc)] = 1  #access proc indices with offset

            multi_hot_output = np.zeros(len(med_voc))
            multi_hot_output[visit[2]] = 1

            X.append(multi_hot_input)
            y.append(multi_hot_output)

    return np.array(X), np.array(y)

# Utils

In [8]:
from sklearn.metrics import jaccard_score, roc_auc_score, precision_score, f1_score, average_precision_score
import warnings
warnings.filterwarnings('ignore')

def llprint(message):
    sys.stdout.write(message)
    sys.stdout.flush()

def ddi_rate_score_LR(preds, path=ddi_adj_path): ###
    # check intersection of predicted meds against ddi matrix
    ddi_A = dill.load(open(path, 'rb'))
    all_cnt, dd_cnt, med_cnt, visit_cnt = 0, 0, 0, 0
    for visit_meds in preds:
        visit_cnt += 1
        med_code_set = np.where(visit_meds==1)[0]
        med_cnt += len(med_code_set)
        for i, med_i in enumerate(med_code_set):
            for j, med_j in enumerate(med_code_set):
                if j <= i:
                    continue
                all_cnt += 1
                if ddi_A[med_i, med_j] == 1 or ddi_A[med_j, med_i] == 1:
                    dd_cnt += 1
    logger.warning(f'visits: {visit_cnt}  meds: {med_cnt}')
    return 0. if all_cnt == 0 else dd_cnt / all_cnt, 0. if visit_cnt == 0 else med_cnt / visit_cnt

def multi_label_metric(y_gt, y_pred, y_prob):

    def jaccard(y_gt, y_pred):
        score = []
        for b in range(y_gt.shape[0]):
            target = np.where(y_gt[b] == 1)[0]
            out_list = np.where(y_pred[b] == 1)[0]
            inter = set(out_list) & set(target)
            union = set(out_list) | set(target)
            jaccard_score = 0 if union == 0 else len(inter) / len(union)
            score.append(jaccard_score)
        return np.mean(score)

    def average_prc(y_gt, y_pred):
        score = []
        for b in range(y_gt.shape[0]):
            target = np.where(y_gt[b] == 1)[0]
            out_list = np.where(y_pred[b] == 1)[0]
            inter = set(out_list) & set(target)
            prc_score = 0 if len(out_list) == 0 else len(inter) / len(out_list)
            score.append(prc_score)
        return score

    def average_recall(y_gt, y_pred):
        score = []
        for b in range(y_gt.shape[0]):
            target = np.where(y_gt[b] == 1)[0]
            out_list = np.where(y_pred[b] == 1)[0]
            inter = set(out_list) & set(target)
            recall_score = 0 if len(target) == 0 else len(inter) / len(target)
            score.append(recall_score)
        return score

    def average_f1(average_prc, average_recall):
        score = []
        for idx in range(len(average_prc)):
            if average_prc[idx] + average_recall[idx] == 0:
                score.append(0)
            else:
                score.append(2*average_prc[idx]*average_recall[idx] / (average_prc[idx] + average_recall[idx]))
        return score

    def f1(y_gt, y_pred):
        all_micro = []
        for b in range(y_gt.shape[0]):
            all_micro.append(f1_score(y_gt[b], y_pred[b], average='macro'))
        return np.mean(all_micro)

    def roc_auc(y_gt, y_prob):
        all_micro = []
        for b in range(len(y_gt)):
            all_micro.append(roc_auc_score(y_gt[b], y_prob[b], average='macro'))
        return np.mean(all_micro)

    def precision_auc(y_gt, y_prob):
        all_micro = []
        for b in range(len(y_gt)):
            all_micro.append(average_precision_score(y_gt[b], y_prob[b], average='macro'))
        return np.mean(all_micro)

    def precision_at_k(y_gt, y_prob, k=3):
        precision = 0
        sort_index = np.argsort(y_prob, axis=-1)[:, ::-1][:, :k]
        for i in range(len(y_gt)):
            TP = 0
            for j in range(len(sort_index[i])):
                if y_gt[i, sort_index[i, j]] == 1:
                    TP += 1
            precision += TP / len(sort_index[i])
        return precision / len(y_gt)

    # roc_auc
    try:
        auc = roc_auc(y_gt, y_prob)
    except:
        auc = 0
    # precision
    p_1 = precision_at_k(y_gt, y_prob, k=1)
    p_3 = precision_at_k(y_gt, y_prob, k=3)
    p_5 = precision_at_k(y_gt, y_prob, k=5)
    # macro f1
    f1 = f1(y_gt, y_pred)
    # precision
    prauc = precision_auc(y_gt, y_prob)
    # jaccard
    ja = jaccard(y_gt, y_pred)
    # pre, recall, f1
    avg_prc = average_prc(y_gt, y_pred)
    avg_recall = average_recall(y_gt, y_pred)
    avg_f1 = average_f1(avg_prc, avg_recall)

    return ja, prauc, np.mean(avg_prc), np.mean(avg_recall), np.mean(avg_f1)

# Train

In [9]:
def main():
    voc_size = (len(diag_voc), len(pro_voc), len(med_voc))
    model = LogisticRegression()
    csfr = OneVsRestClassifier(model)

    history = defaultdict(list)
    times_train, times_eval = [], [] ###

    for epoch in range(1):
        # data processing
        X_train, y_train = process_data_LR(data_train, diag_voc, pro_voc, med_voc)
        X_test, y_test = process_data_LR(data_test, diag_voc, pro_voc, med_voc)
        X_eval, y_eval = process_data_LR(data_eval, diag_voc, pro_voc, med_voc)

        time_start = time.time() 
        csfr.fit(X_train, y_train)
        time_end = time.time()  
        time_fit = time_end - time_start 
        print(f'fitting time: {time_fit}') ###

        #Test inference
        result =[]
        for rnd in range(10): 
            # take a sample of 80% of total test visits
            rng = np.random.default_rng()
            test_set_indices = rng.choice(a=np.arange(start=0, stop=len(X_test)), size=round(len(X_test) * 0.8), replace=True)
            X_set = X_test[test_set_indices]
            y_labels = y_test[test_set_indices]
            time_start = time.time()
            y_preds = csfr.predict(X_set)
            logger.debug(f'preds: {y_preds}')
            time_pred = time.time() - time_start
            print(f'round {rnd} prediction time: {time_pred}')

            y_probs = csfr.predict_proba(X_set)

            ja, prauc, avg_p, avg_r, avg_f1  = multi_label_metric(y_labels, y_preds, y_probs)
            ddi_rate, avg_med = ddi_rate_score_LR(y_preds)
            result.append([ddi_rate, ja, avg_f1, prauc, avg_med])
            times_eval.append(time_pred)

        result = np.array(result)
        mean = result.mean(axis=0)
        std = result.std(axis=0)

        outstring = ""
        for m, s in zip(mean, std):
            outstring += "  & {:.4f} "u"\u00B1"" {:.4f}".format(m, s) ###
        print(outstring)

        print('Epoch: {}, DDI Rate: {:.4}, Jaccard: {:.4}, PRAUC: {:.4}, AVG_PRC: {:.4}, AVG_RECALL: {:.4}, AVG_F1: {:.4}, AVG_MED: {:.4}\n'.format(
            epoch, ddi_rate, ja, prauc, avg_p, avg_r, avg_f1, avg_med
            ))

        times_train.append(time_fit) ###

        history['ja'].append(ja)
        history['avg_p'].append(avg_p)
        history['avg_r'].append(avg_r)
        history['avg_f1'].append(avg_f1)
        history['prauc'].append(prauc)
        history['med'].append(avg_med)

        datadump = np.array([mean, std])
        df = pd.DataFrame(datadump, columns=['ddi', 'ja', 'prauc', 'f1', 'med'], index=['mean', 'std'])
        df.to_csv(TEST_PATH + 'Test_' + args.model_name + f'{dt.datetime.now()}' + '.csv' )

    dill.dump(history, open(WORKING_PATH +'history_{}_{}.pkl'.format(args.model_name, dt.datetime.now()), 'wb')) ###

    timings = np.array(list(zip(times_train, times_eval))) ###
    df = pd.DataFrame(timings, columns=['train', 'test']) ###
    df.to_csv(TEST_PATH + 'TimesTrain_' + args.model_name + f'{dt.datetime.now()}' + '.csv' ) ###

# Execute

In [10]:
if __name__ == '__main__':
    %reload_ext memory_profiler
    %memit -r1 main()

fitting time: 459.54309821128845
round 0 prediction time: 1.0115156173706055
round 1 prediction time: 1.121532917022705
round 2 prediction time: 1.0610957145690918
round 3 prediction time: 1.0308449268341064
round 4 prediction time: 1.035505771636963
round 5 prediction time: 1.038905382156372
round 6 prediction time: 1.0420901775360107
round 7 prediction time: 1.0627961158752441
round 8 prediction time: 1.0496745109558105
round 9 prediction time: 1.0802490711212158
  & 0.0776 ± 0.0012  & 0.4892 ± 0.0033  & 0.6480 ± 0.0030  & 0.7576 ± 0.0028  & 17.1273 ± 0.1544
Epoch: 0, DDI Rate: 0.07743, Jaccard: 0.4895, PRAUC: 0.7561, AVG_PRC: 0.7276, AVG_RECALL: 0.6068, AVG_F1: 0.6484, AVG_MED: 17.16

peak memory: 909.55 MiB, increment: 712.00 MiB
