<a href="https://colab.research.google.com/github/russpv/SafeDrug/blob/main/SAFEDRUG.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Setup

In [1]:
gpu_info = !nvidia-smi
gpu_info = '\n'.join(gpu_info)
if gpu_info.find('failed') >= 0:
  print('Not connected to a GPU')
else:
  print(gpu_info)

from psutil import virtual_memory
ram_gb = virtual_memory().total / 1e9
print('Your runtime has {:.1f} gigabytes of available RAM\n'.format(ram_gb))

if ram_gb < 20:
  print('Not using a high-RAM runtime')
else:
  print('You are using a high-RAM runtime!')

!pip install memory_profiler
!pip install rdkit-pypi

Sun May  8 19:17:08 2022       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 460.32.03    Driver Version: 460.32.03    CUDA Version: 11.2     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  Tesla T4            Off  | 00000000:00:04.0 Off |                    0 |
| N/A   48C    P0    27W /  70W |      0MiB / 15109MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Proces

# Args

In [2]:
import argparse
def arg_parser():
    """ Parse command line arguments

    Outputs:
        arguments {object} -- object containing command line arguments
    """

    # Initializer
    parser = argparse.ArgumentParser()

    # Add arguments here
    parser.add_argument('--Test', action='store_true', default=False, help="test mode")
    parser.add_argument('--model_name', type=str, default='none', help="model name")
    parser.add_argument('--resume_path', type=str, default='none', help='resume path')
    parser.add_argument('--lr', type=float, default=5e-4, help='learning rate')
    parser.add_argument('--target_ddi', type=float, default=0.06, help='target ddi')
    parser.add_argument('--kp', type=float, default=0.05, help='coefficient of P signal')
    parser.add_argument('--dim', type=int, default=64, help='dimension')
    parser.add_argument('--cuda', type=int, default=0, help='which cuda') ###

    parser.add_argument('--smalldata', type=int, default=1, help='which cuda') ###
    parser.add_argument('--mydata', type=int, default=1, help='which cuda') ###
    parser.add_argument('--Inf_time', type=int, default=0, help='which cuda') ###
 
    # Parse and return arguments
    return(parser.parse_args(args=[]))

args = arg_parser()

In [24]:
import os
import dill
import random
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
import pandas as pd
import sys
import time
import statistics
import datetime as dt
import logging

# set seed
seed = 1203 #1203
random.seed(seed)
np.random.seed(seed) #2048
torch.manual_seed(seed)
os.environ["PYTHONHASHSEED"] = str(seed)

# define data path
DATA_PATH = "drive/MyDrive/DL4H/Project/PaperCode/processed_orig/"
MYDATA_PATH = "drive/MyDrive/DL4H/Project/SAFEDRUG_lib/data/processed/"
WORKING_PATH = "drive/MyDrive/DL4H/Project/SAFEDRUG/"
TEST_PATH = "drive/MyDrive/DL4H/Project/SAFEDRUG/results/"

# define dataset
args.mydata = 0
args.smalldata = 0
EPOCH = 50

# define routine
args.Test = False
args.Inf_time = True

# setting
args.model_name = 'SafeDrug_Repl_orig'
args.resume_path = WORKING_PATH + 'saved/' + 'SafeDrug_Repl_origEpoch_25_TARGET_0.06_JA_0.5323_DDI_0.06914_2022-05-08 20:20:05.975347.model'
# 'SafeDrug_ReplEpoch_39_TARGET_0.06_JA_0.2375_DDI_0.3543_2022-05-08 18:38:56.320473.model'
logger = logging.getLogger('')
logger.setLevel(logging.WARNING)

# Data

In [4]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [5]:
# Data switch
if args.mydata == 1:
    data_path = MYDATA_PATH + 'ehr.pkl'
    voc_path = MYDATA_PATH + 'vocabs.pkl'

    ehr_adj_path = MYDATA_PATH + 'ehradj.pkl'
    ddi_adj_path = MYDATA_PATH + 'ddiadj.pkl'
    ddi_mask_path = MYDATA_PATH + 'hmask.pkl'
    molecule_path = MYDATA_PATH + 'atc2SMILES.pkl'

    voc = dill.load(open(voc_path, 'rb'))
    diag_voc, pro_voc, med_voc = voc['diag_vocab'].index2word, voc['pro_vocab'].index2word, voc['med_vocab'].index2word

else:
    data_path = DATA_PATH + 'records_final.pkl'
    voc_path = DATA_PATH + 'voc_final.pkl'


    ehr_adj_path = DATA_PATH + 'ehr_adj_final.pkl'
    ddi_adj_path = DATA_PATH + 'ddi_A_final.pkl'
    ddi_mask_path = DATA_PATH + 'ddi_mask_H.pkl'
    molecule_path = DATA_PATH + 'atc3toSMILES.pkl'
    
    voc = dill.load(open(voc_path, 'rb'))
    diag_voc, pro_voc, med_voc = voc['diag_voc'].idx2word, voc['pro_voc'].idx2word, voc['med_voc'].idx2word

ehr_adj = dill.load(open(ehr_adj_path, 'rb'))
ddi_adj = dill.load(open(ddi_adj_path, 'rb'))
ddi_mask_H = dill.load(open(ddi_mask_path, 'rb'))
data = dill.load(open(data_path, 'rb'))
molecule = dill.load(open(molecule_path, 'rb')) 

if args.smalldata == 1:
    data_train = data[:200] 
    data_test = data[200:250]
    data_eval = data[250:300]
else:
    split_point = int(len(data) * 2 / 3)
    data_train = data[:split_point]
    eval_len = int(len(data[split_point:]) / 2)
    data_test = data[split_point:split_point + eval_len]
    data_eval = data[split_point+eval_len:]

# Utils

In [6]:
from sklearn.metrics import jaccard_score, roc_auc_score, precision_score, f1_score, average_precision_score
from sklearn.model_selection import train_test_split
import warnings

from collections import Counter
from rdkit import Chem
from collections import defaultdict
warnings.filterwarnings('ignore')

def get_n_params(model):
    pp=0
    for p in list(model.parameters()):
        nn=1
        for s in list(p.size()):
            nn = nn*s
        pp += nn
    return pp

# use the same metric from DMNC
def llprint(message):
    sys.stdout.write(message)
    sys.stdout.flush()

def transform_split(X, Y):
    x_train, x_eval, y_train, y_eval = train_test_split(X, Y, train_size=2/3, random_state=1203)
    x_eval, x_test, y_eval, y_test = train_test_split(x_eval, y_eval, test_size=0.5, random_state=1203)
    return x_train, x_eval, x_test, y_train, y_eval, y_test

def sequence_output_process(output_logits, filter_token):
    pind = np.argsort(output_logits, axis=-1)[:, ::-1]

    out_list = []
    break_flag = False
    for i in range(len(pind)):
        if break_flag:
            break
        for j in range(pind.shape[1]):
            label = pind[i][j]
            if label in filter_token:
                break_flag = True
                break
            if label not in out_list:
                out_list.append(label)
                break
    y_pred_prob_tmp = []
    for idx, item in enumerate(out_list):
        y_pred_prob_tmp.append(output_logits[idx, item])
    sorted_predict = [x for _, x in sorted(zip(y_pred_prob_tmp, out_list), reverse=True)]
    return out_list, sorted_predict


def sequence_metric(y_gt, y_pred, y_prob, y_label):
    def average_prc(y_gt, y_label):
        score = []
        for b in range(y_gt.shape[0]):
            target = np.where(y_gt[b]==1)[0]
            out_list = y_label[b]
            inter = set(out_list) & set(target)
            prc_score = 0 if len(out_list) == 0 else len(inter) / len(out_list)
            score.append(prc_score)
        return score


    def average_recall(y_gt, y_label):
        score = []
        for b in range(y_gt.shape[0]):
            target = np.where(y_gt[b] == 1)[0]
            out_list = y_label[b]
            inter = set(out_list) & set(target)
            recall_score = 0 if len(target) == 0 else len(inter) / len(target)
            score.append(recall_score)
        return score


    def average_f1(average_prc, average_recall):
        score = []
        for idx in range(len(average_prc)):
            if (average_prc[idx] + average_recall[idx]) == 0:
                score.append(0)
            else:
                score.append(2*average_prc[idx]*average_recall[idx] / (average_prc[idx] + average_recall[idx]))
        return score


    def jaccard(y_gt, y_label):
        score = []
        for b in range(y_gt.shape[0]):
            target = np.where(y_gt[b] == 1)[0]
            out_list = y_label[b]
            inter = set(out_list) & set(target)
            union = set(out_list) | set(target)
            jaccard_score = 0 if union == 0 else len(inter) / len(union)
            score.append(jaccard_score)
        return np.mean(score)

    def f1(y_gt, y_pred):
        all_micro = []
        for b in range(y_gt.shape[0]):
            all_micro.append(f1_score(y_gt[b], y_pred[b], average='macro'))
        return np.mean(all_micro)

    def roc_auc(y_gt, y_pred_prob):
        all_micro = []
        for b in range(len(y_gt)):
            all_micro.append(roc_auc_score(y_gt[b], y_pred_prob[b], average='macro'))
        return np.mean(all_micro)

    def precision_auc(y_gt, y_prob):
        all_micro = []
        for b in range(len(y_gt)):
            all_micro.append(average_precision_score(y_gt[b], y_prob[b], average='macro'))
        return np.mean(all_micro)

    def precision_at_k(y_gt, y_prob_label, k):
        precision = 0
        for i in range(len(y_gt)):
            TP = 0
            for j in y_prob_label[i][:k]:
                if y_gt[i, j] == 1:
                    TP += 1
            precision += TP / k
        return precision / len(y_gt)
    try:
        auc = roc_auc(y_gt, y_prob)
    except ValueError:
        auc = 0
    p_1 = precision_at_k(y_gt, y_label, k=1)
    p_3 = precision_at_k(y_gt, y_label, k=3)
    p_5 = precision_at_k(y_gt, y_label, k=5)
    f1 = f1(y_gt, y_pred)
    prauc = precision_auc(y_gt, y_prob)
    ja = jaccard(y_gt, y_label)
    avg_prc = average_prc(y_gt, y_label)
    avg_recall = average_recall(y_gt, y_label)
    avg_f1 = average_f1(avg_prc, avg_recall)

    return ja, prauc, np.mean(avg_prc), np.mean(avg_recall), np.mean(avg_f1)


def multi_label_metric(y_gt, y_pred, y_prob):

    def jaccard(y_gt, y_pred):
        score = []
        for b in range(y_gt.shape[0]):
            target = np.where(y_gt[b] == 1)[0]
            out_list = np.where(y_pred[b] == 1)[0]
            inter = set(out_list) & set(target)
            union = set(out_list) | set(target)
            jaccard_score = 0 if union == 0 else len(inter) / len(union)
            score.append(jaccard_score)
        return np.mean(score)

    def average_prc(y_gt, y_pred):
        score = []
        for b in range(y_gt.shape[0]):
            target = np.where(y_gt[b] == 1)[0]
            out_list = np.where(y_pred[b] == 1)[0]
            inter = set(out_list) & set(target)
            prc_score = 0 if len(out_list) == 0 else len(inter) / len(out_list)
            score.append(prc_score)
        return score

    def average_recall(y_gt, y_pred):
        score = []
        for b in range(y_gt.shape[0]):
            target = np.where(y_gt[b] == 1)[0]
            out_list = np.where(y_pred[b] == 1)[0]
            inter = set(out_list) & set(target)
            recall_score = 0 if len(target) == 0 else len(inter) / len(target)
            score.append(recall_score)
        return score

    def average_f1(average_prc, average_recall):
        score = []
        for idx in range(len(average_prc)):
            if average_prc[idx] + average_recall[idx] == 0:
                score.append(0)
            else:
                score.append(2*average_prc[idx]*average_recall[idx] / (average_prc[idx] + average_recall[idx]))
        return score

    def f1(y_gt, y_pred):
        all_micro = []
        for b in range(y_gt.shape[0]):
            all_micro.append(f1_score(y_gt[b], y_pred[b], average='macro'))
        return np.mean(all_micro)

    def roc_auc(y_gt, y_prob):
        all_micro = []
        for b in range(len(y_gt)):
            all_micro.append(roc_auc_score(y_gt[b], y_prob[b], average='macro'))
        return np.mean(all_micro)

    def precision_auc(y_gt, y_prob):
        all_micro = []
        for b in range(len(y_gt)):
            all_micro.append(average_precision_score(y_gt[b], y_prob[b], average='macro'))
        return np.mean(all_micro)

    def precision_at_k(y_gt, y_prob, k=3):
        precision = 0
        sort_index = np.argsort(y_prob, axis=-1)[:, ::-1][:, :k]
        for i in range(len(y_gt)):
            TP = 0
            for j in range(len(sort_index[i])):
                if y_gt[i, sort_index[i, j]] == 1:
                    TP += 1
            precision += TP / len(sort_index[i])
        return precision / len(y_gt)

    # roc_auc
    try:
        auc = roc_auc(y_gt, y_prob)
    except:
        auc = 0
    # precision
    p_1 = precision_at_k(y_gt, y_prob, k=1)
    p_3 = precision_at_k(y_gt, y_prob, k=3)
    p_5 = precision_at_k(y_gt, y_prob, k=5)
    # macro f1
    f1 = f1(y_gt, y_pred)
    # precision
    prauc = precision_auc(y_gt, y_prob)
    # jaccard
    ja = jaccard(y_gt, y_pred)
    # pre, recall, f1
    avg_prc = average_prc(y_gt, y_pred)
    avg_recall = average_recall(y_gt, y_pred)
    avg_f1 = average_f1(avg_prc, avg_recall)

    return ja, prauc, np.mean(avg_prc), np.mean(avg_recall), np.mean(avg_f1)

def ddi_rate_score(record, path=ddi_adj_path): ###
    # ddi rate
    ddi_A = dill.load(open(path, 'rb'))
    all_cnt = 0
    dd_cnt = 0
    for patient in record:
        for adm in patient:
            med_code_set = adm
            for i, med_i in enumerate(med_code_set):
                for j, med_j in enumerate(med_code_set):
                    if j <= i:
                        continue
                    all_cnt += 1
                    if ddi_A[med_i, med_j] == 1 or ddi_A[med_j, med_i] == 1:
                        dd_cnt += 1
    if all_cnt == 0:
        return 0
    return dd_cnt / all_cnt

## Process molecules into graphs

In [7]:
!pip install rdkit-pypi



In [8]:
### Util function that makes a drug graph
# Input: dict of ATC:SMILES strings (up to 3); 
# Output: adj matrix
# Leverage the BRICS package to parse SMILES

from rdkit import Chem
import dill, random, torch
from collections import defaultdict

DATA_PATH = "drive/MyDrive/DL4H/Project/SAFEDRUG_lib/data/processed/"
SMILES_path = DATA_PATH + 'atc2SMILES.pkl'
SMILES_file = dill.load(open(SMILES_path, 'rb')) 

#random.choice(list(SMILES_file.items()))
# https://www.rdkit.org/docs/GettingStartedInPython.html
# https://towardsdatascience.com/expressive-power-of-graph-neural-networks-and-the-weisefeiler-lehman-test-b883db3c7c49
# https://www.youtube.com/watch?v=zCEYiCxrL_0&ab_channel=MicrosoftResearch

def get_elements_by_index(mol, atomtyp_dict):
    # Input: rdkit object
    # Returns a list indexed on atom symbol (+ 'aromatic') for the molecule, sorry periodic table
    symbols = [a.GetSymbol() for a in mol.GetAtoms()]
    aromatics = [a.GetIdx() for a in mol.GetAromaticAtoms()]
    symbols_with_aromatics = [(sym, 'aromatic') if i in aromatics else sym for i, sym in enumerate(symbols)]

    # debugging: {symbol: node#s}
    atom_dict = defaultdict(lambda: []) 
    for i, symbol in enumerate(symbols_with_aromatics):
        atom_dict[symbol].append(i)

    # update index of symbols
    counter = len(atomtyp_dict)
    for sym in symbols_with_aromatics:
        if sym not in atomtyp_dict:
            atomtyp_dict[sym] = counter
            counter = len(atomtyp_dict)

    # return list of nodes indexed on symbol list(atom#)
    atom_list = [atomtyp_dict[a] for a in symbols_with_aromatics]

    return atom_list, atomtyp_dict, atom_dict


def get_bonds_by_index(mol, bondtyp_dict):
    # Returns a dict of EDGETYPES specific to atom node# and bond type
    # Definition of edgetype: {node#: (node#, bondtyp#)}
    bonds = [(bond.GetBeginAtomIdx(), bond.GetEndAtomIdx(), str(bond.GetBondType())) for bond in mol.GetBonds()]

    # update index of bond types (single, double, etc.)
    counter = len(bondtyp_dict)
    for (_, _, b) in bonds:
        if b not in bondtyp_dict:
            bondtyp_dict[b] = counter
            counter = len(bondtyp_dict)

    # make index of nodes with nodes connected to it {i: (j, single#)}
    bond_dict = {}
    for i, j, btyp in bonds:
        if i not in bond_dict:
            bond_dict[i] = [(j, bondtyp_dict[btyp])]
        else:
            bond_dict[i].append((j, bondtyp_dict[btyp]))
        if j not in bond_dict:
            bond_dict[j] = [(i, bondtyp_dict[btyp])]
        else:
            bond_dict[j].append((i, bondtyp_dict[btyp]))

    #print(f'total bonds: {len(bonds)}')

    # list(node#, node#, bondtyp#)
    bond_list = [(bond.GetBeginAtomIdx(), bond.GetEndAtomIdx(), bondtyp_dict[str(bond.GetBondType())]) for bond in mol.GetBonds()]

    return bond_list, bondtyp_dict, bond_dict

In [9]:
# have index (numbering) of nodes, bonds based on node id
# Get the fingerprints for nodes, defined as the symbols and bond types from a node
# Also, update index of edges hashed to nodeID by indexing edge types so nodeIdx: (neighborIdx, edgetype)

def get_fingerprints(bonds, atoms, fptyp_dict, radius, fpbond_dict):
    # Returns a list of FINGERPRINT IDS or subgraphs
    # Definition of fingerprint: (atomtyp#, ((atomtyp#, bondtyp#), (atomtyp#, bondtyp#), ...))
    # Input bonds = {node#: (node#, bondtyp#)}
    # Input atoms = list(atomtyp#)
    fp_list = []
    fp_indexed = []
    fp_dict_counter = len(fptyp_dict)
    fpbond_dict_counter = len(fpbond_dict)

    if radius == 0 or len(atoms) == 1:
        for a in atoms:
            # build/update index of fingerprint types (atomtyp#, atomtyp#, bondtyp#)
            if a not in fptyp_dict:
                fptyp_dict[a] = fp_dict_counter
                fp_dict_counter = len(fptyp_dict)

        fp_indexed = [fptyp_dict[a] for a in atoms]

    else:
        subgroups = atoms
        subgroup_bonds = bonds

        for _ in range(radius):
            new_subgroups = []
            new_bonds = defaultdict(list)

            for i, bond in subgroup_bonds.items():
                neighborhood = [(subgroups[j], btyp) for (j, btyp) in bond]
                fingerprint_tuple = tuple([subgroups[i], tuple(sorted(neighborhood))])
                
                # build/update index of fingerprint types (subgrouptyp#, (subgrouptyp#, bondtyp#))
                if fingerprint_tuple not in fptyp_dict:
                    fptyp_dict[fingerprint_tuple] = fp_dict_counter
                    fp_dict_counter = len(fptyp_dict)
                
                fp_list.append(fingerprint_tuple)
                fp_indexed.append(fptyp_dict[fingerprint_tuple])

                # build nodes to include fingerprints found for radius x+1
                new_subgroups.append(fptyp_dict[fingerprint_tuple])

                ''' redefine bonds between radius x and x+1 constituents '''
                for j, bondtyp in bond:
                    pairs = tuple(sorted((subgroups[i], subgroups[j])))

                    # build/update dict of bonds between subgroups
                    if (pairs, bondtyp) not in fpbond_dict:
                        fpbond_dict[(pairs, bondtyp)] = fpbond_dict_counter
                        fpbond_dict_counter = len(fpbond_dict)

                    bond_ = fpbond_dict[(pairs, bondtyp)]
                    new_bonds[i].append((j, bond_))
            
            subgroups = new_subgroups
            subgroup_bonds = new_bonds

    return fp_list, fptyp_dict, fp_indexed, subgroups

In [10]:
def make_MPNN_data(med_voc, SMILES_file, radius, device):
    data = []
    
    fingerprinttyp_dict = {}
    fingerprintbond_dict = {}

    atomtyp_dict = {}
    bondtyp_dict = {}
    
    smiles_count_index = []
    for index, atc in med_voc.items():
        counter = 0
        for smiles_code in SMILES_file[atc]:
            try: # some molecules cannot be processed; bond indexing shifted > list index mismatch
                mol = Chem.AddHs(Chem.MolFromSmiles(smiles_code))
                atoms, atomtyp_dict, atom_dict = get_elements_by_index(mol, atomtyp_dict)
                bonds, bondtyp_dict, bond_dict = get_bonds_by_index(mol, bondtyp_dict)
                atom_count = len(atoms)
                _, fingerprinttyp_dict, fingerprints_hist, fingerprints = get_fingerprints(bond_dict, atoms, fingerprinttyp_dict, radius, fingerprintbond_dict)
                
                adj_matrix = Chem.GetAdjacencyMatrix(mol)
                if adj_matrix.shape[0] != len(fingerprints):
                    #print(f'mismatch shape: {atc} {adj_matrix.shape[0]} {len(fingerprints)}')
                    for _ in range(adj_matrix.shape[0] - len(fingerprints)):
                        fingerprints.append(1) # pad with '1'
                    #print(f'fixed shape: {atc} {adj_matrix.shape[0]} {len(fingerprints)}')
                fingerprints = torch.tensor(fingerprints, dtype=torch.long)
                adj_matrix = torch.tensor(adj_matrix, dtype=torch.float32)
                data.append((fingerprints, adj_matrix, atom_count))
                counter += 1
            except:
                continue

                    
        smiles_count_index.append(counter)

    # build grid matrix for drugs x smiles (for convolution?)
    total_cols = sum(smiles_count_index) # moleculecount
    total_rows = len(smiles_count_index) # drugcount: paper does len(smiles_count_index) instead of med_voc
    current_col = 0
    
    grid_matrix = [[0 for _ in range(total_cols)] for _ in range(total_rows)]
    for i, smiles_count in enumerate(smiles_count_index):
        for x in range(smiles_count):
            grid_matrix[i][current_col+x] = 1 / smiles_count # normalized
        current_col += smiles_count
    
    grid_matrix = torch.tensor(np.array(grid_matrix)).to(device)

    # fingerprints of each molecule are appended with adj_matrix to the data object  
    return data, len(fingerprinttyp_dict), bonds, bondtyp_dict, atomtyp_dict, fingerprinttyp_dict, fingerprints, grid_matrix

In [11]:
''' TESTING BLOCK
if __name__ == '__main__':
    
    # MPNN Unit tests for small data set
    #med_voc1 = {40: 'N06A', 39: 'N05C'}
    med_voc1 = med_voc.index2word
    molecule_test = SMILES_file
    radius1 = 2
    dataset1, fp_count1, bonds1, btyp1, atyp1, fptyp1, tempdata1, grid1 = make_MPNN_data(med_voc1, molecule_test, radius1, device='cuda')
    
    
    import unittest
    class myTest(unittest.TestCase):
        def __init__(self, expected, actual):
            super().__init__()
            self.expected = expected
            self.actual = actual
        def test(self):         
            for i, _ in enumerate(self.expected):
                self.assertEqual(type(self.expected), type(self.actual), "Should be same type")  # check they are the same type
                self.assertEqual(len(self.expected), len(self.actual), "Should be same length")  # check they are the same length 
                torch.testing.assert_close(self.expected[i][0], self.actual[i][0])  # check fingerprints
                torch.testing.assert_close(self.expected[i][1], self.actual[i][1])  # check adjacency matrix 
                assert self.expected[i][2] == self.actual[i][2], "Fingerprint count should be the same"
    
    # COMPARE WITH PAPER
    #datapaper, fpcountpaper, gridpaper = buildMPNN(molecule_test, med_voc1, radius1)
    #test1 = myTest(dataset, datapaper)
    #test1.test()
'''

' TESTING BLOCK\nif __name__ == \'__main__\':\n    \n    # MPNN Unit tests for small data set\n    #med_voc1 = {40: \'N06A\', 39: \'N05C\'}\n    med_voc1 = med_voc.index2word\n    molecule_test = SMILES_file\n    radius1 = 2\n    dataset1, fp_count1, bonds1, btyp1, atyp1, fptyp1, tempdata1, grid1 = make_MPNN_data(med_voc1, molecule_test, radius1, device=\'cuda\')\n    \n    \n    import unittest\n    class myTest(unittest.TestCase):\n        def __init__(self, expected, actual):\n            super().__init__()\n            self.expected = expected\n            self.actual = actual\n        def test(self):         \n            for i, _ in enumerate(self.expected):\n                self.assertEqual(type(self.expected), type(self.actual), "Should be same type")  # check they are the same type\n                self.assertEqual(len(self.expected), len(self.actual), "Should be same length")  # check they are the same length \n                torch.testing.assert_close(self.expected[i][0], s

### PAPER CODE

In [12]:

# PAPER CODE
def create_atoms(mol, atom_dict):
    """Transform the atom types in a molecule (e.g., H, C, and O)
    into the indices (e.g., H=0, C=1, and O=2).
    Note that each atom index considers the aromaticity.
    """
    atoms = [a.GetSymbol() for a in mol.GetAtoms()]
    for a in mol.GetAromaticAtoms():
        i = a.GetIdx()
        atoms[i] = (atoms[i], 'aromatic')
    atoms = [atom_dict[a] for a in atoms]
    return np.array(atoms)

def create_ijbonddict(mol, bond_dict):
    """Create a dictionary, in which each key is a node ID
    and each value is the tuples of its neighboring node
    and chemical bond (e.g., single and double) IDs.
    """
    i_jbond_dict = defaultdict(lambda: [])
    for b in mol.GetBonds():
        i, j = b.GetBeginAtomIdx(), b.GetEndAtomIdx()
        bond = bond_dict[str(b.GetBondType())]
        i_jbond_dict[i].append((j, bond))
        i_jbond_dict[j].append((i, bond))
    return i_jbond_dict

def extract_fingerprints(radius, atoms, i_jbond_dict,
                         fingerprint_dict, edge_dict):
    """Extract the fingerprints from a molecular graph
    based on Weisfeiler-Lehman algorithm.
    """

    if (len(atoms) == 1) or (radius == 0):
        nodes = [fingerprint_dict[a] for a in atoms]

    else:
        nodes = atoms
        i_jedge_dict = i_jbond_dict

        for _ in range(radius):

            """Update each node ID considering its neighboring nodes and edges.
            The updated node IDs are the fingerprint IDs.
            """
            nodes_ = []
            for i, j_edge in i_jedge_dict.items():
                neighbors = [(nodes[j], edge) for j, edge in j_edge]
                fingerprint = (nodes[i], tuple(sorted(neighbors)))
                nodes_.append(fingerprint_dict[fingerprint])

            """Also update each edge ID considering
            its two nodes on both sides.
            """
            i_jedge_dict_ = defaultdict(lambda: [])
            for i, j_edge in i_jedge_dict.items():
                for j, edge in j_edge:
                    both_side = tuple(sorted((nodes[i], nodes[j])))
                    edge = edge_dict[(both_side, edge)]
                    i_jedge_dict_[i].append((j, edge))

            nodes = nodes_
            i_jedge_dict = i_jedge_dict_

    return np.array(nodes)

def buildMPNN(molecule, med_voc, radius=1, device="cpu:0"):

    atom_dict = defaultdict(lambda: len(atom_dict))
    bond_dict = defaultdict(lambda: len(bond_dict))
    fingerprint_dict = defaultdict(lambda: len(fingerprint_dict))
    edge_dict = defaultdict(lambda: len(edge_dict))
    MPNNSet, average_index = [], []

    for index, atc3 in med_voc.items():

        smilesList = list(molecule[atc3])
        """Create each data with the above defined functions."""
        counter = 0 # counter how many drugs are under that ATC-3
        for smiles in smilesList:
            try:
                mol = Chem.AddHs(Chem.MolFromSmiles(smiles))
                atoms = create_atoms(mol, atom_dict)
                molecular_size = len(atoms)
                i_jbond_dict = create_ijbonddict(mol, bond_dict)
                fingerprints = extract_fingerprints(radius, atoms, i_jbond_dict,
                                                    fingerprint_dict, edge_dict)
                adjacency = Chem.GetAdjacencyMatrix(mol)
                # if fingerprints.shape[0] == adjacency.shape[0]:
                for _ in range(adjacency.shape[0] - fingerprints.shape[0]):
                    #print(f'mismatch: {atc3} {adjacency.shape[0]} {fingerprints.shape[0]} ')
                    fingerprints = np.append(fingerprints, 1)
                
                fingerprints = torch.LongTensor(fingerprints).to(device)
                adjacency = torch.FloatTensor(adjacency).to(device)
                MPNNSet.append((fingerprints, adjacency, molecular_size))
                counter += 1
            except:
                continue
        
        average_index.append(counter)

        """Transform the above each data of numpy
        to pytorch tensor on a device (i.e., CPU or GPU).
        """

    N_fingerprint = len(fingerprint_dict)
    # transform into projection matrix
    n_col = sum(average_index)
    n_row = len(average_index)

    average_projection = np.zeros((n_row, n_col))
    col_counter = 0
    for i, item in enumerate(average_index):
        if item > 0:
            average_projection[i, col_counter : col_counter + item] = 1 / item
        col_counter += item

    return MPNNSet, N_fingerprint, torch.FloatTensor(average_projection)



# Model

In [13]:
class MPNN(nn.Module):
    ''' Create the drug memory embeddings '''
    # z_next = message(fpi_l, fpj_l, W_l)
    # fpi_next = update(fpj_l, z_next)
    # concat fingerprints, add embedding dimension to them, activate feedforward, sumproduct (mm)
    def __init__(self, fingerprint_count, embed_dim, L_layers, device):
        super(MPNN, self).__init__()
        self.device = device
        self.hidden_layers = L_layers

        self.embedding = nn.Embedding(num_embeddings = fingerprint_count, embedding_dim = embed_dim).to(self.device)
        self.linears = nn.ModuleList([nn.Linear(in_features = embed_dim, out_features = embed_dim).to(self.device)
                                     for _ in range(L_layers)])

    def pad(self, matrices, pad_value= 0.):
        ''' Make one big adjacency matrix '''
        shapes = [m.shape for m in matrices]
        M, N = sum([m[0] for m in shapes]), sum([m[1] for m in shapes])
        assert M == N, "something wrong with adjacencies"

        bigmat = torch.zeros(M, M).to(self.device) + pad_value
        col_marker = 0
        row_marker = 0
        for i, matrix in enumerate(matrices):
            m, n = matrix.shape[0], matrix.shape[1]
            bigmat[row_marker:row_marker+m, col_marker:col_marker+n] = matrix
            col_marker += n
            row_marker += m
        return bigmat

    def message(self, layer, input):
        # feedforward across embedding dim (fingerprints x embed)
        return torch.relu(self.linears[layer](input)) # Eq7: MESSAGE, the W(l) part
        
    def update(self, layer, matrix, input):
        # sumproduct operation across connected nodes
        return torch.mm(matrix, input) + input # Eq8: UPDATE

    def readoutprep(self, input, splits):
        # sum of vectors per molecule
        result = [torch.sum(vector, dim=0) for vector in torch.split(input, splits)]
        return torch.stack(result)

    def forward(self, MPNN_data, grid):
        # Input: Zipped dataset fingerprint lists, adjacency matrices, fingerprint lengths
        # Operation: do ops on each set of nodes and adjacency matrix
        # Output: convolved vectors

        fp, adj, fp_lengths = MPNN_data
        # combine and process at once
        fp = torch.cat(fp).to(self.device)
        adj = self.pad(adj)

        fp = self.embedding(fp)
        for i in range(self.hidden_layers):
            x = self.message(i, fp)
            x = self.update(i, adj, x)

        # apply 1/n operation across connected nodes in one direction only
        x = self.readoutprep(x, fp_lengths)
        x = torch.mm(grid.float().to(self.device), x) # Eq9: READOUT average pooling

        return x

In [14]:
class PatientQuery(nn.Module):
    def __init__(self, emb_dim, vocab_size_diag, vocab_size_proc, device):
        super(PatientQuery, self).__init__()
        """ Output: patient representation, concat of diag/proc codes
            RNNs process the embedding dimension
        """
        self.device = device

        self.embeddings_diag = nn.Embedding(vocab_size_diag, emb_dim)
        self.embeddings_proc = nn.Embedding(vocab_size_proc, emb_dim)

        self.dropout = nn.Dropout(p=0.5)
        self.encoder_diag = nn.GRU(input_size=emb_dim, hidden_size=emb_dim, batch_first=True)
        self.encoder_proc = nn.GRU(input_size=emb_dim, hidden_size=emb_dim, batch_first=True)

        self.query = nn.Sequential(
                nn.ReLU(),
                nn.Linear(2 * emb_dim, emb_dim)
        )

        self.init_weights()

    def init_weights(self):
        initrange = 0.1
        self.embeddings_diag.weight.data.uniform_(-initrange, initrange)
        self.embeddings_proc.weight.data.uniform_(-initrange, initrange)

    def forward(self, codes_diag, codes_proc):

        diag_seq = []
        proc_seq = []

        # make sequences of visit embeddings
        for visit in codes_diag:
            diag = self.embeddings_diag(torch.LongTensor(visit).unsqueeze(dim=0).to(self.device))
            diag = self.dropout(diag)
            diag = torch.sum(diag, dim=1).unsqueeze(dim=0) # (1, 1, dim)
            diag_seq.append(diag)

        for visit in codes_proc:
            proc = self.embeddings_diag(torch.LongTensor(visit).unsqueeze(dim=0).to(self.device))
            proc = self.dropout(proc)
            proc = torch.sum(proc, dim=1).unsqueeze(dim=0) # (1, 1, dim)
            proc_seq.append(proc)

        diag_seq = torch.cat(diag_seq, dim=1) #(1, seq, dim)
        proc_seq = torch.cat(proc_seq, dim=1) #(1, seq, dim)

        emb_diag, _ = self.encoder_diag(diag_seq) #(batch, seq, dim)
        emb_proc, _ = self.encoder_proc(proc_seq)

        emb_cat = torch.cat((emb_diag, emb_proc), dim=-1).squeeze(dim=0) # (seq, dim*2)
        patient_representations = self.query(emb_cat)[-1:, :] # (seq, dim) "-1:" preserves shape

        return patient_representations

In [15]:
# PAPER CODE: See if this hand-roll controls loss runaway
from torch.nn.parameter import Parameter

class MaskLinear(nn.Module):
    def __init__(self, in_features, out_features, bias=True):
        super(MaskLinear, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.weight = Parameter(torch.FloatTensor(in_features, out_features))
        if bias:
            self.bias = Parameter(torch.FloatTensor(out_features))
        else:
            self.register_parameter('bias', None)
        self.reset_parameters()

    def reset_parameters(self):
        stdv = 1. / math.sqrt(self.weight.size(1))
        self.weight.data.uniform_(-stdv, stdv)
        if self.bias is not None:
            self.bias.data.uniform_(-stdv, stdv)

    def forward(self, input, mask):
        weight = torch.mul(self.weight, mask)
        output = torch.mm(input, weight)

        if self.bias is not None:
            return output + self.bias
        else:
            return output

    def __repr__(self):
        return self.__class__.__name__ + ' (' \
               + str(self.in_features) + ' -> ' \
               + str(self.out_features) + ')'

In [16]:
class BipartiteEncoder(nn.Module):
    def __init__(self, emb_dim, mask_size, drug_size, bias=False):
        super(BipartiteEncoder, self).__init__()
        """ 
            Bipartite Learning
            Input: mask_size = subgroups, drug_size = drugs
            Output: 
            m_f functionality vector, dim transformation: feed-forward layer NN3, emb_dim -> S subgroups
                sigmoid
            m_l local drug representation, feed-forward layer NN4, m_f * H mask, S -> M drugs
        """
        # adjust shape of patient query to ddi table - subgroups S
        self.transformation1 = nn.Linear(in_features=emb_dim, out_features=mask_size)
        # apply ddi mask
        #self.mask = HMaskApplicator()
        # adjust shape of result to ddi table - drugs M
        self.transformation2 = nn.Linear(in_features=mask_size, out_features=drug_size, bias=bias)
        # PAPER CODE
        self.bipartite_output = MaskLinear(in_features=mask_size, out_features=drug_size, bias=False)

    def forward(self, input, mask):
        x = self.transformation1(input)
        ''' POINT OF DIFF, no sigmoid in paper code'''
        #x = torch.sigmoid(x) #Eq12: NN3, mf, sigmoid not in the paper code!
        ''' POINT OF DIFF, handroll with param-reset vs in-place'''
        self.transformation2.weight.data.mul_(mask) # Eq13: NN4 w/o bias, apply ddi mask H to weights
        x = self.transformation2(x)
        # PAPER CODE
        # x = self.bipartite_output(x, mask.t())
       
        return x

In [17]:
class SafeDrug(nn.Module): 
    def __init__(self, vocab_size_diag, vocab_size_proc, H_mask, ddi_adj, MPNN_data, fingerprint_count, averaging_grid, emb_dim=256, L_layers=2, device='cuda'):
        super(SafeDrug, self).__init__()
        self.device = device
        self.H_mask = H_mask

        # embeddings > RNN encoders > patient query
        self.patient_rep = PatientQuery(emb_dim, vocab_size_diag, vocab_size_proc, device)

        # bipartite encoder
        subgroup_size = H_mask.shape[1]
        drug_size = H_mask.shape[0]
        self.bipartite = BipartiteEncoder(emb_dim, subgroup_size, drug_size, False)

        # MPNN, do message passing first
        MPNN_data = list(zip(*MPNN_data))
        self.MPNN = MPNN(fingerprint_count, emb_dim, L_layers, device).forward(MPNN_data, averaging_grid)
        self.MPNN_match = nn.Sigmoid() #Eq10 patient-to-drug matching, m_r
        self.MPNN_output = nn.Linear(drug_size, drug_size) # NN2
        self.MPNN_layernorm = nn.LayerNorm(drug_size) #Eq11: LN, m_g

        # setup
        self.ddi_adj = torch.FloatTensor(ddi_adj).to(device)
        self.H_mask = torch.FloatTensor(H_mask).to(device)

    def forward(self, x):
        input_diag = [visit[0] for visit in x]
        input_proc = [visit[1] for visit in x]
        
        # patient representation
        query = self.patient_rep(input_diag, input_proc) # (seq, dim)
        # dual molecular graph encoders
        emb_bipartite = self.bipartite(query, self.H_mask)

        matches = self.MPNN_match(torch.mm(query, self.MPNN.t())) # Eq10: m_r relevant drugs, (seq, dim) x (dim, drug)
        scores = self.MPNN_output(matches)
        attention = self.MPNN_layernorm(matches + scores) #Eq11: NN2

        # medication representation (1,M or voc_size)
        result = torch.mul(emb_bipartite, attention) #Eq14: (elementwise mult) but do sigmoid in training

        if self.training:
            ''' DDI Loss '''
            neg_pred_prob = torch.sigmoid(result)
            neg_pred_prob = torch.mul(neg_pred_prob.t(), neg_pred_prob)  # (voc_size, voc_size)
            loss_ddi = 0.0005 * neg_pred_prob.mul(self.ddi_adj).sum() # Eq 17: L_ddi, the coefficient here is not shown

            return result, loss_ddi
        else:
            return result

# Train & Inference

## Eval

In [18]:
def eval(model, data_eval, voc_size, epoch):
    model.eval()

    smm_record = []
    ja, prauc, avg_p, avg_r, avg_f1 = [[] for _ in range(5)]
    med_cnt, visit_cnt = 0, 0

    for step, input in enumerate(data_eval):
        y_gt, y_pred, y_pred_prob, y_pred_label = [], [], [], []
        for adm_idx, adm in enumerate(input):
            target_output = model(input[:adm_idx+1])

            y_gt_tmp = np.zeros(voc_size[2])
            y_gt_tmp[adm[2]] = 1
            y_gt.append(y_gt_tmp)

            # prediction prod
            target_output = torch.sigmoid(target_output).detach().cpu().numpy()[0]
            y_pred_prob.append(target_output)
            
            # prediction med set
            y_pred_tmp = target_output.copy()
            y_pred_tmp[y_pred_tmp>=0.5] = 1
            y_pred_tmp[y_pred_tmp<0.5] = 0
            y_pred.append(y_pred_tmp)

            # prediction label
            y_pred_label_tmp = np.where(y_pred_tmp == 1)[0]
            y_pred_label.append(sorted(y_pred_label_tmp))
            visit_cnt += 1
            med_cnt += len(y_pred_label_tmp)

        smm_record.append(y_pred_label)
        adm_ja, adm_prauc, adm_avg_p, adm_avg_r, adm_avg_f1 = multi_label_metric(np.array(y_gt), np.array(y_pred), np.array(y_pred_prob))

        ja.append(adm_ja)
        prauc.append(adm_prauc)
        avg_p.append(adm_avg_p)
        avg_r.append(adm_avg_r)
        avg_f1.append(adm_avg_f1)
        llprint('\rtest step: {} / {}'.format(step+1, len(data_eval)))

    # ddi rate
    ddi_rate = ddi_rate_score(smm_record, path=ddi_adj_path) ###

    llprint('\nDDI Rate: {:.4}, Jaccard: {:.4},  PRAUC: {:.4}, AVG_PRC: {:.4}, AVG_RECALL: {:.4}, AVG_F1: {:.4}, AVG_MED: {:.4}\n'.format(
        ddi_rate, np.mean(ja), np.mean(prauc), np.mean(avg_p), np.mean(avg_r), np.mean(avg_f1), med_cnt / visit_cnt
    ))

    return ddi_rate, np.mean(ja), np.mean(prauc), np.mean(avg_p), np.mean(avg_r), np.mean(avg_f1), med_cnt / visit_cnt

## Train

In [19]:
from torch.optim import Adam

def main():
    device = torch.device(f'cuda:{args.cuda}' if torch.cuda.is_available() else 'cpu')

    if args.mydata == 1:
        MPNNSet, N_fingerprint, _, _, _, _, _, average_projection = make_MPNN_data(med_voc, molecule, 2, device)
        voc_size = (len(diag_voc), len(pro_voc), len(med_voc))
    else:
        MPNNSet, N_fingerprint, average_projection = buildMPNN(molecule, med_voc, 2, device)
        voc_size = (len(diag_voc), len(pro_voc), len(med_voc))

    model = SafeDrug(voc_size[0], voc_size[1], ddi_mask_H, ddi_adj, MPNNSet, N_fingerprint, average_projection, emb_dim=args.dim, L_layers=2, device=device)
    # model.load_state_dict(torch.load(open(args.resume_path, 'rb')))
    
    if args.Inf_time:
        #https://towardsdatascience.com/the-correct-way-to-measure-inference-time-of-deep-neural-networks-304a54e5187f
        model.load_state_dict(torch.load(open(args.resume_path, 'rb')))
        model.to(device=device)
        tic = time.time()

        starter, ender = torch.cuda.Event(enable_timing=True), torch.cuda.Event(enable_timing=True)
        repetitions = len(data_test)
        timings = np.zeros((repetitions,1))
        dummy_input = [[[13, 98, 585, 1065, 21, 37, 454, 278], [69, 47], [4, 22, 12, 2, 67, 0, 86]],\
                       [[377, 326, 21, 46, 454], [115, 94], [3, 6, 12, 14, 5, 22, 2, 29, 1, 16, 11, 86]],\
                       [[377, 246, 453, 46, 21, 454], [151, 127, 128], [14, 2, 6, 29, 18, 0, 86]], [[963, 258, 32, 93, 94, 13, 103, 571, 21], [164, 423, 424, 425, 95, 426, 361, 48, 46, 2], [5, 4, 6, 7, 9, 11, 12, 3, 13, 16, 14, 22, 1, 2, 29, 44, 45, 48, 56, 20, 76, 86]]]

        #GPU-WARM-UP
        for _ in range(10):
            _ = model(dummy_input)
        count = 0

        # MEASURE PERFORMANCE
        with torch.no_grad():
            #for rep in range(repetitions):
            for rep, example in enumerate(data_test):
                starter.record()
                _ = model(example)
                ender.record()
                # WAIT FOR GPU SYNC
                torch.cuda.synchronize()
                curr_time = starter.elapsed_time(ender)
                timings[rep] = curr_time
                count += 1

        mean_syn = np.sum(timings) / repetitions
        std_syn = np.std(timings)
        print(f'Inference reps {count}, average: {mean_syn} \u00B1 {std_syn} seconds')

        data = np.array([mean_syn, std_syn, count])
        df = pd.DataFrame(data, index=['mean inference time', 'stdev', 'reps'])
        df.to_csv(TEST_PATH + 'Inf_' + args.model_name + device.type + f'{dt.datetime.now()}' + '.csv' )

        return

    if args.Test:
        model.load_state_dict(torch.load(open(args.resume_path, 'rb')))
        model.to(device=device)
        tic = time.time()

        ddi_list, ja_list, prauc_list, f1_list, med_list = [], [], [], [], []

        result = []
        for _ in range(10):
            time_start = time.time()
            test_sample = np.random.choice(data_test, round(len(data_test) * 0.8), replace=True)
            ddi_rate, ja, prauc, avg_p, avg_r, avg_f1, avg_med = eval(model, test_sample, voc_size, 0)
            time_sample = time.time() - time_start ###
            result.append([ddi_rate, ja, avg_f1, prauc, avg_med, time_sample])
            
        result = np.array(result)
        mean = result.mean(axis=0)
        std = result.std(axis=0)

        outstring = ""
        for m, s in zip(mean, std):
            outstring += "{:.4f} "u"\u00B1"" {:.4f} & ".format(m, s) ###

        print(outstring)
        time_round = time.time() - tic
        print(f'test time: {time_round}')
        
        elapsed_time = [0. for _ in range(5)]
        elapsed_time.append(time_round)
        data = np.array([mean, std, elapsed_time])

        df = pd.DataFrame(data, columns=['ddi', 'ja', 'prauc', 'f1', 'med', 'time'], index=['mean', 'std', 'seconds'])
        df.to_csv(TEST_PATH + 'Test_' + args.model_name + device.type + f'{dt.datetime.now()}' + '.csv' )

        return 

    if 'cpu' not in device.type:
        torch.cuda.reset_peak_memory_stats() # flush 
    model.to(device=device)

    print('parameters', sum(p.numel() for p in model.parameters() if p.requires_grad)) ###
    # exit()
    optimizer = Adam(list(model.parameters()), lr=args.lr)

    # start iterations
    history = defaultdict(list)
    best_epoch, best_ja = 0, 0

    times_train, times_eval = [], [] ###
    for epoch in range(EPOCH):
        time_start = time.time() ###
        print ('\nepoch {} --------------------------'.format(epoch + 1))
        
        model.train()
        beta_log = []
        for step, patient in enumerate(data_train): # PATIENT, visit, (diag, proc, med), codes
           
            loss = 0
            for idx, visit in enumerate(patient): # patient, VISIT, (diag, proc, med), codes
                
                seq_patient = patient[:idx+1] #  sequential expansion wise

                loss_bce_target = np.zeros((1, voc_size[2]))
                loss_bce_target[:, visit[2]] = 1
                loss_multi_target = np.full((1, voc_size[2]), -1)
                for idx, item in enumerate(visit[2]): # patient, visit, (diag, prod, MED), codes
                    loss_multi_target[0][idx] = item

                result, loss_ddi = model(seq_patient)

                loss_bce = F.binary_cross_entropy_with_logits(result, torch.FloatTensor(loss_bce_target).to(device))
                loss_multi = F.multilabel_margin_loss(torch.sigmoid(result), torch.LongTensor(loss_multi_target).to(device))

                result = torch.sigmoid(result).detach().cpu().numpy()[0] # Apply final sigmoid()
                result[result >= 0.5] = 1
                result[result < 0.5] = 0
                y_label = np.where(result == 1)[0]
                current_ddi_rate = ddi_rate_score([[y_label]], path=ddi_adj_path) # External DDI knowledge
                
                if current_ddi_rate <= args.target_ddi:
                    loss = 0.95 * loss_bce + 0.05 * loss_multi
                    beta = 1
                else:
                    beta = max(0, 1 - (current_ddi_rate - args.target_ddi) / args.kp) # per published paper
                    loss = beta * (0.95 * loss_bce + 0.05 * loss_multi) + (1 - beta) * loss_ddi # alpha = 0.95
                beta_log.append(beta)
                
                optimizer.zero_grad()
                loss.backward(retain_graph=True)
                optimizer.step()

            llprint(f'\rtraining step: {step+1} / {len(data_train)} loss: {loss} loss_ddi: {loss_ddi}  beta: {beta}') ###
        
        print()
        print(f'\navg_beta: {statistics.mean(beta_log)}') ###
        time_end = time.time()  ###
        ddi_rate, ja, prauc, avg_p, avg_r, avg_f1, avg_med = eval(model, data_eval, voc_size, epoch)
        time_train = time_end - time_start ###
        time_eval = time.time() - time_end  ###
        print(f'training time: {time_train}, test time: {time_eval}') ###

        times_train.append(time_train) ###
        times_eval.append(time_eval) ###

        history['ja'].append(ja)
        history['ddi_rate'].append(ddi_rate)
        history['avg_p'].append(avg_p)
        history['avg_r'].append(avg_r)
        history['avg_f1'].append(avg_f1)
        history['prauc'].append(prauc)
        history['med'].append(avg_med)

        if epoch >= 5:
            print('ddi: {}, Med: {}, Ja: {}, F1: {}, PRAUC: {}'.format(
                np.mean(history['ddi_rate'][-5:]),
                np.mean(history['med'][-5:]),
                np.mean(history['ja'][-5:]),
                np.mean(history['avg_f1'][-5:]),
                np.mean(history['prauc'][-5:])
                ))

        torch.save(model.state_dict(), open(WORKING_PATH +''.join(('saved/', args.model_name, \
            'Epoch_{}_TARGET_{:.2}_JA_{:.4}_DDI_{:.4}_{}.model'.format(epoch, args.target_ddi, ja, ddi_rate, dt.datetime.now()))), 'wb')) ###

        if epoch != 0 and best_ja < ja:
            best_epoch = epoch
            best_ja = ja

        print('best_epoch: {}'.format(best_epoch))

    dill.dump(history, open(WORKING_PATH +'history_{}_{}.pkl'.format(args.model_name, dt.datetime.now()), 'wb')) ###
    
    timings = np.array(list(zip(times_train, times_eval))) ###
    df = pd.DataFrame(timings, columns=['train', 'test']) ###
    df.to_csv(TEST_PATH + 'TimesTrain_' + args.model_name + f'{dt.datetime.now()}' + '.csv' ) ###

    # Maximum cuda memory allocated
    if 'cpu' not in device.type:
        print(f'peak training memory allocated: {torch.cuda.max_memory_allocated(device)}')


# Execute

In [25]:
if __name__ == '__main__':
    %reload_ext memory_profiler
    %memit -r1 main()

Inference reps 1058, average: 1.197962977597754 ± 0.22361619573222505 seconds
peak memory: 4506.52 MiB, increment: 0.36 MiB


In [21]:
!nvidia-smi

Sun May  8 21:18:14 2022       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 460.32.03    Driver Version: 460.32.03    CUDA Version: 11.2     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  Tesla T4            Off  | 00000000:00:04.0 Off |                    0 |
| N/A   68C    P0    31W /  70W |   2728MiB / 15109MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Proces