In [1]:
import torch
from torch import optim
import torch.nn.functional as F

import pickle
from pyhpo import Ontology
from phenodp.encoders import PCL_HPOEncoder, get_training_sample, info_nce_loss

In [2]:
ontology = Ontology(data_folder='../data/hpo-2025-05-06')

with open('../data/node_embedding_dict.pkl', 'rb') as f:
    node_embedding = pickle.load(f)

In [3]:
disease_dict = dict()
disease_list = list(ontology.omim_diseases)
hps_list = (node_embedding.keys())
for d in disease_list:
    disease_dict[d.id] = [ontology.get_hpo_object(t).id for t in list(d.hpo) if ontology.get_hpo_object(t).id in hps_list]

d_count = []
disease_db = []
for i in list(disease_dict.keys()):
    if len(disease_dict[i]) >= 5:
        disease_db.append(i)


In [4]:
input_dim = 256
num_heads = 8
num_layers = 3
hidden_dim = 512
output_dim = 1
max_seq_length = 128

model = PCL_HPOEncoder(input_dim=input_dim, num_heads=num_heads, num_layers=num_layers, hidden_dim=hidden_dim, output_dim=output_dim, max_seq_length=max_seq_length)
optimizer = optim.Adam(model.parameters(), lr=1e-3)



In [5]:
# Training configuration
n_samples = 2000
max_seq_length = 128
num_epochs = 10
batch_size = 100
device = 'cuda' if torch.cuda.is_available() else 'cpu'

# Move model to device and set to training mode
model.to(device)
model.train()

# Prepare training data
inputs_list, mask_list = get_training_sample(disease_db, disease_dict, node_embedding, n_samples)
inputs1 = inputs_list[0].to(device)
inputs2 = inputs_list[1].to(device)
masks1 = mask_list[0].to(device)
masks2 = mask_list[1].to(device)

# Calculate number of batches
num_batches = n_samples // batch_size + (1 if n_samples % batch_size != 0 else 0)

# Training loop
for epoch in range(num_epochs):
    model.train()
    total_loss = 0.0
    total_steps = 0
    
    for batch_idx in range(num_batches):
        # Get batch indices
        start_idx = batch_idx * batch_size
        end_idx = min(start_idx + batch_size, n_samples)
        
        # Prepare batch data
        inputs1_batch = inputs1[start_idx:end_idx]
        inputs2_batch = inputs2[start_idx:end_idx]
        mask1_batch = masks1[start_idx:end_idx].float()
        mask2_batch = masks2[start_idx:end_idx].float()
        
        # Forward pass
        cls_embedding1, emb1 = model(inputs1_batch, mask1_batch)
        cls_embedding2, emb2 = model(inputs2_batch, mask2_batch)
        
        # Compute loss
        loss = info_nce_loss(cls_embedding1, cls_embedding2)
        
        # Backward pass and optimization
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        # Track loss
        total_loss += loss.item()
        total_steps += 1
    
    # Print epoch results
    avg_loss = total_loss / total_steps
    print(f'Epoch {epoch + 1}/{num_epochs}, Loss: {avg_loss:.4f}')



Epoch 1/10, Loss: 4.6339
Epoch 2/10, Loss: 4.1799
Epoch 3/10, Loss: 3.9024
Epoch 4/10, Loss: 3.6616
Epoch 5/10, Loss: 3.4464
Epoch 6/10, Loss: 3.2853
Epoch 7/10, Loss: 3.1323
Epoch 8/10, Loss: 2.9392
Epoch 9/10, Loss: 2.7817
Epoch 10/10, Loss: 2.6304


In [6]:
# Save model
model.to('cpu')
torch.save(model.state_dict(), '../data/transformer_encoder_infoNCE.pth')