In [None]:

import os
import sys
current_dir = os.getcwd()
parent_dir = os.path.dirname(current_dir)
sys.path.insert(0, parent_dir)
# Set the parent directory as the current directory
os.chdir(parent_dir)

In [None]:
from utils.data import read_json_file, print_json_structure
hpo_data = read_json_file('data/dataset/mine_hpo.json')
print(len(hpo_data))
# print_json_structure(hpo_data)
truth = hpo_data["53"]["phenotypes"]
# sample 5 texts and their ground truth phenotypes for testing.
ids = ["53", "54", "55", "56", "57"] 
texts = []
ground_truths = []
for id in ids:
    text = hpo_data[id]["clinical_text"]
    texts.append(text)
    truth = hpo_data[id]["phenotypes"]
    ground_truth = []
    for item in truth:
        ground_truth.append(item["phenotype_name"])
    ground_truths.append(ground_truth)

    sanity_check_list = []
    for item in ground_truth:
        # print(item)
        if item in text:
            sanity_check_list.append(item)
    print("Pairwise checks:")
    print(len(sanity_check_list))
    print(len(truth))
# ground_truth = []
# for item in truth:
#     ground_truth.append(item["phenotype_name"])
# benchmark text runtime
text = hpo_data["53"]["clinical_text"]


In [None]:
def load_mistral_llm_client():
    """
    Load a Mistral 24B LLM client configured with default cache directories
    and assigned to cuda:0 device.
    
    Returns:
        LocalLLMClient: Initialized LLM client for Mistral 24B
    """
    from utils.llm_client import LocalLLMClient
    
    # Default cache directory from mine_hpo.py
    default_cache_dir = "/u/zelalae2/scratch/rdma_cache"
    
    # Initialize and return the client with specific configuration
    llm_client = LocalLLMClient(
        model_type="mistral_24b",  # Explicitly request mistral_24b model
        device="cuda:0",           # Assign to first GPU (cuda:0)
        cache_dir=default_cache_dir,
        temperature=0.0001           # Default temperature from mine_hpo.py
    )
    
    return llm_client

llm_client = load_mistral_llm_client()

In [None]:
from hporag.entity import RetrievalEnhancedEntityExtractor, RetrievalEnhancedEntityExtractorV2
import json
from hporag.entity import BaseEntityExtractor
from utils.embedding import EmbeddingsManager
import numpy as np
# Load system prompts
with open('data/prompts/system_prompts.json', 'r') as f:
    prompts = json.load(f)
    system_message_extraction = prompts.get("system_message_I", "")
    system_message_verification = prompts.get("system_message_II", "")

# Initialize embedding manager with MedEmbed using sentence transformers
embedding_manager = EmbeddingsManager(
    model_type="sentence_transformer",
    model_name="abhinand/MedEmbed-small-v0.1",  # Medical-domain sentence transformer model
    device="cuda:1"
)

# Load embeddings
embedded_documents = np.load('data/vector_stores/G2GHPO_metadata_medembed.npy', allow_pickle=True)
llm_client.temperature = 0.001  # Lower temperature for more precise extraction
# Initialize the retrieval-enhanced extractor
retrieval_extractor = RetrievalEnhancedEntityExtractor(
    llm_client=llm_client,
    embedding_manager=embedding_manager,
    embedded_documents=embedded_documents,
    system_message=system_message_extraction,
    top_k=5
)

retrieval_extractor_v2 = RetrievalEnhancedEntityExtractorV2(
    llm_client=llm_client,
    embedding_manager=embedding_manager,
    embedded_documents=embedded_documents,
    system_message=system_message_extraction,
    max_batch_size=32,
    top_k=5
)

1m24

In [None]:
retr_v1 = retrieval_extractor.extract_entities(text)

max_batch_size = 48, 1m 2s

In [None]:
retr_v2 = retrieval_extractor_v2.extract_entities(text)

In [None]:
print(len(retr_v1))
print(len(retr_v2))

In [None]:
def intersection_method1(list1, list2):
    return list(set(list1) & set(list2))
print(len(intersection_method1(retr_v1, retr_v2)))

In [None]:
def disjoint_elements(list1, list2):
    # Convert to sets for efficient operations
    set1 = set(list1)
    set2 = set(list2)
    
    # Symmetric difference finds elements in either set, but not in both
    return list(set1 ^ set2)

# Example usage
a = [1, 2, 3, 4, 5]
b = [4, 5, 6, 7, 8]

print(disjoint_elements(retr_v1, retr_v2))  # [1, 2, 3, 6, 7, 8]