In [1]:
import torch
from torch import optim
import torch.nn.functional as F

import pickle
from pyhpo import Ontology
from phenodp.encoders import PCL_HPOEncoder, get_training_sample, info_nce_loss

In [2]:
ontology = Ontology(data_folder='../data/hpo-2025-05-06')

with open('../data/node_embedding_dict.pkl', 'rb') as f:
    node_embedding = pickle.load(f)

In [3]:
disease_dict = dict()
disease_list = list(ontology.omim_diseases)
hps_list = (node_embedding.keys())
for d in disease_list:
    disease_dict[d.id] = [ontology.get_hpo_object(t).id for t in list(d.hpo) if ontology.get_hpo_object(t).id in hps_list]

d_count = []
disease_db = []
for i in list(disease_dict.keys()):
    if len(disease_dict[i]) >= 5:
        disease_db.append(i)


In [4]:
input_dim = 256
num_heads = 8
num_layers = 3
hidden_dim = 512
output_dim = 1
max_seq_length = 128

model = PCL_HPOEncoder(input_dim=input_dim, num_heads=num_heads, num_layers=num_layers, hidden_dim=hidden_dim, output_dim=output_dim, max_seq_length=max_seq_length)
optimizer = optim.Adam(model.parameters(), lr=1e-3)



In [None]:
# Training configuration
n_samples = 2000
max_seq_length = 128
num_epochs = 200
batch_size = 100
device = 'cuda' if torch.cuda.is_available() else 'cpu'

# Move model to device and set to training mode
model.to(device)
model.train()

# Prepare training data
inputs_list, mask_list = get_training_sample(disease_db, disease_dict, node_embedding, n_samples)
inputs1 = inputs_list[0].to(device)
inputs2 = inputs_list[1].to(device)
masks1 = mask_list[0].to(device)
masks2 = mask_list[1].to(device)

# Calculate number of batches
num_batches = n_samples // batch_size + (1 if n_samples % batch_size != 0 else 0)

# Training loop
for epoch in range(num_epochs):
    model.train()
    total_loss = 0.0
    total_steps = 0
    
    for batch_idx in range(num_batches):
        # Get batch indices
        start_idx = batch_idx * batch_size
        end_idx = min(start_idx + batch_size, n_samples)
        
        # Prepare batch data
        inputs1_batch = inputs1[start_idx:end_idx]
        inputs2_batch = inputs2[start_idx:end_idx]
        mask1_batch = masks1[start_idx:end_idx].float()
        mask2_batch = masks2[start_idx:end_idx].float()
        
        # Forward pass
        cls_embedding1, emb1 = model(inputs1_batch, mask1_batch)
        cls_embedding2, emb2 = model(inputs2_batch, mask2_batch)
        
        # Compute loss
        loss = info_nce_loss(cls_embedding1, cls_embedding2)
        
        # Backward pass and optimization
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        # Track loss
        total_loss += loss.item()
        total_steps += 1
    
    # Print epoch results
    avg_loss = total_loss / total_steps
    print(f'Epoch {epoch + 1}/{num_epochs}, Loss: {avg_loss:.4f}')



Epoch 1/200, Loss: 4.6314
Epoch 2/200, Loss: 4.1975
Epoch 3/200, Loss: 3.9316
Epoch 4/200, Loss: 3.6773
Epoch 5/200, Loss: 3.4620
Epoch 6/200, Loss: 3.2744
Epoch 7/200, Loss: 3.1047
Epoch 8/200, Loss: 2.9351
Epoch 9/200, Loss: 2.7830
Epoch 10/200, Loss: 2.6486
Epoch 11/200, Loss: 2.5288
Epoch 12/200, Loss: 2.3874
Epoch 13/200, Loss: 2.2671
Epoch 14/200, Loss: 2.1567
Epoch 15/200, Loss: 2.0143
Epoch 16/200, Loss: 1.9517
Epoch 17/200, Loss: 1.8623
Epoch 18/200, Loss: 1.7854
Epoch 19/200, Loss: 1.7337
Epoch 20/200, Loss: 1.6850
Epoch 21/200, Loss: 1.6446
Epoch 22/200, Loss: 1.5920
Epoch 23/200, Loss: 1.5442
Epoch 24/200, Loss: 1.4915
Epoch 25/200, Loss: 1.4448
Epoch 26/200, Loss: 1.4229
Epoch 27/200, Loss: 1.3995
Epoch 28/200, Loss: 1.3801
Epoch 29/200, Loss: 1.3447
Epoch 30/200, Loss: 1.3016
Epoch 31/200, Loss: 1.2866
Epoch 32/200, Loss: 1.2587
Epoch 33/200, Loss: 1.2327
Epoch 34/200, Loss: 1.1918
Epoch 35/200, Loss: 1.1749
Epoch 36/200, Loss: 1.1506
Epoch 37/200, Loss: 1.1344
Epoch 38/2

In [6]:
# Save model
model.to('cpu')
torch.save(model.state_dict(), '../data/transformer_encoder_infoNCE.pth')