In [2]:
# Basic Libraries
import numpy as np
import pandas as pd
import random
import pickle

# PyTorch Libraries
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader

# Transformers Libraries
from transformers import AutoTokenizer, T5ForConditionalGeneration

# Graph Neural Network Libraries
import dgl
import obonet
import networkx as nx

# Phenotype Libraries
from pyhpo.ontology import Ontology
from PhenoDP import *
from PhenoDP_Preprocess import *
from PSD_HPOEncoder import *

In [3]:
Ontology()
pre_model = PhenoDP_Initial(Ontology)

generate disease dict...
generate disease dict...
related hpo num: 8950
generate disease ic dict... 
calculating hp weights
related hpo num: 8950
generate disease ic dict... 
calculating hp weights


# Loading the Pre-trained Summarizer

In [4]:
tokenizer = AutoTokenizer.from_pretrained("../flanT5/")
T5model = T5ForConditionalGeneration.from_pretrained("../flanT5/")  
state_dict = torch.load('../HPO2SUM/flan-model.pth')
T5model.load_state_dict(state_dict)

<All keys matched successfully>

<All keys matched successfully>

# Loading the HPO DAG 

In [5]:

url = '../HPO2SUM/hp.obo'  
graph = obonet.read_obo(url)
feature_dimension = 768
for node in tqdm(graph.nodes(), desc="Processing nodes"):
    graph.nodes[node]['feature'] = get_vec(node, tokenizer, T5model)

Processing nodes: 100%|██████████| 18281/18281 [09:18<00:00, 32.75it/s]



In [7]:
print(nx.info(graph))

Name: hp.obo
Type: MultiDiGraph
Number of nodes: 18281
Number of edges: 22671
Average in degree:   1.2401
Average out degree:   1.2401
Name: hp.obo
Type: MultiDiGraph
Number of nodes: 18281
Number of edges: 22671
Average in degree:   1.2401
Average out degree:   1.2401


# Training the PSD-HPOEncoder 

In [8]:
dgl_graph = nx_to_dgl(graph)
feature_dimension = 768
in_feats = feature_dimension
h_feats = 256
out_feats = feature_dimension

# DGL graph does not implement the API for GPU, so we use CPU for computation
device = torch.device("cpu")
model = GCN(in_feats, h_feats, out_feats).to(device)

# train_model(model, dgl_graph, epochs=50, lr=0.001, node_mask_percentage=0.2, edge_mask_percentage=0.2)

train_model(model, dgl_graph, epochs=2000, lr=0.001, node_mask_percentage=0.2, edge_mask_percentage=0.2)
model.eval()
with torch.no_grad():
    reconstructed_features, reconstructed_latent = model(dgl_graph, dgl_graph.ndata['feat'])


Epoch 0, Loss: 4.0047609218163416e-05
Epoch 0, Loss: 4.0047609218163416e-05
Epoch 10, Loss: 1.798968696675729e-05
Epoch 10, Loss: 1.798968696675729e-05
Epoch 20, Loss: 9.810275514610112e-06
Epoch 20, Loss: 9.810275514610112e-06
Epoch 30, Loss: 7.659506991331e-06
Epoch 30, Loss: 7.659506991331e-06
Epoch 40, Loss: 6.901744200149551e-06
Epoch 40, Loss: 6.901744200149551e-06


In [9]:
model.eval()
with torch.no_grad():
    outputs, latent = model(dgl_graph, dgl_graph.ndata['feat'])
node_embedding_dict = {node_id: latent[idx].numpy() for idx, node_id in enumerate(list(graph.nodes))}

In [None]:
with open('./node_embedding_dict_T5_gcn.pkl', 'wb') as f:
    pickle.dump(node_embedding_dict, f)