<a href="https://colab.research.google.com/github/ricaelum42/Replication-of-SafeDrug/blob/main/code/SafeDrug.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
!pip install -q condacolab
import condacolab
condacolab.install()

✨🍰✨ Everything looks OK!


In [2]:
!curl -L bit.ly/rdkit-colab | tar xz -C /

curl: /usr/local/lib/libcurl.so.4: no version information available (required by curl)
  % Total    % Received % Xferd  Average Speed   Time    Time     Time  Current
                                 Dload  Upload   Total   Spent    Left  Speed
  0     0    0     0    0     0      0      0 --:--:-- --:--:-- --:--:--     0100   163  100   163    0     0   1405      0 --:--:-- --:--:-- --:--:--  1405
  0     0    0     0    0     0      0      0 --:--:-- --:--:-- --:--:--     0
  0     0    0     0    0     0      0      0 --:--:-- --:--:-- --:--:--     0
100 29.6M  100 29.6M    0     0  24.2M      0  0:00:01  0:00:01 --:--:-- 26.2M


In [3]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [4]:
%cd /content/drive/My Drive/Colab Notebooks/SafeDrug_Replication/code

/content/drive/.shortcut-targets-by-id/1l2OESp-U_3tWFs26iaoqwRpKJ2fsZbAb/SafeDrug_Replication/code


In [5]:
!pip install dnc

Collecting dnc
  Using cached dnc-1.1.0-py3-none-any.whl (20 kB)
Collecting numpy
  Using cached numpy-1.21.6-cp37-cp37m-manylinux_2_12_x86_64.manylinux2010_x86_64.whl (15.7 MB)
Collecting flann
  Using cached flann-1.6.13-py3-none-any.whl (24 kB)
Collecting torch
  Downloading torch-1.11.0-cp37-cp37m-manylinux1_x86_64.whl (750.6 MB)
[K     |████████████████████████████████| 750.6 MB 10 kB/s 
[?25hCollecting typing-extensions
  Downloading typing_extensions-4.2.0-py3-none-any.whl (24 kB)
Installing collected packages: typing-extensions, numpy, torch, flann, dnc
Successfully installed dnc-1.1.0 flann-1.6.13 numpy-1.21.6 torch-1.11.0 typing-extensions-4.2.0


In [6]:
import dill
import numpy as np
import argparse
from collections import defaultdict
from sklearn.metrics import jaccard_score
from torch.optim import Adam
import os
import torch
import time
from models import SafeDrugModel
from util import llprint, multi_label_metric, ddi_rate_score, get_n_params, buildMPNN
import torch.nn.functional as F

In [20]:
torch.manual_seed(1203)
np.random.seed(2048)
# torch.set_num_threads(30)
# os.environ["CUDA_VISIBLE_DEVICES"] = "1"

# setting
model_name = 'SafeDrug'
resume_path = 'saved/{}/Epoch_49_TARGET_0.06_JA_0.5111_DDI_0.0619.model'.format(model_name)
# resume_path = 'Epoch_49_TARGET_0.06_JA_0.5099_DDI_0.06283.model'

if not os.path.exists(os.path.join("saved", model_name)):
        os.makedirs(os.path.join("saved", model_name))

In [8]:
# Training settings
parser = argparse.ArgumentParser()
parser.add_argument('--Test', action='store_true', default=False, help="test mode")
parser.add_argument('--model_name', type=str, default=model_name, help="model name")
parser.add_argument('--resume_path', type=str, default=resume_path, help='resume path')
parser.add_argument('--lr', type=float, default=2e-4, help='learning rate')
parser.add_argument('--target_ddi', type=float, default=0.06, help='target ddi')
parser.add_argument('--kp', type=float, default=0.05, help='coefficient of P signal')
parser.add_argument('--dim', type=int, default=64, help='dimension')
parser.add_argument('--cuda', type=int, default=0, help='which cuda')

args = parser.parse_args(args=[])

In [9]:
# evaluate
def eval(model, data_eval, voc_size, epoch):
    model.eval()

    smm_record = []
    ja, prauc, avg_p, avg_r, avg_f1 = [[] for _ in range(5)]
    med_cnt, visit_cnt = 0, 0

    for step, input in enumerate(data_eval):
        y_gt, y_pred, y_pred_prob, y_pred_label = [], [], [], []
        for adm_idx, adm in enumerate(input):
            target_output, _ = model(input[:adm_idx+1])

            y_gt_tmp = np.zeros(voc_size[2])
            y_gt_tmp[adm[2]] = 1
            y_gt.append(y_gt_tmp)

            # prediction prod
            target_output = F.sigmoid(target_output).detach().cpu().numpy()[0]
            y_pred_prob.append(target_output)
            
            # prediction med set
            y_pred_tmp = target_output.copy()
            y_pred_tmp[y_pred_tmp>=0.5] = 1
            y_pred_tmp[y_pred_tmp<0.5] = 0
            y_pred.append(y_pred_tmp)

            # prediction label
            y_pred_label_tmp = np.where(y_pred_tmp == 1)[0]
            y_pred_label.append(sorted(y_pred_label_tmp))
            visit_cnt += 1
            med_cnt += len(y_pred_label_tmp)

        smm_record.append(y_pred_label)
        adm_ja, adm_prauc, adm_avg_p, adm_avg_r, adm_avg_f1 = multi_label_metric(np.array(y_gt), np.array(y_pred), np.array(y_pred_prob))

        ja.append(adm_ja)
        prauc.append(adm_prauc)
        avg_p.append(adm_avg_p)
        avg_r.append(adm_avg_r)
        avg_f1.append(adm_avg_f1)
        llprint('\rtest step: {} / {}'.format(step, len(data_eval)))

    # ddi rate
    ddi_rate = ddi_rate_score(smm_record, path='../data/output/ddi_A_final.pkl')

    llprint('\nDDI Rate: {:.4}, Jaccard: {:.4},  PRAUC: {:.4}, AVG_PRC: {:.4}, AVG_RECALL: {:.4}, AVG_F1: {:.4}, AVG_MED: {:.4}\n'.format(
        ddi_rate, np.mean(ja), np.mean(prauc), np.mean(avg_p), np.mean(avg_r), np.mean(avg_f1), med_cnt / visit_cnt
    ))

    return ddi_rate, np.mean(ja), np.mean(prauc), np.mean(avg_p), np.mean(avg_r), np.mean(avg_f1), med_cnt / visit_cnt

def main():
    
    # load data
    data_path = '../data/output/records_final.pkl'
    voc_path = '../data/output/voc_final.pkl'

    ddi_adj_path = '../data/output/ddi_A_final.pkl'
    ddi_mask_path = '../data/output/ddi_mask_H.pkl'
    molecule_path = '../data/output/atc3toSMILES.pkl'
    device = torch.device('cuda:{}'.format(args.cuda))

    ddi_adj = dill.load(open(ddi_adj_path, 'rb'))
    ddi_mask_H = dill.load(open(ddi_mask_path, 'rb'))
    data = dill.load(open(data_path, 'rb'))
    molecule = dill.load(open(molecule_path, 'rb')) 

    voc = dill.load(open(voc_path, 'rb'))
    diag_voc, pro_voc, med_voc = voc['diag_voc'], voc['pro_voc'], voc['med_voc']

    split_point = int(len(data) * 2 / 3)
    data_train = data[:split_point]
    eval_len = int(len(data[split_point:]) / 2)
    data_test = data[split_point:split_point + eval_len]
    data_eval = data[split_point+eval_len:]

    MPNNSet, N_fingerprint, average_projection = buildMPNN(molecule, med_voc.idx2word, 2, device)
    voc_size = (len(diag_voc.idx2word), len(pro_voc.idx2word), len(med_voc.idx2word))

    model = SafeDrugModel(voc_size, ddi_adj, ddi_mask_H, MPNNSet, N_fingerprint, average_projection, emb_dim=args.dim, device=device)
    # model.load_state_dict(torch.load(open(args.resume_path, 'rb')))

    if args.Test:
      tic = time.time()
      model.load_state_dict(torch.load(open(args.resume_path, 'rb')))
      model.to(device=device)
      ddi_rate, ja, prauc, avg_p, avg_r, avg_f1, avg_med = eval(model, data_eval, voc_size, 0)
      print ('training time: {}'.format(time.time() - tic))
      return ddi_rate, ja, prauc, avg_p, avg_r, avg_f1, avg_med

    model.to(device=device)
    print('parameters', get_n_params(model))
    # exit()
    optimizer = Adam(list(model.parameters()), lr=args.lr)

    # start iterations
    history = defaultdict(list)
    best_epoch, best_ja = 0, 0

    EPOCH = 50
    for epoch in range(EPOCH):
        tic = time.time()
        print ('\nepoch {} --------------------------'.format(epoch + 1))
        
        model.train()
        for step, input in enumerate(data_train):

            loss = 0
            for idx, adm in enumerate(input):

                seq_input = input[:idx+1]
                loss_bce_target = np.zeros((1, voc_size[2]))
                loss_bce_target[:, adm[2]] = 1

                loss_multi_target = np.full((1, voc_size[2]), -1)
                for idx, item in enumerate(adm[2]):
                    loss_multi_target[0][idx] = item

                result, loss_ddi = model(seq_input)

                loss_bce = F.binary_cross_entropy_with_logits(result, torch.FloatTensor(loss_bce_target).to(device))
                loss_multi = F.multilabel_margin_loss(F.sigmoid(result), torch.LongTensor(loss_multi_target).to(device))

                result = F.sigmoid(result).detach().cpu().numpy()[0]
                result[result >= 0.5] = 1
                result[result < 0.5] = 0
                y_label = np.where(result == 1)[0]
                current_ddi_rate = ddi_rate_score([[y_label]], path='../data/output/ddi_A_final.pkl')
                
                if current_ddi_rate <= args.target_ddi:
                    loss = 0.95 * loss_bce + 0.05 * loss_multi
                else:
                    beta = min(0, 1 + (args.target_ddi - current_ddi_rate) / args.kp)
                    loss = beta * (0.95 * loss_bce + 0.05 * loss_multi) + (1 - beta) * loss_ddi

                optimizer.zero_grad()
                loss.backward(retain_graph=True)
                optimizer.step()

            llprint('\rtraining step: {} / {}'.format(step, len(data_train)))

        print ()
        tic2 = time.time() 
        ddi_rate, ja, prauc, avg_p, avg_r, avg_f1, avg_med = eval(model, data_eval, voc_size, epoch)
        print ('training time: {}, test time: {}'.format(time.time() - tic, time.time() - tic2))

        history['ja'].append(ja)
        history['ddi_rate'].append(ddi_rate)
        history['avg_p'].append(avg_p)
        history['avg_r'].append(avg_r)
        history['avg_f1'].append(avg_f1)
        history['prauc'].append(prauc)
        history['med'].append(avg_med)
        history['train_time'].append(time.time() - tic)
        history['memory'].append(torch.cuda.memory_allocated(0)/1024/1024)

        if epoch >= 5:
            print ('ddi: {}, Med: {}, Ja: {}, F1: {}, PRAUC: {}'.format(
                np.mean(history['ddi_rate'][-5:]),
                np.mean(history['med'][-5:]),
                np.mean(history['ja'][-5:]),
                np.mean(history['avg_f1'][-5:]),
                np.mean(history['prauc'][-5:])
                ))

        torch.save(model.state_dict(), open(os.path.join('saved', args.model_name, \
            'Epoch_{}_TARGET_{:.2}_JA_{:.4}_DDI_{:.4}.model'.format(epoch, args.target_ddi, ja, ddi_rate)), 'wb'))

        if epoch != 0 and best_ja < ja:
            best_epoch = epoch
            best_ja = ja

        print ('best_epoch: {}'.format(best_epoch))

    dill.dump(history, open(os.path.join('saved', args.model_name, 'history_{}.pkl'.format(args.model_name)), 'wb'))


In [10]:
if __name__ == '__main__':
    main()

parameters 366122

epoch 1 --------------------------
training step: 4232 / 4233
test step: 1058 / 1059
DDI Rate: 0.06063, Jaccard: 0.4428,  PRAUC: 0.7101, AVG_PRC: 0.5768, AVG_RECALL: 0.679, AVG_F1: 0.6047, AVG_MED: 23.14
training time: 240.47699213027954, test time: 15.89719557762146
best_epoch: 0

epoch 2 --------------------------
training step: 4232 / 4233
test step: 1058 / 1059
DDI Rate: 0.05598, Jaccard: 0.4535,  PRAUC: 0.7175, AVG_PRC: 0.6512, AVG_RECALL: 0.6122, AVG_F1: 0.6165, AVG_MED: 18.35
training time: 239.02172327041626, test time: 15.472921371459961
best_epoch: 1

epoch 3 --------------------------
training step: 4232 / 4233
test step: 1058 / 1059
DDI Rate: 0.06429, Jaccard: 0.462,  PRAUC: 0.7252, AVG_PRC: 0.6656, AVG_RECALL: 0.6141, AVG_F1: 0.6244, AVG_MED: 18.14
training time: 234.1456458568573, test time: 15.208494424819946
best_epoch: 2

epoch 4 --------------------------
training step: 4232 / 4233
test step: 1058 / 1059
DDI Rate: 0.06498, Jaccard: 0.4644,  PRAUC: 0

In [11]:
import pickle
with open(os.path.join('saved', args.model_name, 'history_{}.pkl'.format(args.model_name)), 'rb') as f:
  fin = pickle.load(f)

In [12]:
import numpy as np
print('Average Training time is: {}'.format(np.average(fin['train_time'])))
print('Average Memory Usage is: {}'.format(np.average(fin['memory'])))
print('Average DDI is {}, Std DDI is {}'.format(np.average(fin['ddi_rate']), np.std(fin['ddi_rate'])))
print('Average Jaccard is {}, Std Jaccard is {}'.format(np.average(fin['ja']), np.std(fin['ja'])))

Average Training time is: 239.16050338745117
Average Memory Usage is: 688.23740234375
Average DDI is 0.06285029051592592, Std DDI is 0.0021224032912721138
Average Jaccard is 0.49252456207397977, Std Jaccard is 0.0165789511246957


In [13]:
!nvidia-smi

Sun May  8 02:55:52 2022       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 460.32.03    Driver Version: 460.32.03    CUDA Version: 11.2     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  Tesla P100-PCIE...  Off  | 00000000:00:04.0 Off |                    0 |
| N/A   67C    P0    40W / 250W |   2187MiB / 16280MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Proces

In [22]:
#Testing settings
resume_path = 'saved/{}/Epoch_49_TARGET_0.06_JA_0.5111_DDI_0.0619.model'.format(model_name)
parser = argparse.ArgumentParser()
parser.add_argument('--Test', action='store_true', default=True, help="test mode")
parser.add_argument('--model_name', type=str, default=model_name, help="model name")
parser.add_argument('--resume_path', type=str, default=resume_path, help='resume path')
parser.add_argument('--lr', type=float, default=2e-4, help='learning rate')
parser.add_argument('--target_ddi', type=float, default=0.06, help='target ddi')
parser.add_argument('--kp', type=float, default=0.05, help='coefficient of P signal')
parser.add_argument('--dim', type=int, default=64, help='dimension')
parser.add_argument('--cuda', type=int, default=0, help='which cuda')

args = parser.parse_args(args=[])

In [23]:
if __name__ == '__main__':
    main()

test step: 1058 / 1059
DDI Rate: 0.0619, Jaccard: 0.5111,  PRAUC: 0.7678, AVG_PRC: 0.6941, AVG_RECALL: 0.6714, AVG_F1: 0.6681, AVG_MED: 19.32
training time: 16.017808198928833
