# Get Model Parameters

In [13]:
import torch
import torch.nn as nn
from sklearn.metrics import jaccard_score, roc_auc_score, precision_score, f1_score, average_precision_score
import numpy as np
import dill
import time
import argparse
from torch.nn import CrossEntropyLoss
from torch.optim import Adam
import os
import torch.nn.functional as F
from collections import defaultdict
import sys
sys.path.append("..")
from models import Retain, Leap, GAMENet, SafeDrugModel
from util import llprint, multi_label_metric, ddi_rate_score, get_n_params, sequence_output_process, buildMPNN

In [14]:
parameters = {}

In [15]:
# Retain
# load data
data_path = '../data/output/records_final.pkl'
voc_path = '../data/output/voc_final.pkl'

device = "cuda" if torch.cuda.is_available() else "cpu"
# print(f"Using {device} device")

data = dill.load(open(data_path, 'rb'))
voc = dill.load(open(voc_path, 'rb'))
diag_voc, pro_voc, med_voc = voc['diag_voc'], voc['pro_voc'], voc['med_voc']

split_point = int(len(data) * 2 / 3)
data_train = data[:split_point]
eval_len = int(len(data[split_point:]) / 2)
data_test = data[split_point:split_point + eval_len]
data_eval = data[split_point+eval_len:]
voc_size = (len(diag_voc.idx2word), len(pro_voc.idx2word), len(med_voc.idx2word))

model = Retain(voc_size, device=device)

# print('retain parameters:', get_n_params(model))
parameters['retain'] = get_n_params(model)

In [16]:
# Leap
# load data
data_path = '../data/output/records_final.pkl'
voc_path = '../data/output/voc_final.pkl'

device = "cuda" if torch.cuda.is_available() else "cpu"
# print(f"Using {device} device")

data = dill.load(open(data_path, 'rb'))
voc = dill.load(open(voc_path, 'rb'))
diag_voc, pro_voc, med_voc = voc['diag_voc'], voc['pro_voc'], voc['med_voc']

split_point = int(len(data) * 2 / 3)
data_train = data[:split_point]
eval_len = int(len(data[split_point:]) / 2)
data_test = data[split_point:split_point + eval_len]
data_eval = data[split_point+eval_len:]
voc_size = (len(diag_voc.idx2word), len(pro_voc.idx2word), len(med_voc.idx2word))

model = Leap(voc_size, device=device)

# print('leap parameters:', get_n_params(model))
parameters['leap'] = get_n_params(model)

In [17]:
# GAMENet
data_path = '../data/output/records_final.pkl'
voc_path = '../data/output/voc_final.pkl'

ehr_adj_path = '../data/output/ehr_adj_final.pkl'
ddi_adj_path = '../data/output/ddi_A_final.pkl'

device = "cuda" if torch.cuda.is_available() else "cpu"
# print(f"Using {device} device")

ehr_adj = dill.load(open(ehr_adj_path, 'rb'))
ddi_adj = dill.load(open(ddi_adj_path, 'rb'))
data = dill.load(open(data_path, 'rb'))

voc = dill.load(open(voc_path, 'rb'))
diag_voc, pro_voc, med_voc = voc['diag_voc'], voc['pro_voc'], voc['med_voc']

    # np.random.seed(2048)
    # np.random.shuffle(data)
split_point = int(len(data) * 2 / 3)
data_train = data[:split_point]
eval_len = int(len(data[split_point:]) / 2)
data_test = data[split_point:split_point + eval_len]
data_eval = data[split_point+eval_len:]

voc_size = (len(diag_voc.idx2word), len(pro_voc.idx2word), len(med_voc.idx2word))
model = GAMENet(voc_size, ehr_adj, ddi_adj, emb_dim=64, device=device, ddi_in_memory=True)
    
# print('gamenet parameters:', get_n_params(model))
parameters['gamenet'] = get_n_params(model)

In [18]:
# SafeDrugModel
# load data
data_path = '../data/output/records_final.pkl'
voc_path = '../data/output/voc_final.pkl'

ddi_adj_path = '../data/output/ddi_A_final.pkl'
ddi_mask_path = '../data/output/ddi_mask_H.pkl'
molecule_path = '../data/output/atc3toSMILES.pkl'

device = "cuda" if torch.cuda.is_available() else "cpu"
# print(f"Using {device} device")

ddi_adj = dill.load(open(ddi_adj_path, 'rb'))
ddi_mask_H = dill.load(open(ddi_mask_path, 'rb'))
data = dill.load(open(data_path, 'rb'))
molecule = dill.load(open(molecule_path, 'rb')) 

voc = dill.load(open(voc_path, 'rb'))
diag_voc, pro_voc, med_voc = voc['diag_voc'], voc['pro_voc'], voc['med_voc']

split_point = int(len(data) * 2 / 3)
data_train = data[:split_point]
eval_len = int(len(data[split_point:]) / 2)
data_test = data[split_point:split_point + eval_len]    
data_eval = data[split_point+eval_len:]

MPNNSet, N_fingerprint, average_projection = buildMPNN(molecule, med_voc.idx2word, 2, device)
voc_size = (len(diag_voc.idx2word), len(pro_voc.idx2word), len(med_voc.idx2word))

model = SafeDrugModel(voc_size, ddi_adj, ddi_mask_H, MPNNSet, N_fingerprint, average_projection, emb_dim=64, device=device)

# print('safedrug parameters:', get_n_params(model))
parameters['safedrug'] = get_n_params(model)

In [19]:
parameters

{'retain': 285489, 'leap': 177395, 'gamenet': 444209, 'safedrug': 368777}