In [1]:
import json

file_dir = "./graphs/condition/CCSCM"

file_id2ent = f"{file_dir}/id2ent.json"
file_ent2id = f"{file_dir}/ent2id.json"
file_id2rel = f"{file_dir}/id2rel.json"
file_rel2id = f"{file_dir}/rel2id.json"

with open(file_id2ent, 'r') as file:
    cond_id2ent = json.load(file)
with open(file_ent2id, 'r') as file:
    cond_ent2id = json.load(file)
with open(file_id2rel, 'r') as file:
    cond_id2rel = json.load(file)
with open(file_rel2id, 'r') as file:
    cond_rel2id = json.load(file)


import csv

condition_mapping_file = "./resources/CCSCM.csv"
procedure_mapping_file = "./resources/CCSPROC.csv"
drug_file = "./resources/ATC.csv"

condition_dict = {}
with open(condition_mapping_file, newline='') as csvfile:
    reader = csv.DictReader(csvfile)
    for row in reader:
        condition_dict[row['code']] = row['name'].lower()

procedure_dict = {}
with open(procedure_mapping_file, newline='') as csvfile:
    reader = csv.DictReader(csvfile)
    for row in reader:
        procedure_dict[row['code']] = row['name'].lower()

drug_dict = {}
with open(drug_file, newline='') as csvfile:
    reader = csv.DictReader(csvfile)
    for row in reader:
        if row['level'] == '5.0':
            drug_dict[row['code']] = row['name'].lower()

In [3]:
# from pyhealth.datasets import MIMIC3Dataset
# from GraphCare.task_fn import drug_recommendation_fn

# mimic3_ds = MIMIC3Dataset(
#     root="../../../data/physionet.org/files/mimiciii/1.4/", 
#     tables=["DIAGNOSES_ICD", "PROCEDURES_ICD", "PRESCRIPTIONS"],      
#     code_mapping={
#         "NDC": ("ATC", {"target_kwargs": {"level": 3}}),
#         "ICD9CM": "CCSCM",
#         "ICD9PROC": "CCSPROC"
#         },
# )

# sample_dataset = mimic3_ds.set_task(drug_recommendation_fn)

In [2]:
def flatten(lst):
    result = []
    for item in lst:
        if isinstance(item, list):
            result.extend(flatten(item))
        else:
            result.append(item)
    return result

In [3]:
from pyhealth.tokenizer import Tokenizer
import numpy as np
from tqdm import tqdm
import torch

def multihot(label, num_labels):
    multihot = np.zeros(num_labels)
    for l in label:
        multihot[l] = 1
    return multihot

def prepare_label(drugs):
    label_tokenizer = Tokenizer(
        sample_dataset.get_all_tokens(key='drugs')
    )

    labels_index = label_tokenizer.convert_tokens_to_indices(drugs)
    # print(labels_index)
    # convert to multihot
    num_labels = label_tokenizer.get_vocabulary_size()
    # print(num_labels)
    labels = multihot(labels_index, num_labels)
    return labels


# for patient in tqdm(sample_dataset):
#     # patient['drugs_all'] = flatten(patient['drugs'])
#     # print(patient['drugs_all'])
#     patient['drugs_ind'] = torch.tensor(prepare_label(patient['drugs']))

  from .autonotebook import tqdm as notebook_tqdm


In [4]:
# import pickle

# with open('./exp_data/ccscm_ccsproc/sample_dataset.pkl', 'wb') as f:
#     pickle.dump(sample_dataset, f)

In [5]:
import pickle

with open('./exp_data/ccscm_ccsproc/sample_dataset.pkl', 'rb') as f:
    sample_dataset= pickle.load(f)

In [6]:
from pyhealth.datasets import split_by_patient

train_dataset, val_dataset, test_dataset = split_by_patient(sample_dataset, [0.8, 0.1, 0.1], seed=528)

In [7]:
train_dataset[0]['conditions'], train_dataset[0]['procedures']

([['101', '106', '98', '138']], [['44', '47', '50']])

In [8]:
from tqdm import tqdm
import numpy as np
import networkx as nx
import pickle

with open('./graphs/cond_proc/CCSCM_CCSPROC/ent2id.json', 'r') as file:
    ent2id = json.load(file)
with open('./graphs/cond_proc/CCSCM_CCSPROC/rel2id.json', 'r') as file:
    rel2id = json.load(file)
with open('./graphs/cond_proc/CCSCM_CCSPROC/entity_embedding.pkl', 'rb') as file:
    ent_emb = pickle.load(file)
    

In [9]:
ent_emb[0]

array([-0.02357265,  0.002313  ,  0.02204529, ..., -0.01157682,
        0.01255   ,  0.00188047])

In [12]:
G = nx.Graph()

for i in range(len(ent_emb)):
    G.add_nodes_from([
        (i, {'y': i, 'x': ent_emb[i]})
    ])

triples_all = []
for patient in tqdm(sample_dataset):
    triples = []
    triple_set = set()
    # node_set = set()
    conditions = flatten(patient['conditions'])
    for condition in conditions:
        cond_file = f'./graphs/condition/CCSCM/{condition}.txt'
        with open(cond_file, 'r') as f:
            lines = f.readlines()
        for line in lines:
            items = line.split('\t')
            if len(items) == 3:
                h, r, t = items
                t = t[:-1]
                h = int(ent2id[h])
                r = int(rel2id[r])
                t = int(ent2id[t])
                triple = (h, r, t)
                if triple not in triple_set:
                    triples.append((h, t))
                    triple_set.add(triple)
                    # node_set.add(h)
                    # node_set.add(r)

    G.add_edges_from(
        triples,
        # label=prepare_label(patient['drugs'])
    )
    
    # triples.append(prepare_label(patient['drugs']))
    # triples_all.append(np.array(triples))


100%|██████████| 44399/44399 [01:54<00:00, 388.58it/s]


In [13]:
from torch_geometric.utils import to_networkx, from_networkx
import pickle

G_tg = from_networkx(G)

# with open('./exp_data/ccscm_ccsproc/graph_tg.pkl', 'wb') as f:
#     pickle.dump(G, f)

  data[key] = torch.tensor(value)


In [14]:
import torch

def get_subgraph(dataset):
    
    subgraph_list = []
    for patient in tqdm(dataset):
        triple_set = set()
        node_set = set()
        conditions = flatten(patient['conditions'])
        for condition in conditions:
            cond_file = f'./graphs/condition/CCSCM/{condition}.txt'
            with open(cond_file, 'r') as f:
                lines = f.readlines()
            for line in lines:
                items = line.split('\t')
                if len(items) == 3:
                    h, r, t = items
                    t = t[:-1]
                    h = int(ent2id[h])
                    r = int(rel2id[r])
                    t = int(ent2id[t])
                    triple = (h, r, t)
                    if triple not in triple_set:
                        triple_set.add(triple)
                        node_set.add(h)
                        node_set.add(r)

        P = G_tg.subgraph(torch.tensor([*node_set]))
        P.label = patient['drugs_ind']
        subgraph_list.append(P)

    return subgraph_list


In [15]:
train_graph_list = get_subgraph(train_dataset)
val_graph_list = get_subgraph(val_dataset)
test_graph_list = get_subgraph(test_dataset)

100%|██████████| 35578/35578 [02:17<00:00, 259.39it/s]
100%|██████████| 4346/4346 [00:16<00:00, 263.44it/s]
100%|██████████| 4475/4475 [00:18<00:00, 247.05it/s]


In [16]:
from torch_geometric.loader import DataListLoader, DataLoader

class Dataset(torch.utils.data.Dataset):
    def __init__(self, graph_list):
        self.graph_list=graph_list
    def __len__(self):
        return len(self.graph_list)
    def __getitem__(self, idx):
        return self.graph_list[idx]

train_set = Dataset(train_graph_list)
val_set = Dataset(val_graph_list)
test_set = Dataset(test_graph_list)

In [17]:
import numpy as np
import torch
import torch.nn.functional as F
from torch.nn import Linear
from torch_geometric.nn import GATConv
from torch_geometric.data import DataLoader, Data
from torch_geometric.nn import global_mean_pool
from torch_geometric.nn import DataParallel
from torch_geometric.loader import DataListLoader
import pickle

class GAT(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels, heads):
        super(GAT, self).__init__()
        self.conv1 = GATConv(in_channels, hidden_channels, heads=heads)
        self.conv2 = GATConv(hidden_channels*heads, hidden_channels, heads=heads)
        self.conv3 = GATConv(hidden_channels*heads, hidden_channels, heads=1)

        self.fc = Linear(hidden_channels, out_channels)
        
    def forward(self, x, edge_index, batch):
        x = F.elu(self.conv1(x, edge_index))
        # print(x.shape)
        x = F.elu(self.conv2(x, edge_index))
        # print(x.shape)
        x = F.elu(self.conv3(x, edge_index))
        # print(x.shape)
        x = global_mean_pool(x, batch)
        # print(x.shape)
        x = F.dropout(x, p=0.5, training=self.training)
        # print(x.shape)
        logits = self.fc(x)
        # print(logits.shape)
        return logits

In [18]:
from torch_geometric.nn import GINConv

class GIN(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels):
        super(GIN, self).__init__()
        self.conv1 = GINConv(Linear(in_channels, hidden_channels))
        self.conv2 = GINConv(Linear(hidden_channels, hidden_channels))
        self.conv3 = GINConv(Linear(hidden_channels, hidden_channels))

        self.fc = Linear(hidden_channels, out_channels)
        
    def forward(self, x, edge_index, batch):
        x = F.relu(self.conv1(x, edge_index))
        x = F.relu(self.conv2(x, edge_index))
        x = F.relu(self.conv3(x, edge_index))
        x = global_mean_pool(x, batch)
        x = F.dropout(x, p=0.5, training=self.training)
        logits = self.fc(x)
        return logits

In [19]:
from torch_geometric.nn import HGTConv

class HGT(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels, heads=2):
        super(HGT, self).__init__()
        self.conv1 = HGTConv(in_channels, hidden_channels, heads=heads)
        self.conv2 = HGTConv(hidden_channels*heads, hidden_channels, heads=heads)
        self.conv3 = HGTConv(hidden_channels*heads, hidden_channels, heads=1)

        self.fc = Linear(hidden_channels, out_channels)
        
    def forward(self, x, edge_index, batch):
        x = F.relu(self.conv1(x, edge_index))
        x = F.relu(self.conv2(x, edge_index))
        x = F.relu(self.conv3(x, edge_index))
        x = global_mean_pool(x, batch)
        x = F.dropout(x, p=0.5, training=self.training)
        logits = self.fc(x)
        return logits

In [20]:
from tqdm import tqdm

def train(model, device, train_loader, optimizer):
    model.train()
    training_loss = 0
    pbar = tqdm(train_loader)
    for data in pbar:
        pbar.set_description(f'loss: {training_loss}')
        data = data.to(device)
        optimizer.zero_grad()
        out = model(data.x, data.edge_index, data.batch)
        try:
            label = data.label.reshape(int(train_loader.batch_size), int(len(data.label)/train_loader.batch_size))
        except:
            continue
        loss = F.binary_cross_entropy_with_logits(out, label)
        loss.backward()
        training_loss = loss
        optimizer.step()
    
    return training_loss 

In [21]:
# for data in train_loader:
#     print(train_loader.batch_size)
#     print(len(data.label))
#     print(data.label.reshape(int(train_loader.batch_size), int(len(data.label)/train_loader.batch_size)).shape)
#     break

In [22]:
from pyhealth.metrics import multilabel_metrics_fn

def evaluate(model, device, loader):
    model.eval()
    y_prob_all = []
    y_true_all = []

    for data in tqdm(loader):
        data = data.to(device)
        with torch.no_grad():
            logits = model(data.x, data.edge_index, data.batch)
            y_prob = torch.sigmoid(logits)
            try:
                y_true = data.label.reshape(int(loader.batch_size), int(len(data.label)/loader.batch_size))
            except:
                continue
            y_prob_all.append(y_prob.cpu())
            y_true_all.append(y_true.cpu())
            
    y_true_all = np.concatenate(y_true_all, axis=0)
    y_prob_all = np.concatenate(y_prob_all, axis=0)
    # pr_auc = multilabel_metrics_fn(y_true=y_true_all, y_prob=y_true_all, metrics="pr_auc_macro")

    return y_true_all, y_prob_all

In [23]:
from sklearn.metrics import average_precision_score

def train_loop(train_loader, val_loader, model, optimizer, device, epochs):
    for epoch in range(1, epochs+1):
        loss = train(model, device, train_loader, optimizer)
        # y_true_all, y_prob_all = evaluate(model, device, train_loader)
        # train_pr_auc = average_precision_score(y_true_all, y_prob_all, average="macro")
        y_true_all, y_prob_all = evaluate(model, device, val_loader)
        val_pr_auc = average_precision_score(y_true_all, y_prob_all, average="samples")
        print(f'Epoch: {epoch}, Training loss: {loss}, Val PRAUC: {val_pr_auc:.4f}')


In [24]:
train_set[0].x.shape[1], len(train_set[0].label)

(1536, 197)

In [25]:
train_loader = DataLoader(train_set, batch_size=16, shuffle=True)
val_loader = DataLoader(val_set, batch_size=8, shuffle=False)
test_loader = DataLoader(test_set, batch_size=8, shuffle=False)



In [None]:
in_channels = train_set[0].x.shape[1]
out_channels = len(train_set[0].label)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# model = GAT(in_channels=in_channels, out_channels=out_channels, hidden_channels=256, heads=3).to(device)
# model = GIN(in_channels=in_channels, out_channels=out_channels, hidden_channels=512).to(device)
model = HGT(in_channels=in_channels, out_channels=out_channels, hidden_channels=512, heads=2).to(device)

model.double()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

train_loop(train_loader=train_loader, val_loader=val_loader, model=model, optimizer=optimizer, device=device, epochs=100)

In [None]:
# torch.save(model.state_dict(), './exp_data/saved_weights_gat_mimic3_drugrec.pkl')
# torch.save(model.state_dict(), './exp_data/saved_weights_gin_mimic3_drugrec.pkl')
torch.save(model.state_dict(), './exp_data/saved_weights_hgt_mimic3_drugrec.pkl')

In [246]:
y_true_all, y_prob_all = evaluate(model, device, test_loader)

100%|██████████| 544/544 [00:39<00:00, 13.72it/s]


In [255]:
from pyhealth.metrics import multilabel_metrics_fn
from sklearn.metrics import average_precision_score

# len(y_true_all), len(y_prob_all)
pr_auc = average_precision_score(y_true_all, y_prob_all, average="samples")
print(pr_auc)

0.757423070584265


In [None]:
# import numpy as np

# graph_all_patients = np.array(triples_all)

In [None]:
# import pickle

# with open('./graph_mimic3_patients.pkl', 'wb') as f:
#     pickle.dump(graph_all_patients, f)

In [None]:
# import pickle

# with open('./graph_mimic3_patients.pkl', 'rb') as f:
#     graph_all_patients = pickle.load(f)

In [None]:
# with open('./graphs/cond_proc/CCSCM_CCSPROC/id2ent.json', 'r') as file:
#     id2ent = json.load(file)
# with open('./graphs/cond_proc/CCSCM_CCSPROC/id2rel.json', 'r') as file:
#     id2rel = json.load(file)

In [None]:
# import torch
# import dgl
# import numpy as np

# patient = graph_all_patients[0]
# nodes = set()
# for triple in patient:
#     h, r, t = triple[0], triple[1], triple[2]
#     nodes.add(h)
#     nodes.add(t)

# node_idx_ori2new = {node : idx for idx, node in enumerate(nodes)}
# node_idx_new2ori = {idx: node for node, idx in node_idx_ori2new.items()}
# heads = torch.tensor([node_idx_ori2new[head] for head in patient[:, 0]])
# tails = torch.tensor([node_idx_ori2new[tail] for tail in patient[:, 2]])

# g = dgl.graph((heads, tails), num_nodes=len(nodes))

# g.edata['edge_ids'] = torch.from_numpy(patient[:, 1])
# g.ndata['node_ids'] = torch.from_numpy(np.array([node_idx_new2ori[i] for i in range(len(nodes))]))

In [None]:
# g.ndata['node_ids'][165], g.edata['edge_ids'][0], g.ndata['node_ids'][147], patient[0]

(tensor(10904), tensor(3966), tensor(14935), array([10904,  3966, 14935]))

In [None]:
# g.ndata['node_ids'][226], g.edata['edge_ids'][1], g.ndata['node_ids'][385], patient[1]

(tensor(9090), tensor(3966), tensor(7639), array([9090, 3966, 7639]))

In [None]:
# """Torch modules for graph attention networks with fully valuable edges (EGAT)."""
# # pylint: disable= no-member, arguments-differ, invalid-name
# import torch as th
# from torch import nn
# from torch.nn import init

# from dgl import function as fn
# from dgl.nn.functional import edge_softmax
# from dgl.base import DGLError
# from dgl.utils import expand_as_pair

# # pylint: enable=W0235
# class EGATConv(nn.Module):
#     def __init__(self,
#                  in_node_feats,
#                  in_edge_feats,
#                  out_node_feats,
#                  out_edge_feats,
#                  num_heads,
#                  bias=True):

#         super().__init__()
#         self._num_heads = num_heads
#         self._in_src_node_feats, self._in_dst_node_feats = expand_as_pair(in_node_feats)
#         self._out_node_feats = out_node_feats
#         self._out_edge_feats = out_edge_feats
#         if isinstance(in_node_feats, tuple):
#             self.fc_node_src = nn.Linear(
#                 self._in_src_node_feats, out_node_feats * num_heads, bias=False)
#             self.fc_ni = nn.Linear(
#                 self._in_src_node_feats, out_edge_feats*num_heads, bias=False)
#             self.fc_nj = nn.Linear(
#                 self._in_dst_node_feats, out_edge_feats*num_heads, bias=False)
#         else:
#             self.fc_node_src = nn.Linear(
#                 self._in_src_node_feats, out_node_feats * num_heads, bias=False)
#             self.fc_ni = nn.Linear(
#                 self._in_src_node_feats, out_edge_feats*num_heads, bias=False)
#             self.fc_nj = nn.Linear(
#                 self._in_src_node_feats, out_edge_feats*num_heads, bias=False)

#         self.fc_fij = nn.Linear(in_edge_feats, out_edge_feats*num_heads, bias=False)
#         self.attn = nn.Parameter(th.FloatTensor(size=(1, num_heads, out_edge_feats)))
#         if bias:
#             self.bias = nn.Parameter(th.FloatTensor(size=(num_heads * out_edge_feats,)))
#         else:
#             self.register_buffer('bias', None)
#         self.reset_parameters()

#     def reset_parameters(self):
#         """
#         Reinitialize learnable parameters.
#         """
#         gain = init.calculate_gain('relu')
#         init.xavier_normal_(self.fc_node_src.weight, gain=gain)
#         init.xavier_normal_(self.fc_ni.weight, gain=gain)
#         init.xavier_normal_(self.fc_fij.weight, gain=gain)
#         init.xavier_normal_(self.fc_nj.weight, gain=gain)
#         init.xavier_normal_(self.attn, gain=gain)
#         init.constant_(self.bias, 0)

#     def forward(self, graph, nfeats, efeats, get_attention=False):
#         with graph.local_scope():
#             if (graph.in_degrees() == 0).any():
#                 raise DGLError('There are 0-in-degree nodes in the graph, '
#                                'output for those nodes will be invalid. '
#                                'This is harmful for some applications, '
#                                'causing silent performance regression. '
#                                'Adding self-loop on the input graph by '
#                                'calling `g = dgl.add_self_loop(g)` will resolve '
#                                'the issue.')

#             # calc edge attention
#             # same trick way as in dgl.nn.pytorch.GATConv, but also includes edge feats
#             # https://github.com/dmlc/dgl/blob/master/python/dgl/nn/pytorch/conv/gatconv.py
#             if isinstance(nfeats, tuple):
#                 nfeats_src, nfeats_dst = nfeats
#             else:
#                 nfeats_src = nfeats_dst = nfeats

#             f_ni = self.fc_ni(nfeats_src)
#             f_nj = self.fc_nj(nfeats_dst)
#             f_fij = self.fc_fij(efeats)

#             graph.srcdata.update({'f_ni': f_ni})
#             graph.dstdata.update({'f_nj': f_nj})
#             # add ni, nj factors
#             graph.apply_edges(fn.u_add_v('f_ni', 'f_nj', 'f_tmp'))
#             # add fij to node factor
#             f_out = graph.edata.pop('f_tmp') + f_fij
#             if self.bias is not None:
#                 f_out = f_out + self.bias
#             f_out = nn.functional.leaky_relu(f_out)
#             f_out = f_out.view(-1, self._num_heads, self._out_edge_feats)
#             # compute attention factor
#             e = (f_out * self.attn).sum(dim=-1).unsqueeze(-1)
#             graph.edata['a'] = edge_softmax(graph, e)
#             graph.srcdata['h_out'] = self.fc_node_src(nfeats_src).view(-1, self._num_heads,
#                                                              self._out_node_feats)
#             # calc weighted sum
#             graph.update_all(fn.u_mul_e('h_out', 'a', 'm'),
#                              fn.sum('m', 'h_out'))

#             h_out = graph.dstdata['h_out'].view(-1, self._num_heads, self._out_node_feats)
#             if get_attention:
#                 return h_out, f_out, graph.edata.pop('a')
#             else:
#                 return h_out, f_out


In [None]:
# import dgl
# import torch
# import torch.nn as nn
# import torch.nn.functional as F
# # from dgl.nn.pytorch import EGATConv

# class GNNModel(nn.Module):
#     def __init__(self, node_feat_dim, edge_feat_dim, num_classes, node_feats, edge_feats, training):
#         super(GNNModel, self).__init__()
#         self.node_feat_dim = node_feat_dim
#         self.edge_feat_dim = edge_feat_dim
#         self.num_classes = num_classes
#         self.training = training
        
#         # Define node and edge feature embeddings
#         self.node_feat_emb = nn.Embedding.from_pretrained(torch.from_numpy(node_feats).double(), freeze=True)
#         self.edge_feat_emb = nn.Embedding.from_pretrained(torch.from_numpy(edge_feats).double(), freeze=True)
        
#         # Define GNN layers with attention mechanism
#         self.gnn_layers = nn.ModuleList()
#         self.gnn_layers.append(EGATConv(in_node_feats=node_feat_dim, 
#                                         in_edge_feats=edge_feat_dim, 
#                                         out_node_feats=node_feat_dim, 
#                                         out_edge_feats=edge_feat_dim, 
#                                         num_heads=3
#                                         ))
#         self.gnn_layers.append(EGATConv(in_node_feats=node_feat_dim, 
#                                         in_edge_feats=edge_feat_dim, 
#                                         out_node_feats=node_feat_dim, 
#                                         out_edge_feats=edge_feat_dim, 
#                                         num_heads=3, 
#                                         ))
#         self.gnn_layers.append(EGATConv(in_node_feats=node_feat_dim + edge_feat_dim + node_feat_dim,
#                                         in_edge_feats=edge_feat_dim,
#                                         out_node_feats=node_feat_dim,
#                                         out_edge_feats=edge_feat_dim,
#                                         num_heads=3,
#                                         ))
#         self.gnn_layers.append(EGATConv(in_node_feats=node_feat_dim, 
#                                         in_edge_feats=edge_feat_dim, 
#                                         out_node_feats=node_feat_dim, 
#                                         out_edge_feats=edge_feat_dim, 
#                                         num_heads=3, 
#                                         ))
#         self.gnn_layers.append(EGATConv(in_node_feats=node_feat_dim, 
#                                         in_edge_feats=edge_feat_dim, 
#                                         out_node_feats=node_feat_dim, 
#                                         out_edge_feats=edge_feat_dim, 
#                                         num_heads=3, 
#                                         ))
        
#         # Define attention mechanism
#         self.attn_node = nn.Linear(node_feat_dim, 1, bias=False)
#         self.attn_edge = nn.Linear(edge_feat_dim, 1, bias=False)
        
#         # Define final output layer
#         self.output_layer = nn.Linear(node_feat_dim, num_classes)
    
#     def forward(self, g, get_attention=False):
#         # Get node and edge features
#         node_feats = self.node_feat_emb(g.ndata['node_ids']).double()
#         edge_feats = self.edge_feat_emb(g.edata['edge_ids']).double()
#         print(node_feats.shape, edge_feats.shape)
        
#         # Pass node and edge features through GNN layers with attention
#         attentions = []
#         for i in range(len(self.gnn_layers)):
#             if i == 2:
#                 # Add edge features to node features and perform attention
#                 node_feats = torch.cat([node_feats, edge_feats, node_feats], dim=-1)
#                 node_feats = F.dropout(node_feats, p=0.5, training=self.training)
#                 attn_scores = self.attn_edge(edge_feats) + self.attn_node(node_feats)
#                 attn_scores = torch.softmax(attn_scores, dim=1)
#                 attn_scores = F.dropout(attn_scores, p=0.2, training=self.training)
#                 edge_feats = edge_feats * attn_scores
#                 node_feats, edge_feats = self.gnn_layers[i](g, node_feats, edge_feats)
#                 node_feats = node_feats + edge_feats
#                 attentions.append(attn_scores)
#             else:
#                 # Perform GNN layer and attention
#                 node_feats, edge_feats = self.gnn_layers[i](g, node_feats, edge_feats)
#                 node_feats = F.dropout(node_feats, p=0.5, training=self.training)
#                 attn_scores = self.attn_node(node_feats)
#                 attn_scores = torch.softmax(attn_scores, dim=1)
#                 attn_scores = F.dropout(attn_scores, p=0.2, training=self.training)
#                 node_feats = node_feats * attn_scores
#                 attentions.append(attn_scores)
        
#         # Compute final output probabilities with sigmoid activation
#         output_feats = self.output_layer(node_feats)
#         output_probs = torch.sigmoid(output_feats)
        
#         if get_attention:
#             return output_probs, attentions
#         else:
#             return output_probs


In [None]:
# with open('./graphs/cond_proc/CCSCM_CCSPROC/entity_embedding.pkl', 'rb') as f:
#     ent_emb = pickle.load(f)
# with open('./graphs/cond_proc/CCSCM_CCSPROC/relation_embedding.pkl', 'rb') as f:
#     rel_emb = pickle.load(f)

# ent_emb.shape, rel_emb.shape

((19347, 1536), (5042, 1536))

In [None]:
# node_feat_dim = ent_emb.shape[-1]
# edge_feat_dim = rel_emb.shape[-1]
# num_classes = len(sample_dataset.get_all_tokens('drugs'))
# node_feats = ent_emb
# edge_feats = rel_emb

# gnn = GNNModel(node_feat_dim, edge_feat_dim, num_classes, node_feats, edge_feats, training=True)

In [None]:
# import importlib
# importlib.reload(dgl)

# g = dgl.add_self_loop(g)
# gnn.double()
# gnn(g)