In [1]:
import obonet
import torch
import pickle
from tqdm import tqdm
from transformers import AutoModelForCausalLM, AutoTokenizer
from phenodp.encoders import PSD_HPOEncoder, GCN

In [2]:
def get_average_embedding(text, tokenizer, model, device):
    """
    Compute the average embedding of a given text using the pre-trained model.

    Args:
        text (str): The input text to encode.

    Returns:
        np.ndarray: The average embedding of the text.

    Raises:
        ValueError: If NaN values are detected in the embedding.
    """
    # Convert the text into model inputs
    inputs = tokenizer(text, return_tensors="pt").to(device)
    
    # Obtain the hidden states from the model
    with torch.no_grad():
        outputs = model(**inputs, output_hidden_states=True)
    
    # Extract the hidden states from the last layer
    last_hidden_states = outputs.hidden_states[-1]
    
    # Compute the average of the hidden states to obtain the text representation
    # Convert bfloat16 to float32 before converting to numpy to avoid unsupported ScalarType error
    average_encoding = last_hidden_states.mean(dim=1).squeeze().to(torch.float32).cpu().numpy()
    
    
    return average_encoding

In [3]:
device = "cuda" if torch.cuda.is_available() else "cpu"

model_path = '../data/model/Bio-Medical-3B-CoT-Finetuned'

model = AutoModelForCausalLM.from_pretrained(
    model_path,
    torch_dtype="auto",
    device_map=device
)
tokenizer = AutoTokenizer.from_pretrained(model_path)

Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

In [4]:
url = '../data/hpo-2025-05-06/hp.obo'  
graph = obonet.read_obo(url)
feature_dimension = 2048

# Process each node in the graph to compute and store its embedding
for node in tqdm(graph.nodes(), desc="Processing nodes"):
    try:
        # Retrieve the HPO term object for the node
        hpo_term = graph.nodes[node].get('name', 'Unknown Term')
        # Compute the embedding for the node's name
        embedding = get_average_embedding(hpo_term, tokenizer, model, device)
        # Store the embedding in the node's features
        graph.nodes[node]['feature'] = embedding
    except ValueError as e:
        # If NaN is detected, raise an error and terminate the program
        print(f"Error processing node {node}: {e}")
        raise

Processing nodes: 100%|██████████| 19177/19177 [28:43<00:00, 11.12it/s]


In [5]:
dgl_graph = PSD_HPOEncoder.nx_to_dgl(graph)
feature_dimension = 2048
in_feats = feature_dimension
h_feats = 256
out_feats = feature_dimension

# DGL graph does not implement the API for GPU, so we use CPU for computation
device = torch.device("cpu")
model = GCN(in_feats, h_feats, out_feats).to(device)


PSD_HPOEncoder.train_model(model, dgl_graph, epochs=500, lr=0.001, node_mask_percentage=0.2, edge_mask_percentage=0.2)

  dgl_graph.ndata['feat'] = torch.tensor(features, dtype=torch.float32)


Epoch 0, Loss: 4.382920742034912
Epoch 10, Loss: 3.642106056213379
Epoch 20, Loss: 3.432603120803833
Epoch 30, Loss: 3.311546564102173
Epoch 40, Loss: 3.289074420928955
Epoch 50, Loss: 3.2476754188537598
Epoch 60, Loss: 3.2304956912994385
Epoch 70, Loss: 3.117399215698242
Epoch 80, Loss: 3.140692710876465
Epoch 90, Loss: 3.0696816444396973
Epoch 100, Loss: 2.9765782356262207
Epoch 110, Loss: 2.9778101444244385
Epoch 120, Loss: 2.969238519668579
Epoch 130, Loss: 2.8587207794189453
Epoch 140, Loss: 2.804210901260376
Epoch 150, Loss: 2.733131170272827
Epoch 160, Loss: 2.6965548992156982
Epoch 170, Loss: 2.591127872467041
Epoch 180, Loss: 2.546566963195801
Epoch 190, Loss: 2.526063919067383
Epoch 200, Loss: 2.4742660522460938
Epoch 210, Loss: 2.3950393199920654
Epoch 220, Loss: 2.388521671295166
Epoch 230, Loss: 2.298631429672241
Epoch 240, Loss: 2.2302441596984863
Epoch 250, Loss: 2.2321877479553223
Epoch 260, Loss: 2.1530818939208984
Epoch 270, Loss: 2.139406442642212
Epoch 280, Loss: 2.

In [6]:
model.eval()
with torch.no_grad():
    outputs, latent = model(dgl_graph, dgl_graph.ndata['feat'])
node_embedding_dict = {node_id: latent[idx].numpy() for idx, node_id in enumerate(list(graph.nodes))}

In [7]:
with open('../data/node_embedding_dict.pkl', 'wb') as f:
    pickle.dump(node_embedding_dict, f)