In [1]:
import os
import sys
current_dir = os.getcwd()
parent_dir = os.path.dirname(current_dir)
sys.path.insert(0, parent_dir)
# Set the parent directory as the current directory
os.chdir(parent_dir)

In [2]:
entity = "Mg"
def load_mistral_llm_client():
    """
    Load a Mistral 24B LLM client configured with default cache directories
    and assigned to cuda:0 device.
    
    Returns:
        LocalLLMClient: Initialized LLM client for Mistral 24B
    """
    from utils.llm_client import LocalLLMClient
    
    # Default cache directory from mine_hpo.py
    default_cache_dir = "/u/zelalae2/scratch/rdma_cache"
    
    # Initialize and return the client with specific configuration
    llm_client = LocalLLMClient(
        model_type="mistral_24b",  # Explicitly request mistral_24b model
        device="cuda:0",           # Assign to first GPU (cuda:0)
        cache_dir=default_cache_dir,
        temperature=0.0001           # Default temperature from mine_hpo.py
    )
    
    return llm_client

In [3]:

def initialize_abbreviation_searcher(retriever_type="sentence_transformer", model_name="abhinand/MedEmbed-small-v0.1"):
    """
    Initialize the abbreviation searcher with embeddings file.
    
    Args:
        retriever_type: Type of embedding model to use
        model_name: Name of the embedding model
        
    Returns:
        Initialized ToolSearcher for abbreviation lookups
    """
    from utils.search_tools import ToolSearcher
    from utils.embedding import EmbeddingsManager
    
    # Path to the embeddings file - adjust as needed
    abbreviations_file = "/home/johnwu3/projects/rare_disease/workspace/repos/RDMA/data/tools/abbreviations_medembed_sm.npy"
    
    # Initialize embedding manager
    embedding_manager = EmbeddingsManager(
        model_type=retriever_type,
        model_name=model_name,
        device="cpu"  # Use CPU to avoid GPU conflicts
    )
    
    # Initialize abbreviation searcher
    abbreviation_searcher = ToolSearcher(
        model_type=retriever_type,
        model_name=model_name,
        device="cpu",  # Use CPU for abbreviation searching
        top_k=3  # Get top 3 matches for abbreviations
    )
    
    # Load embeddings

    abbreviation_searcher.load_embeddings(abbreviations_file)
    print("Abbreviation searcher initialized successfully")
    return abbreviation_searcher
   

In [6]:
import time
def check_abbreviation_llm(entity, llm_client, system_message):
    """
    Check if an entity is an abbreviation using LLM.
    
    Args:
        entity: Entity text to check
        llm_client: LLM client for querying
        system_message: System message for LLM
        
    Returns:
        Dictionary with abbreviation check results
    """
    start_time = time.time()
    
    
    # Create the prompt
    abbreviation_prompt = f"Is '{entity}' an abbreviation in medical or clinical context? " + \
                         f"Respond with ONLY 'YES' if it's an abbreviation or 'NO' if it's not."
    
    # Query the LLM
    response = llm_client.query(abbreviation_prompt, system_message)
    
    # Parse the response
    response_text = response.strip().upper()
    is_abbreviation = "YES" in response_text and "NO" not in response_text
    
    end_time = time.time()
    
    return {
        'is_abbreviation': is_abbreviation,
        'confidence': 0.9 if is_abbreviation else 0.1,
        'response': response,
        'method': 'llm',
        'execution_time': end_time - start_time
    }


def check_abbreviation_retrieval(entity, abbreviation_searcher):
    """
    Check if an entity is an abbreviation using vector retrieval.
    
    Args:
        entity: Entity text to check
        abbreviation_searcher: ToolSearcher for abbreviation lookups
        
    Returns:
        Dictionary with abbreviation check results
    """
    import time
    
    start_time = time.time()
    
    # Quick rule-based check to filter obvious non-abbreviations


        # Search for abbreviation
    search_results = abbreviation_searcher.search(entity)
    
    if not search_results:
        result = {
            'is_abbreviation': False,
            'confidence': 0.9,
            'method': 'no_match_found',
            'execution_time': time.time() - start_time
        }
        return result
   
    # Get top result
    top_result = search_results[0]
    print(top_result)
    similarity = top_result.get('similarity', 0.0)
    query_term = top_result.get('query_term', '')
    expanded_term = top_result.get('result', '')
    
    # Check if this is a good match
    is_abbreviation = similarity > 0.98 and query_term == entity
    
    result = {
        'is_abbreviation': is_abbreviation,
        'expanded_term': expanded_term if is_abbreviation else None,
        'confidence': similarity if is_abbreviation else (1.0 - similarity),
        'method': 'abbreviation_lookup',
        'top_matches': search_results[:3],  # Include top 3 matches
        'execution_time': time.time() - start_time
    }
    
    return result
        
   

In [7]:
def benchmark_abbreviation_detection(test_cases, llm_client, abbreviation_searcher):
    """
    Benchmark LLM vs. retrieval approaches for abbreviation detection.
    
    Args:
        test_cases: List of dictionaries with entity and expected result
        llm_client: LLM client for querying
        abbreviation_searcher: ToolSearcher for abbreviation lookups
        
    Returns:
        Dictionary with benchmark results
    """
    import time
    
    system_message = "You are a medical expert specializing in clinical terminology and abbreviations."
    results = []
    
    llm_times = []
    retrieval_times = []
    llm_correct = 0
    retrieval_correct = 0
    
    print(f"Running benchmark on {len(test_cases)} test cases...")
    
    for i, test_case in enumerate(test_cases):
        entity = test_case['entity']
        expected = test_case.get('expected', None)
        
        print(f"\nTesting entity: '{entity}'")
        
        # Test LLM approach
        llm_result = check_abbreviation_llm(entity, llm_client, system_message)
        llm_times.append(llm_result['execution_time'])
        
        # Test retrieval approach
        retrieval_result = check_abbreviation_retrieval(entity, abbreviation_searcher)
        retrieval_times.append(retrieval_result['execution_time'])
        
        # Check correctness if expected result is provided
        if expected is not None:
            if llm_result['is_abbreviation'] == expected:
                llm_correct += 1
                
            if retrieval_result['is_abbreviation'] == expected:
                retrieval_correct += 1
        
        # Store results
        case_result = {
            'entity': entity,
            'expected': expected,
            'llm_result': llm_result,
            'retrieval_result': retrieval_result
        }
        results.append(case_result)
        
        # Print summary for this test case
        print(f"  LLM: {llm_result['is_abbreviation']} ({llm_result['execution_time']:.3f}s)")
        print(f"  Retrieval: {retrieval_result['is_abbreviation']} ({retrieval_result['execution_time']:.3f}s)")
        if expected is not None:
            print(f"  Expected: {expected}")
    
    # Calculate overall metrics
    avg_llm_time = sum(llm_times) / len(llm_times) if llm_times else 0
    avg_retrieval_time = sum(retrieval_times) / len(retrieval_times) if retrieval_times else 0
    
    total_with_expected = sum(1 for test in test_cases if test.get('expected') is not None)
    llm_accuracy = llm_correct / total_with_expected if total_with_expected > 0 else 0
    retrieval_accuracy = retrieval_correct / total_with_expected if total_with_expected > 0 else 0
    
    # Print summary
    print("\n=== Benchmark Summary ===")
    print(f"Total test cases: {len(test_cases)}")
    print(f"LLM average time: {avg_llm_time:.3f}s")
    print(f"Retrieval average time: {avg_retrieval_time:.3f}s")
    print(f"Time improvement: {(avg_llm_time - avg_retrieval_time) / avg_llm_time * 100:.1f}%")
    
    if total_with_expected > 0:
        print(f"LLM accuracy: {llm_accuracy:.2f} ({llm_correct}/{total_with_expected})")
        print(f"Retrieval accuracy: {retrieval_accuracy:.2f} ({retrieval_correct}/{total_with_expected})")
    
    return {
        'detailed_results': results,
        'summary': {
            'avg_llm_time': avg_llm_time,
            'avg_retrieval_time': avg_retrieval_time,
            'llm_accuracy': llm_accuracy,
            'retrieval_accuracy': retrieval_accuracy,
            'llm_correct': llm_correct,
            'retrieval_correct': retrieval_correct,
            'total_cases': len(test_cases),
            'total_with_expected': total_with_expected
        }
    }

def run_abbreviation_benchmark():
    """Main function to run the benchmark."""
    import time
    
    # Load LLM client
    llm_client = load_mistral_llm_client()
    
    # Initialize abbreviation searcher
    abbreviation_searcher = initialize_abbreviation_searcher()
    
    # Define test cases
    test_cases = [
        # Common medical abbreviations (should be detected by both methods)
        {'entity': 'Mg', 'expected': True},  # Magnesium
        {'entity': 'CHF', 'expected': True},  # Congestive Heart Failure
        {'entity': 'MI', 'expected': True},  # Myocardial Infarction
        {'entity': 'HTN', 'expected': True},  # Hypertension
        {'entity': 'DM', 'expected': True},  # Diabetes Mellitus
        
        # Rare disease abbreviations (might be harder for retrieval)
        {'entity': 'PKU', 'expected': True},  # Phenylketonuria
        {'entity': 'CF', 'expected': True},  # Cystic Fibrosis
        {'entity': 'DMD', 'expected': True},  # Duchenne Muscular Dystrophy
        
        # Ambiguous terms
        {'entity': 'HF', 'expected': True},  # Heart Failure or Hydrofluoric Acid
        {'entity': 'MS', 'expected': True},  # Multiple Sclerosis or Mass Spectrometry
        
        # Non-abbreviations that look like abbreviations
        {'entity': 'DNA', 'expected': True},  # Technically an abbreviation
        {'entity': 'LASER', 'expected': True},  # Originally an acronym
        {'entity': 'RARE', 'expected': False},  # Just an uppercase word
        
        # Normal words (should not be detected as abbreviations)
        {'entity': 'disease', 'expected': False},
        {'entity': 'syndrome', 'expected': False},
        {'entity': 'patient', 'expected': False},
        
        # Borderline cases
        {'entity': 'IQ', 'expected': True},  # Intelligence Quotient
        {'entity': 'pH', 'expected': True},  # Power of Hydrogen
        {'entity': 'mm', 'expected': True},  # Millimeter
        
        # Additional test cases can be added here
    ]
    
    # Run benchmark
    benchmark_results = benchmark_abbreviation_detection(test_cases, llm_client, abbreviation_searcher)
    
    # Save results to file (optional)
    timestamp = time.strftime("%Y%m%d-%H%M%S")
    result_file = f"abbreviation_benchmark_{timestamp}.json"
    
    import json
    with open(result_file, 'w') as f:
        json.dump(benchmark_results, f, indent=2)
    
    print(f"Benchmark results saved to {result_file}")

# Execute the benchmark
if __name__ == "__main__":
    run_abbreviation_benchmark()

Initialized ModelLoader with cache directory: /shared/rsaas/jw3/rare_disease/model_cache
Loading LLM!
Device configuration: cuda:0
Using device map: {'': 'cuda:0'}
Loading 70B model with quantization: mistral_24b
Generated cache path: /shared/rsaas/jw3/rare_disease/model_cache/Mistral-Small-24B-Instruct-2501_4bit_nf4
Valid cache found at /shared/rsaas/jw3/rare_disease/model_cache/Mistral-Small-24B-Instruct-2501_4bit_nf4
Loading cached quantized model from /shared/rsaas/jw3/rare_disease/model_cache/Mistral-Small-24B-Instruct-2501_4bit_nf4




Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

Device set to use cuda:0


Hello! I'm here to help. How can I assist you today? If you have any medical questions or need information on a specific topic, feel free to ask. Please note that while I strive to provide accurate and helpful information, I am an AI and my knowledge cutoff is 2023, so I might not have real-time or up-to-date information. For urgent medical concerns, always consult a healthcare professional.

Here are a few examples of how I can assist you:

* Explain medical terms or concepts
* Provide information on diseases, symptoms, and treatments
* Offer insights into medical procedures and tests
* Discuss healthcare guidelines and recommendations
* Answer questions related to biomedical research and studies

What would you like to know or discuss?
Loading model...
Model type: sentence_transformer
Model name: abhinand/MedEmbed-small-v0.1
Device: cpu
Initializing SentenceTransformer with model: abhinand/MedEmbed-small-v0.1 on device: cpu
Model running on CPU
Verifying model by embedding sample tex