In [None]:
import re
import torch
from pyhpo import Ontology

from phenodp import PhenoDP, PhenoDP_Initial
from phenodp.encoders import PCL_HPOEncoder
from phenodp.utils import load_similarity_matrix, load_node_embeddings

In [2]:
ontology = Ontology(data_folder='../data/hpo-2025-05-06')

In [None]:
# Initialize the PhenoDP model
recommender = PCL_HPOEncoder(input_dim=256, num_heads=8, num_layers=3, hidden_dim=512, output_dim=1, max_seq_length=128)
pre_model = PhenoDP_Initial(ontology, ic_type='omim')
hp2d_sim_dict = load_similarity_matrix('../data/JC_sim_dict.pkl')
node_embedding = load_node_embeddings('../data/node_embedding_dict.pkl')
recommender.load_state_dict(torch.load('../data/transformer_encoder_infoNCE.pth'))
phenodp = PhenoDP(pre_model, hp2d_sim_dict, node_embedding, recommender)

In [None]:
Patient_hps = ['HP:0000670', 'HP:0004322', 'HP:0000992', 'HP:0001290', 'HP:0000407', 'HP:0000252', 'HP:0000490']

In [4]:
def calculate_coefficient_of_variation(results, top_n=3):
    """
    Computes the mean, standard deviation, and coefficient of variation (CV) for the 'Total_Similarity' column 
    of the top n rows in a given DataFrame.

    Parameters:
    results (pd.DataFrame): A DataFrame containing the 'Total_Similarity' column.
    top_n (int): The number of top rows to consider for the calculation. Default is 3.

    Returns:
    mean (float): The mean value of the 'Total_Similarity' column for the top n rows.
    std (float): The standard deviation of the 'Total_Similarity' column for the top n rows.
    cv (float): The coefficient of variation (CV) expressed as a percentage.
    """
    # Extract the top n rows of the 'Total_Similarity' column
    data = results.head(top_n)['Total_Similarity']
    
    # Calculate mean, standard deviation, and coefficient of variation
    mean = data.mean()
    std = data.std()
    cv = (std / mean) * 100
    
    return cv

def Get_Definition(hpo_list):
    definition_list = []
    for t in hpo_list:
        definition = Ontology.get_hpo_object(t).definition
        match = re.search(r'"(.*?)"', definition)
        if match:
            definition_list.append(match.group(1))
    return ' '.join(definition_list)

def generate_diagnosis_prompt(Patient_hps, results, Top_n=3, Top_Recom=2):
    """
    Generate a prompt for explaining potential symptoms to differentiate between candidate diseases.

    Args:
        Patient_hps (list): List of patient's observed HPO terms.
        results (pd.DataFrame): DataFrame containing candidate diseases and their details.
        Top_n (int): Number of top candidate diseases to consider. Default is 3.
        Top_Recom (int): Number of top recommended symptoms for each disease. Default is 2.

    Returns:
        str: A formatted prompt for disease differentiation.
    """
    # Get observed symptoms
    observered_syn = Get_Definition(Patient_hps)
    
    # Get top candidate diseases
    Condidate_diseases = results.head(Top_n)['Disease'].values
    
    # Initialize lists for diseases and recommendations
    diseases_list = []
    Recom_list = []
    txt_inputs = []
    
    # Process each candidate disease
    for index, t in enumerate(Condidate_diseases):
        # Get recommended symptoms for the disease
        recom = phenodp.run_Recommender(Patient_hps, target_disease=t, candidate_diseases=Condidate_diseases)
        Recom_list.append([Ontology.get_hpo_object(t).name for t in recom.head(Top_Recom).hp.values])
        
        # Get disease details and append to diseases_list
        diseases_list.extend([str(index + 1) + '. [OMIM:' + str(t) + '] ' + j.name for j in Ontology.omim_diseases if j.id == t])
        
        # Format disease and symptoms for txt_inputs
        txt_inputs.append(diseases_list[-1] + ' : ' + ', '.join(Recom_list[-1]))
    
    # Format diseases_list and txt_inputs as strings
    diseases_list_str = "\n".join(diseases_list)
    txt_inputs_str = "\n".join(txt_inputs)
    
    # Generate the prompt
    prompt = f"""
Assume you are an experienced clinical physician. Below is a patient’s symptom description using HPO (Human Phenotype Ontology) terms, along with three candidate diagnoses. To further differentiate between these diagnoses, the physician has provided potential symptoms that the patient does not currently exhibit but could help clarify or confirm the diagnosis. Your task is to explain why these potential symptoms are critical for distinguishing between the three diseases.  

**Patient’s Symptom Description**:  
{observered_syn}  

**Three Most Likely Disease Diagnoses**:  
{diseases_list_str}  

**Potential Symptoms for Further Differentiation**:  
{txt_inputs_str}  

**Instructions**:  
1. **Explain Potential Symptoms**: Provide a clear and concise rationale for why the listed potential symptoms are critical for distinguishing between the three diseases. Focus on how these symptoms are specific to or more prevalent in one disease compared to the others.  
2. **Do Not Diagnose**: Do not make any new diagnoses or suggest additional diseases. Your response should focus solely on explaining the potential symptoms for differentiation.  
3. **Length and Style**: The report should be approximately 200–300 words in length, written in a professional and authentic tone that mimics a human expert.  
4. **No References**: Do not include any references in the report.   
"""
    return prompt

### Ranker

In [None]:
results = phenodp.run_Ranker(Patient_hps)
calculate_coefficient_of_variation(results, top_n=3)

### Recommender

In [None]:
recommendations = phenodp.run_Recommender(
    given_hps=Patient_hps,
    target_disease=216400,
    candidate_diseases=[216400, 278760, 133540]
)
recommendations

### Summarizer

In [None]:
prompt = generate_diagnosis_prompt(Patient_hps, results, Top_n=5, Top_Recom=3)
print(prompt)

# 保存为txt文件
with open('../data/case_report_prompt.txt', 'w', encoding='utf-8') as f:
    f.write(prompt)