In [None]:
import os
import sys
current_dir = os.getcwd()
parent_dir = os.path.dirname(current_dir)
sys.path.insert(0, parent_dir)
# Set the parent directory as the current directory
os.chdir(parent_dir)

In [3]:
import json
import os
import pandas as pd
import string
import re
from collections import Counter, defaultdict
from typing import Dict, List, Set, Tuple, Optional, Any

def analyze_false_positives(
    results_file: str,
    output_file: Optional[str] = None,
    match_summary: bool = True,
    detailed: bool = False,
    verbose: bool = True,
    verify_existence: bool = True
) -> Dict[str, Any]:
    """
    Analyze supervised evaluation results to extract information about false positives.
    
    Args:
        results_file: Path to supervised evaluation JSON results file
        output_file: Optional path to save analysis results as CSV
        match_summary: Whether to generate summary of rare disease matches
        detailed: Whether to include detailed information in output
        verbose: Whether to print progress and results to console
        verify_existence: Whether to verify reclassified entities exist in original text
        
    Returns:
        Dictionary containing analysis results
    """
    # Load results
    if verbose:
        print(f"Loading results from {results_file}")
    
    try:
        with open(results_file, 'r') as f:
            results = json.load(f)
    except Exception as e:
        error_msg = f"Error loading results file: {str(e)}"
        if verbose:
            print(error_msg)
        return {"error": error_msg}
    
    # Extract false positives
    if verbose:
        print("Analyzing false positives...")
        
    all_fps, reclassified_fps, fp_document_counts, entity_details = _extract_false_positives(results)
    
    # Verify existence in original text if requested
    existence_verification = None
    if verify_existence:
        if verbose:
            print("Verifying entity existence in original text...")
        existence_verification = _verify_entities_in_text(results, reclassified_fps)
        
        if verbose:
            verified_count = sum(1 for e in existence_verification['verification_results'].values() 
                               if e['document_existence_rate'] > 0)
            total_count = len(existence_verification['verification_results'])
            print(f"Verified {verified_count}/{total_count} reclassified entities exist in at least one document "
                  f"({verified_count/total_count*100:.1f}% exist)")
            
            # Count entities that exist in all their documents
            perfect_count = sum(1 for e in existence_verification['verification_results'].values() 
                              if e['document_existence_rate'] == 1.0)
            print(f"{perfect_count}/{total_count} entities exist in all their associated documents "
                  f"({perfect_count/total_count*100:.1f}%)")
    
    # Calculate statistics
    total_fps = len(all_fps)
    total_reclassified = len(reclassified_fps)
    reclassification_rate = (total_reclassified / total_fps) * 100 if total_fps > 0 else 0
    
    # Calculate document frequency statistics
    doc_freq = Counter(fp_document_counts.values())
    
    # Sort false positives by frequency
    fps_by_freq = sorted(
        [(entity, count) for entity, count in fp_document_counts.items()],
        key=lambda x: x[1],
        reverse=True
    )
    
    # Analyze matches if requested
    match_analysis = None
    if match_summary:
        if verbose:
            print("Analyzing rare disease matches...")
        match_analysis = _analyze_matches(entity_details)
    
    # Print summary statistics if verbose
    if verbose:
        print("\n===== False Positive Analysis =====")
        print(f"Total unique false positives: {total_fps}")
        print(f"Reclassified as true positives: {total_reclassified} ({reclassification_rate:.1f}%)")
        
        print("\nFalse positive document frequency:")
        for count, freq in sorted(doc_freq.items()):
            print(f"  {count} document{'s' if count > 1 else ''}: {freq} entities")
        
        print("\nTop 20 most frequent false positives:")
        for i, (entity, count) in enumerate(fps_by_freq[:20], 1):
            reclassified = "✓" if entity in reclassified_fps else "✗"
            # Add existence verification if available
            existence_mark = ""
            existence_info = ""
            if verify_existence and entity in reclassified_fps:
                if entity in existence_verification['verification_results']:
                    result = existence_verification['verification_results'][entity]
                    exists_rate = result['document_existence_rate'] * 100
                    existence_mark = "📄" if exists_rate > 0 else "❌"
                    existence_info = f" (exists in {exists_rate:.1f}% of docs)"
            print(f"  {i}. {entity} ({count} docs) {reclassified} {existence_mark}{existence_info}")
        
        if match_analysis:
            print("\nTop 10 rare disease matches for reclassified false positives:")
            for i, (match, count) in enumerate(match_analysis['match_counts'].most_common(10), 1):
                print(f"  {i}. {match}: {count} entities")
    
    # Prepare data for output
    if output_file:
        if verbose:
            print(f"\nSaving analysis to {output_file}")
        
        # Create DataFrame
        data = []
        for entity in all_fps:
            details = entity_details[entity]
            row = {
                'entity': entity,
                'document_count': fp_document_counts[entity],
                'reclassified': entity in reclassified_fps,
                'document_ids': ','.join(details['documents']),
            }
            
            # Add existence verification if available
            if verify_existence and entity in reclassified_fps and entity in existence_verification['verification_results']:
                result = existence_verification['verification_results'][entity]
                row['existence_rate'] = result['document_existence_rate']
                row['documents_exists_in'] = ','.join(result['documents_exists_in'])
                row['documents_missing_from'] = ','.join(result['documents_missing_from'])
            
            # Add match information if available and entity was reclassified
            if match_summary and entity in reclassified_fps and entity in match_analysis.get('entity_to_match', {}):
                row['matched_to'] = match_analysis['entity_to_match'][entity]
            
            data.append(row)
        
        # Convert to DataFrame and sort
        df = pd.DataFrame(data)
        df = df.sort_values('document_count', ascending=False)
        
        # Save to CSV
        os.makedirs(os.path.dirname(os.path.abspath(output_file)), exist_ok=True)
        df.to_csv(output_file, index=False)
        
        if verbose:
            print(f"Analysis saved to {output_file}")
    
    # Prepare return value with FP sets and statistics
    confirmed_fps = all_fps - reclassified_fps
    
    if verbose:
        print("\nFalse positive sets:")
        print("1. All false positives:")
        print(sorted(all_fps))
        
        print("\n2. Reclassified false positives:")
        print(sorted(reclassified_fps))
        
        print("\n3. Confirmed false positives (not reclassified):")
        print(sorted(confirmed_fps))
        
        if verify_existence:
            # Entities that don't exist in any documents
            nonexistent = [e for e, v in existence_verification['verification_results'].items() 
                          if v['document_existence_rate'] == 0]
            if nonexistent:
                print("\n4. Reclassified entities that don't exist in ANY of their documents:")
                print(sorted(nonexistent))
            
            # Entities that exist in some but not all documents
            partially_existent = [e for e, v in existence_verification['verification_results'].items() 
                                if 0 < v['document_existence_rate'] < 1]
            if partially_existent:
                print("\n5. Reclassified entities that exist in SOME but not ALL of their documents:")
                for entity in sorted(partially_existent):
                    result = existence_verification['verification_results'][entity]
                    print(f"  - {entity}")
                    print(f"    ✓ Exists in: {', '.join(result['documents_exists_in'])}")
                    print(f"    ✗ Missing from: {', '.join(result['documents_missing_from'])}")
    
    # Return analysis results
    analysis_results = {
        'statistics': {
            'total_fps': total_fps,
            'reclassified_fps': total_reclassified,
            'reclassification_rate': reclassification_rate,
            'document_frequency': dict(doc_freq)
        },
        'sets': {
            'all_fps': sorted(list(all_fps)),
            'reclassified_fps': sorted(list(reclassified_fps)),
            'confirmed_fps': sorted(list(confirmed_fps))
        },
        'entity_details': entity_details if detailed else None,
        'top_fps': fps_by_freq[:50],  # Top 50 most frequent FPs
        'match_analysis': match_analysis,
        'existence_verification': existence_verification
    }
    
    return analysis_results

def _extract_false_positives(results: Dict) -> Tuple[Set[str], Set[str], Dict, Dict]:
    """
    Extract all false positives and reclassified false positives.
    
    Args:
        results: Results dictionary from supervised evaluation
        
    Returns:
        Tuple containing:
        - Set of all original false positive entities
        - Set of reclassified false positive entities
        - Dictionary mapping entity to document count
        - Dictionary with entity details
    """
    all_fps = set()
    reclassified_fps = set()
    fp_document_counts = {}  # Track how many documents each FP appears in
    entity_details = {}      # Track detailed info for each entity
    
    # Process each document
    for doc_id, doc_results in results['document_results'].items():
        # Process original false positives
        for fp in doc_results.get('old_false_positives', []):
            entity = fp['entity'].lower()
            all_fps.add(entity)
            
            # Update document count
            fp_document_counts[entity] = fp_document_counts.get(entity, 0) + 1
            
            # Store entity details
            if entity not in entity_details:
                entity_details[entity] = {
                    'entity': entity,
                    'documents': [],
                    'reclassified': False,
                    'matches': []
                }
            
            entity_details[entity]['documents'].append(doc_id)
        
        # Process reclassified false positives
        for fp in doc_results.get('new_true_positives', []):
            entity = fp['entity'].lower()
            reclassified_fps.add(entity)
            
            # Update entity details
            if entity in entity_details:
                entity_details[entity]['reclassified'] = True
                
                # Store match details if available
                if 'rd_term' in fp and 'orpha_id' in fp:
                    match = {
                        'rd_term': fp['rd_term'],
                        'orpha_id': fp['orpha_id']
                    }
                    entity_details[entity]['matches'].append(match)
    
    return all_fps, reclassified_fps, fp_document_counts, entity_details

def _verify_entities_in_text(results: Dict, reclassified_fps: Set[str]) -> Dict:
    """
    Verify if reclassified entities actually exist in the original text of each document.
    
    Args:
        results: Results dictionary from supervised evaluation
        reclassified_fps: Set of reclassified false positive entities
        
    Returns:
        Dictionary with verification results
    """
    verification_results = {}
    entity_doc_map = defaultdict(list)
    
    # Create a mapping of entity to documents
    for doc_id, doc_results in results['document_results'].items():
        # Get all documents where an entity was classified as a false positive
        for fp in doc_results.get('old_false_positives', []):
            entity = fp['entity'].lower()
            if entity in reclassified_fps:
                entity_doc_map[entity].append(doc_id)
    
    # Check each reclassified entity in each of its documents
    for entity in reclassified_fps:
        # Get all documents this entity appears in
        docs = entity_doc_map.get(entity, [])
        
        # Track documents where entity exists vs. doesn't exist
        exists_in_docs = []
        missing_from_docs = []
        
        if not docs:
            verification_results[entity] = {
                'document_existence_rate': 0,
                'documents_exists_in': [],
                'documents_missing_from': [],
                'doc_count': 0
            }
            continue
        
        # Check each document for this entity
        for doc_id in docs:
            doc_text = results['document_results'][doc_id].get('original_text', '')
            if not doc_text:
                missing_from_docs.append(doc_id)  # No text available
                continue
                
            if _entity_exists_in_text(entity, doc_text):
                exists_in_docs.append(doc_id)
            else:
                missing_from_docs.append(doc_id)
        
        # Calculate existence rate for this entity
        total_docs = len(docs)
        existence_rate = len(exists_in_docs) / total_docs if total_docs > 0 else 0
        
        verification_results[entity] = {
            'document_existence_rate': existence_rate,
            'documents_exists_in': exists_in_docs,
            'documents_missing_from': missing_from_docs,
            'doc_count': total_docs
        }
    
    # Calculate summary statistics
    total_verified = len(verification_results)
    exists_in_any_count = sum(1 for result in verification_results.values() 
                           if result['document_existence_rate'] > 0)
    exists_in_all_count = sum(1 for result in verification_results.values() 
                           if result['document_existence_rate'] == 1.0)
    missing_count = sum(1 for result in verification_results.values() 
                      if result['document_existence_rate'] == 0)
    
    return {
        'verification_results': verification_results,
        'summary': {
            'total_verified': total_verified,
            'exists_in_any_count': exists_in_any_count,
            'exists_in_all_count': exists_in_all_count,
            'missing_count': missing_count,
            'exists_in_any_rate': (exists_in_any_count / total_verified * 100) if total_verified else 0,
            'exists_in_all_rate': (exists_in_all_count / total_verified * 100) if total_verified else 0
        }
    }

def _entity_exists_in_text(entity: str, text: str) -> bool:
    """
    Check if an entity exists in the original text.
    
    Args:
        entity: Entity text to search for
        text: Original text to search in
        
    Returns:
        True if entity is found in text, False otherwise
    """
    # Normalize both entity and text for more accurate matching
    entity_normalized = entity.lower()
    text_normalized = text.lower()
    
    # Check for exact match
    if entity_normalized in text_normalized:
        return True
        
    # Check for case where entity has punctuation that might be different in text
    # Strip punctuation and check again
    entity_no_punct = re.sub(r'[^\w\s]', '', entity_normalized)
    if entity_no_punct and entity_no_punct in text_normalized:
        return True
        
    # Check for other common formatting issues
    # Try splitting entity on common separators and check if parts exist together
    parts = re.split(r'[-/,\s]+', entity_normalized)
    if len(parts) > 1:
        # Check if all parts appear within a reasonable window of each other
        all_parts_present = True
        for part in parts:
            if part and len(part) > 2 and part not in text_normalized:  # Only check meaningful parts
                all_parts_present = False
                break
                
        if all_parts_present:
            return True
            
    return False

def _analyze_matches(entity_details: Dict) -> Dict:
    """
    Analyze the matches for reclassified false positives.
    
    Args:
        entity_details: Dictionary with entity details
        
    Returns:
        Dictionary with match statistics
    """
    match_counts = Counter()
    entity_to_match = {}
    
    for entity, details in entity_details.items():
        if details['reclassified'] and details['matches']:
            # Use the most common match for this entity
            matches = [f"{m['rd_term']} ({m['orpha_id']})" for m in details['matches']]
            most_common_match = Counter(matches).most_common(1)[0][0]
            
            match_counts[most_common_match] += 1
            entity_to_match[entity] = most_common_match
    
    return {
        'match_counts': match_counts,
        'entity_to_match': entity_to_match
    }

# Example usage as a script if needed
if __name__ == "__main__":
    import argparse
    
    # def parse_arguments() -> argparse.Namespace:
    #     """Parse command line arguments."""
    #     parser = argparse.ArgumentParser(
    #         description='Analyze supervised evaluation results to extract false positive information'
    #     )
    #     parser.add_argument('--results_file', type=str, required=True,
    #                       help='Path to supervised evaluation JSON results file')
    #     parser.add_argument('--output_file', type=str,
    #                       help='Path to save analysis results (CSV)')
    #     parser.add_argument('--detailed', action='store_true',
    #                       help='Include detailed information in output')
    #     parser.add_argument('--match_summary', action='store_true',
    #                       help='Generate summary of rare disease matches')
    #     parser.add_argument('--quiet', action='store_true',
    #                       help='Do not print progress and results')
        
    #     return parser.parse_args()
    
    # args = parse_arguments()
    results_file = "data/results/agents/rd/rd_mistral24b_medembed_supervised.json"
    verbose = True 
    detailed = True 
    match_summary = True
    # Call the main function
    results = analyze_false_positives(
        results_file=results_file,
        verify_existence=True
    )
    results

Loading results from data/results/agents/rd/rd_mistral24b_medembed_supervised.json
Analyzing false positives...
Verifying entity existence in original text...
Verified 18/18 reclassified entities exist in at least one document (100.0% exist)
18/18 entities exist in all their associated documents (100.0%)
Analyzing rare disease matches...

===== False Positive Analysis =====
Total unique false positives: 92
Reclassified as true positives: 18 (19.6%)

False positive document frequency:
  1 document: 73 entities
  2 documents: 10 entities
  3 documents: 4 entities
  4 documents: 2 entities
  5 documents: 2 entities
  6 documents: 1 entities

Top 20 most frequent false positives:
  1. afib (6 docs) ✗ 
  2. cholangitis (5 docs) ✗ 
  3. hepatic encephalopathy (5 docs) ✗ 
  4. respiratory distress syndrome (4 docs) ✗ 
  5. hcc (4 docs) ✗ 
  6. tracheomalacia (3 docs) ✓ 📄 (exists in 100.0% of docs)
  7. atrial septal defect (3 docs) ✗ 
  8. pulmonary artery systolic hypertension (3 docs) ✗ 
  

In [7]:
for entity, details in results['existence_verification']['verification_results'].items():
    if details['document_existence_rate'] < 1.0:
        print(f"Entity '{entity}' missing from documents: {', '.join(details['documents_missing_from'])}")

Entity 'pyruvate kinase deficiency' missing from documents: 52501
Entity 'thrombotic thrombocytopenic purpura' missing from documents: 52501
Entity 'hemolytic uremic syndrome' missing from documents: 52501
Entity 'paroxysmal nocturnal hemoglobinuria' missing from documents: 52501
Entity 'restrictive cardiomyopathy' missing from documents: 52501
Entity 'beta thalassemia' missing from documents: 52501
Entity 'guillain-barre syndrome' missing from documents: 38775
Entity 'myelomeningocele' missing from documents: 53406


# Checking again if there are any false positives

In [2]:
import json
import pandas as pd
import traceback
from typing import Dict, List, Set, Tuple

def process_mimic_json(filepath: str) -> pd.DataFrame:
    """Process MIMIC-style JSON with annotations.
    
    Args:
        filepath: Path to JSON file containing clinical notes with annotations
        
    Returns:
        pd.DataFrame: DataFrame with processed notes and annotations
    """
    try:
        # Load JSON data
        with open(filepath, 'r') as f:
            data = json.load(f)
            
        # Process each document
        records = []
        for doc_id, doc_data in data.items():
            if 'note_details' not in doc_data:
                continue
                
            note_details = doc_data['note_details']
            annotations = doc_data.get('annotations', [])
            
            # Extract relevant fields
            record = {
                'document_id': doc_id,
                'patient_id': note_details.get('subject_id'),
                'admission_id': note_details.get('hadm_id'),
                'category': note_details.get('category'),
                'chart_date': note_details.get('chartdate'),
                'clinical_note': note_details.get('text', ''),
                'gold_annotations': []
            }
            
            # Process all annotations that have a mention
            for ann in annotations:
                if ann.get('mention'):  # Include any annotation with a mention
                    gold_annotation = {
                        'mention': ann['mention'],
                        'orpha_id': ann.get('ordo_with_desc', '').split()[0] if ann.get('ordo_with_desc') else '',
                        'orpha_desc': ' '.join(ann.get('ordo_with_desc', '').split()[1:]) if ann.get('ordo_with_desc') else '',
                        'document_section': ann.get('document_structure'),
                        'confidence': 1.0
                    }
                    record['gold_annotations'].append(gold_annotation)
            
            records.append(record)
            
        # Create DataFrame
        df = pd.DataFrame(records)
        
        # Basic validation and cleaning
        df['clinical_note'] = df['clinical_note'].astype(str)
        df = df.dropna(subset=['clinical_note'])
        
        # Print dataset statistics
        print(f"\nDataset Statistics:")
        print(f"Total documents: {len(df)}")
        print(f"Documents with annotations: {len(df[df['gold_annotations'].str.len() > 0])}")
        print(f"Total annotations: {sum(df['gold_annotations'].str.len())}")
        print(f"Document categories: {df['category'].value_counts().to_dict()}")
        
        return df
        
    except Exception as e:
        print(f"Error processing JSON file: {str(e)}")
        traceback.print_exc()
        return pd.DataFrame()

def get_error_sets(predictions_df: pd.DataFrame, gold_df: pd.DataFrame) -> Dict[str, Set[str]]:
    """
    Get sets of all false positives, false negatives, and true positives.
    
    Args:
        predictions_df: DataFrame with predicted entities
        gold_df: DataFrame with gold standard annotations from process_mimic_json
        
    Returns:
        Dict containing sets of all false positives, false negatives, and true positives
    """
    # Convert document_id to string in both DataFrames for consistent comparison
    predictions_df = predictions_df.copy()
    gold_df = gold_df.copy()
    predictions_df['document_id'] = predictions_df['document_id'].astype(str)
    gold_df['document_id'] = gold_df['document_id'].astype(str)
    
    # Initialize sets to collect all errors and correct predictions
    all_true_positives = set()
    all_false_positives = set()
    all_false_negatives = set()
    
    # Document-level tracking for detailed analysis
    doc_error_counts = {}
    
    # Process all documents
    for _, gold_row in gold_df.iterrows():
        doc_id = gold_row['document_id']
        gold_anns = gold_row['gold_annotations']
        
        # Create gold standard set
        gold_entities = {ann['mention'].lower() for ann in gold_anns}
        
        # Get all predictions for this document
        doc_preds = predictions_df[predictions_df['document_id'] == doc_id]
        
        # Collect all predictions for this document
        pred_entities = set()
        
        if not doc_preds.empty:
            for _, pred_row in doc_preds.iterrows():
                entity = pred_row.get('entity', '').lower() if pd.notna(pred_row.get('entity')) else ''
                if entity:
                    pred_entities.add(entity)
        
        # Calculate document-specific sets
        true_pos = gold_entities.intersection(pred_entities)
        false_pos = pred_entities - gold_entities
        false_neg = gold_entities - pred_entities
        
        # Track document-level error counts
        doc_error_counts[doc_id] = {
            'tp': len(true_pos),
            'fp': len(false_pos),
            'fn': len(false_neg)
        }
        
        # Update global sets
        all_true_positives.update(true_pos)
        all_false_positives.update(false_pos)
        all_false_negatives.update(false_neg)
    
    # Return all error sets
    result = {
        'true_positives': all_true_positives,
        'false_positives': all_false_positives,
        'false_negatives': all_false_negatives,
    }
    
    return result

def print_error_analysis(error_sets: Dict[str, Set[str]]):
    """
    Print detailed analysis of error sets.
    
    Args:
        error_sets: Dict containing sets of true positives, false positives, and false negatives
    """
    # Print counts
    print("\n===== ERROR ANALYSIS =====")
    print(f"True positives: {len(error_sets['true_positives'])}")
    print(f"False positives: {len(error_sets['false_positives'])}")
    print(f"False negatives: {len(error_sets['false_negatives'])}")
    
    # Calculate precision, recall, F1
    tp_count = len(error_sets['true_positives'])
    fp_count = len(error_sets['false_positives'])
    fn_count = len(error_sets['false_negatives'])
    
    precision = tp_count / (tp_count + fp_count) if (tp_count + fp_count) > 0 else 0
    recall = tp_count / (tp_count + fn_count) if (tp_count + fn_count) > 0 else 0
    f1 = 2 * (precision * recall) / (precision + recall) if (precision + recall) > 0 else 0
    
    print(f"Precision: {precision:.3f}")
    print(f"Recall: {recall:.3f}")
    print(f"F1 Score: {f1:.3f}")
    print(error_sets["false_positives"])
    print(error_sets["false_negatives"])
    # Print entity lists
    print("\n----- TRUE POSITIVES -----")
    for entity in sorted(error_sets['true_positives']):
        print(f"  {entity}")
    
    print("\n----- FALSE POSITIVES -----")
    for entity in sorted(error_sets['false_positives']):
        print(f"  {entity}")
    
    print("\n----- FALSE NEGATIVES -----")
    for entity in sorted(error_sets['false_negatives']):
        print(f"  {entity}")

# Example of usage in a Jupyter notebook

import pandas as pd
import json

# Process gold standard data
gold_df = process_mimic_json("data/dataset/filtered_rd_annos_updated_adam.json")

# Load predictions
predictions_df = pd.read_csv("data/results/rd_llama70b_fastembed.csv")

# Get error sets
error_sets = get_error_sets(predictions_df, gold_df)

# Print analysis
print_error_analysis(error_sets)

# To save to files for further analysis
# import csv

# # Save true positives
# with open('true_positives.txt', 'w') as f:
#     for entity in sorted(error_sets['true_positives']):
#         f.write(f"{entity}\n")

# # Save false positives  
# with open('false_positives.txt', 'w') as f:
#     for entity in sorted(error_sets['false_positives']):
#         f.write(f"{entity}\n")
        
# # Save false negatives
# with open('false_negatives.txt', 'w') as f:
#     for entity in sorted(error_sets['false_negatives']):
#         f.write(f"{entity}\n")



Dataset Statistics:
Total documents: 117
Documents with annotations: 117
Total annotations: 333
Document categories: {'Discharge summary': 117}

===== ERROR ANALYSIS =====
True positives: 54
False positives: 137
False negatives: 59
Precision: 0.283
Recall: 0.478
F1 Score: 0.355
{'myelomeningocele', 'spontaneous pneumothorax', 'pancreatic cystic dz.', 'myeloproliferative neoplasm', 'polycythemia', 'moyamoya disease', 'mitral valve stenosis', 'malignant salivary gland neoplasms', 'tbi', 'creat', 'necrotizing soft tissue infection', 'adrenal insufficiency', 'intestinal necrosis', 'vre pyelonephritis', 'nocardia', 'vascular malformation', 'fmf', 'thrombotic thrombocytopenic purpura', 'thalassemia', 'ebv positive b-cell lymphoma', 'eosinophilic granulomatosis with polyangiitis', 'cryptococcal ag', 'apnea prematurity', 'sessile serrated adenoma', 'giant cell arteritis', 'ampullary intestinal adenoma', 'esophageal candidiasis', 'sss', 'vogt-koyanagi syndrome', 'juvenile idiopathic arthritis'

In [5]:
def calculate_document_level_metrics(predictions_df, gold_df, corrected_rare_diseases=None):
    """
    Calculate metrics at the document level similar to evaluate_predictions function.
    Optionally corrects false positives based on a list of diseases that should be considered true positives.
    
    Args:
        predictions_df: DataFrame with predicted entities
        gold_df: DataFrame with gold standard annotations
        corrected_rare_diseases: Optional list of diseases that should not be considered false positives
        
    Returns:
        Dict containing original and corrected metrics
    """
    # Convert document_id to string in both DataFrames for consistent comparison
    predictions_df = predictions_df.copy()
    gold_df = gold_df.copy()
    predictions_df['document_id'] = predictions_df['document_id'].astype(str)
    gold_df['document_id'] = gold_df['document_id'].astype(str)
    
    # Convert corrected list to lowercase set for faster lookups
    corrected_set = set(disease.lower() for disease in (corrected_rare_diseases or []))
    
    # Initialize counters for original metrics
    entity_true_positives = 0
    entity_false_positives = 0
    entity_false_negatives = 0
    
    # Initialize counters for corrected metrics
    corrected_true_positives = 0
    corrected_false_positives = 0
    corrected_false_negatives = 0
    
    # Tracking for which entities were actually corrected
    moved_to_true_positives = set()
    
    # Process all documents
    for _, gold_row in gold_df.iterrows():
        doc_id = gold_row['document_id']
        gold_anns = gold_row['gold_annotations']
        
        # Create gold standard sets for this document
        gold_entities = {ann['mention'].lower() for ann in gold_anns}
        
        # Get all predictions for this document
        doc_preds = predictions_df[predictions_df['document_id'] == doc_id]
        
        # Collect all predictions for this document
        pred_entities = set()
        if not doc_preds.empty:
            for _, pred_row in doc_preds.iterrows():
                entity = pred_row.get('entity', '').lower() if pd.notna(pred_row.get('entity')) else ''
                if entity:
                    pred_entities.add(entity)
        
        # Calculate document-specific metrics
        doc_true_positives = gold_entities.intersection(pred_entities)
        doc_false_positives = pred_entities - gold_entities
        doc_false_negatives = gold_entities - pred_entities
        
        # Update original counters
        entity_true_positives += len(doc_true_positives)
        entity_false_positives += len(doc_false_positives)
        entity_false_negatives += len(doc_false_negatives)
        
        # Correct false positives if they're in the corrected list
        corrected_doc_false_positives = set()
        corrected_doc_true_positives = set(doc_true_positives)  # Start with original TPs
        
        for entity in doc_false_positives:
            if entity.lower() in corrected_set:
                # This is actually a true positive
                corrected_doc_true_positives.add(entity)
                moved_to_true_positives.add(entity)
            else:
                # This remains a false positive
                corrected_doc_false_positives.add(entity)
        
        # Update corrected counters
        corrected_true_positives += len(corrected_doc_true_positives)
        corrected_false_positives += len(corrected_doc_false_positives)
        corrected_false_negatives += len(doc_false_negatives)  # False negatives remain the same
    
    # Calculate original metrics
    precision_original = entity_true_positives / (entity_true_positives + entity_false_positives) if (entity_true_positives + entity_false_positives) > 0 else 0
    recall_original = entity_true_positives / (entity_true_positives + entity_false_negatives) if (entity_true_positives + entity_false_negatives) > 0 else 0
    f1_original = 2 * (precision_original * recall_original) / (precision_original + recall_original) if (precision_original + recall_original) > 0 else 0
    
    # Calculate corrected metrics
    precision_corrected = corrected_true_positives / (corrected_true_positives + corrected_false_positives) if (corrected_true_positives + corrected_false_positives) > 0 else 0
    recall_corrected = corrected_true_positives / (corrected_true_positives + corrected_false_negatives) if (corrected_true_positives + corrected_false_negatives) > 0 else 0
    f1_corrected = 2 * (precision_corrected * recall_corrected) / (precision_corrected + recall_corrected) if (precision_corrected + recall_corrected) > 0 else 0
    
    # Collect not found entities (corrected diseases that weren't in any document's false positives)
    corrected_disease_found = {disease.lower(): False for disease in (corrected_rare_diseases or [])}
    for entity in moved_to_true_positives:
        if entity.lower() in corrected_disease_found:
            corrected_disease_found[entity.lower()] = True
    
    not_found_in_false_positives = {disease for disease, found in corrected_disease_found.items() if not found}
    
    # Create results dictionary
    results = {
        'original': {
            'true_positives': entity_true_positives,
            'false_positives': entity_false_positives,
            'false_negatives': entity_false_negatives,
            'precision': precision_original,
            'recall': recall_original,
            'f1': f1_original
        },
        'corrected': {
            'true_positives': corrected_true_positives,
            'false_positives': corrected_false_positives,
            'false_negatives': corrected_false_negatives,
            'precision': precision_corrected,
            'recall': recall_corrected,
            'f1': f1_corrected
        },
        'moved_to_true_positives': moved_to_true_positives,
        'not_found_in_false_positives': not_found_in_false_positives
    }
    
    return results

def print_document_level_metrics(results):
    """
    Print a detailed comparison of original and corrected document-level metrics.
    
    Args:
        results: Dictionary returned by calculate_document_level_metrics
    """
    print("\n===== DOCUMENT-LEVEL METRICS COMPARISON =====")
    print("                   ORIGINAL    CORRECTED")
    print(f"True Positives:    {results['original']['true_positives']:5d}       {results['corrected']['true_positives']:5d}")
    print(f"False Positives:   {results['original']['false_positives']:5d}       {results['corrected']['false_positives']:5d}")
    print(f"False Negatives:   {results['original']['false_negatives']:5d}       {results['corrected']['false_negatives']:5d}")
    print(f"Precision:         {results['original']['precision']:.3f}      {results['corrected']['precision']:.3f}")
    print(f"Recall:            {results['original']['recall']:.3f}      {results['corrected']['recall']:.3f}")
    print(f"F1 Score:          {results['original']['f1']:.3f}      {results['corrected']['f1']:.3f}")
    
    # Print details about corrections
    if results['moved_to_true_positives']:
        print(f"\nEntities moved from false positives to true positives: {len(results['moved_to_true_positives'])}")
        for entity in sorted(results['moved_to_true_positives']):
            print(f"  {entity}")
    
    if results['not_found_in_false_positives']:
        print(f"\nEntities in correction list but not found in false positives: {len(results['not_found_in_false_positives'])}")
        for entity in sorted(results['not_found_in_false_positives']):
            print(f"  {entity}")

# Example usage
"""
# Define the corrected list of rare diseases
rare_diseases = [
    "myelomeningocele",
    "fmf",
    "thrombotic thrombocytopenic purpura",
    # ... other diseases ...
]

# Process gold standard data
gold_df = process_mimic_json("data/dataset/filtered_rd_annos_updated_adam.json")

# Load predictions
predictions_df = pd.read_csv("rd_llama70b_fastembed.csv")

# Calculate document-level metrics with corrections
doc_metrics = calculate_document_level_metrics(predictions_df, gold_df, rare_diseases)

# Print metrics comparison
print_document_level_metrics(doc_metrics)
"""
# Example usage

# Define the corrected list of rare diseases (these should not be considered false positives)
rare_diseases = [
    "myelomeningocele",
    "fmf",
    "thrombotic thrombocytopenic purpura",
    "eosinophilic granulomatosis with polyangiitis",
    "giant cell arteritis",
    "vogt-koyanagi syndrome",
    "juvenile idiopathic arthritis",
    "churg-strauss syndrome",
    "primary pulmonary hypertension",
    "scleroderma",
    "arrhythmogenic right ventricular dysplasia",
    "fibrous dysplasia",
    "synovial osteochondromatosis",
    "relapsing polychondritis",
    "takayasu arteritis",
    "cryoglobulinemic vasculitis",
    "dermatomyositis",
    "hemophagocytic syndrome",
    "systemic lupus erythematosus",
    "mixed connective tissue disease",
    "thromboangiitis obliterans",
    "buerger's disease",
    "microscopic polyangiitis",
    "aplastic anemia",
    "kawasaki disease",
    "antithrombin deficiency",
    "henoch-schonlein purpura",
    "myelodysplastic syndrome",
    "temporal arteritis",
    "pah"
]

# Recalculate metrics
# Process gold standard data
gold_df = process_mimic_json("data/dataset/filtered_rd_annos_updated_adam.json")

# Load predictions
predictions_df = pd.read_csv("data/results/rd_llama70b_medembed.csv")

# Calculate document-level metrics with corrections
doc_metrics = calculate_document_level_metrics(predictions_df, gold_df, rare_diseases)

# Print metrics comparison
print_document_level_metrics(doc_metrics)


Dataset Statistics:
Total documents: 117
Documents with annotations: 117
Total annotations: 333
Document categories: {'Discharge summary': 117}

===== DOCUMENT-LEVEL METRICS COMPARISON =====
                   ORIGINAL    CORRECTED
True Positives:       87          98
False Positives:     145         134
False Negatives:     108         108
Precision:         0.375      0.422
Recall:            0.446      0.476
F1 Score:          0.407      0.447

Entities moved from false positives to true positives: 11
  aplastic anemia
  arrhythmogenic right ventricular dysplasia
  fibrous dysplasia
  fmf
  hemophagocytic syndrome
  myelomeningocele
  pah
  primary pulmonary hypertension
  synovial osteochondromatosis
  thrombotic thrombocytopenic purpura
  vogt-koyanagi syndrome

Entities in correction list but not found in false positives: 19
  antithrombin deficiency
  buerger's disease
  churg-strauss syndrome
  cryoglobulinemic vasculitis
  dermatomyositis
  eosinophilic granulomatosis with po

# Dataset cleaning further to re-label false positives.

In [1]:
import numpy as np
from utils.embedding import EmbeddingsManager
import json

def inspect_embeddings(embeddings_file: str):
    """Inspect the first document in the embeddings file."""
    print(f"\nInspecting embeddings from: {embeddings_file}")
    print("-" * 50)
    
    # Load the embeddings file
    try:
        data = np.load(embeddings_file, allow_pickle=True)
        print(f"Total number of documents: {len(data)}")
        
        # Print first document details
        first_doc = data[0]
        print("\nFirst document structure:")
        print("-" * 25)
        
        # Print all fields except embedding
        for key, value in first_doc.items():
            if key != 'embedding':
                print(f"{key}: {value}")
        
        # Print embedding details
        embedding = first_doc['embedding']
        print(f"\nEmbedding details:")
        print(f"Shape: {embedding.shape}")
        print(f"Type: {embedding.dtype}")
        print(f"First 5 values: {embedding[:5]}")
        print(f"Min value: {np.min(embedding)}")
        print(f"Max value: {np.max(embedding)}")
        print(f"Mean value: {np.mean(embedding)}")
        
    except Exception as e:
        print(f"Error loading embeddings: {str(e)}")

def test_retrieval(embeddings_file: str, model_type: str = 'fastembed', model_name: str = "BAAI/bge-small-en-v1.5"):
    """Test the retrieval component with a sample query."""
    print(f"\nTesting retrieval functionality")
    print("-" * 50)
    
    try:
        # Initialize EmbeddingsManager
        embeddings_manager = EmbeddingsManager(
            model_type=model_type,
            model_name=model_name
        )
        
        # Load embedded documents
        embedded_documents = embeddings_manager.load_documents(embeddings_file)
        print(f"Loaded {len(embedded_documents)} documents")
        
        # Prepare embeddings array and create index
        embeddings_array = embeddings_manager.prepare_embeddings(embedded_documents)
        index = embeddings_manager.create_index(embeddings_array)
        
        # Test queries
        test_queries = [
            "rare genetic disorder affecting bone development",
            "inherited metabolic disorder",
            "rare autoimmune condition"
        ]
        
        for query in test_queries:
            print(f"\nQuery: {query}")
            print("-" * 25)
            
            # Get query embedding and search
            query_vector = embeddings_manager.query_text(query).reshape(1, -1)
            distances, indices = embeddings_manager.search(query_vector, index, k=5)
            
            # Print top 5 results
            print("Top 5 matches:")
            for idx, (distance, doc_idx) in enumerate(zip(distances[0], indices[0]), 1):
                doc = embedded_documents[doc_idx]
                similarity = 1 / (1 + distance)
                print(f"{idx}. {doc['name']} (ORPHA:{doc['id']}) - Similarity: {similarity:.3f}")
                if doc.get('definition'):
                    print(f"   Definition: {doc['definition'][:200]}...")
                print()
                
    except Exception as e:
        print(f"Error testing retrieval: {str(e)}")

if __name__ == "__main__":
    # Update this path to your embeddings file
    EMBEDDINGS_FILE = "data/vector_stores/rd_orpha_fastembed.npy"
    
    # Inspect embeddings
    inspect_embeddings(EMBEDDINGS_FILE)
    
    # Test retrieval
    test_retrieval(EMBEDDINGS_FILE)

  from .autonotebook import tqdm as notebook_tqdm



Inspecting embeddings from: data/vector_stores/rd_orpha_fastembed.npy
--------------------------------------------------
Total number of documents: 26750

First document structure:
-------------------------
name: 48,xxyy syndrome
id: ORPHA:10
definition: a rare sex chromosome number anomaly disorder characterized, genetically, by the presence of an extra x and y chromosome in males and, clinically, by tall stature, dysfunctional testes associated with infertility and insufficient testosterone production, cognitive, affective and social functioning impairments, global developmental delay, and an increased risk of congenital malformations.

Embedding details:
Shape: (384,)
Type: float32
First 5 values: [-0.03409642 -0.01160857  0.05178694  0.04140279 -0.00915304]
Min value: -0.32415685057640076
Max value: 0.329723596572876
Mean value: 0.0003787704335991293

Testing retrieval functionality
--------------------------------------------------
Loading model...
Model type: fastembed
Model nam

# Corrections to Updated_Adam.json

In [1]:
import json
from typing import List, Dict
import re

def update_annotations_with_validated_mentions(
    annotations_file: str,
    validated_mentions: List[str],
    output_file: str
) -> Dict[str, int]:
    """
    Update annotations by adding validated false positives as new annotations.
    
    Args:
        annotations_file: Path to original annotations file
        validated_mentions: List of validated mentions to add
        output_file: Path to save updated annotations
        
    Returns:
        Dict containing statistics about updates made
    """
    # Load original annotations
    with open(annotations_file, 'r') as f:
        annotations = json.load(f)
    
    stats = {
        'documents_updated': 0,
        'new_annotations_added': 0,
        'mentions_found': 0
    }
    
    # Compile case-insensitive patterns for each mention
    mention_patterns = {
        mention: re.compile(r'\b' + re.escape(mention) + r'\b', re.IGNORECASE)
        for mention in validated_mentions
    }
    
    # Process each document
    for doc_id, doc in annotations.items():
        text = doc['note_details']['text']
        existing_mentions = {anno['mention'].lower() for anno in doc['annotations']}
        new_annotations = []
        
        # Check for each validated mention in the text
        for mention, pattern in mention_patterns.items():
            # Skip if mention is already annotated
            if mention.lower() in existing_mentions:
                continue
                
            # Find all occurrences
            matches = list(pattern.finditer(text))
            if matches:
                stats['mentions_found'] += len(matches)
                # Add each occurrence as a new annotation
                for match in matches:
                    new_anno = {
                        'mention': text[match.start():match.end()],  # Use exact text from document
                        'start': match.start(),
                        'end': match.end(),
                        'type': 'RARE_DISEASE',  # Assuming this is the standard type
                        'validated_false_positive': True  # Mark as validated false positive
                    }
                    new_annotations.append(new_anno)
        
        # Update document if new annotations were found
        if new_annotations:
            doc['annotations'].extend(new_annotations)
            stats['documents_updated'] += 1
            stats['new_annotations_added'] += len(new_annotations)
    
    # Save updated annotations
    with open(output_file, 'w') as f:
        json.dump(annotations, f, indent=2)
    
    return stats

# Example usage:
validated_mentions = [
    'papillary carcinoma',
    'glioblastoma multiforme',
    'transitional cell carcinoma',
    'multifocal atrial tachycardia (mat)',
    'sarcoidosis',
    'methemoglobinemia',
    'central nervous system and systemic lymphoma',
    'sclerosis cholangitis',  # corrected from 'sclerosing cholangitis'
    'mediastinitis',
    'mesenteric vein thrombosis',
    'multiple myeloma',
    'hepatocellular carcinoma',  # corrected from 'hepatocellular ca'
    'primary cns lymphoma',
    "bechet's disease",
    'neovascular glaucoma',
    'meningocele',
    'alopecia',
    'neovascular glaucoma angle closure',
    'pyoderma gangrenosum',
    'budd-chiari',
    'intraductal papillary mucinous tumor',
    'complex tracheal stenosis',
    'cervical stenosis',
    'bronchiectasis',
    'medullary sponge kidney',
    'protein s',
    'antiphospholipid antibody syndrome',
    'protein c',
    'acute myelogenous leukemia',
    'anaplastic thyroid carcinoma',
    'thymoma',
    'congenital bleeding disorder',
    'tracheal stenosis'
]
# Update annotations
stats = update_annotations_with_validated_mentions(
    annotations_file='data/dataset/rd_annos.json',
    validated_mentions=validated_mentions,
    output_file='data/dataset/filtered_rd_annos_updated_adam.json'
)

print("\nUpdate Statistics:")
print(f"Documents updated: {stats['documents_updated']}")
print(f"New annotations added: {stats['new_annotations_added']}")
print(f"Total mentions found: {stats['mentions_found']}")


Update Statistics:
Documents updated: 31
New annotations added: 74
Total mentions found: 74


In [2]:
import json
import re
from typing import List, Set, Dict, Tuple

def is_abbreviation(text: str) -> bool:
    """
    Check if a string is likely an abbreviation.
    """
    # Remove spaces and convert to uppercase for checking
    cleaned_text = text.replace(" ", "")
    
    # Skip strings that are too long or too short to be abbreviations
    if len(cleaned_text) < 2 or len(cleaned_text) > 5:
        return False
    
    # Special cases - don't count these as abbreviations
    special_cases = {"H1N1", "B12"}
    if cleaned_text.upper() in special_cases:
        return False
    
    # Check if it's all uppercase letters (allowing spaces)
    if text.upper() == text and any(c.isalpha() for c in text):
        return True
    
    # Check for pattern of multiple capital letters (allowing spaces/numbers)
    capital_count = sum(1 for c in text if c.isupper())
    return capital_count >= 2

def should_remove_mention(mention: str) -> bool:
    """Determine if a mention should be removed."""
    cleaned_mention = mention.strip().lower()
    
    # Special lowercase terms to remove
    lowercase_remove = {"hits", "als", "nph"}
    if cleaned_mention in lowercase_remove:
        return True
    
    # Remove if it's an abbreviation
    return is_abbreviation(mention)

def calculate_metrics(tp: int, fp: int, fn: int) -> Dict:
    """Calculate precision, recall, and F1 score."""
    precision = tp / (tp + fp) if (tp + fp) > 0 else 0
    recall = tp / (tp + fn) if (tp + fn) > 0 else 0
    f1 = 2 * (precision * recall) / (precision + recall) if (precision + recall) > 0 else 0
    
    return {
        'precision': precision,
        'recall': recall,
        'f1': f1,
        'true_positives': tp,
        'false_positives': fp,
        'false_negatives': fn
    }

def recalculate_metrics_with_improvements(
    results_file: str,
    validated_mentions: List[str],
    exclude_abbreviations: bool = False
) -> Dict:
    """
    Recalculate metrics showing incremental improvements.
    
    Args:
        results_file: Path to the original results JSON file
        validated_mentions: List of mentions that were validated as true positives
        exclude_abbreviations: Whether to exclude abbreviations
    """
    # Load results file
    with open(results_file, 'r') as f:
        results = json.load(f)
    
    # Convert validated mentions to set
    validated_set = {mention.lower() for mention in validated_mentions}
    
    # Initialize counters for different scenarios
    metrics = {
        'original': {'tp': 0, 'fp': 0, 'fn': 0},
        'with_validated': {'tp': 0, 'fp': 0, 'fn': 0},
        'with_validated_no_abbrev': {'tp': 0, 'fp': 0, 'fn': 0}
    }
    
    # Process each document's metrics
    for doc_id, doc_metrics in results['evaluation_results']['document_metrics'].items():
        analysis = doc_metrics['analysis']
        
        # Get the sets of mentions
        correct = set(analysis['correct_mentions'])
        spurious = set(analysis['spurious_mentions'])
        missed = set(analysis['missed_mentions'])
        
        # 1. Original metrics
        metrics['original']['tp'] += len(correct)
        metrics['original']['fp'] += len(spurious)
        metrics['original']['fn'] += len(missed)
        
        # 2. Metrics with validated mentions
        validated_spurious = {m for m in spurious if m.lower() in validated_set}
        remaining_spurious = spurious - validated_spurious
        
        metrics['with_validated']['tp'] += len(correct) + len(validated_spurious)
        metrics['with_validated']['fp'] += len(remaining_spurious)
        metrics['with_validated']['fn'] += len(missed)
        
        # 3. Metrics with validated mentions and no abbreviations
        if exclude_abbreviations:
            # Filter mentions
            filtered_correct = {m for m in correct if not should_remove_mention(m)}
            filtered_spurious = {m for m in remaining_spurious if not should_remove_mention(m)}
            filtered_validated = {m for m in validated_spurious if not should_remove_mention(m)}
            filtered_missed = {m for m in missed if not should_remove_mention(m)}
            
            metrics['with_validated_no_abbrev']['tp'] += len(filtered_correct) + len(filtered_validated)
            metrics['with_validated_no_abbrev']['fp'] += len(filtered_spurious)
            metrics['with_validated_no_abbrev']['fn'] += len(filtered_missed)
    
    # Calculate final metrics for each scenario
    return {
        'original': calculate_metrics(
            metrics['original']['tp'],
            metrics['original']['fp'],
            metrics['original']['fn']
        ),
        'with_validated': calculate_metrics(
            metrics['with_validated']['tp'],
            metrics['with_validated']['fp'],
            metrics['with_validated']['fn']
        ),
        'with_validated_no_abbrev': calculate_metrics(
            metrics['with_validated_no_abbrev']['tp'],
            metrics['with_validated_no_abbrev']['fp'],
            metrics['with_validated_no_abbrev']['fn']
        ) if exclude_abbreviations else None
    }

# Example usage
validated_mentions = [
    'papillary carcinoma',
    'glioblastoma multiforme',
    'transitional cell carcinoma',
    'multifocal atrial tachycardia (mat)',
    'sarcoidosis',
    'methemoglobinemia',
    'central nervous system and systemic lymphoma',
    'sclerosis cholangitis',  # corrected from 'sclerosing cholangitis'
    'mediastinitis',
    'mesenteric vein thrombosis',
    'multiple myeloma',
    'hepatocellular carcinoma',  # corrected from 'hepatocellular ca'
    'primary cns lymphoma',
    "bechet's disease",
    'neovascular glaucoma',
    'meningocele',
    'alopecia',
    'neovascular glaucoma angle closure',
    'pyoderma gangrenosum',
    'budd-chiari',
    'intraductal papillary mucinous tumor',
    'complex tracheal stenosis',
    'cervical stenosis',
    'bronchiectasis',
    'medullary sponge kidney',
    'protein s',
    'antiphospholipid antibody syndrome',
    'protein c',
    'acute myelogenous leukemia',
    'anaplastic thyroid carcinoma',
    'thymoma',
    'congenital bleeding disorder',
    'tracheal stenosis'
]

results_file = "data/results/rd_results_llama3_70b.json"

# Get metrics for both scenarios
metrics = recalculate_metrics_with_improvements(
    results_file, 
    validated_mentions,
    exclude_abbreviations=True
)

def print_metrics(metrics_dict: Dict, title: str):
    print(f"\n{title}:")
    print(f"Precision: {metrics_dict['precision']:.3f}")
    print(f"Recall: {metrics_dict['recall']:.3f}")
    print(f"F1: {metrics_dict['f1']:.3f}")
    print(f"True Positives: {metrics_dict['true_positives']}")
    print(f"False Positives: {metrics_dict['false_positives']}")
    print(f"False Negatives: {metrics_dict['false_negatives']}")

print_metrics(metrics['original'], "Original Metrics")
print_metrics(metrics['with_validated'], "Metrics with Validated Mentions")
print_metrics(metrics['with_validated_no_abbrev'], "Metrics with Validated Mentions + No Abbreviations")

# Calculate improvements
print("\nImprovements:")
print("\n1. Impact of Validating Mentions:")
print(f"Precision: {metrics['with_validated']['precision'] - metrics['original']['precision']:.3f}")
print(f"Recall: {metrics['with_validated']['recall'] - metrics['original']['recall']:.3f}")
print(f"F1: {metrics['with_validated']['f1'] - metrics['original']['f1']:.3f}")

print("\n2. Additional Impact of Excluding Abbreviations:")
print(f"Precision: {metrics['with_validated_no_abbrev']['precision'] - metrics['with_validated']['precision']:.3f}")
print(f"Recall: {metrics['with_validated_no_abbrev']['recall'] - metrics['with_validated']['recall']:.3f}")
print(f"F1: {metrics['with_validated_no_abbrev']['f1'] - metrics['with_validated']['f1']:.3f}")


Original Metrics:
Precision: 0.347
Recall: 0.316
F1: 0.331
True Positives: 50
False Positives: 94
False Negatives: 108

Metrics with Validated Mentions:
Precision: 0.590
Recall: 0.440
F1: 0.504
True Positives: 85
False Positives: 59
False Negatives: 108

Metrics with Validated Mentions + No Abbreviations:
Precision: 0.590
Recall: 0.475
F1: 0.526
True Positives: 85
False Positives: 59
False Negatives: 94

Improvements:

1. Impact of Validating Mentions:
Precision: 0.243
Recall: 0.124
F1: 0.173

2. Additional Impact of Excluding Abbreviations:
Precision: 0.000
Recall: 0.034
F1: 0.022


# printing of commonly missed spurious correlations, etc.

In [7]:
import json
from typing import Dict, Set, Tuple

def analyze_mentions(results: Dict) -> Tuple[Set[str], Set[str], Set[str]]:
    """
    Analyze mentions from evaluation results.
    
    Args:
        results (dict): Dictionary containing evaluation results
        
    Returns:
        tuple: Sets of correct, spurious, and missed mentions
    """
    correct_mentions = set()
    spurious_mentions = set()
    missed_mentions = set()
    print(results.keys())
    # Iterate through all documents
    for doc_id, doc_results in results["evaluation_results"]["document_metrics"].items():
        if 'analysis' in doc_results:
            analysis = doc_results['analysis']
            correct_mentions.update(analysis.get('correct_mentions', []))
            spurious_mentions.update(analysis.get('spurious_mentions', []))
            missed_mentions.update(analysis.get('missed_mentions', []))
    
    return correct_mentions, spurious_mentions, missed_mentions

def analyze_results_file(file_path: str) -> Tuple[Set[str], Set[str], Set[str]]:
    """
    Load and analyze mentions from a results JSON file.
    
    Args:
        file_path (str): Path to the results JSON file
        
    Returns:
        tuple: Sets of correct, spurious, and missed mentions
    """
    # Load the JSON file
    with open(file_path, 'r') as f:
        results = json.load(f)
    
    # Analyze the results
    return analyze_mentions(results)

# Example usage in notebook:
file_path = "results/results_llama3_70b.json"
correct, spurious, missed = analyze_results_file(file_path)

print("------Correct-----")
print(correct)
print("------Spurious-----")
print(spurious)
print("------Missed-----")
print(missed)

dict_keys(['evaluation_results', 'args', 'device', 'processing_stats'])
------Correct-----
{'arachnoid cyst', 'pulmonary arterial hypertension', 'dilated cardiomyopathy', 'essential thrombocythemia', 'legionella', 'heparin-induced thrombocytopenia', 'babesiosis', 'retinopathy of prematurity', 'sarcoidosis', 'budd chiari syndrome', 'autoimmune pancreatitis', 'primary adrenal insufficiency', 'essential thrombocytosis', 'multifocal atrial tachycardia', 'nocardiosis', 'heparin induced thrombocytopenia', 'necrotizing enterocolitis', 'post polio syndrome', 'hypogammaglobulinemia', 'alport syndrome', 'autoimmune hepatitis', 'beta-thalassemia', 'budd-chiari syndrome', 'protein c deficiency', 'sick sinus syndrome', 'crest syndrome', 'bullous pemphigoid', 'hemochromatosis', 'primary sclerosing cholangitis', 'familial mediterranean fever', 'alkaptonuria', 'retinitis pigmentosa', 'central retinal vein occlusion', 'calciphylaxis', 'amyotrophic lateral sclerosis'}
------Spurious-----
{'fibrous dyspl

# Ontology Structure Printing

In [1]:
import json
from collections import defaultdict
from typing import Dict, Any, Set, List
from pathlib import Path

class RareDiseaseDataInspector:
    """Inspector class for analyzing rare disease data files."""
    
    def __init__(self, triples_path: str, ontology_path: str):
        """
        Initialize with paths to both data files.
        
        Args:
            triples_path: Path to RareDisease_Phenotype_Triples.json
            ontology_path: Path to rare_disease_ontology.jsonl
        """
        self.triples_path = Path(triples_path)
        self.ontology_path = Path(ontology_path)
        
    def _load_jsonl(self, path: Path) -> List[Dict]:
        """Load JSONL file line by line."""
        data = []
        with open(path) as f:
            for line in f:
                try:
                    data.append(json.loads(line))
                except json.JSONDecodeError as e:
                    print(f"Error decoding line in {path}: {e}")
        return data
    
    def _analyze_structure(self, data: Any, prefix: str = "") -> Dict:
        """
        Recursively analyze the structure of nested data.
        
        Args:
            data: Data structure to analyze
            prefix: Current key prefix for nested structures
            
        Returns:
            Dict containing structure analysis
        """
        structure = {
            "type": type(data).__name__,
            "sample": str(data)[:100] if isinstance(data, (str, bytes)) else None,
            "length": len(data) if hasattr(data, "__len__") else None
        }
        
        if isinstance(data, dict):
            structure["keys"] = {
                k: self._analyze_structure(v, f"{prefix}.{k}" if prefix else k)
                for k, v in (data.items() if len(data) < 5 else list(data.items())[:5])
            }
        elif isinstance(data, list) and data:
            structure["element_sample"] = self._analyze_structure(data[0])
            
        return structure
    
    def _count_field_occurrences(self, data: List[Dict]) -> Dict[str, int]:
        """Count how often each field appears across all entries."""
        field_counts = defaultdict(int)
        total_entries = len(data)
        
        for entry in data:
            if isinstance(entry, dict):
                for key in entry.keys():
                    field_counts[key] += 1
        
        # Convert to percentages
        return {
            key: (count / total_entries * 100)
            for key, count in field_counts.items()
        }
    
    def inspect_files(self) -> None:
        """Inspect and print analysis of both data files."""
        # Analyze triples file
        print("\n=== RareDisease_Phenotype_Triples.json Analysis ===")
        try:
            with open(self.triples_path) as f:
                triples_data = json.load(f)
            
            print("\nStructure:")
            structure = self._analyze_structure(triples_data)
            self._print_structure(structure)
            
            if isinstance(triples_data, list):
                print(f"\nTotal entries: {len(triples_data)}")
                if triples_data:
                    field_stats = self._count_field_occurrences(triples_data)
                    print("\nField coverage:")
                    for field, percentage in sorted(field_stats.items()):
                        print(f"  {field}: {percentage:.1f}%")
        except Exception as e:
            print(f"Error analyzing triples file: {e}")
        
        # Analyze ontology file
        print("\n=== rare_disease_ontology.jsonl Analysis ===")
        try:
            ontology_data = self._load_jsonl(self.ontology_path)
            
            print(f"\nTotal entries: {len(ontology_data)}")
            if ontology_data:
                print("\nStructure of first entry:")
                structure = self._analyze_structure(ontology_data[0])
                self._print_structure(structure)
                
                field_stats = self._count_field_occurrences(ontology_data)
                print("\nField coverage:")
                for field, percentage in sorted(field_stats.items()):
                    print(f"  {field}: {percentage:.1f}%")
                
                # Analyze ORPHA IDs
                orpha_ids = [entry.get('id') for entry in ontology_data if 'id' in entry]
                malformed_ids = [id for id in orpha_ids if not str(id).startswith('ORPHA:')]
                print(f"\nORPHA ID analysis:")
                print(f"  Total IDs: {len(orpha_ids)}")
                print(f"  Malformed IDs: {len(malformed_ids)}")
                if malformed_ids:
                    print(f"  Sample malformed ID: {malformed_ids[0]}")
        except Exception as e:
            print(f"Error analyzing ontology file: {e}")
    
    def _print_structure(self, structure: Dict, indent: int = 2) -> None:
        """Pretty print the structure analysis."""
        def _print_level(data: Dict, level: int = 0):
            prefix = " " * (level * indent)
            
            if "type" in data:
                print(f"{prefix}Type: {data['type']}")
                if data.get("length") is not None:
                    print(f"{prefix}Length: {data['length']}")
                if data.get("sample"):
                    print(f"{prefix}Sample: {data['sample']}")
            
            if "keys" in data:
                print(f"{prefix}Keys:")
                for key, value in data["keys"].items():
                    print(f"{prefix}{key}:")
                    _print_level(value, level + 1)
            
            if "element_sample" in data:
                print(f"{prefix}Element sample:")
                _print_level(data["element_sample"], level + 1)
        
        _print_level(structure)

def inspect_rare_disease_files(triples_path: str, ontology_path: str) -> None:
    """
    Convenience function to inspect rare disease data files.
    
    Args:
        triples_path: Path to RareDisease_Phenotype_Triples.json
        ontology_path: Path to rare_disease_ontology.jsonl
    """
    inspector = RareDiseaseDataInspector(triples_path, ontology_path)
    inspector.inspect_files()

# Example usage:
if __name__ == "__main__":
    inspect_rare_disease_files(
        "/home/johnwu3/projects/rare_disease/workspace/ontology/RareDisease_Phenotype_Triples.json",
        "/home/johnwu3/projects/rare_disease/workspace/repos/AutoRD/data/rare_disease_ontology.jsonl"
    )


=== RareDisease_Phenotype_Triples.json Analysis ===

Structure:
Type: list
Length: 114994
Element sample:
  Type: dict
  Length: 3
  Keys:
  source:
    Type: dict
    Length: 3
    Keys:
    id:
      Type: str
      Length: 9
      Sample: Orpha:324
    name:
      Type: list
      Length: 6
      Element sample:
        Type: str
        Length: 13
        Sample: Fabry disease
    definition:
      Type: str
      Length: 400
      Sample: A rare genetic, multisystemic lysosomal disease characterized by specific cutaneous (angiokeratoma),
  link:
    Type: dict
    Length: 1
    Keys:
    frequency:
      Type: str
      Length: 22
      Sample: Very frequent (99-80%)
  target:
    Type: dict
    Length: 3
    Keys:
    id:
      Type: str
      Length: 10
      Sample: HP:0001635
    name:
      Type: list
      Length: 7
      Element sample:
        Type: str
        Length: 3
        Sample: CHF
    definition:
      Type: str
      Length: 344
      Sample: The presence of an

# Fixing the SimpleAutoRD pipeline

In [93]:
import json
from typing import List, Dict, Any, Optional, Tuple
from dataclasses import dataclass
import torch
from difflib import get_close_matches
import re

@dataclass
class Entity:
    name: str  # Original extracted term
    start: int
    end: int
    entity_type: str
    negated: bool = False
    orpha_id: Optional[str] = None
    confidence: float = 1.0
    extracted_phrase: str = ""  # The exact phrase from text
    matched_orpha_name: Optional[str] = None  # The name from Orphanet if matched

class SimpleAutoRD:
    def __init__(
        self,
        pipeline: Any,
        rare_disease_ontology_path: str,
    ):
        self.pipeline = pipeline
        self.rare_disease_ontology = self._load_jsonl(rare_disease_ontology_path)
        # Create a mapping of disease names to their ontology entries
        self.disease_mapping = {}
        for entry in self.rare_disease_ontology:
            # Keep the original ID format but also add a normalized version
            normalized_entry = {
                'name': entry['name'],
                'id': entry['id'],  # Keep original format
                'orpha_id': entry['id'],  # Add normalized version
                'definition': entry.get('definition', ''),
            }
            self.disease_mapping[entry['name'].lower()] = normalized_entry

    def inspect_ontology(self) -> None:
        """Print the structure and statistics of the loaded ontology file."""
        print("\nOntology Structure Analysis:")
        print("-" * 50)
        
        # Basic statistics
        print(f"Total entries: {len(self.rare_disease_ontology)}")
        
        # Analyze structure of first few entries
        if self.rare_disease_ontology:
            print("\nSample entry structure:")
            sample_entry = self.rare_disease_ontology[0]
            for key, value in sample_entry.items():
                print(f"  {key}: {type(value).__name__} = {value[:100] if isinstance(value, str) else value}")
        
        # Analyze key consistency
        all_keys = set()
        key_counts = {}
        
        for entry in self.rare_disease_ontology:
            entry_keys = set(entry.keys())
            all_keys.update(entry_keys)
            
            # Count how many entries have each key
            for key in entry_keys:
                key_counts[key] = key_counts.get(key, 0) + 1
        
        print("\nKey statistics:")
        total_entries = len(self.rare_disease_ontology)
        for key in sorted(all_keys):
            count = key_counts[key]
            percentage = (count / total_entries) * 100
            print(f"  {key}: present in {count}/{total_entries} entries ({percentage:.1f}%)")
        
        # Check for potential issues
        print("\nPotential issues:")
        required_keys = {'name', 'id'}  # Updated required keys
        missing_required = [entry for entry in self.rare_disease_ontology 
                        if not all(key in entry for key in required_keys)]
        
        if missing_required:
            print(f"- {len(missing_required)} entries missing required keys")
            print("  Sample problematic entry:")
            print(f"  {missing_required[0]}")
        
        # Value analysis for critical fields
        print("\nValue analysis:")
        print("Name statistics:")
        empty_names = sum(1 for entry in self.rare_disease_ontology if not entry.get('name'))
        print(f"- Empty names: {empty_names}")
        
        print("\nORPHA ID statistics:")
        orpha_ids = [entry['id'] for entry in self.rare_disease_ontology if 'id' in entry]
        malformed_ids = [id for id in orpha_ids if not id.startswith('ORPHA:')]
        print(f"- Total ORPHA IDs: {len(orpha_ids)}")
        print(f"- Malformed ORPHA IDs: {len(malformed_ids)}")
        if malformed_ids:
            print(f"  Sample malformed ID: {malformed_ids[0]}")
            
        print("\nDefinition statistics:")
        has_definition = sum(1 for entry in self.rare_disease_ontology if entry.get('definition'))
        print(f"- Entries with definitions: {has_definition}/{total_entries} ({(has_definition/total_entries)*100:.1f}%)")
        
    def _load_jsonl(self, path: str) -> List[Dict]:
        with open(path) as f:
            return [json.loads(line) for line in f]

    def _generate_text(self, prompt: str) -> str:
        """Generate text using the pipeline with chat template."""
        messages = [
            {
                "role": "system", 
                "content": "You are an expert in healthcare and biomedical domain. Extract medical entities accurately."
            },
            {
                "role": "user", 
                "content": prompt
            },
        ]
        
        full_prompt = self.pipeline.tokenizer.apply_chat_template(
            messages,
            tokenize=False,
            add_generation_prompt=True
        )
        
        outputs = self.pipeline(
            full_prompt,
            max_new_tokens=20000, # 20k new tokens to see full response in case it's that long.
            do_sample=True,
            temperature=0.1,
            top_p=0.9,
        )
        
        return outputs[0]["generated_text"][len(full_prompt):]

    def _extract_terms(self, text: str) -> List[str]:
        """Extract medical terms from text, returning a simple list of terms."""
        prompt = f"""Extract all diseases and conditions that are NOT negated (i.e., don't include terms that are preceded by 'no', 'not', 'without', etc.) from the text below.

            Text: {text}

            Return only a Python list of strings, with each term exactly as it appears in the text."""
        
        try:
            response = self._generate_text(prompt).strip()
            # print("LLM DEBUG:", response)
            
            # Extract content between square brackets if present
            if '[' in response and ']' in response:
                response = response[response.find('[') + 1:response.rfind(']')]
            
            # Split on commas and clean up each term
            terms = []
            for term in response.split(','):
                # Clean up the term
                cleaned_term = term.strip()
                # Remove surrounding quotes (single or double)
                cleaned_term = cleaned_term.strip('"').strip("'")
                # Only add non-empty terms
                if cleaned_term:
                    terms.append(cleaned_term)
            
            # print("TERMS DEBUG:", terms)
            return terms
                
        except Exception as e:
            print(f"Error in term extraction: {str(e)}")
            return []

    def _string_similarity(self, s1: str, s2: str) -> float:
        """Calculate string similarity score using a combination of metrics."""
        from difflib import SequenceMatcher
        
        # Convert to lowercase for comparison
        s1, s2 = s1.lower(), s2.lower()
        
        # Get word sets
        words1 = set(s1.split())
        words2 = set(s2.split())
        
        # Calculate Jaccard similarity for words
        word_similarity = len(words1.intersection(words2)) / len(words1.union(words2))
        
        # Calculate sequence similarity
        seq_similarity = SequenceMatcher(None, s1, s2).ratio()
        
        # Combine both metrics (weighing sequence similarity more heavily)
        return (0.3 * word_similarity + 0.7 * seq_similarity)

    def _verify_match_with_llm(self, term: str, candidates: List[Tuple[str, float]]) -> Optional[str]:
        """Use LLM to verify the best match among candidates."""
        if not candidates:
            return None
            
        candidates_str = "\n".join(f"- {name} (similarity: {score:.2f})" 
                                 for name, score in candidates)
        
        prompt = f"""Given a medical term and potential matches from our ontology, determine if any are valid matches.

Term: {term}

Potential matches:
{candidates_str}

Respond with only the name of the best matching term, or "none" if none are valid matches."""

        try:
            response = self._generate_text(prompt).strip().lower()
            if response == "none":
                return None
            # Find the candidate that most closely matches the LLM's response
            matches = get_close_matches(response, [c[0] for c in candidates], n=1, cutoff=0.9)
            return matches[0] if matches else None
        except Exception:
            return candidates[0][0] if candidates else None
        
    def _split_text(self, text: str, max_chunk_size: int = 4000) -> List[str]:
        """Split text into chunks, ensuring splits occur at delimiters.
        
        Args:
            text: Input text to split
            max_chunk_size: Maximum size of each chunk before forcing a split
            
        Returns:
            List of text chunks
        """
        chunks = []
        current_chunk = ""
        
        # Define preferred split points in order of preference
        delimiters = ["\n\n", "\n", ". ", "; ", ", "]
        
        # Helper function to find the best delimiter position
        def find_split_point(text: str, max_size: int) -> int:
            if len(text) <= max_size:
                return len(text)
                
            # Try each delimiter in order of preference
            for delimiter in delimiters:
                # Look for the last delimiter before max_size
                pos = text[:max_size].rfind(delimiter)
                if pos > 0:  # Found a good split point
                    return pos + len(delimiter)
                    
            # If no good delimiter found, force split at max_size
            return max_size
        
        while text:
            split_point = find_split_point(text, max_chunk_size)
            
            if split_point == len(text):
                chunks.append(text)
                break
                
            # Add chunk and continue with remaining text
            chunks.append(text[:split_point].strip())
            text = text[split_point:].strip()
        
        return [c for c in chunks if c]  # Remove any empty chunks
        
    def process_text(self, text: str) -> List[Entity]:
        """Process text through the pipeline."""
        # Split text into natural chunks
        chunks = self._split_text(text)
        
        # Process each chunk and combine results
        all_entities = []
        for chunk in chunks:
            try:
                chunk_entities = self._process_chunk(chunk)
                all_entities.extend(chunk_entities)
            except Exception as e:
                print(f"Error processing chunk: {str(e)}")
                continue
                
        return all_entities
    
    def _fuzzy_match_disease(self, term: str, threshold: float = 0.6) -> List[Tuple[str, float]]:
        """Find closest matching diseases in ontology using fuzzy matching.
        
        Returns:
            List of tuples (disease_name, similarity_score) above threshold
        """
        term = term.lower()
        matches = []
        
        # First try exact word set matching
        term_words = set(term.split())
        
        for disease_name in self.disease_mapping.keys():
            # Calculate similarity score
            similarity = self._string_similarity(term, disease_name)
            if similarity >= threshold:
                matches.append((disease_name, similarity))
        
        # Sort by similarity score
        matches.sort(key=lambda x: x[1], reverse=True)
        return matches[:5]  # Return top 5 matches


    def _verify_rare_disease_with_llm(self, term: str, ontology_entry: Optional[Dict] = None) -> Tuple[bool, float]:
        """Verify if the term is actually a rare disease using LLM.
        
        Args:
            term: The extracted term to verify
            ontology_entry: Optional ontology entry if available
            
        Returns:
            Tuple[bool, float]: (is_rare_disease, confidence)
        """
        context = ""
        if ontology_entry:
            context = f"\nOrphanet information:\nDisease: {ontology_entry['name']}\nOrpha ID: {ontology_entry['orpha_id']}"
        
        prompt = f"""Determine if the following medical term represents a rare disease.

        Term: {term}{context}

        Consider:
        1. Is this a disease (not just a symptom or condition)?
        2. Is it rare (affecting less than 1 in 2000 people)?
        3. If an Orphanet entry is provided, does the term actually match it?
        4. If the orphanet diseases have rare in its name, make sure that the extracted term contains the word rare for it to be considered a rare disease.
        5. Specific forms/variants of common diseases are not rare diseases unless explicitly stated as rare in the term and context.
        
        Respond with only one line in this format:
        DECISION: true/false

        Example responses:
        DECISION: true
        DECISION: false"""

        try:
            response = self._generate_text(prompt).strip().lower()
            # Extract decision
            match = re.search(r'decision:\s*(true|false)', response)
            if match:
                return match.group(1) == 'true'
            return False
        except Exception as e:
            print(f"Error in LLM verification: {str(e)}")
            return False

    def _process_chunk(self, text: str) -> List[Entity]:
        """Process a single chunk of text."""
        try:
            extracted_terms = self._extract_terms(text)
            
            # Deduplicate terms while preserving order
            seen = set()
            unique_terms = []
            for term in extracted_terms:
                if term.lower() not in seen:
                    seen.add(term.lower())
                    unique_terms.append(term)
            
            entities = []
            for term in unique_terms:
                try:
                    # Find the exact phrase in text
                    term_pattern = re.escape(term)
                    matches = list(re.finditer(term_pattern, text, re.IGNORECASE))
                    if not matches:
                        continue
                        
                    # First try exact matching
                    ontology_entry = self.disease_mapping.get(term.lower())
                    matched_name = None
                    
                    if not ontology_entry:
                        # If no exact match, try fuzzy matching
                        candidates = self._fuzzy_match_disease(term)
                        if candidates:
                            # Only verify with LLM if we have close matches
                            if candidates[0][1] > 0.9:  # High confidence match
                                best_match = candidates[0][0]
                                ontology_entry = self.disease_mapping[best_match]
                                matched_name = best_match
                            else:
                                # Verify with LLM for less confident matches
                                best_match = self._verify_match_with_llm(term, candidates)
                                if best_match:
                                    ontology_entry = self.disease_mapping[best_match]
                                    matched_name = best_match
                    
                    # Verify if it's actually a rare disease using LLM
                    is_rare = self._verify_rare_disease_with_llm(term, ontology_entry)
                    
                    # Only create entity if LLM confirms it's a rare disease
                    if is_rare:
                        for match in matches:
                            entity = Entity(
                                name=term,
                                start=match.start(),
                                end=match.end(),
                                entity_type='rare_disease',
                                orpha_id=ontology_entry.get('orpha_id') if ontology_entry else None,
                                extracted_phrase=match.group(),
                                matched_orpha_name=matched_name or (ontology_entry['name'] if ontology_entry else None)
                            )
                            entities.append(entity)
                            
                except Exception as e:
                    print(f"Error processing term '{term}': {str(e)}")
                    continue
                    
            return entities
        except Exception as e:
            print(f"Error in _process_chunk: {str(e)}")
            return []
        
        

In [94]:
from extraction.llm import ModelLoader
cache_dir = "/u/zelalae2/scratch/rdma_cache"
print(f"Initializing ModelLoader with cache directory: {cache_dir}")
model_loader = ModelLoader(cache_dir=cache_dir)
model = "llama3_70b"
print(f"Loading {model} model...")
device = "cuda:1"
pipeline = model_loader.get_llm_pipeline(device, model)

Unused kwargs: ['_load_in_4bit', '_load_in_8bit', 'quant_method']. These kwargs are not used in <class 'transformers.utils.quantization_config.BitsAndBytesConfig'>.


Initializing ModelLoader with cache directory: /shared/rsaas/jw3/rare_disease/model_cache
Initialized ModelLoader with cache directory: /shared/rsaas/jw3/rare_disease/model_cache
Loading llama3_70b model...
Loading LLM!
Loading 70B model: llama3_70b
Generated cache path: /shared/rsaas/jw3/rare_disease/model_cache/Llama-3.3-70B-Instruct_4bit_nf4
Valid cache found at /shared/rsaas/jw3/rare_disease/model_cache/Llama-3.3-70B-Instruct_4bit_nf4
Loading cached quantized model from /shared/rsaas/jw3/rare_disease/model_cache/Llama-3.3-70B-Instruct_4bit_nf4


Loading checkpoint shards:  12%|█▎        | 1/8 [00:01<00:12,  1.80s/it]


KeyboardInterrupt: 

In [8]:
def load_mimic_annotations(file_path: str) -> Dict:
    """Load and parse the MIMIC annotations file."""
    with open(file_path, 'r') as f:
        return json.load(f)
    
annotations_file = "/home/johnwu3/projects/rare_disease/workspace/repos/RareDiseaseMention/filtered_rd_annos.json"
annotations = load_mimic_annotations(annotations_file)


dict_keys(['287', '869', '1423', '1208', '950', '977', '1552', '1284', '1790', '2452', '2541', '2795', '3060', '3149', '3390', '4226', '4043', '4735', '4806', '4938', '5873', '10406', '5738', '6465', '5811', '7363', '13231', '7688', '6188', '8979', '7699', '10859', '6206', '10004', '10715', '7949', '8877', '9512', '6819', '8001', '8960', '7519', '13600', '11238', '16328', '16330', '16334', '16347', '13666', '14936', '11052', '11938', '11838', '11996', '11604', '18299', '22297', '17850', '15908', '17640', '15693', '20208', '31392', '24716', '30442', '31210', '24774', '31762', '29858', '20162', '32721', '25095', '20589', '20617', '23916', '22133', '20387', '21573', '21179', '21197', '39676', '38775', '37158', '37801', '39348', '40555', '37217', '37240', '40196', '25829', '28965', '26976', '26123', '27335', '29045', '29048', '26825', '30504', '32261', '28690', '41498', '47026', '47035', '45138', '42584', '42597', '44132', '44150', '44472', '46706', '45078', '46361', '46820', '46636', '329

In [95]:
# print(annotations['287'])
auto_rd = SimpleAutoRD(pipeline, "data_preprocessing/data/rare_disease_ontology.jsonl")
test_text = annotations['287']['note_details']['text']
test_annotations = annotations['287']['annotations']
# rare_diseases = test_annotations.keys()
print(test_annotations)
mentions = []
for annotation in test_annotations:
    mention = annotation['mention']
    mentions.append(mention)

entities = auto_rd.process_text(test_text)     

[{'mention': 'sick sinus syndrome', 'umls_with_desc': 'C0037052 Sick Sinus Syndrome', 'ordo_with_desc': 'Orphanet_166282  Familial sick sinus syndrome', 'gold_text_to_umls_label': 1, 'gold_text_to_ordo_label': 0, 'document_structure': 'History_of_Past_Illness', 'semehr_label': 1}]


In [58]:
auto_rd.inspect_ontology()


Ontology Structure Analysis:
--------------------------------------------------
Total entries: 14588

Sample entry structure:
  id: str = ORPHA:10
  name: str = 48,XXYY syndrome
  definition: str = A rare sex chromosome number anomaly disorder characterized, genetically, by the presence of an extr

Key statistics:
  definition: present in 14588/14588 entries (100.0%)
  id: present in 14588/14588 entries (100.0%)
  name: present in 14588/14588 entries (100.0%)

Potential issues:
- 14588 entries missing required keys
  Sample problematic entry:
  {'id': 'ORPHA:10', 'name': '48,XXYY syndrome', 'definition': 'A rare sex chromosome number anomaly disorder characterized, genetically, by the presence of an extra X and Y chromosome in males and, clinically, by tall stature, dysfunctional testes associated with infertility and insufficient testosterone production, cognitive, affective and social functioning impairments, global developmental delay, and an increased risk of congenital malformati

In [96]:
def get_rare_diseases(entities) -> List[Entity]:
    """
    Process text and return only the rare disease entities.
    
    Args:
        text: Input text to process
        
    Returns:
        List[Entity]: List of entities where entity_type is 'rare_disease'
    """
    # Process the text to get all entities
    
    # Filter for only rare disease entities
    rare_diseases = [
        entity for entity in entities 
        if entity.entity_type == 'rare_disease' and entity.orpha_id is not None
    ]
    
    return rare_diseases
rds = get_rare_diseases(entities)
for rd in rds:
    rd.extracted_phrase in mentions
print(rds)
print(mentions)


[Entity(name='sick sinus syndrome', start=3609, end=3628, entity_type='rare_disease', negated=False, orpha_id='ORPHA:166282', confidence=1.0, extracted_phrase='sick sinus syndrome', matched_orpha_name='familial sick sinus syndrome')]
['sick sinus syndrome']


In [6]:


auto_rd = SimpleAutoRD(pipeline, "data_preprocessing/data/rare_disease_ontology.jsonl")
print("\n--- Test Cases for _extract_terms ---")

def test_extract_terms(text, expected_terms, test_name):
    print(f"\nRunning test: {test_name}")
    extracted_terms = auto_rd._extract_terms(text)
    print(f"  Text: '{text}'")
    print(f"  Expected Terms: {expected_terms}")
    print(f"  Extracted Terms: {extracted_terms}")
    assert extracted_terms == expected_terms, f"Test '{test_name}' failed. Expected {expected_terms}, but got {extracted_terms}"
    print(f"  Test '{test_name}' passed!\n")


# Test case 1: Simple case with a few diseases
text1 = "The patient has type 2 diabetes and hypertension."
expected_terms1 = ["type 2 diabetes", "hypertension"]
test_extract_terms(text1, expected_terms1, "Simple diseases")

# Test case 2: Negated term
text2 = "The patient has no asthma, but does have COPD."
expected_terms2 = ["COPD"]
test_extract_terms(text2, expected_terms2, "Negated term")

# Test case 3: Symptoms and conditions
text3 = "Symptoms include fatigue, joint pain, and muscle weakness.  The condition is rheumatoid arthritis."
expected_terms3 = ["fatigue", "joint pain", "muscle weakness", "rheumatoid arthritis"]
test_extract_terms(text3, expected_terms3, "Symptoms and Conditions")

# Test case 4: Comma-separated list in a sentence
text4 = "The patient presented with fever, cough, and malaise."
expected_terms4 = ["fever", "cough", "malaise"] # LLM might return slightly different order, adjust expected if needed.
test_extract_terms(text4, expected_terms4, "Comma-separated list")

# Test case 5: No medical terms
text5 = "This is a normal sentence with no medical terms."
expected_terms5 = []
test_extract_terms(text5, expected_terms5, "No medical terms")

# Test case 6: Only negated term
text6 = "The patient has no fever and no cough."
expected_terms6 = []
test_extract_terms(text6, expected_terms6, "Only negated terms")

# # Test case 7: Empty text
# text7 = ""
# expected_terms7 = []
# test_extract_terms(text7, expected_terms7, "Empty text")

# Test case 8:  Terms with numbers and hyphens
text8 = "Patient has type 1 diabetes and pre-diabetes."
expected_terms8 = ["type 1 diabetes", "pre-diabetes"]
test_extract_terms(text8, expected_terms8, "Terms with numbers and hyphens")

print("--- _extract_terms tests completed ---")


--- Test Cases for _extract_terms ---

Running test: Simple diseases
Extracting terms from text...
Generating response from LLM...

Parsing LLM response...

Extracted 2 valid terms
Sample terms:
  - type 2 diabetes
  - hypertension
  Text: 'The patient has type 2 diabetes and hypertension.'
  Expected Terms: ['type 2 diabetes', 'hypertension']
  Extracted Terms: ['type 2 diabetes', 'hypertension']
  Test 'Simple diseases' passed!


Running test: Negated term
Extracting terms from text...
Generating response from LLM...

Parsing LLM response...

Extracted 1 valid terms
Sample terms:
  - COPD
  Text: 'The patient has no asthma, but does have COPD.'
  Expected Terms: ['COPD']
  Extracted Terms: ['COPD']
  Test 'Negated term' passed!


Running test: Symptoms and Conditions
Extracting terms from text...
Generating response from LLM...

Parsing LLM response...

Extracted 4 valid terms
Sample terms:
  - fatigue
  - joint pain
  - muscle weakness
  - rheumatoid arthritis
  Text: 'Symptoms inc

In [7]:
print("\n--- Test Cases for _fuzzy_match_disease ---")

def test_fuzzy_match_disease(term, expected_matches, test_name, threshold=0.6):
    print(f"\nRunning test: {test_name} (threshold={threshold})")
    matches = auto_rd._fuzzy_match_disease(term, threshold=threshold)
    print(matches)
    # print(f"  Term: '{term}'")
    # print(f"  Expected Matches: {expected_matches}")
    # print(f"  Fuzzy Matches: {matches}")
    # assert set(matches) == set(expected_matches), f"Test '{test_name}' failed. Expected matches to contain {expected_matches}, but got {matches}"
    # print(f"  Test '{test_name}' passed!\n")

# Assuming "Spinocerebellar Ataxia Type 2" and "Spinocerebellar Ataxia Type 3" are in your ontology

# Test case 1: Exact match
test_fuzzy_match_disease("Spinocerebellar Ataxia Type 2", ["Spinocerebellar Ataxia Type 2"], "Exact match")

# Test case 2: Slightly fuzzy match
test_fuzzy_match_disease("Spinocerebellar Ataxia Type II", ["Spinocerebellar Ataxia Type 2"], "Slightly fuzzy - Roman numeral")

# Test case 3: More fuzzy, typo
test_fuzzy_match_disease("Spinocerebellar Ataxia typ 2", ["Spinocerebellar Ataxia Type 2"], "More fuzzy - typo")

# Test case 4:  Similar but different disease (should not match highly at higher threshold)
test_fuzzy_match_disease("Spinocerebellar Ataxia Type 99", [], "No match - different type", threshold=0.9)
test_fuzzy_match_disease("Spinocerebellar Ataxia Type 99", ['Spinocerebellar Ataxia Type 2', 'Spinocerebellar Ataxia Type 3'], "Possible matches at lower threshold", threshold=0.5) # May match at lower threshold depending on ontology

# Test case 5: No match at all
test_fuzzy_match_disease("Completely unrelated term", [], "No match - unrelated term")

print("--- _fuzzy_match_disease tests completed ---")


--- Test Cases for _fuzzy_match_disease ---

Running test: Exact match (threshold=0.6)
['spinocerebellar ataxia with axonal neuropathy type 2', 'spinocerebellar ataxia type 2']

Running test: Slightly fuzzy - Roman numeral (threshold=0.6)
['spinocerebellar ataxia type 8', 'spinocerebellar ataxia type 7', 'spinocerebellar ataxia type 6', 'spinocerebellar ataxia type 5', 'spinocerebellar ataxia type 4']

Running test: More fuzzy - typo (threshold=0.6)
['spinocerebellar ataxia type 2', 'spinocerebellar ataxia type 42', 'spinocerebellar ataxia type 32', 'spinocerebellar ataxia type 29', 'spinocerebellar ataxia type 28']

Running test: No match - different type (threshold=0.9)
['spinocerebellar ataxia type 49', 'spinocerebellar ataxia type 29', 'spinocerebellar ataxia type 8', 'spinocerebellar ataxia type 7', 'spinocerebellar ataxia type 6']

Running test: Possible matches at lower threshold (threshold=0.5)
['spinocerebellar ataxia type 49', 'spinocerebellar ataxia type 29', 'spinocerebell