<a href="https://colab.research.google.com/github/russpv/SafeDrug/blob/main/RETAIN.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Setup

In [1]:
gpu_info = !nvidia-smi
gpu_info = '\n'.join(gpu_info)
if gpu_info.find('failed') >= 0:
  print('Not connected to a GPU')
else:
  print(gpu_info)

from psutil import virtual_memory
ram_gb = virtual_memory().total / 1e9
print('Your runtime has {:.1f} gigabytes of available RAM\n'.format(ram_gb))

if ram_gb < 20:
  print('Not using a high-RAM runtime')
else:
  print('You are using a high-RAM runtime!')

! pip install memory_profiler

Mon May  9 00:16:13 2022       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 460.32.03    Driver Version: 460.32.03    CUDA Version: 11.2     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  Tesla P100-PCIE...  Off  | 00000000:00:04.0 Off |                    0 |
| N/A   34C    P0    27W / 250W |      0MiB / 16280MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Proces

# Args

In [2]:
import argparse
def arg_parser():
    """ Parse command line arguments

    Outputs:
        arguments {object} -- object containing command line arguments
    """

    # Initializer
    parser = argparse.ArgumentParser()

    # Add arguments here
    parser.add_argument('--Test', action='store_true', default=False, help="test mode")
    parser.add_argument('--model_name', type=str, default='none', help="model name")
    parser.add_argument('--resume_path', type=str, default='none', help='resume path')
    parser.add_argument('--lr', type=float, default=5e-4, help='learning rate')
    parser.add_argument('--target_ddi', type=float, default=0.06, help="target ddi")
    parser.add_argument('--dropout', type=float, default=0.5, help="dropout for embeddings")
    parser.add_argument('--dim', type=int, default=64, help='dimension')
    parser.add_argument('--cuda', type=int, default=0, help='which cuda') ###
    parser.add_argument('--reverse', type=int, default=1, help='reverse input sequence') ###

    parser.add_argument('--smalldata', type=int, default=1, help='debug data set') ###
    parser.add_argument('--mydata', type=int, default=1, help='paper code') ###
    parser.add_argument('--Inf_time', type=int, default=0, help='inference time test') ###
 
    # Parse and return arguments
    return(parser.parse_args(args=[]))

args = arg_parser()

In [11]:
import os
import dill
import random
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
import pandas as pd
import sys
import time
import statistics
import datetime as dt
import logging

# set seed
seed = 1203 #1203
random.seed(seed)
np.random.seed(seed) #2048
torch.manual_seed(seed)
os.environ["PYTHONHASHSEED"] = str(seed)

# define data path
DATA_PATH = "drive/MyDrive/DL4H/Project/PaperCode/processed_orig/"
MYDATA_PATH = "drive/MyDrive/DL4H/Project/SAFEDRUG_lib/data/processed/"
WORKING_PATH = "drive/MyDrive/DL4H/Project/RETAIN/"
TEST_PATH = "drive/MyDrive/DL4H/Project/RETAIN/results/"

# define dataset
args.mydata = 0
args.smalldata = 0
EPOCH = 50

# define routine
args.Test = True
args.Inf_time = False

# setting
args.model_name = 'RETAIN_orig_rev'
args.reverse = 1

args.resume_path = WORKING_PATH + 'saved/' + 'RETAIN_orig_rev_0Epoch_44_TARGET_0.06_JA_0.4534_DDI_0.08434_2022-05-08 22:12:56.834379.model'
# RETAIN_rev_1Epoch_44_TARGET_0.06_JA_0.4533_DDI_0.7759_2022-05-08 20:01:22.121790.model
# RETAIN_rev_0Epoch_49_TARGET_0.06_JA_0.4582_DDI_0.7824_2022-05-08 17:17:16.110156.model
logger = logging.getLogger('')
logger.setLevel(logging.CRITICAL)

# Data

In [4]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [5]:
# Data switch
if args.mydata == 1:
    data_path = MYDATA_PATH + 'ehr.pkl'
    voc_path = MYDATA_PATH + 'vocabs.pkl'

    ehr_adj_path = MYDATA_PATH + 'ehradj.pkl'
    ddi_adj_path = MYDATA_PATH + 'ddiadj.pkl'
    ddi_mask_path = MYDATA_PATH + 'hmask.pkl'
    molecule_path = MYDATA_PATH + 'atc2SMILES.pkl'

    voc = dill.load(open(voc_path, 'rb'))
    diag_voc, pro_voc, med_voc = voc['diag_vocab'].index2word, voc['pro_vocab'].index2word, voc['med_vocab'].index2word

else:
    data_path = DATA_PATH + 'records_final.pkl'
    voc_path = DATA_PATH + 'voc_final.pkl'


    ehr_adj_path = DATA_PATH + 'ehr_adj_final.pkl'
    ddi_adj_path = DATA_PATH + 'ddi_A_final.pkl'
    ddi_mask_path = DATA_PATH + 'ddi_mask_H.pkl'
    molecule_path = DATA_PATH + 'atc3toSMILES.pkl'
    
    voc = dill.load(open(voc_path, 'rb'))
    diag_voc, pro_voc, med_voc = voc['diag_voc'].idx2word, voc['pro_voc'].idx2word, voc['med_voc'].idx2word

ehr_adj = dill.load(open(ehr_adj_path, 'rb'))
ddi_adj = dill.load(open(ddi_adj_path, 'rb'))
ddi_mask_H = dill.load(open(ddi_mask_path, 'rb'))
data = dill.load(open(data_path, 'rb'))
molecule = dill.load(open(molecule_path, 'rb')) 

if args.smalldata == 1:
    data_train = data[:200] 
    data_test = data[200:250]
    data_eval = data[250:300]
else:
    split_point = int(len(data) * 2 / 3)
    data_train = data[:split_point]
    eval_len = int(len(data[split_point:]) / 2)
    data_test = data[split_point:split_point + eval_len]
    data_eval = data[split_point+eval_len:]

# Utils

In [6]:
from sklearn.metrics import jaccard_score, roc_auc_score, precision_score, f1_score, average_precision_score
import numpy as np
import pandas as pd
from sklearn.model_selection import train_test_split
import sys
import warnings
import dill
from collections import Counter
from collections import defaultdict
import torch
warnings.filterwarnings('ignore')

def get_n_params(model):
    pp=0
    for p in list(model.parameters()):
        nn=1
        for s in list(p.size()):
            nn = nn*s
        pp += nn
    return pp

# use the same metric from DMNC
def llprint(message):
    sys.stdout.write(message)
    sys.stdout.flush()

def transform_split(X, Y):
    x_train, x_eval, y_train, y_eval = train_test_split(X, Y, train_size=2/3, random_state=1203)
    x_eval, x_test, y_eval, y_test = train_test_split(x_eval, y_eval, test_size=0.5, random_state=1203)
    return x_train, x_eval, x_test, y_train, y_eval, y_test

def sequence_output_process(output_logits, filter_token):
    pind = np.argsort(output_logits, axis=-1)[:, ::-1]

    out_list = []
    break_flag = False
    for i in range(len(pind)):
        if break_flag:
            break
        for j in range(pind.shape[1]):
            label = pind[i][j]
            if label in filter_token:
                break_flag = True
                break
            if label not in out_list:
                out_list.append(label)
                break
    y_pred_prob_tmp = []
    for idx, item in enumerate(out_list):
        y_pred_prob_tmp.append(output_logits[idx, item])
    sorted_predict = [x for _, x in sorted(zip(y_pred_prob_tmp, out_list), reverse=True)]
    return out_list, sorted_predict


def sequence_metric(y_gt, y_pred, y_prob, y_label):
    def average_prc(y_gt, y_label):
        score = []
        for b in range(y_gt.shape[0]):
            target = np.where(y_gt[b]==1)[0]
            out_list = y_label[b]
            inter = set(out_list) & set(target)
            prc_score = 0 if len(out_list) == 0 else len(inter) / len(out_list)
            score.append(prc_score)
        return score


    def average_recall(y_gt, y_label):
        score = []
        for b in range(y_gt.shape[0]):
            target = np.where(y_gt[b] == 1)[0]
            out_list = y_label[b]
            inter = set(out_list) & set(target)
            recall_score = 0 if len(target) == 0 else len(inter) / len(target)
            score.append(recall_score)
        return score


    def average_f1(average_prc, average_recall):
        score = []
        for idx in range(len(average_prc)):
            if (average_prc[idx] + average_recall[idx]) == 0:
                score.append(0)
            else:
                score.append(2*average_prc[idx]*average_recall[idx] / (average_prc[idx] + average_recall[idx]))
        return score


    def jaccard(y_gt, y_label):
        score = []
        for b in range(y_gt.shape[0]):
            target = np.where(y_gt[b] == 1)[0]
            out_list = y_label[b]
            inter = set(out_list) & set(target)
            union = set(out_list) | set(target)
            jaccard_score = 0 if union == 0 else len(inter) / len(union)
            score.append(jaccard_score)
        return np.mean(score)

    def f1(y_gt, y_pred):
        all_micro = []
        for b in range(y_gt.shape[0]):
            all_micro.append(f1_score(y_gt[b], y_pred[b], average='macro'))
        return np.mean(all_micro)

    def roc_auc(y_gt, y_pred_prob):
        all_micro = []
        for b in range(len(y_gt)):
            all_micro.append(roc_auc_score(y_gt[b], y_pred_prob[b], average='macro'))
        return np.mean(all_micro)

    def precision_auc(y_gt, y_prob):
        all_micro = []
        for b in range(len(y_gt)):
            all_micro.append(average_precision_score(y_gt[b], y_prob[b], average='macro'))
        return np.mean(all_micro)

    def precision_at_k(y_gt, y_prob_label, k):
        precision = 0
        for i in range(len(y_gt)):
            TP = 0
            for j in y_prob_label[i][:k]:
                if y_gt[i, j] == 1:
                    TP += 1
            precision += TP / k
        return precision / len(y_gt)
    try:
        auc = roc_auc(y_gt, y_prob)
    except ValueError:
        auc = 0
    p_1 = precision_at_k(y_gt, y_label, k=1)
    p_3 = precision_at_k(y_gt, y_label, k=3)
    p_5 = precision_at_k(y_gt, y_label, k=5)
    f1 = f1(y_gt, y_pred)
    prauc = precision_auc(y_gt, y_prob)
    ja = jaccard(y_gt, y_label)
    avg_prc = average_prc(y_gt, y_label)
    avg_recall = average_recall(y_gt, y_label)
    avg_f1 = average_f1(avg_prc, avg_recall)

    return ja, prauc, np.mean(avg_prc), np.mean(avg_recall), np.mean(avg_f1)

def ddi_rate_score(record, path=ddi_adj_path): ###
    # ddi rate
    ddi_A = dill.load(open(path, 'rb'))
    all_cnt = 0
    dd_cnt = 0
    for patient in record:
        for adm in patient:
            med_code_set = adm
            for i, med_i in enumerate(med_code_set):
                for j, med_j in enumerate(med_code_set):
                    if j <= i:
                        continue
                    all_cnt += 1
                    if ddi_A[med_i, med_j] == 1 or ddi_A[med_j, med_i] == 1:
                        dd_cnt += 1
    if all_cnt == 0:
        return 0
    return dd_cnt / all_cnt

def multi_label_metric(y_gt, y_pred, y_prob):

    def jaccard(y_gt, y_pred):
        score = []
        for b in range(y_gt.shape[0]):
            target = np.where(y_gt[b] == 1)[0]
            out_list = np.where(y_pred[b] == 1)[0]
            inter = set(out_list) & set(target)
            union = set(out_list) | set(target)
            jaccard_score = 0 if union == 0 else len(inter) / len(union)
            score.append(jaccard_score)
        return np.mean(score)

    def average_prc(y_gt, y_pred):
        score = []
        for b in range(y_gt.shape[0]):
            target = np.where(y_gt[b] == 1)[0]
            out_list = np.where(y_pred[b] == 1)[0]
            inter = set(out_list) & set(target)
            prc_score = 0 if len(out_list) == 0 else len(inter) / len(out_list)
            score.append(prc_score)
        return score

    def average_recall(y_gt, y_pred):
        score = []
        for b in range(y_gt.shape[0]):
            target = np.where(y_gt[b] == 1)[0]
            out_list = np.where(y_pred[b] == 1)[0]
            inter = set(out_list) & set(target)
            recall_score = 0 if len(target) == 0 else len(inter) / len(target)
            score.append(recall_score)
        return score

    def average_f1(average_prc, average_recall):
        score = []
        for idx in range(len(average_prc)):
            if average_prc[idx] + average_recall[idx] == 0:
                score.append(0)
            else:
                score.append(2*average_prc[idx]*average_recall[idx] / (average_prc[idx] + average_recall[idx]))
        return score

    def f1(y_gt, y_pred):
        all_micro = []
        for b in range(y_gt.shape[0]):
            all_micro.append(f1_score(y_gt[b], y_pred[b], average='macro'))
        return np.mean(all_micro)

    def roc_auc(y_gt, y_prob):
        all_micro = []
        for b in range(len(y_gt)):
            all_micro.append(roc_auc_score(y_gt[b], y_prob[b], average='macro'))
        return np.mean(all_micro)

    def precision_auc(y_gt, y_prob):
        all_micro = []
        for b in range(len(y_gt)):
            all_micro.append(average_precision_score(y_gt[b], y_prob[b], average='macro'))
        return np.mean(all_micro)

    def precision_at_k(y_gt, y_prob, k=3):
        precision = 0
        sort_index = np.argsort(y_prob, axis=-1)[:, ::-1][:, :k]
        for i in range(len(y_gt)):
            TP = 0
            for j in range(len(sort_index[i])):
                if y_gt[i, sort_index[i, j]] == 1:
                    TP += 1
            precision += TP / len(sort_index[i])
        return precision / len(y_gt)

    # roc_auc
    try:
        auc = roc_auc(y_gt, y_prob)
    except:
        auc = 0
    # precision
    p_1 = precision_at_k(y_gt, y_prob, k=1)
    p_3 = precision_at_k(y_gt, y_prob, k=3)
    p_5 = precision_at_k(y_gt, y_prob, k=5)
    # macro f1
    f1 = f1(y_gt, y_pred)
    # precision
    prauc = precision_auc(y_gt, y_prob)
    # jaccard
    ja = jaccard(y_gt, y_pred)
    # pre, recall, f1
    avg_prc = average_prc(y_gt, y_pred)
    avg_recall = average_recall(y_gt, y_pred)
    avg_f1 = average_f1(avg_prc, avg_recall)

    return ja, prauc, np.mean(avg_prc), np.mean(avg_recall), np.mean(avg_f1)


# Model


In [7]:
class RETAIN(nn.Module):
    def __init__(self, diag_voc, pro_voc, med_voc, embedding_dim=64, dropout=0.5, device=torch.device('cpu:0')):
        super().__init__()
        self.device = device
        self.diag_voc, self.pro_voc, self.med_voc  = len(diag_voc), len(pro_voc), len(med_voc)
        self.vocab_len = self.diag_voc + self.pro_voc + self.med_voc

        self.embedding = nn.Embedding(self.vocab_len + 1, embedding_dim, padding_idx=self.vocab_len) # account for pad value
        self.dropout = nn.Dropout(dropout)
        self.rnn_a = nn.GRU(embedding_dim, embedding_dim, batch_first=True)
        self.rnn_b = nn.GRU(embedding_dim, embedding_dim, batch_first=True)
        self.att_a = nn.Linear(embedding_dim, 1)
        self.att_b = nn.Linear(embedding_dim, embedding_dim)

        self.fc = nn.Linear(embedding_dim, self.med_voc) # output medication codes, sigmoid in training
    
    def attention_sum(self, alpha, beta, rev_v, rev_masks):
        '''
        sums visit*embeddings into just embeddings

        Arguments:
            alpha: the alpha attention weights  (batch_size, seq_length, 1)
            beta: the beta attention weights of shape (batch_size, seq_length, hidden_dim)
            rev_v: the visit embeddings in reversed time of shape (batch_size, # visits, embedding_dim)
        '''
        applied = alpha * beta * rev_v # combine alpha and beta, broadcasted to (batch_size, visits, hidden_dim)
        c = torch.sum(applied, dim=-2) # sum over visits for (batch_size, hidden_dim)

        return c

    def sum_embeddings_with_mask(self, x, masks):
        x = x * masks.unsqueeze(-1)
        x = torch.sum(x, dim = -2)
        return x

    def forward(self, seq_input):
        """
        Arguments:
            seq_input: visit sequence in forward time (visits, batch_size, code)
        Outputs:
            probs: probabilities of shape (batch_size)
        """
        # Pad visits, create mask, reverse visits
        max_input_len = max([(len(v[0]) + len(v[1]) + len(v[2])) for v in seq_input])
        input_np = []
        mask = []
        for visit in seq_input:
            input_tmp = []
            input_tmp.extend(visit[0]) #diags
            input_tmp.extend(list(np.array(visit[1]) + self.diag_voc)) #procs, offset code set
            input_tmp.extend(list(np.array(visit[2]) + self.diag_voc + self.pro_voc)) #meds, offset code set
            if args.reverse == 1:
                input_tmp = list(reversed(input_tmp)) # reversed time
            mask_tmp = [1 for c in input_tmp]
            if len(input_tmp) < max_input_len:
                padding = [self.vocab_len]*(max_input_len - len(input_tmp)) # zero taken, use next available number
                mask_padding = [0 for x in padding]
                input_tmp.extend( padding )
                mask_tmp.extend( mask_padding )
            input_np.append(input_tmp) # make list of lists
            mask.append(mask_tmp)
        input_np = torch.LongTensor(input_np).to(self.device)
        mask = torch.LongTensor(mask).to(self.device)

        logger.warning(f'\nfinal input size: {input_np.size()}')
        logger.warning(f'mask size: {mask.size()}')
        
        emb_v = self.embedding(input_np) #(visits, codes, embedding_dim)
        emb_v = self.dropout(emb_v)
        emb_v = self.sum_embeddings_with_mask(emb_v, mask) #( # visits, embedding_dim) 

        g, _ = self.rnn_a(emb_v.unsqueeze(dim=0)) #( 1, seq_length=visit, embedding dim)
        h, _ = self.rnn_b(emb_v.unsqueeze(dim=0)) #( 1, seq_length=visit, embedding dim)
        logger.warning(f'g size: {g.size()}')
        logger.warning(f'h size: {h.size()}')
        alpha = torch.softmax(self.att_a(g.squeeze(dim=0)), dim=1) #(seq_length=visit, embedding dim)
        beta = torch.tanh(self.att_b(h.squeeze(dim=0))) #( seq_length=visit, embedding dim)
        logger.warning(f'alpha size: {alpha.size()}')
        logger.warning(f'beta size: {beta.size()}')
        c = self.attention_sum(alpha, beta, emb_v, mask).unsqueeze(dim=0) #( hidden states)
        logger.warning(f'c size: {c.size()}')
        logits = self.fc(c)

        logger.warning(f'output size: {logits.size()}')
        return logits

# Training


In [8]:
from torch.optim import Adam
import time

# evaluate
def eval(model, data_eval, voc_size, epoch):
    model.eval()

    smm_record = []
    ja, prauc, avg_p, avg_r, avg_f1 = [[] for _ in range(5)]
    med_cnt, visit_cnt = 0, 0

    for step, input in enumerate(data_eval):
        y_gt, y_pred, y_pred_prob, y_pred_label = [], [], [], []
        
        if len(input) < 2: continue # design requires more than one visit
        for idx in range(1, len(input)): 
            target_output = model(input[:idx])

            y_gt_tmp = np.zeros(voc_size[2])
            y_gt_tmp[input[idx][2]] = 1
            y_gt.append(y_gt_tmp)

            # prediction prob
            target_output = torch.sigmoid(target_output).detach().cpu().numpy()[0]
            y_pred_prob.append(target_output)
            
            # prediction meds
            y_pred_tmp = target_output.copy()
            y_pred_tmp[y_pred_tmp>=0.5] = 1
            y_pred_tmp[y_pred_tmp<0.5] = 0
            y_pred.append(y_pred_tmp)

            # prediction label
            y_pred_label_tmp = np.where(y_pred_tmp == 1)[0]
            y_pred_label.append(y_pred_label_tmp)
            visit_cnt += 1
            med_cnt += len(y_pred_label_tmp)

        smm_record.append(y_pred_label)
        adm_ja, adm_prauc, adm_avg_p, adm_avg_r, adm_avg_f1 = multi_label_metric(np.array(y_gt), np.array(y_pred), np.array(y_pred_prob))

        ja.append(adm_ja)
        prauc.append(adm_prauc)
        avg_p.append(adm_avg_p)
        avg_r.append(adm_avg_r)
        avg_f1.append(adm_avg_f1)
        llprint('\rtest step: {} / {}'.format(step+1, len(data_eval)))

    # ddi rate
    ddi_rate = ddi_rate_score(smm_record, path=ddi_adj_path) ###

    llprint('\nDDI Rate: {:.4}, Jaccard: {:.4},  PRAUC: {:.4}, AVG_PRC: {:.4}, AVG_RECALL: {:.4}, AVG_F1: {:.4}, AVG_MED: {:.4}\n'.format(
        ddi_rate, np.mean(ja), np.mean(prauc), np.mean(avg_p), np.mean(avg_r), np.mean(avg_f1), med_cnt / visit_cnt
    ))

    return ddi_rate, np.mean(ja), np.mean(prauc), np.mean(avg_p), np.mean(avg_r), np.mean(avg_f1), med_cnt / visit_cnt

def main():

    device = torch.device(f'cuda:{args.cuda}' if torch.cuda.is_available() else 'cpu')

    voc_size = (len(diag_voc), len(pro_voc), len(med_voc))
    model = RETAIN(diag_voc, pro_voc, med_voc, args.dim, args.dropout, device=device)
    # model.load_state_dict(torch.load(open(args.resume_path, 'rb')))

    if args.Inf_time:
        if 'cpu' in device.type:
            print(f'Aborting inference timing, no GPU...')
            return
        #https://towardsdatascience.com/the-correct-way-to-measure-inference-time-of-deep-neural-networks-304a54e5187f
        model.load_state_dict(torch.load(open(args.resume_path, 'rb')))
        model.to(device=device)
        tic = time.time()

        starter, ender = torch.cuda.Event(enable_timing=True), torch.cuda.Event(enable_timing=True)
        repetitions = len(data_test)
        timings = np.zeros((repetitions,1))
        dummy_input = [[[13, 98, 585, 1065, 21, 37, 454, 278], [69, 47], [4, 22, 12, 2, 67, 0, 86]],\
                       [[377, 326, 21, 46, 454], [115, 94], [3, 6, 12, 14, 5, 22, 2, 29, 1, 16, 11, 86]],\
                       [[377, 246, 453, 46, 21, 454], [151, 127, 128], [14, 2, 6, 29, 18, 0, 86]], [[963, 258, 32, 93, 94, 13, 103, 571, 21], [164, 423, 424, 425, 95, 426, 361, 48, 46, 2], [5, 4, 6, 7, 9, 11, 12, 3, 13, 16, 14, 22, 1, 2, 29, 44, 45, 48, 56, 20, 76, 86]]]

        #GPU-WARM-UP
        for _ in range(10):
            _ = model(dummy_input)
        count = 0

        # MEASURE PERFORMANCE
        with torch.no_grad():
            #for rep in range(repetitions):
            for rep, example in enumerate(data_test):
                starter.record()
                _ = model(example)
                ender.record()
                # WAIT FOR GPU SYNC
                torch.cuda.synchronize()
                curr_time = starter.elapsed_time(ender)
                timings[rep] = curr_time
                count += 1

        mean_syn = np.sum(timings) / repetitions
        std_syn = np.std(timings)
        print(f'Inference reps {count}, average: {mean_syn} \u00B1 {std_syn} seconds')

        data = np.array([mean_syn, std_syn, count])
        df = pd.DataFrame(data, index=['mean inference time', 'stdev', 'reps'])
        df.to_csv(TEST_PATH + 'Inf_' + args.model_name + device.type + f'{dt.datetime.now()}' + '.csv' )

        return

    if args.Test:
        model.load_state_dict(torch.load(open(args.resume_path, 'rb')))
        model.to(device=device)
        tic = time.time()

        ddi_list, ja_list, prauc_list, f1_list, med_list = [], [], [], [], []

        result = []
        for _ in range(10):
            time_start = time.time()
            test_sample = np.random.choice(data_test, round(len(data_test) * 0.8), replace=True)
            ddi_rate, ja, prauc, avg_p, avg_r, avg_f1, avg_med = eval(model, test_sample, voc_size, 0)
            time_sample = time.time() - time_start ###
            result.append([ddi_rate, ja, avg_f1, prauc, avg_med, time_sample])
            
        result = np.array(result)
        mean = result.mean(axis=0)
        std = result.std(axis=0)

        outstring = ""
        for m, s in zip(mean, std):
            outstring += "{:.4f} "u"\u00B1"" {:.4f} & ".format(m, s) ###

        print(outstring)
        time_round = time.time() - tic
        print(f'test time: {time_round}')
        
        elapsed_time = [0. for _ in range(5)]
        elapsed_time.append(time_round)
        data = np.array([mean, std, elapsed_time])

        df = pd.DataFrame(data, columns=['ddi', 'ja', 'prauc', 'f1', 'med', 'time'], index=['mean', 'std', 'seconds'])
        df.to_csv(TEST_PATH + 'Test_' + args.model_name + device.type + f'{dt.datetime.now()}' + '.csv' )

        return 

    if 'cpu' not in device.type:
        torch.cuda.reset_peak_memory_stats() # flush 
    model.to(device=device)
    print('parameters', sum(p.numel() for p in model.parameters() if p.requires_grad)) ###

    # exit()
    optimizer = Adam(list(model.parameters()), lr=args.lr)

    # start iterations
    history = defaultdict(list)
    best_epoch, best_ja = 0, 0

    times_train, times_eval = [], [] ###
    for epoch in range(EPOCH):
        time_start = time.time() ###
        print('\nepoch {} --------------------------'.format(epoch + 1))
        model.train()
       
        for step, input in enumerate(data_train): 
            if len(input) < 2: continue # design requires more than one visit

            loss = 0
            for idx in range(1, len(input)): 
                seq_input = input[:idx]
                loss_bce_target = np.zeros((1, voc_size[2]))
                loss_bce_target[:, input[idx][2]] = 1 # take the next visit's drugs

                result = model(seq_input)

                loss += F.binary_cross_entropy_with_logits(result, torch.FloatTensor(loss_bce_target).to(device)) # accumulate all visits

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            llprint(f'\rtraining step: {step+1} / {len(data_train)} loss: {loss} ') ###
        
        print()

        time_end = time.time()  ###
        ddi_rate, ja, prauc, avg_p, avg_r, avg_f1, avg_med = eval(model, data_eval, voc_size, epoch)
        time_train = time_end - time_start ###
        time_eval = time.time() - time_end  ###
        print(f'training time: {time_train}, test time: {time_eval}, torch.device: {device}') ###

        times_train.append(time_train) ###
        times_eval.append(time_eval) ###

        history['ja'].append(ja)
        history['ddi_rate'].append(ddi_rate)
        history['avg_p'].append(avg_p)
        history['avg_r'].append(avg_r)
        history['avg_f1'].append(avg_f1)
        history['prauc'].append(prauc)
        history['med'].append(avg_med)

        if epoch >= 5:
            print('ddi: {}, Med: {}, Ja: {}, F1: {}, PRAUC: {}'.format(
                np.mean(history['ddi_rate'][-5:]),
                np.mean(history['med'][-5:]),
                np.mean(history['ja'][-5:]),
                np.mean(history['avg_f1'][-5:]),
                np.mean(history['prauc'][-5:])
                ))

        torch.save(model.state_dict(), open(WORKING_PATH +''.join(('saved/', args.model_name, '_', 'rev_'+str(args.reverse),\
            'Epoch_{}_TARGET_{:.2}_JA_{:.4}_DDI_{:.4}_{}.model'.format(epoch, args.target_ddi, ja, ddi_rate, dt.datetime.now()))), 'wb')) ###

        if epoch != 0 and best_ja < ja:
            best_epoch = epoch
            best_ja = ja

        print('best_epoch: {}'.format(best_epoch))

    dill.dump(history, open(WORKING_PATH +'history_{}_{}.pkl'.format(args.model_name, dt.datetime.now()), 'wb')) ###
    
    timings = np.array(list(zip(times_train, times_eval))) ###
    df = pd.DataFrame(timings, columns=['train', 'test']) ###
    df.to_csv(TEST_PATH + 'TimesTrain_' + args.model_name + f'{dt.datetime.now()}' + '.csv' ) ###

    # Maximum cuda memory allocated
    if 'cpu' not in device.type:
        print(f'peak training memory allocated: {torch.cuda.max_memory_allocated(device)}')

# Execute

In [12]:
if __name__ == '__main__':
    %reload_ext memory_profiler
    %memit -r1 main()

test step: 846 / 846
DDI Rate: 0.08359, Jaccard: 0.4467,  PRAUC: 0.7507, AVG_PRC: 0.7391, AVG_RECALL: 0.555, AVG_F1: 0.6107, AVG_MED: 15.27
test step: 846 / 846
DDI Rate: 0.08479, Jaccard: 0.4447,  PRAUC: 0.7555, AVG_PRC: 0.7447, AVG_RECALL: 0.5471, AVG_F1: 0.6086, AVG_MED: 15.26
test step: 846 / 846
DDI Rate: 0.08406, Jaccard: 0.4437,  PRAUC: 0.7439, AVG_PRC: 0.7345, AVG_RECALL: 0.5499, AVG_F1: 0.6074, AVG_MED: 15.35
test step: 846 / 846
DDI Rate: 0.09188, Jaccard: 0.4531,  PRAUC: 0.7557, AVG_PRC: 0.7465, AVG_RECALL: 0.5587, AVG_F1: 0.617, AVG_MED: 15.48
test step: 846 / 846
DDI Rate: 0.08521, Jaccard: 0.443,  PRAUC: 0.7528, AVG_PRC: 0.7426, AVG_RECALL: 0.548, AVG_F1: 0.607, AVG_MED: 14.99
test step: 846 / 846
DDI Rate: 0.08377, Jaccard: 0.4475,  PRAUC: 0.7598, AVG_PRC: 0.7507, AVG_RECALL: 0.5503, AVG_F1: 0.6118, AVG_MED: 14.87
test step: 846 / 846
DDI Rate: 0.08435, Jaccard: 0.4448,  PRAUC: 0.7564, AVG_PRC: 0.7433, AVG_RECALL: 0.5504, AVG_F1: 0.6087, AVG_MED: 15.38
test step: 846 / 8

In [10]:
!nvidia-smi

Mon May  9 00:17:53 2022       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 460.32.03    Driver Version: 460.32.03    CUDA Version: 11.2     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  Tesla P100-PCIE...  Off  | 00000000:00:04.0 Off |                    0 |
| N/A   35C    P0    34W / 250W |    949MiB / 16280MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Proces