<a href="https://colab.research.google.com/github/ricaelum42/Replication-of-SafeDrug/blob/main/baseline/retain/RETAIN.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
from google.colab import drive
drive.mount('/content/drive')
%cd /content/drive/MyDrive/UIUC/spring2022/GAMENet/code
!ls
!pwd

Mounted at /content/drive
/content/drive/MyDrive/UIUC/spring2022/GAMENet/code
baseline   models.py	SafeDrug.py  train_GAMENet.py
layers.py  __pycache__	saved	     util.py
/content/drive/MyDrive/UIUC/spring2022/GAMENet/code


In [2]:
!pip install dnc

Collecting dnc
  Downloading dnc-1.1.0-py3-none-any.whl (20 kB)
Collecting flann
  Downloading flann-1.6.13-py3-none-any.whl (24 kB)
Installing collected packages: flann, dnc
Successfully installed dnc-1.1.0 flann-1.6.13


In [3]:
import torch
import torch.nn as nn
from sklearn.metrics import jaccard_score, roc_auc_score, precision_score, f1_score, average_precision_score
import numpy as np
import dill
import time
from torch.nn import CrossEntropyLoss
from torch.optim import Adam
import os
import torch.nn.functional as F
from collections import defaultdict

import sys
sys.path.append("..")
from models import Retain
from util import llprint, multi_label_metric, ddi_rate_score, get_n_params

In [4]:
torch.manual_seed(1203)
model_name = 'Retain'
#resume_name = ''
resume_name = 'saved/{}/Epoch_49_JA_0.4905_DDI_0.0777.model'.format(model_name)

if not os.path.exists(os.path.join("saved", model_name)):
        os.makedirs(os.path.join("saved", model_name))

In [11]:
def eval(model, data_eval, voc_size, epoch):
    # evaluate
    print('')
    model.eval()
    smm_record = []
    ja, prauc, avg_p, avg_r, avg_f1 = [[] for _ in range(5)]
    case_study = defaultdict(dict)
    med_cnt = 0
    visit_cnt = 0
    for step, input in enumerate(data_eval):
        if len(input) < 2: # visit > 2
            continue
        y_gt = []
        y_pred = []
        y_pred_prob = []
        y_pred_label = []
        for i in range(1, len(input)):

            y_pred_label_tmp = []
            y_gt_tmp = np.zeros(voc_size[2])
            y_gt_tmp[input[i][2]] = 1
            y_gt.append(y_gt_tmp)

            target_output1 = model(input[:i])

            target_output1 = F.sigmoid(target_output1).detach().cpu().numpy()[0]
            y_pred_prob.append(target_output1)
            y_pred_tmp = target_output1.copy()
            y_pred_tmp[y_pred_tmp >= 0.3] = 1
            y_pred_tmp[y_pred_tmp < 0.3] = 0
            y_pred.append(y_pred_tmp)
            for idx, value in enumerate(y_pred_tmp):
                if value == 1:
                    y_pred_label_tmp.append(idx)
            y_pred_label.append(y_pred_label_tmp)
            med_cnt += len(y_pred_label_tmp)
            visit_cnt += 1

        smm_record.append(y_pred_label)
        adm_ja, adm_prauc, adm_avg_p, adm_avg_r, adm_avg_f1 = multi_label_metric(np.array(y_gt), np.array(y_pred),
                                                                                   np.array(y_pred_prob))
        case_study[adm_ja] = {'ja': adm_ja, 'patient':input, 'y_label':y_pred_label}
        ja.append(adm_ja)
        prauc.append(adm_prauc)
        avg_p.append(adm_avg_p)
        avg_r.append(adm_avg_r)
        avg_f1.append(adm_avg_f1)
        llprint('\rEval--Epoch: %d, Step: %d/%d' % (epoch, step, len(data_eval)))

    dill.dump(case_study, open(os.path.join('saved', model_name, 'case_study.pkl'), 'wb'))
    # ddi rate
    ddi_rate = ddi_rate_score(smm_record)

    llprint('\tDDI Rate: %.4f, Jaccard: %.4f,  PRAUC: %.4f, AVG_PRC: %.4f, AVG_RECALL: %.4f, AVG_F1: %.4f\n' % (
        ddi_rate, np.mean(ja), np.mean(prauc), np.mean(avg_p), np.mean(avg_r), np.mean(avg_f1)
    ))
    print('avg med', med_cnt / visit_cnt)

    return ddi_rate, np.mean(ja), np.mean(prauc), np.mean(avg_p), np.mean(avg_r), np.mean(avg_f1)


def main():
    if not os.path.exists(os.path.join("saved", model_name)):
        os.makedirs(os.path.join("saved", model_name))

    data_path = '../data/records_final.pkl'
    voc_path = '../data/voc_final.pkl'
    device = torch.device('cuda:0')

    data = dill.load(open(data_path, 'rb'))
    voc = dill.load(open(voc_path, 'rb'))
    diag_voc, pro_voc, med_voc = voc['diag_voc'], voc['pro_voc'], voc['med_voc']

    split_point = int(len(data) * 2 / 3)
    data_train = data[:split_point]
    eval_len = int(len(data[split_point:]) / 2)
    data_test = data[split_point:split_point + eval_len]
    data_eval = data[split_point+eval_len:]
    voc_size = (len(diag_voc.idx2word), len(pro_voc.idx2word), len(med_voc.idx2word))

    EPOCH = 50
    LR = 0.0002
    TEST = True

    model = Retain(voc_size, device=device)
    if TEST:
        model.load_state_dict(torch.load(open(os.path.join(resume_name), 'rb')))
    model.to(device=device)
    print('parameters', get_n_params(model))

    optimizer = Adam(model.parameters(), lr=LR)

    if TEST:
        tic = time.time()
        model.load_state_dict(torch.load(open(os.path.join(resume_name), 'rb')))
        model.to(device=device)
        ddi_rate, ja, prauc, avg_p, avg_r, avg_f1 = eval(model, data_test, voc_size, 0)
        print ('TEST time: {}'.format(time.time() - tic))
        return ddi_rate, ja, prauc, avg_p, avg_r, avg_f1
    else:
        history = defaultdict(list)
        for epoch in range(EPOCH):
            
            tic = time.time()

            loss_record = []
            start_time = time.time()
            model.train()
            for step, input in enumerate(data_train):
                if len(input) < 2:
                    continue

                loss = 0
                for i in range(1, len(input)):
                    target = np.zeros((1, voc_size[2]))
                    target[:, input[i][2]] = 1

                    output_logits = model(input[:i])
                    loss += F.binary_cross_entropy_with_logits(output_logits, torch.FloatTensor(target).to(device))
                    loss_record.append(loss.item())

                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

                llprint('\rTrain--Epoch: %d, Step: %d/%d' % (epoch, step, len(data_train)))
            tic2 = time.time()
            
            print ('training time: {}, test time: {}'.format(time.time() - tic, time.time() - tic2))

            ddi_rate, ja, prauc, avg_p, avg_r, avg_f1 = eval(model, data_eval, voc_size, epoch)
            history['ja'].append(ja)
            history['ddi_rate'].append(ddi_rate)
            history['avg_p'].append(avg_p)
            history['avg_r'].append(avg_r)
            history['avg_f1'].append(avg_f1)
            history['prauc'].append(prauc)
            history['train_time'].append(time.time() - tic)
            history['memory'].append(torch.cuda.memory_allocated(0)/1024/1024)

            end_time = time.time()
            elapsed_time = (end_time - start_time) / 60
            llprint('\tEpoch: %d, Loss1: %.4f, One Epoch Time: %.2fm, Appro Left Time: %.2fh\n' % (epoch,
                                                                                                np.mean(loss_record),
                                                                                                elapsed_time,
                                                                                                elapsed_time * (
                                                                                                            EPOCH - epoch - 1)/60))

            torch.save(model.state_dict(), open( os.path.join('saved', model_name, 'Epoch_%d_JA_%.4f_DDI_%.4f.model' % (epoch, ja, ddi_rate)), 'wb'))
            print('')

        dill.dump(history, open(os.path.join('saved', model_name, 'history.pkl'), 'wb'))

        # test
        torch.save(model.state_dict(), open(
            os.path.join('saved', model_name, 'final.model'), 'wb'))


if __name__ == '__main__':
    main()

parameters 285489

Eval--Epoch: 0, Step: 1057/1058	DDI Rate: 0.0800, Jaccard: 0.4857,  PRAUC: 0.7530, AVG_PRC: 0.5928, AVG_RECALL: 0.7595, AVG_F1: 0.6451
avg med 26.316749585406303
TEST time: 7.729975700378418


In [None]:
import pickle
with open(os.path.join('saved', 'Retain', 'history.pkl'), 'rb') as f:
  fin = pickle.load(f)

In [None]:
import numpy as np
print('Average Training time is: {}'.format(np.average(fin['train_time'])))
print('Average Memory Usage is: {}'.format(np.average(fin['memory'])))
print('Average DDI is {}, Std DDI is {}'.format(np.average(fin['ddi_rate']), np.std(fin['ddi_rate'])))
print('Average Jaccard is {}, Std Jaccard is {}'.format(np.average(fin['ja']), np.std(fin['ja'])))

Average Training time is: 37.40925647735596
Average Memory Usage is: 4.3642578125
Average DDI is 0.08121607624395909, Std DDI is 0.003033679479938456
Average Jaccard is 0.476266234715875, Std Jaccard is 0.018412171853630543
