In [1]:
import json

file_dir = "./graphs/condition/CCSCM"

file_id2ent = f"{file_dir}/id2ent.json"
file_ent2id = f"{file_dir}/ent2id.json"
file_id2rel = f"{file_dir}/id2rel.json"
file_rel2id = f"{file_dir}/rel2id.json"

with open(file_id2ent, 'r') as file:
    cond_id2ent = json.load(file)
with open(file_ent2id, 'r') as file:
    cond_ent2id = json.load(file)
with open(file_id2rel, 'r') as file:
    cond_id2rel = json.load(file)
with open(file_rel2id, 'r') as file:
    cond_rel2id = json.load(file)


import csv

condition_mapping_file = "./resources/CCSCM.csv"
procedure_mapping_file = "./resources/CCSPROC.csv"
drug_file = "./resources/ATC.csv"

condition_dict = {}
with open(condition_mapping_file, newline='') as csvfile:
    reader = csv.DictReader(csvfile)
    for row in reader:
        condition_dict[row['code']] = row['name'].lower()

procedure_dict = {}
with open(procedure_mapping_file, newline='') as csvfile:
    reader = csv.DictReader(csvfile)
    for row in reader:
        procedure_dict[row['code']] = row['name'].lower()

drug_dict = {}
with open(drug_file, newline='') as csvfile:
    reader = csv.DictReader(csvfile)
    for row in reader:
        if row['level'] == '5.0':
            drug_dict[row['code']] = row['name'].lower()

In [2]:
# from pyhealth.datasets import MIMIC3Dataset
# from GraphCare.task_fn import drug_recommendation_fn

# mimic3_ds = MIMIC3Dataset(
#     root="../../../data/physionet.org/files/mimiciii/1.4/", 
#     tables=["DIAGNOSES_ICD", "PROCEDURES_ICD", "PRESCRIPTIONS"],      
#     code_mapping={
#         "NDC": ("ATC", {"target_kwargs": {"level": 3}}),
#         "ICD9CM": "CCSCM",
#         "ICD9PROC": "CCSPROC"
#         },
# )

# sample_dataset = mimic3_ds.set_task(drug_recommendation_fn)

In [3]:
def flatten(lst):
    result = []
    for item in lst:
        if isinstance(item, list):
            result.extend(flatten(item))
        else:
            result.append(item)
    return result

In [4]:
from pyhealth.tokenizer import Tokenizer
import numpy as np
from tqdm import tqdm
import torch

def multihot(label, num_labels):
    multihot = np.zeros(num_labels)
    for l in label:
        multihot[l] = 1
    return multihot

def prepare_label(drugs):
    label_tokenizer = Tokenizer(
        sample_dataset.get_all_tokens(key='drugs')
    )

    labels_index = label_tokenizer.convert_tokens_to_indices(drugs)
    # print(labels_index)
    # convert to multihot
    num_labels = label_tokenizer.get_vocabulary_size()
    # print(num_labels)
    labels = multihot(labels_index, num_labels)
    return labels


# for patient in tqdm(sample_dataset):
#     # patient['drugs_all'] = flatten(patient['drugs'])
#     # print(patient['drugs_all'])
#     patient['drugs_ind'] = torch.tensor(prepare_label(patient['drugs']))

  from .autonotebook import tqdm as notebook_tqdm


In [5]:
# import pickle

# with open('./exp_data/ccscm_ccsproc/sample_dataset.pkl', 'wb') as f:
#     pickle.dump(sample_dataset, f)

In [6]:
import pickle

with open('./exp_data/ccscm_ccsproc/sample_dataset.pkl', 'rb') as f:
    sample_dataset= pickle.load(f)

In [7]:
from pyhealth.datasets import split_by_patient

train_dataset, val_dataset, test_dataset = split_by_patient(sample_dataset, [0.8, 0.1, 0.1], seed=528)

In [8]:
train_dataset[0]['conditions'], train_dataset[0]['procedures']

([['101', '106', '98', '138']], [['44', '47', '50']])

In [9]:
from tqdm import tqdm
import numpy as np
import networkx as nx
import pickle

with open('./graphs/cond_proc/CCSCM_CCSPROC/ent2id.json', 'r') as file:
    ent2id = json.load(file)
with open('./graphs/cond_proc/CCSCM_CCSPROC/rel2id.json', 'r') as file:
    rel2id = json.load(file)
with open('./graphs/cond_proc/CCSCM_CCSPROC/entity_embedding.pkl', 'rb') as file:
    ent_emb = pickle.load(file)
    

In [10]:
ent_emb[0]

array([-0.02357265,  0.002313  ,  0.02204529, ..., -0.01157682,
        0.01255   ,  0.00188047])

In [11]:
G = nx.Graph()

for i in range(len(ent_emb)):
    G.add_nodes_from([
        (i, {'y': i, 'x': ent_emb[i]})
    ])

triples_all = []
for patient in tqdm(sample_dataset):
    triples = []
    triple_set = set()
    # node_set = set()
    conditions = flatten(patient['conditions'])
    for condition in conditions:
        cond_file = f'./graphs/condition/CCSCM/{condition}.txt'
        with open(cond_file, 'r') as f:
            lines = f.readlines()
        for line in lines:
            items = line.split('\t')
            if len(items) == 3:
                h, r, t = items
                t = t[:-1]
                h = int(ent2id[h])
                r = int(rel2id[r])
                t = int(ent2id[t])
                triple = (h, r, t)
                if triple not in triple_set:
                    triples.append((h, t))
                    triple_set.add(triple)
                    # node_set.add(h)
                    # node_set.add(r)
    procedures = flatten(patient['procedures'])
    for procedure in procedures:
        proc_file = f'./graphs/procedure/CCSPROC/{procedure}.txt'
        with open(proc_file, 'r') as f:
            lines = f.readlines()
        for line in lines:
            items = line.split('\t')
            if len(items) == 3:
                h, r, t = items
                t = t[:-1]
                h = int(ent2id[h])
                r = int(rel2id[r])
                t = int(ent2id[t])
                triple = (h, r, t)
                if triple not in triple_set:
                    triples.append((h, t))
                    triple_set.add(triple)

    G.add_edges_from(
        triples,
        # label=prepare_label(patient['drugs'])
    )
    
    # triples.append(prepare_label(patient['drugs']))
    # triples_all.append(np.array(triples))


100%|██████████| 44399/44399 [02:26<00:00, 303.22it/s]


In [12]:
from torch_geometric.utils import to_networkx, from_networkx
import pickle

G_tg = from_networkx(G)

with open('./exp_data/ccscm_ccsproc/graph_tg.pkl', 'wb') as f:
    pickle.dump(G, f)

  data[key] = torch.tensor(value)


In [13]:
import torch

def get_subgraph(dataset, idx):
    
    subgraph_list = []
    # for patient in tqdm(dataset):
    patient = dataset[idx]
    triple_set = set()
    node_set = set()
    conditions = flatten(patient['conditions'])
    for condition in conditions:
        cond_file = f'./graphs/condition/CCSCM/{condition}.txt'
        with open(cond_file, 'r') as f:
            lines = f.readlines()
        for line in lines:
            items = line.split('\t')
            if len(items) == 3:
                h, r, t = items
                t = t[:-1]
                h = int(ent2id[h])
                r = int(rel2id[r])
                t = int(ent2id[t])
                triple = (h, r, t)
                if triple not in triple_set:
                    triple_set.add(triple)
                    node_set.add(h)
                    node_set.add(r)

    procedures = flatten(patient['procedures'])
    for procedure in procedures:
        proc_file = f'./graphs/procedure/CCSPROC/{procedure}.txt'
        with open(proc_file, 'r') as f:
            lines = f.readlines()
        for line in lines:
            items = line.split('\t')
            if len(items) == 3:
                h, r, t = items
                t = t[:-1]
                h = int(ent2id[h])
                r = int(rel2id[r])
                t = int(ent2id[t])
                triple = (h, r, t)
                if triple not in triple_set:
                    triple_set.add(triple)
                    node_set.add(h)
                    node_set.add(r)

    P = G_tg.subgraph(torch.tensor([*node_set]))
    P.label = patient['drugs_ind']
        # subgraph_list.append(P)

    return P


In [14]:
# train_graph_list = get_subgraph(train_dataset)
# val_graph_list = get_subgraph(val_dataset)
# test_graph_list = get_subgraph(test_dataset)

In [15]:
from torch_geometric.loader import DataListLoader, DataLoader

class Dataset(torch.utils.data.Dataset):
    def __init__(self, dataset):
        self.dataset=dataset
    def __len__(self):
        return len(self.dataset)
    def __getitem__(self, idx):
        return get_subgraph(dataset=self.dataset, idx=idx)

train_set = Dataset(train_dataset)
val_set = Dataset(val_dataset)
test_set = Dataset(test_dataset)

In [16]:
import numpy as np
import torch
import torch.nn.functional as F
from torch.nn import Linear
from torch_geometric.nn import GATConv
from torch_geometric.data import DataLoader, Data
from torch_geometric.nn import global_mean_pool
from torch_geometric.nn import DataParallel
from torch_geometric.loader import DataListLoader
import pickle

class GAT(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels, heads):
        super(GAT, self).__init__()
        self.conv1 = GATConv(in_channels, hidden_channels, heads=heads)
        self.conv2 = GATConv(hidden_channels*heads, hidden_channels, heads=heads)
        self.conv3 = GATConv(hidden_channels*heads, hidden_channels, heads=1)

        self.fc = Linear(hidden_channels, out_channels)
        
    def forward(self, x, edge_index, batch):
        x = F.elu(self.conv1(x, edge_index))
        # print(x.shape)
        x = F.elu(self.conv2(x, edge_index))
        # print(x.shape)
        x = F.elu(self.conv3(x, edge_index))
        # print(x.shape)
        x = global_mean_pool(x, batch)
        # print(x.shape)
        x = F.dropout(x, p=0.5, training=self.training)
        # print(x.shape)
        logits = self.fc(x)
        # print(logits.shape)
        return logits

In [17]:
from torch_geometric.nn import GINConv

class GIN(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels):
        super(GIN, self).__init__()
        self.conv1 = GINConv(Linear(in_channels, hidden_channels))
        self.conv2 = GINConv(Linear(hidden_channels, hidden_channels))
        self.conv3 = GINConv(Linear(hidden_channels, hidden_channels))

        self.fc = Linear(hidden_channels, out_channels)
        
    def forward(self, x, edge_index, batch):
        x = F.relu(self.conv1(x, edge_index))
        x = F.relu(self.conv2(x, edge_index))
        x = F.relu(self.conv3(x, edge_index))
        x = global_mean_pool(x, batch)
        x = F.dropout(x, p=0.5, training=self.training)
        logits = self.fc(x)
        return logits

In [18]:
from torch_geometric.nn import HGTConv

class HGT(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels, heads=2):
        super(HGT, self).__init__()
        self.conv1 = HGTConv(in_channels, hidden_channels, heads=heads)
        self.conv2 = HGTConv(hidden_channels*heads, hidden_channels, heads=heads)
        self.conv3 = HGTConv(hidden_channels*heads, hidden_channels, heads=1)

        self.fc = Linear(hidden_channels, out_channels)
        
    def forward(self, x, edge_index, batch):
        x = F.relu(self.conv1(x, edge_index))
        x = F.relu(self.conv2(x, edge_index))
        x = F.relu(self.conv3(x, edge_index))
        x = global_mean_pool(x, batch)
        x = F.dropout(x, p=0.5, training=self.training)
        logits = self.fc(x)
        return logits

In [19]:
from tqdm import tqdm

def train(model, device, train_loader, optimizer):
    model.train()
    training_loss = 0
    pbar = tqdm(train_loader)
    for data in pbar:
        pbar.set_description(f'loss: {training_loss}')
        data = data.to(device)
        optimizer.zero_grad()
        out = model(data.x, data.edge_index, data.batch)
        try:
            label = data.label.reshape(int(train_loader.batch_size), int(len(data.label)/train_loader.batch_size))
        except:
            continue
        loss = F.binary_cross_entropy_with_logits(out, label)
        loss.backward()
        training_loss = loss
        optimizer.step()
    
    return training_loss 

In [20]:
# for data in train_loader:
#     print(train_loader.batch_size)
#     print(len(data.label))
#     print(data.label.reshape(int(train_loader.batch_size), int(len(data.label)/train_loader.batch_size)).shape)
#     break

In [21]:
from pyhealth.metrics import multilabel_metrics_fn

def evaluate(model, device, loader):
    model.eval()
    y_prob_all = []
    y_true_all = []

    for data in tqdm(loader):
        data = data.to(device)
        with torch.no_grad():
            logits = model(data.x, data.edge_index, data.batch)
            y_prob = torch.sigmoid(logits)
            try:
                y_true = data.label.reshape(int(loader.batch_size), int(len(data.label)/loader.batch_size))
            except:
                continue
            y_prob_all.append(y_prob.cpu())
            y_true_all.append(y_true.cpu())
            
    y_true_all = np.concatenate(y_true_all, axis=0)
    y_prob_all = np.concatenate(y_prob_all, axis=0)
    # pr_auc = multilabel_metrics_fn(y_true=y_true_all, y_prob=y_true_all, metrics="pr_auc_macro")

    return y_true_all, y_prob_all

In [22]:
from sklearn.metrics import average_precision_score

def train_loop(train_loader, val_loader, model, optimizer, device, epochs):
    for epoch in range(1, epochs+1):
        loss = train(model, device, train_loader, optimizer)
        # y_true_all, y_prob_all = evaluate(model, device, train_loader)
        # train_pr_auc = average_precision_score(y_true_all, y_prob_all, average="macro")
        y_true_all, y_prob_all = evaluate(model, device, val_loader)
        val_pr_auc = average_precision_score(y_true_all, y_prob_all, average="samples")
        print(f'Epoch: {epoch}, Training loss: {loss}, Val PRAUC: {val_pr_auc:.4f}')


In [23]:
train_set[0].x.shape[1], len(train_set[0].label)

(1536, 197)

In [24]:
train_loader = DataLoader(train_set, batch_size=16, shuffle=True)
val_loader = DataLoader(val_set, batch_size=8, shuffle=False)
test_loader = DataLoader(test_set, batch_size=8, shuffle=False)



In [25]:
in_channels = train_set[0].x.shape[1]
out_channels = len(train_set[0].label)

device = torch.device('cuda:5' if torch.cuda.is_available() else 'cpu')
# model = GAT(in_channels=in_channels, out_channels=out_channels, hidden_channels=512, heads=2).to(device)
model = GIN(in_channels=in_channels, out_channels=out_channels, hidden_channels=512).to(device)
# model = HGT(in_channels=in_channels, out_channels=out_channels, hidden_channels=512, heads=2).to(device)

model.double()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

train_loop(train_loader=train_loader, val_loader=val_loader, model=model, optimizer=optimizer, device=device, epochs=100)

loss: 0.20845470383603468: 100%|██████████| 2224/2224 [08:51<00:00,  4.19it/s]
100%|██████████| 544/544 [01:05<00:00,  8.25it/s]


Epoch: 1, Training loss: 0.20845470383603468, Val PRAUC: 0.7174


loss: 0.23972805611222395: 100%|██████████| 2224/2224 [09:05<00:00,  4.08it/s]
100%|██████████| 544/544 [01:04<00:00,  8.40it/s]


Epoch: 2, Training loss: 0.23972805611222395, Val PRAUC: 0.7288


loss: 0.22434450699606728: 100%|██████████| 2224/2224 [08:41<00:00,  4.26it/s]
100%|██████████| 544/544 [01:05<00:00,  8.30it/s]


Epoch: 3, Training loss: 0.22434450699606728, Val PRAUC: 0.7291


loss: 0.2418892674426145: 100%|██████████| 2224/2224 [08:37<00:00,  4.30it/s] 
100%|██████████| 544/544 [00:58<00:00,  9.22it/s]


Epoch: 4, Training loss: 0.2418892674426145, Val PRAUC: 0.7354


loss: 0.22962764692098747: 100%|██████████| 2224/2224 [09:08<00:00,  4.05it/s]
100%|██████████| 544/544 [01:00<00:00,  9.03it/s]


Epoch: 5, Training loss: 0.22962764692098747, Val PRAUC: 0.7387


loss: 0.24029206685211477: 100%|██████████| 2224/2224 [08:22<00:00,  4.42it/s]
100%|██████████| 544/544 [00:53<00:00, 10.24it/s]


Epoch: 6, Training loss: 0.24029206685211477, Val PRAUC: 0.7409


loss: 0.2112569540783083: 100%|██████████| 2224/2224 [08:26<00:00,  4.39it/s] 
100%|██████████| 544/544 [00:54<00:00, 10.05it/s]


Epoch: 7, Training loss: 0.2112569540783083, Val PRAUC: 0.7395


loss: 0.22318523439524726: 100%|██████████| 2224/2224 [08:13<00:00,  4.51it/s]
100%|██████████| 544/544 [00:54<00:00, 10.02it/s]


Epoch: 8, Training loss: 0.22318523439524726, Val PRAUC: 0.7473


loss: 0.20956010375467465: 100%|██████████| 2224/2224 [08:22<00:00,  4.42it/s]
100%|██████████| 544/544 [00:55<00:00,  9.88it/s]


Epoch: 9, Training loss: 0.20956010375467465, Val PRAUC: 0.7501


loss: 0.21258448169505043: 100%|██████████| 2224/2224 [08:13<00:00,  4.50it/s]
100%|██████████| 544/544 [00:54<00:00,  9.95it/s]


Epoch: 10, Training loss: 0.21258448169505043, Val PRAUC: 0.7522


loss: 0.20101141359798774: 100%|██████████| 2224/2224 [08:26<00:00,  4.39it/s]
100%|██████████| 544/544 [00:56<00:00,  9.69it/s]


Epoch: 11, Training loss: 0.20101141359798774, Val PRAUC: 0.7545


loss: 0.17546434132352914: 100%|██████████| 2224/2224 [08:22<00:00,  4.42it/s]
100%|██████████| 544/544 [00:55<00:00,  9.80it/s]


Epoch: 12, Training loss: 0.17546434132352914, Val PRAUC: 0.7536


loss: 0.21171690192051368: 100%|██████████| 2224/2224 [08:23<00:00,  4.42it/s]
100%|██████████| 544/544 [00:52<00:00, 10.27it/s]


Epoch: 13, Training loss: 0.21171690192051368, Val PRAUC: 0.7524


loss: 0.1776039270844646: 100%|██████████| 2224/2224 [08:25<00:00,  4.40it/s] 
100%|██████████| 544/544 [00:52<00:00, 10.28it/s]


Epoch: 14, Training loss: 0.1776039270844646, Val PRAUC: 0.7559


loss: 0.26245010133760927: 100%|██████████| 2224/2224 [08:24<00:00,  4.41it/s]
100%|██████████| 544/544 [00:56<00:00,  9.66it/s]


Epoch: 15, Training loss: 0.26245010133760927, Val PRAUC: 0.7546


loss: 0.22165916829809776: 100%|██████████| 2224/2224 [08:26<00:00,  4.39it/s]
100%|██████████| 544/544 [00:56<00:00,  9.65it/s]


Epoch: 16, Training loss: 0.22165916829809776, Val PRAUC: 0.7584


loss: 0.19589370207235915: 100%|██████████| 2224/2224 [08:45<00:00,  4.23it/s]
100%|██████████| 544/544 [00:58<00:00,  9.31it/s]


Epoch: 17, Training loss: 0.19589370207235915, Val PRAUC: 0.7595


loss: 0.17310211766259195: 100%|██████████| 2224/2224 [08:28<00:00,  4.37it/s]
100%|██████████| 544/544 [00:56<00:00,  9.70it/s]


Epoch: 18, Training loss: 0.17310211766259195, Val PRAUC: 0.7579


loss: 0.19706800228609456: 100%|██████████| 2224/2224 [08:28<00:00,  4.37it/s]
100%|██████████| 544/544 [00:55<00:00,  9.83it/s]


Epoch: 19, Training loss: 0.19706800228609456, Val PRAUC: 0.7610


loss: 0.23533733806053414: 100%|██████████| 2224/2224 [08:20<00:00,  4.44it/s]
100%|██████████| 544/544 [00:55<00:00,  9.82it/s]


Epoch: 20, Training loss: 0.23533733806053414, Val PRAUC: 0.7604


loss: 0.208458969008806: 100%|██████████| 2224/2224 [08:42<00:00,  4.25it/s]  
100%|██████████| 544/544 [00:55<00:00,  9.78it/s]


Epoch: 21, Training loss: 0.208458969008806, Val PRAUC: 0.7613


loss: 0.2005399138125495: 100%|██████████| 2224/2224 [08:37<00:00,  4.29it/s] 
100%|██████████| 544/544 [00:56<00:00,  9.67it/s]


Epoch: 22, Training loss: 0.2005399138125495, Val PRAUC: 0.7633


loss: 0.16708124847935812: 100%|██████████| 2224/2224 [08:44<00:00,  4.24it/s]
100%|██████████| 544/544 [00:57<00:00,  9.54it/s]


Epoch: 23, Training loss: 0.16708124847935812, Val PRAUC: 0.7635


loss: 0.19868043015748374: 100%|██████████| 2224/2224 [08:41<00:00,  4.26it/s]
100%|██████████| 544/544 [01:06<00:00,  8.15it/s]


Epoch: 24, Training loss: 0.19868043015748374, Val PRAUC: 0.7626


loss: 0.17750884154332322: 100%|██████████| 2224/2224 [08:42<00:00,  4.26it/s]
100%|██████████| 544/544 [00:55<00:00,  9.88it/s]


Epoch: 25, Training loss: 0.17750884154332322, Val PRAUC: 0.7647


loss: 0.1761976968581584: 100%|██████████| 2224/2224 [08:35<00:00,  4.31it/s] 
100%|██████████| 544/544 [00:58<00:00,  9.38it/s]


Epoch: 26, Training loss: 0.1761976968581584, Val PRAUC: 0.7650


loss: 0.18582236073673208: 100%|██████████| 2224/2224 [08:40<00:00,  4.28it/s]
100%|██████████| 544/544 [00:58<00:00,  9.34it/s]


Epoch: 27, Training loss: 0.18582236073673208, Val PRAUC: 0.7656


loss: 0.2638587587362063: 100%|██████████| 2224/2224 [08:51<00:00,  4.18it/s] 
100%|██████████| 544/544 [00:58<00:00,  9.28it/s]


Epoch: 28, Training loss: 0.2638587587362063, Val PRAUC: 0.7649


loss: 0.20087744499222818: 100%|██████████| 2224/2224 [08:48<00:00,  4.21it/s]
100%|██████████| 544/544 [00:58<00:00,  9.28it/s]


Epoch: 29, Training loss: 0.20087744499222818, Val PRAUC: 0.7642


loss: 0.1780377905530934: 100%|██████████| 2224/2224 [08:53<00:00,  4.17it/s] 
100%|██████████| 544/544 [00:58<00:00,  9.31it/s]


Epoch: 30, Training loss: 0.1780377905530934, Val PRAUC: 0.7662


loss: 0.22515538179057615: 100%|██████████| 2224/2224 [08:46<00:00,  4.22it/s]
100%|██████████| 544/544 [00:57<00:00,  9.47it/s]


Epoch: 31, Training loss: 0.22515538179057615, Val PRAUC: 0.7667


loss: 0.1847679121280134: 100%|██████████| 2224/2224 [08:45<00:00,  4.23it/s] 
100%|██████████| 544/544 [00:59<00:00,  9.18it/s]


Epoch: 32, Training loss: 0.1847679121280134, Val PRAUC: 0.7670


loss: 0.2540184548991319: 100%|██████████| 2224/2224 [08:48<00:00,  4.21it/s] 
100%|██████████| 544/544 [00:58<00:00,  9.33it/s]


Epoch: 33, Training loss: 0.2540184548991319, Val PRAUC: 0.7670


loss: 0.20348826640186155: 100%|██████████| 2224/2224 [08:52<00:00,  4.18it/s]
100%|██████████| 544/544 [01:00<00:00,  8.99it/s]


Epoch: 34, Training loss: 0.20348826640186155, Val PRAUC: 0.7657


loss: 0.20162477042439256: 100%|██████████| 2224/2224 [08:44<00:00,  4.24it/s]
100%|██████████| 544/544 [00:58<00:00,  9.34it/s]


Epoch: 35, Training loss: 0.20162477042439256, Val PRAUC: 0.7679


loss: 0.21079879979609947: 100%|██████████| 2224/2224 [08:39<00:00,  4.28it/s]
100%|██████████| 544/544 [01:04<00:00,  8.49it/s]


Epoch: 36, Training loss: 0.21079879979609947, Val PRAUC: 0.7692


loss: 0.226747016133661: 100%|██████████| 2224/2224 [08:36<00:00,  4.30it/s]  
100%|██████████| 544/544 [00:55<00:00,  9.84it/s]


Epoch: 37, Training loss: 0.226747016133661, Val PRAUC: 0.7680


loss: 0.2285863094093403: 100%|██████████| 2224/2224 [08:33<00:00,  4.33it/s] 
100%|██████████| 544/544 [00:45<00:00, 11.88it/s]


Epoch: 38, Training loss: 0.2285863094093403, Val PRAUC: 0.7685


loss: 0.18187432528792763: 100%|██████████| 2224/2224 [07:23<00:00,  5.02it/s]
100%|██████████| 544/544 [00:45<00:00, 11.89it/s]


Epoch: 39, Training loss: 0.18187432528792763, Val PRAUC: 0.7699


loss: 0.1910438167349075: 100%|██████████| 2224/2224 [07:25<00:00,  4.99it/s] 
100%|██████████| 544/544 [00:47<00:00, 11.49it/s]


Epoch: 40, Training loss: 0.1910438167349075, Val PRAUC: 0.7689


loss: 0.19246263441065442: 100%|██████████| 2224/2224 [07:24<00:00,  5.00it/s]
100%|██████████| 544/544 [00:45<00:00, 12.04it/s]


Epoch: 41, Training loss: 0.19246263441065442, Val PRAUC: 0.7674


loss: 0.21248835976265035: 100%|██████████| 2224/2224 [07:18<00:00,  5.08it/s]
100%|██████████| 544/544 [00:46<00:00, 11.58it/s]


Epoch: 42, Training loss: 0.21248835976265035, Val PRAUC: 0.7686


loss: 0.19745928648934732: 100%|██████████| 2224/2224 [08:01<00:00,  4.62it/s]
100%|██████████| 544/544 [00:48<00:00, 11.12it/s]


Epoch: 43, Training loss: 0.19745928648934732, Val PRAUC: 0.7679


loss: 0.28744406703851383: 100%|██████████| 2224/2224 [07:59<00:00,  4.64it/s]
100%|██████████| 544/544 [00:55<00:00,  9.81it/s]


Epoch: 44, Training loss: 0.28744406703851383, Val PRAUC: 0.7692


loss: 0.27873875669321696: 100%|██████████| 2224/2224 [08:03<00:00,  4.60it/s]
100%|██████████| 544/544 [00:49<00:00, 10.97it/s]


Epoch: 45, Training loss: 0.27873875669321696, Val PRAUC: 0.7693


loss: 0.22060631948070736: 100%|██████████| 2224/2224 [07:23<00:00,  5.02it/s]
100%|██████████| 544/544 [00:46<00:00, 11.81it/s]


Epoch: 46, Training loss: 0.22060631948070736, Val PRAUC: 0.7700


loss: 0.2211875746321175: 100%|██████████| 2224/2224 [07:21<00:00,  5.04it/s] 
100%|██████████| 544/544 [00:45<00:00, 11.99it/s]


Epoch: 47, Training loss: 0.2211875746321175, Val PRAUC: 0.7675


loss: 0.21826267332896776: 100%|██████████| 2224/2224 [07:24<00:00,  5.00it/s]
100%|██████████| 544/544 [00:45<00:00, 12.00it/s]


Epoch: 48, Training loss: 0.21826267332896776, Val PRAUC: 0.7710


loss: 0.19526134620361288: 100%|██████████| 2224/2224 [07:26<00:00,  4.98it/s]
100%|██████████| 544/544 [00:44<00:00, 12.13it/s]


Epoch: 49, Training loss: 0.19526134620361288, Val PRAUC: 0.7709


loss: 0.2026119061625142: 100%|██████████| 2224/2224 [07:26<00:00,  4.99it/s] 
100%|██████████| 544/544 [00:45<00:00, 12.01it/s]


Epoch: 50, Training loss: 0.2026119061625142, Val PRAUC: 0.7690


loss: 0.24128056124161143: 100%|██████████| 2224/2224 [07:25<00:00,  4.99it/s]
100%|██████████| 544/544 [00:45<00:00, 11.86it/s]


Epoch: 51, Training loss: 0.24128056124161143, Val PRAUC: 0.7706


loss: 0.221472125282259: 100%|██████████| 2224/2224 [07:25<00:00,  5.00it/s]  
100%|██████████| 544/544 [00:45<00:00, 12.05it/s]


Epoch: 52, Training loss: 0.221472125282259, Val PRAUC: 0.7704


loss: 0.21651583959616824: 100%|██████████| 2224/2224 [07:28<00:00,  4.96it/s]
100%|██████████| 544/544 [00:45<00:00, 12.05it/s]


Epoch: 53, Training loss: 0.21651583959616824, Val PRAUC: 0.7712


loss: 0.2398981508837271: 100%|██████████| 2224/2224 [07:27<00:00,  4.97it/s] 
100%|██████████| 544/544 [00:46<00:00, 11.77it/s]


Epoch: 54, Training loss: 0.2398981508837271, Val PRAUC: 0.7711


loss: 0.2202709591540564: 100%|██████████| 2224/2224 [07:30<00:00,  4.93it/s] 
100%|██████████| 544/544 [00:46<00:00, 11.59it/s]


Epoch: 55, Training loss: 0.2202709591540564, Val PRAUC: 0.7725


loss: 0.23969705890707457: 100%|██████████| 2224/2224 [07:28<00:00,  4.96it/s]
100%|██████████| 544/544 [00:46<00:00, 11.71it/s]


Epoch: 56, Training loss: 0.23969705890707457, Val PRAUC: 0.7713


loss: 0.2119968150811147: 100%|██████████| 2224/2224 [07:22<00:00,  5.02it/s] 
100%|██████████| 544/544 [00:46<00:00, 11.82it/s]


Epoch: 57, Training loss: 0.2119968150811147, Val PRAUC: 0.7720


loss: 0.20444621267863758: 100%|██████████| 2224/2224 [07:16<00:00,  5.10it/s]
100%|██████████| 544/544 [00:45<00:00, 11.89it/s]


Epoch: 58, Training loss: 0.20444621267863758, Val PRAUC: 0.7716


loss: 0.20279250825933506: 100%|██████████| 2224/2224 [07:15<00:00,  5.10it/s]
100%|██████████| 544/544 [00:45<00:00, 12.04it/s]


Epoch: 59, Training loss: 0.20279250825933506, Val PRAUC: 0.7726


loss: 0.23962509343566418: 100%|██████████| 2224/2224 [08:57<00:00,  4.14it/s]
100%|██████████| 544/544 [01:23<00:00,  6.52it/s]


Epoch: 60, Training loss: 0.23962509343566418, Val PRAUC: 0.7723


loss: 0.20205192508226952: 100%|██████████| 2224/2224 [10:32<00:00,  3.52it/s]
100%|██████████| 544/544 [00:47<00:00, 11.51it/s]


Epoch: 61, Training loss: 0.20205192508226952, Val PRAUC: 0.7726


loss: 0.19791778946605693: 100%|██████████| 2224/2224 [07:29<00:00,  4.95it/s]
100%|██████████| 544/544 [00:46<00:00, 11.63it/s]


Epoch: 62, Training loss: 0.19791778946605693, Val PRAUC: 0.7704


loss: 0.2656511314999034: 100%|██████████| 2224/2224 [07:29<00:00,  4.95it/s] 
100%|██████████| 544/544 [00:46<00:00, 11.75it/s]


Epoch: 63, Training loss: 0.2656511314999034, Val PRAUC: 0.7722


loss: 0.17504192640642616: 100%|██████████| 2224/2224 [07:28<00:00,  4.96it/s]
100%|██████████| 544/544 [00:46<00:00, 11.79it/s]


Epoch: 64, Training loss: 0.17504192640642616, Val PRAUC: 0.7709


loss: 0.15994202639227145: 100%|██████████| 2224/2224 [07:32<00:00,  4.92it/s]
100%|██████████| 544/544 [00:46<00:00, 11.75it/s]


Epoch: 65, Training loss: 0.15994202639227145, Val PRAUC: 0.7715


loss: 0.1648333914614242: 100%|██████████| 2224/2224 [07:27<00:00,  4.98it/s] 
100%|██████████| 544/544 [00:46<00:00, 11.78it/s]


Epoch: 66, Training loss: 0.1648333914614242, Val PRAUC: 0.7719


loss: 0.13525692814589593: 100%|██████████| 2224/2224 [07:24<00:00,  5.00it/s]
100%|██████████| 544/544 [00:47<00:00, 11.48it/s]


Epoch: 67, Training loss: 0.13525692814589593, Val PRAUC: 0.7711


loss: 0.22046469213048428: 100%|██████████| 2224/2224 [07:28<00:00,  4.96it/s]
100%|██████████| 544/544 [00:46<00:00, 11.63it/s]


Epoch: 68, Training loss: 0.22046469213048428, Val PRAUC: 0.7732


loss: 0.19771240577363075: 100%|██████████| 2224/2224 [07:30<00:00,  4.93it/s]
100%|██████████| 544/544 [00:46<00:00, 11.61it/s]


Epoch: 69, Training loss: 0.19771240577363075, Val PRAUC: 0.7710


loss: 0.16152382769798265: 100%|██████████| 2224/2224 [07:28<00:00,  4.95it/s]
100%|██████████| 544/544 [00:46<00:00, 11.74it/s]


Epoch: 70, Training loss: 0.16152382769798265, Val PRAUC: 0.7729


loss: 0.2193718015479725: 100%|██████████| 2224/2224 [07:32<00:00,  4.92it/s] 
100%|██████████| 544/544 [00:47<00:00, 11.56it/s]


Epoch: 71, Training loss: 0.2193718015479725, Val PRAUC: 0.7719


loss: 0.22373552037258046: 100%|██████████| 2224/2224 [07:30<00:00,  4.94it/s]
100%|██████████| 544/544 [00:46<00:00, 11.66it/s]


Epoch: 72, Training loss: 0.22373552037258046, Val PRAUC: 0.7730


loss: 0.21471194562883106: 100%|██████████| 2224/2224 [07:29<00:00,  4.95it/s]
100%|██████████| 544/544 [00:46<00:00, 11.69it/s]


Epoch: 73, Training loss: 0.21471194562883106, Val PRAUC: 0.7700


loss: 0.2610336570418934: 100%|██████████| 2224/2224 [07:29<00:00,  4.95it/s] 
100%|██████████| 544/544 [00:47<00:00, 11.54it/s]


Epoch: 74, Training loss: 0.2610336570418934, Val PRAUC: 0.7707


loss: 0.2026405506556108: 100%|██████████| 2224/2224 [07:31<00:00,  4.93it/s] 
100%|██████████| 544/544 [00:47<00:00, 11.57it/s]


Epoch: 75, Training loss: 0.2026405506556108, Val PRAUC: 0.7726


loss: 0.22303605636571258: 100%|██████████| 2224/2224 [07:33<00:00,  4.90it/s]
100%|██████████| 544/544 [00:46<00:00, 11.60it/s]


Epoch: 76, Training loss: 0.22303605636571258, Val PRAUC: 0.7731


loss: 0.24921361295005828: 100%|██████████| 2224/2224 [07:41<00:00,  4.82it/s]
100%|██████████| 544/544 [00:47<00:00, 11.51it/s]


Epoch: 77, Training loss: 0.24921361295005828, Val PRAUC: 0.7712


loss: 0.1502805209450741: 100%|██████████| 2224/2224 [07:29<00:00,  4.95it/s] 
100%|██████████| 544/544 [00:45<00:00, 11.99it/s]


Epoch: 78, Training loss: 0.1502805209450741, Val PRAUC: 0.7723


loss: 0.17816704615517293: 100%|██████████| 2224/2224 [07:25<00:00,  4.99it/s]
100%|██████████| 544/544 [00:46<00:00, 11.79it/s]


Epoch: 79, Training loss: 0.17816704615517293, Val PRAUC: 0.7716


loss: 0.26479352117914584: 100%|██████████| 2224/2224 [07:23<00:00,  5.02it/s]
100%|██████████| 544/544 [00:45<00:00, 11.86it/s]


Epoch: 80, Training loss: 0.26479352117914584, Val PRAUC: 0.7727


loss: 0.2518688727092115: 100%|██████████| 2224/2224 [07:30<00:00,  4.94it/s] 
100%|██████████| 544/544 [00:45<00:00, 11.85it/s]


Epoch: 81, Training loss: 0.2518688727092115, Val PRAUC: 0.7723


loss: 0.23481248585433223: 100%|██████████| 2224/2224 [07:25<00:00,  4.99it/s]
100%|██████████| 544/544 [00:45<00:00, 11.87it/s]


Epoch: 82, Training loss: 0.23481248585433223, Val PRAUC: 0.7723


loss: 0.19270287113738618: 100%|██████████| 2224/2224 [07:22<00:00,  5.03it/s]
100%|██████████| 544/544 [00:45<00:00, 11.87it/s]


Epoch: 83, Training loss: 0.19270287113738618, Val PRAUC: 0.7730


loss: 0.21479746201523026:  55%|█████▍    | 1215/2224 [04:02<03:21,  5.02it/s]


KeyboardInterrupt: 

In [None]:
# torch.save(model.state_dict(), './exp_data/saved_weights_gat_mimic3_drugrec.pkl')
torch.save(model.state_dict(), './exp_data/saved_weights_gin_mimic3_drugrec.pkl')
# torch.save(model.state_dict(), './exp_data/saved_weights_hgt_mimic3_drugrec.pkl')

In [None]:
y_true_all, y_prob_all = evaluate(model, device, test_loader)

In [None]:
from pyhealth.metrics import multilabel_metrics_fn
from sklearn.metrics import average_precision_score

# len(y_true_all), len(y_prob_all)
pr_auc = average_precision_score(y_true_all, y_prob_all, average="samples")
print(pr_auc)

In [None]:
# import numpy as np

# graph_all_patients = np.array(triples_all)

In [None]:
# import pickle

# with open('./graph_mimic3_patients.pkl', 'wb') as f:
#     pickle.dump(graph_all_patients, f)

In [None]:
# import pickle

# with open('./graph_mimic3_patients.pkl', 'rb') as f:
#     graph_all_patients = pickle.load(f)

In [None]:
# with open('./graphs/cond_proc/CCSCM_CCSPROC/id2ent.json', 'r') as file:
#     id2ent = json.load(file)
# with open('./graphs/cond_proc/CCSCM_CCSPROC/id2rel.json', 'r') as file:
#     id2rel = json.load(file)

In [None]:
# import torch
# import dgl
# import numpy as np

# patient = graph_all_patients[0]
# nodes = set()
# for triple in patient:
#     h, r, t = triple[0], triple[1], triple[2]
#     nodes.add(h)
#     nodes.add(t)

# node_idx_ori2new = {node : idx for idx, node in enumerate(nodes)}
# node_idx_new2ori = {idx: node for node, idx in node_idx_ori2new.items()}
# heads = torch.tensor([node_idx_ori2new[head] for head in patient[:, 0]])
# tails = torch.tensor([node_idx_ori2new[tail] for tail in patient[:, 2]])

# g = dgl.graph((heads, tails), num_nodes=len(nodes))

# g.edata['edge_ids'] = torch.from_numpy(patient[:, 1])
# g.ndata['node_ids'] = torch.from_numpy(np.array([node_idx_new2ori[i] for i in range(len(nodes))]))

In [None]:
# g.ndata['node_ids'][165], g.edata['edge_ids'][0], g.ndata['node_ids'][147], patient[0]

In [None]:
# g.ndata['node_ids'][226], g.edata['edge_ids'][1], g.ndata['node_ids'][385], patient[1]

In [None]:
# """Torch modules for graph attention networks with fully valuable edges (EGAT)."""
# # pylint: disable= no-member, arguments-differ, invalid-name
# import torch as th
# from torch import nn
# from torch.nn import init

# from dgl import function as fn
# from dgl.nn.functional import edge_softmax
# from dgl.base import DGLError
# from dgl.utils import expand_as_pair

# # pylint: enable=W0235
# class EGATConv(nn.Module):
#     def __init__(self,
#                  in_node_feats,
#                  in_edge_feats,
#                  out_node_feats,
#                  out_edge_feats,
#                  num_heads,
#                  bias=True):

#         super().__init__()
#         self._num_heads = num_heads
#         self._in_src_node_feats, self._in_dst_node_feats = expand_as_pair(in_node_feats)
#         self._out_node_feats = out_node_feats
#         self._out_edge_feats = out_edge_feats
#         if isinstance(in_node_feats, tuple):
#             self.fc_node_src = nn.Linear(
#                 self._in_src_node_feats, out_node_feats * num_heads, bias=False)
#             self.fc_ni = nn.Linear(
#                 self._in_src_node_feats, out_edge_feats*num_heads, bias=False)
#             self.fc_nj = nn.Linear(
#                 self._in_dst_node_feats, out_edge_feats*num_heads, bias=False)
#         else:
#             self.fc_node_src = nn.Linear(
#                 self._in_src_node_feats, out_node_feats * num_heads, bias=False)
#             self.fc_ni = nn.Linear(
#                 self._in_src_node_feats, out_edge_feats*num_heads, bias=False)
#             self.fc_nj = nn.Linear(
#                 self._in_src_node_feats, out_edge_feats*num_heads, bias=False)

#         self.fc_fij = nn.Linear(in_edge_feats, out_edge_feats*num_heads, bias=False)
#         self.attn = nn.Parameter(th.FloatTensor(size=(1, num_heads, out_edge_feats)))
#         if bias:
#             self.bias = nn.Parameter(th.FloatTensor(size=(num_heads * out_edge_feats,)))
#         else:
#             self.register_buffer('bias', None)
#         self.reset_parameters()

#     def reset_parameters(self):
#         """
#         Reinitialize learnable parameters.
#         """
#         gain = init.calculate_gain('relu')
#         init.xavier_normal_(self.fc_node_src.weight, gain=gain)
#         init.xavier_normal_(self.fc_ni.weight, gain=gain)
#         init.xavier_normal_(self.fc_fij.weight, gain=gain)
#         init.xavier_normal_(self.fc_nj.weight, gain=gain)
#         init.xavier_normal_(self.attn, gain=gain)
#         init.constant_(self.bias, 0)

#     def forward(self, graph, nfeats, efeats, get_attention=False):
#         with graph.local_scope():
#             if (graph.in_degrees() == 0).any():
#                 raise DGLError('There are 0-in-degree nodes in the graph, '
#                                'output for those nodes will be invalid. '
#                                'This is harmful for some applications, '
#                                'causing silent performance regression. '
#                                'Adding self-loop on the input graph by '
#                                'calling `g = dgl.add_self_loop(g)` will resolve '
#                                'the issue.')

#             # calc edge attention
#             # same trick way as in dgl.nn.pytorch.GATConv, but also includes edge feats
#             # https://github.com/dmlc/dgl/blob/master/python/dgl/nn/pytorch/conv/gatconv.py
#             if isinstance(nfeats, tuple):
#                 nfeats_src, nfeats_dst = nfeats
#             else:
#                 nfeats_src = nfeats_dst = nfeats

#             f_ni = self.fc_ni(nfeats_src)
#             f_nj = self.fc_nj(nfeats_dst)
#             f_fij = self.fc_fij(efeats)

#             graph.srcdata.update({'f_ni': f_ni})
#             graph.dstdata.update({'f_nj': f_nj})
#             # add ni, nj factors
#             graph.apply_edges(fn.u_add_v('f_ni', 'f_nj', 'f_tmp'))
#             # add fij to node factor
#             f_out = graph.edata.pop('f_tmp') + f_fij
#             if self.bias is not None:
#                 f_out = f_out + self.bias
#             f_out = nn.functional.leaky_relu(f_out)
#             f_out = f_out.view(-1, self._num_heads, self._out_edge_feats)
#             # compute attention factor
#             e = (f_out * self.attn).sum(dim=-1).unsqueeze(-1)
#             graph.edata['a'] = edge_softmax(graph, e)
#             graph.srcdata['h_out'] = self.fc_node_src(nfeats_src).view(-1, self._num_heads,
#                                                              self._out_node_feats)
#             # calc weighted sum
#             graph.update_all(fn.u_mul_e('h_out', 'a', 'm'),
#                              fn.sum('m', 'h_out'))

#             h_out = graph.dstdata['h_out'].view(-1, self._num_heads, self._out_node_feats)
#             if get_attention:
#                 return h_out, f_out, graph.edata.pop('a')
#             else:
#                 return h_out, f_out


In [None]:
# import dgl
# import torch
# import torch.nn as nn
# import torch.nn.functional as F
# # from dgl.nn.pytorch import EGATConv

# class GNNModel(nn.Module):
#     def __init__(self, node_feat_dim, edge_feat_dim, num_classes, node_feats, edge_feats, training):
#         super(GNNModel, self).__init__()
#         self.node_feat_dim = node_feat_dim
#         self.edge_feat_dim = edge_feat_dim
#         self.num_classes = num_classes
#         self.training = training
        
#         # Define node and edge feature embeddings
#         self.node_feat_emb = nn.Embedding.from_pretrained(torch.from_numpy(node_feats).double(), freeze=True)
#         self.edge_feat_emb = nn.Embedding.from_pretrained(torch.from_numpy(edge_feats).double(), freeze=True)
        
#         # Define GNN layers with attention mechanism
#         self.gnn_layers = nn.ModuleList()
#         self.gnn_layers.append(EGATConv(in_node_feats=node_feat_dim, 
#                                         in_edge_feats=edge_feat_dim, 
#                                         out_node_feats=node_feat_dim, 
#                                         out_edge_feats=edge_feat_dim, 
#                                         num_heads=3
#                                         ))
#         self.gnn_layers.append(EGATConv(in_node_feats=node_feat_dim, 
#                                         in_edge_feats=edge_feat_dim, 
#                                         out_node_feats=node_feat_dim, 
#                                         out_edge_feats=edge_feat_dim, 
#                                         num_heads=3, 
#                                         ))
#         self.gnn_layers.append(EGATConv(in_node_feats=node_feat_dim + edge_feat_dim + node_feat_dim,
#                                         in_edge_feats=edge_feat_dim,
#                                         out_node_feats=node_feat_dim,
#                                         out_edge_feats=edge_feat_dim,
#                                         num_heads=3,
#                                         ))
#         self.gnn_layers.append(EGATConv(in_node_feats=node_feat_dim, 
#                                         in_edge_feats=edge_feat_dim, 
#                                         out_node_feats=node_feat_dim, 
#                                         out_edge_feats=edge_feat_dim, 
#                                         num_heads=3, 
#                                         ))
#         self.gnn_layers.append(EGATConv(in_node_feats=node_feat_dim, 
#                                         in_edge_feats=edge_feat_dim, 
#                                         out_node_feats=node_feat_dim, 
#                                         out_edge_feats=edge_feat_dim, 
#                                         num_heads=3, 
#                                         ))
        
#         # Define attention mechanism
#         self.attn_node = nn.Linear(node_feat_dim, 1, bias=False)
#         self.attn_edge = nn.Linear(edge_feat_dim, 1, bias=False)
        
#         # Define final output layer
#         self.output_layer = nn.Linear(node_feat_dim, num_classes)
    
#     def forward(self, g, get_attention=False):
#         # Get node and edge features
#         node_feats = self.node_feat_emb(g.ndata['node_ids']).double()
#         edge_feats = self.edge_feat_emb(g.edata['edge_ids']).double()
#         print(node_feats.shape, edge_feats.shape)
        
#         # Pass node and edge features through GNN layers with attention
#         attentions = []
#         for i in range(len(self.gnn_layers)):
#             if i == 2:
#                 # Add edge features to node features and perform attention
#                 node_feats = torch.cat([node_feats, edge_feats, node_feats], dim=-1)
#                 node_feats = F.dropout(node_feats, p=0.5, training=self.training)
#                 attn_scores = self.attn_edge(edge_feats) + self.attn_node(node_feats)
#                 attn_scores = torch.softmax(attn_scores, dim=1)
#                 attn_scores = F.dropout(attn_scores, p=0.2, training=self.training)
#                 edge_feats = edge_feats * attn_scores
#                 node_feats, edge_feats = self.gnn_layers[i](g, node_feats, edge_feats)
#                 node_feats = node_feats + edge_feats
#                 attentions.append(attn_scores)
#             else:
#                 # Perform GNN layer and attention
#                 node_feats, edge_feats = self.gnn_layers[i](g, node_feats, edge_feats)
#                 node_feats = F.dropout(node_feats, p=0.5, training=self.training)
#                 attn_scores = self.attn_node(node_feats)
#                 attn_scores = torch.softmax(attn_scores, dim=1)
#                 attn_scores = F.dropout(attn_scores, p=0.2, training=self.training)
#                 node_feats = node_feats * attn_scores
#                 attentions.append(attn_scores)
        
#         # Compute final output probabilities with sigmoid activation
#         output_feats = self.output_layer(node_feats)
#         output_probs = torch.sigmoid(output_feats)
        
#         if get_attention:
#             return output_probs, attentions
#         else:
#             return output_probs


In [None]:
# with open('./graphs/cond_proc/CCSCM_CCSPROC/entity_embedding.pkl', 'rb') as f:
#     ent_emb = pickle.load(f)
# with open('./graphs/cond_proc/CCSCM_CCSPROC/relation_embedding.pkl', 'rb') as f:
#     rel_emb = pickle.load(f)

# ent_emb.shape, rel_emb.shape

In [None]:
# node_feat_dim = ent_emb.shape[-1]
# edge_feat_dim = rel_emb.shape[-1]
# num_classes = len(sample_dataset.get_all_tokens('drugs'))
# node_feats = ent_emb
# edge_feats = rel_emb

# gnn = GNNModel(node_feat_dim, edge_feat_dim, num_classes, node_feats, edge_feats, training=True)

In [None]:
# import importlib
# importlib.reload(dgl)

# g = dgl.add_self_loop(g)
# gnn.double()
# gnn(g)