In [1]:
import pandas as pd
import numpy as np
import pickle
from pyhpo import Ontology
from PCL_HPOEncoder import *
import torch
import torch.nn.functional as F
from torch import optim
import random



In [2]:
Ontology()

<pyhpo.ontology.OntologyClass at 0x7f1697240400>

In [3]:
with open('../HPO2SUM/github_project/node_embedding_dict_T5_gcn.pkl', 'rb') as f:
    node_embedding = pickle.load(f)


In [4]:
disease_dict = dict()
disease_list = list(Ontology.omim_diseases)
for d in disease_list:
    disease_dict[d.id] = [Ontology.get_hpo_object(t).id for t in list(d.hpo)]

d_count = []
disease_db = []
for i in list(disease_dict.keys()):
    if len(disease_dict[i]) >= 5:
        disease_db.append(i)


In [8]:
input_dim = 256
num_heads = 8
num_layers = 3
hidden_dim = 512
output_dim = 1
max_seq_length = 128

model = PCL_HPOEncoder(input_dim, num_heads, num_layers, hidden_dim, output_dim, max_seq_length)
optimizer = optim.Adam(model.parameters(), lr=1e-3)

In [10]:
n_s = 2000
max_seq_length = 128
# num_epochs = 10
num_epochs = 2000
batch_size = 100
device = 'cuda:3'
model.to(device)
model.train()
inputs_list, mask_list = get_training_sample(disease_db, disease_dict, node_embedding, n_s)
inputs1 = inputs_list[0].to(device)
inputs2 = inputs_list[1].to(device)
masks1 = mask_list[0].to(device)
masks2 = mask_list[0].to(device)


num_batches = n_s // batch_size + (1 if n_s % batch_size != 0 else 0)

for epoch in range(num_epochs):
    model.train()
    total_loss = 0.0
    total_steps = 0
    
    for batch_idx in range(num_batches):
        start_idx = batch_idx * batch_size
        end_idx = min(start_idx + batch_size, n_s)
        
        inputs1_batch = inputs1[start_idx:end_idx]
        inputs2_batch = inputs2[start_idx:end_idx]
        mask1_batch = masks1[start_idx:end_idx]
        mask2_batch = masks2[start_idx:end_idx]
        cls_embedding1, emb1 = model(inputs1_batch, mask1_batch)
        cls_embedding2, emb2 = model(inputs2_batch, mask2_batch)
        
        labels = torch.tensor([1.0 if i == j else 0.0 for i in range(len(inputs1_batch)) for j in range(len(inputs2_batch))]).to(device).view(inputs1_batch.size(0), inputs2_batch.size(0))
        
        loss = info_nce_loss(cls_embedding1, cls_embedding2)
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        total_loss += loss.item()
        total_steps += 1
    
    print(f'Epoch {epoch + 1}, Loss: {total_loss / total_steps}')


Epoch 1, Loss: 5.1400044679641725
Epoch 2, Loss: 5.075931429862976
Epoch 3, Loss: 5.1044677734375
Epoch 4, Loss: 5.087262034416199
Epoch 5, Loss: 5.056925201416016
Epoch 6, Loss: 5.038917374610901
Epoch 7, Loss: 5.007943749427795
Epoch 8, Loss: 4.988239336013794
Epoch 9, Loss: 4.974637484550476
Epoch 10, Loss: 4.9491067886352536


In [11]:
model.to('cpu')
torch.save(model.state_dict(), './transformer_encoder_infonce_norm.pth')