In [1]:
import obonet
import torch
import pickle
from tqdm import tqdm
from pyhpo import Ontology
from transformers import AutoModelForCausalLM, AutoTokenizer
from phenodp.encoders import PSD_HPOEncoder, GCN

In [4]:
def get_average_embedding(text, tokenizer, model, device):
    """
    Compute the average embedding of a given text using the pre-trained model.

    Args:
        text (str): The input text to encode.

    Returns:
        np.ndarray: The average embedding of the text.

    Raises:
        ValueError: If NaN values are detected in the embedding.
    """
    # Convert the text into model inputs
    inputs = tokenizer(text, return_tensors="pt").to(device)
    
    # Obtain the hidden states from the model
    with torch.no_grad():
        outputs = model(**inputs, output_hidden_states=True)
    
    # Extract the hidden states from the last layer
    last_hidden_states = outputs.hidden_states[-1]
    
    # Compute the average of the hidden states to obtain the text representation
    # Convert bfloat16 to float32 before converting to numpy to avoid unsupported ScalarType error
    average_encoding = last_hidden_states.mean(dim=1).squeeze().to(torch.float32).cpu().numpy()
    
    
    return average_encoding

In [3]:
device = "cuda" if torch.cuda.is_available() else "cpu"

model_path = '../data/model/Bio-Medical-3B-CoT-Finetuned'

model = AutoModelForCausalLM.from_pretrained(
    model_path,
    torch_dtype="auto",
    device_map=device
)
tokenizer = AutoTokenizer.from_pretrained(model_path)

Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

In [5]:
ontology = Ontology(data_folder='../data/hpo-2025-05-06')

url = '../data/hpo-2025-05-06/hp.obo'  
graph = obonet.read_obo(url)
feature_dimension = 2048

# Process each node in the graph to compute and store its embedding
for node in tqdm(graph.nodes(), desc="Processing nodes"):
    try:
        embedding = get_average_embedding(ontology.get_hpo_object(node).name, tokenizer, model, device)
        # Store the embedding in the node's features
        graph.nodes[node]['feature'] = embedding
    except ValueError as e:
        # If NaN is detected, raise an error and terminate the program
        print(f"Error processing node {node}: {e}")
        raise

Processing nodes: 100%|██████████| 19177/19177 [17:02<00:00, 18.76it/s]


In [6]:
dgl_graph = PSD_HPOEncoder.nx_to_dgl(graph)
feature_dimension = 2048
in_feats = feature_dimension
h_feats = 256
out_feats = feature_dimension

# DGL graph does not implement the API for GPU, so we use CPU for computation
device = torch.device("cpu")
model = GCN(in_feats, h_feats, out_feats).to(device)

PSD_HPOEncoder.train_model(model, dgl_graph, epochs=500, lr=0.001, node_mask_percentage=0.2, edge_mask_percentage=0.2)

  dgl_graph.ndata['feat'] = torch.tensor(features, dtype=torch.float32)


Epoch 0, Loss: 4.326616287231445
Epoch 10, Loss: 3.6091842651367188
Epoch 20, Loss: 3.4174458980560303
Epoch 30, Loss: 3.3009395599365234
Epoch 40, Loss: 3.3029990196228027
Epoch 50, Loss: 3.2532284259796143
Epoch 60, Loss: 3.219243288040161
Epoch 70, Loss: 3.1401946544647217
Epoch 80, Loss: 3.155456066131592
Epoch 90, Loss: 3.1003711223602295
Epoch 100, Loss: 3.041008234024048
Epoch 110, Loss: 2.9589552879333496
Epoch 120, Loss: 2.9611613750457764
Epoch 130, Loss: 2.86667799949646
Epoch 140, Loss: 2.8205506801605225
Epoch 150, Loss: 2.7890706062316895
Epoch 160, Loss: 2.6870105266571045
Epoch 170, Loss: 2.6204681396484375
Epoch 180, Loss: 2.6074070930480957
Epoch 190, Loss: 2.491335868835449
Epoch 200, Loss: 2.4607696533203125
Epoch 210, Loss: 2.4158196449279785
Epoch 220, Loss: 2.3929336071014404
Epoch 230, Loss: 2.329793930053711
Epoch 240, Loss: 2.2623910903930664
Epoch 250, Loss: 2.2221603393554688
Epoch 260, Loss: 2.182469606399536
Epoch 270, Loss: 2.101799726486206
Epoch 280, Lo

In [7]:
model.eval()
with torch.no_grad():
    outputs, latent = model(dgl_graph, dgl_graph.ndata['feat'])
node_embedding_dict = {node_id: latent[idx].numpy() for idx, node_id in enumerate(list(graph.nodes))}

In [8]:
with open('../data/node_embedding_dict.pkl', 'wb') as f:
    pickle.dump(node_embedding_dict, f)