In [17]:
import obonet
import torch
import pickle
from tqdm import tqdm
from transformers import AutoModelForCausalLM, AutoTokenizer
from phenodp.encoders import PSD_HPOEncoder, GCN

In [None]:
def get_average_embedding(text, tokenizer, model, device):
    """
    Compute the average embedding of a given text using the pre-trained model.

    Args:
        text (str): The input text to encode.

    Returns:
        np.ndarray: The average embedding of the text.

    Raises:
        ValueError: If NaN values are detected in the embedding.
    """
    # Convert the text into model inputs
    inputs = tokenizer(text, return_tensors="pt").to(device)
    
    # Obtain the hidden states from the model
    with torch.no_grad():
        outputs = model(**inputs, output_hidden_states=True)
    
    # Extract the hidden states from the last layer
    last_hidden_states = outputs.hidden_states[-1]
    
    # Compute the average of the hidden states to obtain the text representation
    # Convert bfloat16 to float32 before converting to numpy to avoid unsupported ScalarType error
    average_encoding = last_hidden_states.mean(dim=1).squeeze().to(torch.float32).cpu().numpy()
    
    
    return average_encoding

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"

model_path = '../data/model/Bio-Medical-3B-CoT-Finetuned'

model = AutoModelForCausalLM.from_pretrained(
    model_path,
    torch_dtype="auto",
    device_map=device
)
tokenizer = AutoTokenizer.from_pretrained(model_path)

In [None]:
url = '../data/hpo-2025-05-06/hp.obo'  
graph = obonet.read_obo(url)
feature_dimension = 2048

# Process each node in the graph to compute and store its embedding
for node in tqdm(graph.nodes(), desc="Processing nodes"):
    try:
        # Retrieve the HPO term object for the node
        hpo_term = graph.nodes[node].get('name', 'Unknown Term')
        # Compute the embedding for the node's name
        embedding = get_average_embedding(hpo_term, tokenizer, model, device)
        # Store the embedding in the node's features
        graph.nodes[node]['feature'] = embedding
    except ValueError as e:
        # If NaN is detected, raise an error and terminate the program
        print(f"Error processing node {node}: {e}")
        raise

In [None]:
dgl_graph = PSD_HPOEncoder.nx_to_dgl(graph)
feature_dimension = 2048
in_feats = feature_dimension
h_feats = 256
out_feats = feature_dimension

# DGL graph does not implement the API for GPU, so we use CPU for computation
device = torch.device("cpu")
model = GCN(in_feats, h_feats, out_feats).to(device)


PSD_HPOEncoder.train_model(model, dgl_graph, epochs=500, lr=0.001, node_mask_percentage=0.2, edge_mask_percentage=0.2)

In [19]:
model.eval()
with torch.no_grad():
    outputs, latent = model(dgl_graph, dgl_graph.ndata['feat'])
node_embedding_dict = {node_id: latent[idx].numpy() for idx, node_id in enumerate(list(graph.nodes))}

In [20]:
with open('../data/node_embedding_dict.pkl', 'wb') as f:
    pickle.dump(node_embedding_dict, f)