In [None]:

import os
import sys
current_dir = os.getcwd()
parent_dir = os.path.dirname(current_dir)
sys.path.insert(0, parent_dir)
# Set the parent directory as the current directory
os.chdir(parent_dir)

In [None]:
# load the dataset 
from utils.data import read_json_file, print_json_structure

dataset_path = "data/dataset/filtered_rd_annos_updated_adam.json"
dataset = read_json_file(dataset_path)
print_json_structure(dataset)
text = dataset["287"]["note_details"]["text"]
# print(text)


In [None]:
def load_mistral_llm_client():
    """
    Load a Mistral 24B LLM client configured with default cache directories
    and assigned to cuda:0 device.
    
    Returns:
        LocalLLMClient: Initialized LLM client for Mistral 24B
    """
    from utils.llm_client import LocalLLMClient
    
    # Default cache directory from mine_hpo.py
    default_cache_dir = "/u/zelalae2/scratch/rdma_cache"
    
    # Initialize and return the client with specific configuration
    llm_client = LocalLLMClient(
        model_type="mistral_24b",  # Explicitly request mistral_24b model
        device="cuda:0",           # Assign to first GPU (cuda:0)
        cache_dir=default_cache_dir,
        temperature=0.1           # Default temperature from mine_hpo.py
    )
    
    return llm_client

llm_client = load_mistral_llm_client()

In [None]:
from abc import ABC, abstractmethod
import json
import pandas as pd
import numpy as np
import re
from typing import List, Dict, Any, Optional, Tuple
from datetime import datetime
from fuzzywuzzy import fuzz
from hporag.context import ContextExtractor
from rdrag.entity import BaseRDExtractor

class RetrievalEnhancedRDExtractor(BaseRDExtractor):
    """
    Entity extractor that uses embedding retrieval to enhance LLM-based extraction.
    
    For each sentence:
    1. Retrieves relevant rare disease terms to provide context
    2. Uses this context to prompt the LLM for more accurate disease extraction
    
    This approach helps the LLM with domain knowledge before extraction begins.
    
    Features:
    - Optionally merges small sentences to ensure minimum chunk size for processing
    """
    
    def __init__(self, llm_client, embedding_manager, embedded_documents, 
                 system_message: str, top_k: int = 10, min_sentence_size: Optional[int] = None):
        """
        Initialize the retrieval-enhanced rare disease extractor.
        
        Args:
            llm_client: Client for querying the language model
            embedding_manager: Manager for vector embeddings and search
            embedded_documents: Rare disease terms with embeddings
            system_message: System message for LLM extraction
            top_k: Number of top candidates to retrieve per sentence
            min_sentence_size: Minimum character length for sentences (smaller ones will be merged)
        """
        self.llm_client = llm_client
        self.embedding_manager = embedding_manager
        self.embedded_documents = embedded_documents
        self.index = None
        self.system_message = system_message
        self.top_k = top_k
        self.min_sentence_size = min_sentence_size
        self.context_extractor = ContextExtractor()
        
    def prepare_index(self):
        """Prepare FAISS index from embedded documents if not already prepared."""
        if self.index is None:
            embeddings_array = self.embedding_manager.prepare_embeddings(self.embedded_documents)
            self.index = self.embedding_manager.create_index(embeddings_array)
    
    def _retrieve_candidates(self, sentence: str) -> List[Dict]:
        """
        Retrieve relevant rare disease candidates for a sentence.
        
        Args:
            sentence: Clinical sentence to find relevant rare disease terms for
            
        Returns:
            List of dictionaries with rare disease information and similarity scores
        """
        self.prepare_index()
        
        # Embed the query
        query_vector = self.embedding_manager.query_text(sentence).reshape(1, -1)
        
        # Search for similar items
        distances, indices = self.embedding_manager.search(query_vector, self.index, k=min(800, len(self.embedded_documents)))
        
        # Extract unique metadata
        candidates = []
        seen_metadata = set()
        
        for idx, distance in zip(indices[0], distances[0]):
            try:
                # Access document directly since it already has name, id, definition
                document = self.embedded_documents[idx]
                
                # Create an identifier for deduplication
                metadata_id = f"{document.get('name', '')}-{document.get('id', '')}"
                
                if metadata_id not in seen_metadata:
                    seen_metadata.add(metadata_id)
                    candidates.append({
                        'name': document.get('name', ''),
                        'id': document.get('id', ''),
                        'definition': document.get('definition', ''),
                        'similarity_score': 1.0 / (1.0 + distance)  # Convert distance to similarity
                    })
                    
                    if len(candidates) >= self.top_k:
                        break
            except Exception as e:
                print(f"Error processing metadata at index {idx}: {e}")
                continue
                    
        return candidates
    def _create_enhanced_prompt(self, sentence: str, candidates: List[Dict]) -> str:
        """
        Create a prompt enhanced with retrieved candidates.
        
        Args:
            sentence: Clinical sentence to extract rare diseases from
            candidates: Retrieved rare disease candidates to use as context
            
        Returns:
            Formatted prompt for LLM
        """
        # Format candidates as context
        context_items = []
        for candidate in candidates:
            context_items.append(f"- {candidate['name']} (ID: {candidate['id']})")
        
        context_text = "\n".join(context_items)
        
        # Create the enhanced prompt
        prompt = (
            f"I have CLINICAL TEXT: \"{sentence}\"\n\n"
            f"Here are some relevant rare disease terms for reference that may help you find rare disease mentions in the sentence:\n\n"
            f"{context_text}\n\n"
            f"Based on this sentence and the provided rare disease terms as reference, extract all potential disease mentions "
            f"that are NOT negated (i.e., NOT preceded by 'no', 'not', 'without', 'ruled out', etc.). "
            f"Please also include any potential abbreviations that might be referring to rare diseases in the CLINICAL TEXT."
            f"\n\nReturn only a Python list of strings, with each disease exactly as it appears in the CLINICAL TEXT. "
            f"Ensure the output is concise without any additional notes, commentary, or meta explanations."
        )
        
        return prompt
    
    def _merge_small_sentences(self, sentences: List[str], min_size: int) -> List[str]:
        """
        Merge sentences smaller than the minimum size with subsequent sentences.
        
        Args:
            sentences: List of extracted sentences
            min_size: Minimum character length for a sentence
            
        Returns:
            List of merged sentences meeting the minimum size requirement
        """
        if not sentences:
            return []
        
        if min_size is None or min_size <= 0:
            return sentences
            
        merged_sentences = []
        current_idx = 0
        
        while current_idx < len(sentences):
            current_sentence = sentences[current_idx]
            
            # If the current sentence is already large enough, add it directly
            if len(current_sentence) >= min_size:
                merged_sentences.append(current_sentence)
                current_idx += 1
                continue
            
            # Start merging with next sentences until we reach min_size
            merged_chunk = current_sentence
            next_idx = current_idx + 1
            
            while next_idx < len(sentences) and len(merged_chunk) < min_size:
                # Add the next sentence to our chunk with a space
                if merged_chunk and sentences[next_idx]:
                    merged_chunk += " " + sentences[next_idx]
                else:
                    merged_chunk += sentences[next_idx]
                next_idx += 1
            
            # Add the merged chunk to our results
            merged_sentences.append(merged_chunk)
            
            # Update the index to continue after the merged sentences
            current_idx = next_idx
        
        return merged_sentences
    
    def extract_entities(self, text: str) -> List[str]:
        """
        Extract rare disease mentions from text using retrieval-enhanced prompting.
        With sentence merging for efficiency when min_sentence_size is set.
        
        Args:
            text: Clinical text to extract rare disease mentions from
            
        Returns:
            List of extracted rare disease mentions
        """
        # Split text into sentences
        original_sentences = self.context_extractor.extract_sentences(text)
        
        # Merge small sentences if min_sentence_size is set
        if self.min_sentence_size:
            sentences = self._merge_small_sentences(original_sentences, self.min_sentence_size)
            print(f"After merging: Processing {len(sentences)} chunks instead of {len(original_sentences)} raw sentences")
        else:
            sentences = original_sentences
        
        all_entities = []
        
        for sentence in sentences:
            # Skip empty or very short sentences
            if not sentence or len(sentence) < 5:
                continue
                
            # Retrieve candidates for this sentence/chunk
            candidates = self._retrieve_candidates(sentence)
            
            # Create enhanced prompt
            prompt = self._create_enhanced_prompt(sentence, candidates)
            
            # Query LLM
            findings_text = self.llm_client.query(prompt, self.system_message)
            # Extract entities from response

            print("--------- DEBUG ---------")
            print(prompt)
            print()
            print(findings_text)
            print("------------------------")
            entities = self._extract_findings_from_response(findings_text)
            # Add to results
            all_entities.extend(entities)
        
        # Remove duplicates while preserving order
        unique_entities = []
        seen = set()
        for entity in all_entities:
            entity_lower = entity.lower()
            if entity_lower not in seen and entity:
                seen.add(entity_lower)
                unique_entities.append(entity)
        
        return unique_entities
    
    def _extract_findings_from_response(self, response: str) -> List[str]:
        """
        Parse LLM response to extract findings.
        
        Args:
            response: Raw LLM response text
            
        Returns:
            List of extracted entities
        """
        try:
            # Extract content between square brackets if present
            if '[' in response and ']' in response:
                response = response[response.find('[') + 1:response.rfind(']')]
            
            # Split on commas and clean up each term
            findings = []
            for term in response.split(','):
                cleaned_term = term.strip().strip('"\'')
                if cleaned_term:
                    findings.append(cleaned_term)
                    
            return findings
        except Exception as e:
            print(f"Error parsing LLM response: {str(e)}")
            return []
    
    def process_batch(self, texts: List[str]) -> List[List[str]]:
        """
        Process a batch of texts for rare disease extraction.
        
        Args:
            texts: List of clinical texts to process
            
        Returns:
            List of lists containing extracted rare disease mentions for each text
        """
        results = []
        for text in texts:
            entities = self.extract_entities(text)
            results.append(entities)
        return results
    
from utils.embedding import EmbeddingsManager
embedding_manager = EmbeddingsManager(
    model_type="sentence_transformer",
    model_name="abhinand/MedEmbed-small-v0.1",  # Medical-domain sentence transformer model
    device="cuda:0"
)

embedded_documents = np.load('/home/johnwu3/projects/rare_disease/workspace/repos/RDMA/data/vector_stores/rd_orpha_medembed.npy', allow_pickle=True)
extractor = RetrievalEnhancedRDExtractor(llm_client, 
                                         embedding_manager, 
                                         embedded_documents, 
                                         system_message="You are a rare disease expert. Please extract disease mentions from the text.", 
                                         min_sentence_size=500,
                                         top_k=10)
extracted = extractor.extract_entities(text)


In [None]:
print(extracted)

In [None]:
original = ['congestive heart failure',
 "Parkinson's",
 'cardiogenic shock',
 'atrial fibrillation',
 'solar urticaria',
 'insufficiency',
 'mitral regurgitation',
 'mib e3 ubiquitin protein ligase 2',
 'Hypercholesterolemia',
 'Peripheral vascular disease',
 'peripheral neuropathy',
 'glaucoma',
 "Parkinson's disease",
 'Osteoarthritis',
 'NPH insulin',
 'epilepsy and/or ataxia with myoclonus as a major feature',
 'diffuse cerebral and cerebellar atrophy-intractable seizures-progressive microcephaly syndrome',
 'infantile cerebral and cerebellar atrophy with postnatal progressive microcephaly',
 'hypomyelination with atrophy of basal ganglia and cerebellum',
 'porencephaly',
 'progressive myoclonic epilepsy type 9',
 'focal epilepsy-intellectual disability-cerebro-cerebellar malformation',
 'neuronal tumor',
 'spastic paraplegia 29',
 'phosphatase 80',
 'albumin',
 'serum',
 'spastic paraplegia 36',
 'Type 2 diabetes mellitus',
 'sick sinus syndrome',
 'cardiomegaly',
 'hypertrophy',
 'restrictive lung disease',
 'multifocal opacities',
 'adenoma',
 'pneumonia',
 'insulin']