In [2]:
import os
import sys
current_dir = os.getcwd()
parent_dir = os.path.dirname(current_dir)
sys.path.insert(0, parent_dir)
# Set the parent directory as the current directory
os.chdir(parent_dir)

In [None]:
import json
import re
import difflib
import numpy as np
from collections import Counter
from typing import List, Dict, Any, Optional

def normalize_text(text):
    """
    Normalize text for more robust matching:
    - Convert to lowercase
    - Remove extra whitespaces
    - Remove punctuation
    """
    if not text:
        return ""
        
    # Convert to lowercase
    text = text.lower()
    
    # Remove extra whitespaces
    text = re.sub(r'\s+', ' ', text).strip()
    
    # Remove punctuation (but keep some meaningful characters)
    text = re.sub(r'[^\w\s°-]', ' ', text)
    
    return text

def is_direct_phenotype(phenotype, clinical_text, threshold=0.85):
    """
    Determine if a phenotype is directly mentioned in the text, accounting for minor variations.
    
    Args:
        phenotype: The phenotype name to check
        clinical_text: The clinical note text
        threshold: Similarity threshold for considering direct mention (higher = stricter)
        
    Returns:
        Boolean indicating if the phenotype is directly mentioned
    """
    norm_phenotype = normalize_text(phenotype)
    norm_text = normalize_text(clinical_text)
    
    # Check for exact match first
    if norm_phenotype in norm_text:
        return True
    
    # Check individual words from the text 
    text_words = norm_text.split()
    
    # For single-word phenotypes
    if len(norm_phenotype.split()) == 1:
        for word in text_words:
            # Use character-level similarity for single words
            similarity = difflib.SequenceMatcher(None, norm_phenotype, word).ratio()
            if similarity >= threshold:
                return True
    
    # For multi-word phenotypes, check n-grams in the text
    phenotype_words = norm_phenotype.split()
    n = len(phenotype_words)
    
    for i in range(len(text_words) - n + 1):
        text_ngram = " ".join(text_words[i:i+n])
        
        # Check similarity of the entire phrase
        similarity = difflib.SequenceMatcher(None, norm_phenotype, text_ngram).ratio()
        if similarity >= threshold:
            return True
    
    # Check for medical term variations
    variations = generate_medical_variations(phenotype)
    for variation in variations:
        norm_variation = normalize_text(variation)
        if norm_variation in norm_text:
            return True
        
        # Also check fuzzy matching for variations
        for i in range(len(text_words) - len(norm_variation.split()) + 1):
            text_ngram = " ".join(text_words[i:i+len(norm_variation.split())])
            similarity = difflib.SequenceMatcher(None, norm_variation, text_ngram).ratio()
            if similarity >= threshold:
                return True
    
    # If we've reached here, it's not a direct phenotype
    return False

def generate_medical_variations(term):
    """Generate common variations of medical terms."""
    variations = []
    
    # Handle plural forms
    if not term.endswith('s'):
        variations.append(term + 's')  # Simple plural
    
    if term.endswith('y'):
        variations.append(term[:-1] + 'ies')  # y -> ies
    
    if term.endswith('us'):
        variations.append(term[:-2] + 'i')  # us -> i
    
    # Handle hyphenated vs non-hyphenated
    if '-' in term:
        variations.append(term.replace('-', ' '))
    if ' ' in term:
        variations.append(term.replace(' ', '-'))
    
    # Handle common abbreviations
    # This would need to be expanded for specific medical domains
    abbreviations = {
        'cardiomyopathy': 'CM',
        'myocardial infarction': 'MI',
        'congestive heart failure': 'CHF',
        'diabetes mellitus': 'DM',
        'hypertension': 'HTN',
        'chronic obstructive pulmonary disease': 'COPD'
    }
    
    if term.lower() in abbreviations:
        variations.append(abbreviations[term.lower()])
    
    # Check if term is an abbreviation and add the full form
    inv_abbreviations = {v.lower(): k for k, v in abbreviations.items()}
    if term.upper() in [k.upper() for k in inv_abbreviations.keys()]:
        variations.append(inv_abbreviations[term.lower()])
    
    return variations

def find_best_context(phenotype, clinical_text, window_size=150):
    """
    Find the most likely context where the implied phenotype might reside,
    using fuzzy matching and a sliding window approach.
    
    Args:
        phenotype: The implied phenotype to match
        clinical_text: The complete clinical note text
        window_size: Size of context window (characters)
        
    Returns:
        The best matching context snippet
    """
    norm_phenotype = normalize_text(phenotype)
    norm_text = normalize_text(clinical_text)
    
    # If exact match exists, extract the context around it
    if norm_phenotype in norm_text:
        # Find position of the match
        start_pos = norm_text.find(norm_phenotype)
        # Extract context window
        context_start = max(0, start_pos - window_size//2)
        context_end = min(len(norm_text), start_pos + len(norm_phenotype) + window_size//2)
        # Return the original text for this section
        window_start = len(normalize_text(clinical_text[:context_start]))
        window_end = len(normalize_text(clinical_text[:context_end]))
        return clinical_text[window_start:window_end].strip()
    
    # No exact match, try fuzzy matching with sliding windows
    best_score = 0
    best_context = ""
    
    # Create text windows with sentences to preserve logical units
    sentences = re.split(r'(?<=[.!?])\s+', clinical_text)
    current_window = ""
    
    for sentence in sentences:
        # If adding this sentence exceeds window size, evaluate current window
        if len(current_window + " " + sentence) > window_size * 2:
            norm_window = normalize_text(current_window)
            
            # Calculate similarity score
            words = norm_window.split()
            max_word_score = 0
            
            # Try matching with individual words and phrases
            for j in range(len(words)):
                for k in range(j+1, min(j+6, len(words)+1)):  # Try phrases up to 5 words
                    phrase = " ".join(words[j:k])
                    score = difflib.SequenceMatcher(None, norm_phenotype, phrase).ratio()
                    max_word_score = max(max_word_score, score)
            
            # Also try matching against the whole window
            window_score = difflib.SequenceMatcher(None, norm_phenotype, norm_window).ratio()
            score = max(max_word_score, window_score)
            
            if score > best_score:
                best_score = score
                best_context = current_window
            
            # Reset window, starting with the current sentence
            current_window = sentence
        else:
            # Add sentence to current window
            if current_window:
                current_window += " " + sentence
            else:
                current_window = sentence
    
    # Check the last window
    if current_window:
        norm_window = normalize_text(current_window)
        window_score = difflib.SequenceMatcher(None, norm_phenotype, norm_window).ratio()
        if window_score > best_score:
            best_score = window_score
            best_context = current_window
    
    # If we found a reasonably good match
    if best_score > 0.4:
        return best_context.strip()
    
    # Look for relevant medical terms in each sentence as a fallback
    # This helps find relevant context even with conceptual implications
    medical_terms = [
        "elevated", "increased", "decreased", "low", "high", "abnormal",
        "level", "count", "test", "positive", "negative", "detected",
        "diagnosis", "symptom", "sign", "finding", "disease", "disorder",
        "syndrome", "condition", "complication", "deficit", "dysfunction"
    ]
    
    best_term_count = 0
    best_term_context = ""
    
    for sentence in sentences:
        term_count = sum(1 for term in medical_terms if term in sentence.lower())
        if term_count > best_term_count:
            best_term_count = term_count
            best_term_context = sentence
    
    if best_term_context:
        return best_term_context.strip()
    
    # Ultimate fallback: Return first part of the text as context
    return clinical_text[:min(window_size, len(clinical_text))].strip()

def process_dataset(input_file_path, output_file_path):
    """
    Process dataset to create a new JSON structure with implied phenotypes, HPO codes and contexts.
    
    Args:
        input_file_path: Path to input JSON dataset
        output_file_path: Path to output JSON dataset
    """
    # Load the input dataset
    with open(input_file_path, 'r') as f:
        input_data = json.load(f)
    
    # Create output dataset
    output_data = {}
    
    # Process each document
    for doc_id, doc_data in input_data.items():
        # Get clinical text
        clinical_text = doc_data.get("clinical_text", "")
        if not clinical_text:
            continue
        
        # Get phenotypes
        phenotypes = doc_data.get("phenotypes", [])
        if not phenotypes and "verified_phenotypes" in doc_data:
            phenotypes = doc_data.get("verified_phenotypes", [])
        
        # Track all phenotypes for this document
        all_phenotypes = []
        implied_phenotypes_with_context = []
        direct_phenotypes = []
        
        for phenotype_item in phenotypes:
            # Handle different phenotype data structures
            if isinstance(phenotype_item, dict):
                # Extract phenotype name based on available fields
                phenotype_name = phenotype_item.get("phenotype", "")
                if not phenotype_name:
                    phenotype_name = phenotype_item.get("phenotype_name", "")
                
                # Extract HPO code
                hpo_code = phenotype_item.get("hp_id", "")
                if not hpo_code:
                    hpo_code = phenotype_item.get("HPO_Term", "")  # Alternative field name
                if not hpo_code:
                    hpo_code = phenotype_item.get("hpo_id", "")    # Another alternative
                if not hpo_code:
                    hpo_code = phenotype_item.get("hpo_term", "")  # Another alternative
                
                # Try to determine if it's implied using proper fuzzy matching
                status = phenotype_item.get("status", "")
                
                # If status is already specified, use it
                if status == "implied_phenotype":
                    is_implied = True
                elif status == "direct_phenotype":
                    is_implied = False
                else:
                    # Otherwise use our fuzzy matching logic
                    is_implied = not is_direct_phenotype(phenotype_name, clinical_text)
            
            elif isinstance(phenotype_item, str):
                phenotype_name = phenotype_item
                hpo_code = ""  # No HPO code available for string phenotypes
                # Use fuzzy matching to determine if implied
                is_implied = not is_direct_phenotype(phenotype_name, clinical_text)
            
            else:
                continue
            
            # Skip if no phenotype name
            if not phenotype_name:
                continue
                
            # Categorize as direct or implied
            phenotype_info = {
                "phenotype": phenotype_name,
                "hpo_code": hpo_code,
            }
            
            all_phenotypes.append(phenotype_info)
            
            if is_implied:
                # Find best context for implied phenotypes
                context = find_best_context(phenotype_name, clinical_text)
                
                # Add to implied phenotypes list
                implied_phenotypes_with_context.append({
                    "implied_phenotype": phenotype_name,
                    "hpo_code": hpo_code,
                    "context": context
                })
            else:
                # Track direct phenotypes
                direct_phenotypes.append(phenotype_info)
        
        # Add to output data
        output_data[doc_id] = {
            "text": clinical_text,
            "all_phenotypes": all_phenotypes,
            "direct_phenotypes": direct_phenotypes,
            "implied_phenotypes_with_context": implied_phenotypes_with_context
        }
    
    # Save the output dataset
    with open(output_file_path, 'w') as f:
        json.dump(output_data, f, indent=2)
    
    # Calculate statistics
    total_direct = sum(len(doc["direct_phenotypes"]) for doc in output_data.values())
    total_implied = sum(len(doc["implied_phenotypes_with_context"]) for doc in output_data.values())
    total_phenotypes = total_direct + total_implied
    
    print(f"Processed {len(input_data)} documents, saved {len(output_data)} to {output_file_path}")
    print(f"Found {total_phenotypes} total phenotypes")
    print(f"Direct phenotypes: {total_direct} ({total_direct/total_phenotypes*100:.1f}% of total)")
    print(f"Implied phenotypes: {total_implied} ({total_implied/total_phenotypes*100:.1f}% of total)")
    
    # Print statistics on HPO code availability
    with_hpo = 0
    without_hpo = 0
    for doc in output_data.values():
        for item in doc["implied_phenotypes_with_context"]:
            if item["hpo_code"]:
                with_hpo += 1
            else:
                without_hpo += 1
    
    if with_hpo + without_hpo > 0:
        print(f"Implied phenotypes with HPO codes: {with_hpo} ({with_hpo/(with_hpo+without_hpo)*100:.1f}%)")
        print(f"Implied phenotypes without HPO codes: {without_hpo} ({without_hpo/(with_hpo+without_hpo)*100:.1f}%)")
    
    return output_data

# Example usage for Jupyter notebook:
# Define your input and output paths

# Run the processing function
# output_data = process_dataset(input_file_path, output_file_path)


# Run the processing function
# output_data = process_dataset(input_file_path, output_file_path)

# Example usage for Jupyter notebook:
# Define your input and output paths

# hpo_data = read_json_file('data/dataset/mine_hpo.json')

input_file_path = 'data/dataset/mine_hpo.json'
output_file_path = "data/dataset/implied_phenotypes.json"

# Run the processing function
output_data = process_dataset(input_file_path, output_file_path)


Processed 116 documents, saved 116 to data/dataset/implied_phenotypes.json
Found 1813 total phenotypes
Direct phenotypes: 1320 (72.8% of total)
Implied phenotypes: 493 (27.2% of total)
Implied phenotypes with HPO codes: 493 (100.0%)
Implied phenotypes without HPO codes: 0 (0.0%)


In [10]:
from rdma.utils.data import read_json_file, print_json_structure

annotations_file = 'data/dataset/implied_phenotypes.json'
ground_truth_file = 'data/dataset/mine_hpo.json' 
annotations_file = read_json_file(annotations_file)
ground_truth_file = read_json_file(ground_truth_file)
print("annotations file")
print_json_structure(annotations_file)
print("-----------------")
print("ground truth file")
print_json_structure(ground_truth_file)


annotations file
Dictionary:
  1 (dict): 
  Dictionary:
    text (str): 
    all_phenotypes (list): 
    List:
      Item 0 (dict): 
      Dictionary:
        phenotype (str): 
        hpo_code (str): 
      Item 1 (dict): 
      Dictionary:
        phenotype (str): 
        hpo_code (str): 
      Item 2 (dict): 
      Dictionary:
        phenotype (str): 
        hpo_code (str): 
      Item 3 (dict): 
      Dictionary:
        phenotype (str): 
        hpo_code (str): 
      Item 4 (dict): 
      Dictionary:
        phenotype (str): 
        hpo_code (str): 
      Item 5 (dict): 
      Dictionary:
        phenotype (str): 
        hpo_code (str): 
      Item 6 (dict): 
      Dictionary:
        phenotype (str): 
        hpo_code (str): 
      Item 7 (dict): 
      Dictionary:
        phenotype (str): 
        hpo_code (str): 
    direct_phenotypes (list): 
    List:
      Item 0 (dict): 
      Dictionary:
        phenotype (str): 
        hpo_code (str): 
      Item 1 (dict): 
      D

In [19]:
def analyze_phenotype_distribution(annotations_data: Dict) -> Dict:
    """
    Analyze the distribution of direct vs implied phenotypes in the annotations.
    
    Args:
        annotations_data: Dictionary with annotation data
        
    Returns:
        Dictionary with counts and percentages for direct and implied phenotypes
    """
    # Initialize counters
    total_direct = 0
    total_implied = 0
    
    # Count direct and implied phenotypes across all documents
    for doc_id, doc_data in annotations_data.items():
        total_direct += len(doc_data.get("direct_phenotypes", []))
        total_implied += len(doc_data.get("implied_phenotypes_with_context", []))
    
    total_phenotypes = total_direct + total_implied
    
    # Calculate percentages
    direct_percentage = (total_direct / total_phenotypes * 100) if total_phenotypes > 0 else 0
    implied_percentage = (total_implied / total_phenotypes * 100) if total_phenotypes > 0 else 0
    
    # Print statistics to console for verification
    print(f"Total phenotypes: {total_phenotypes}")
    print(f"Direct phenotypes: {total_direct} ({direct_percentage:.2f}%)")
    print(f"Implied phenotypes: {total_implied} ({implied_percentage:.2f}%)")
    
    return {
        'total': total_phenotypes,
        'direct': {
            'count': total_direct,
            'percentage': direct_percentage
        },
        'implied': {
            'count': total_implied,
            'percentage': implied_percentage
        }
    }
def plot_error_distribution_individual(false_positive_counts, false_negative_counts, code_to_name, n=5, error_type='fp', use_plotly=True):
    """
    Create individual bar charts for false positives or false negatives.
    
    Args:
        false_positive_counts: Counter with false positive codes
        false_negative_counts: Counter with false negative codes
        code_to_name: Dictionary mapping HPO codes to phenotype names
        n: Number of top errors to display
        error_type: 'fp' for false positives, 'fn' for false negatives
        use_plotly: Whether to use Plotly (True) or Matplotlib (False)
        
    Returns:
        Figure object (either Plotly or Matplotlib)
    """
    is_fp = error_type == 'fp'
    error_counts = false_positive_counts if is_fp else false_negative_counts
    color = 'rgb(75,75,75)' if is_fp else 'rgb(141,160,203)'  # Dark Gray for FP, Grayish-blue for FN
    
    if use_plotly:
        import plotly.graph_objects as go
        
        # Create dataframe for the selected error type
        data = []
        for code, count in error_counts.most_common(n):
            name = code_to_name.get(code, "Unknown")
            data.append({
                'HPO Code': code,
                'Phenotype': name,
                'Label': f"{code} ({name})",
                'Count': count,
                'Error Type': 'False Positive' if is_fp else 'False Negative'
            })
        
        if data:
            import pandas as pd
            df = pd.DataFrame(data)
            
            # Create the figure
            fig = go.Figure()
            
            fig.add_trace(
                go.Bar(
                    y=df['Label'],
                    x=df['Count'],
                    orientation='h',
                    marker=dict(color=color),
                    name='False Positives' if is_fp else 'False Negatives'
                )
            )
            
            # Update layout - with NO TITLE
            fig.update_layout(
                height=500,
                width=800,
                showlegend=False,
                font=dict(family="Arial", size=12*SCALE),
                margin=dict(l=20, r=20, t=20, b=20),  # Reduced top margin since no title
            )
            
            # Update axes
            fig.update_xaxes(title_text="Count", title_font=dict(size=14*SCALE))
            fig.update_yaxes(autorange="reversed")  # Reverse y-axis
            
            return fig
        else:
            # Empty figure with a message
            fig = go.Figure()
            fig.add_annotation(
                text=f"No {'false positives' if is_fp else 'false negatives'} found",
                xref="paper", yref="paper",
                x=0.5, y=0.5,
                showarrow=False,
                font=dict(size=14*SCALE)
            )
            fig.update_layout(
                height=500,
                width=800,
                margin=dict(l=20, r=20, t=20, b=20)
            )
            return fig
    else:
        import matplotlib.pyplot as plt
        import numpy as np
        import textwrap
        
        # Create a wrapper function for text wrapping
        def wrap_labels(text, width=25):
            """Wrap text to fit within specified width."""
            return '\n'.join(textwrap.wrap(text, width=width))
        
        # Create figure
        fig, ax = plt.subplots(figsize=(10, 8))
        
        # Get top N errors
        top_errors = error_counts.most_common(n)
        if top_errors:
            codes, counts = zip(*top_errors)
            names = [f"{code} ({code_to_name.get(code, 'Unknown')})" for code in codes]
            names_wrapped = [wrap_labels(name) for name in names]
            
            # Create horizontal bar chart
            y_pos = np.arange(len(names_wrapped))
            ax.barh(y_pos, counts, align='center', color=color)
            ax.set_yticks(y_pos)
            ax.set_yticklabels(names_wrapped, fontsize=10*SCALE)
            ax.invert_yaxis()  # Labels read top-to-bottom
            ax.set_xlabel('Count', fontsize=12*SCALE)
            ax.tick_params(axis='both', which='major', labelsize=10*SCALE)
        else:
            ax.text(0.5, 0.5, f"No {'false positives' if is_fp else 'false negatives'} found", 
                   ha='center', va='center', fontsize=12*SCALE)
        
        # Adjust layout
        plt.tight_layout()
        
        return fig

def plot_error_type_pie_individual(error_counts, error_type='fp', use_plotly=True):
    """
    Create individual pie chart for false positive or false negative error types.
    
    Args:
        error_counts: Dictionary with error counts by type (direct/implied/unknown)
        error_type: 'fp' for false positives, 'fn' for false negatives
        use_plotly: Whether to use Plotly (True) or Matplotlib (False)
        
    Returns:
        Figure object (either Plotly or Matplotlib)
    """
    is_fp = error_type == 'fp'
    
    # Define color palette
    color_palette = [
        'rgb(252,141,98)',   # Orange (for Direct Phenotypes)
        'rgb(141,160,203)',  # Grayish-blue (for Implied Phenotypes)
        'rgb(150,150,150)'   # Gray (for Unknown Type)
    ]
    
    if use_plotly:
        import plotly.graph_objects as go
        
        # Create data for the pie chart
        data = []
        if error_counts['direct'] > 0:
            data.append({'Category': 'Direct Phenotypes', 'Count': error_counts['direct']})
        if error_counts['implied'] > 0:
            data.append({'Category': 'Implied Phenotypes', 'Count': error_counts['implied']})
        if error_counts['unknown'] > 0:
            data.append({'Category': 'Unknown Type', 'Count': error_counts['unknown']})
        
        # Create figure
        fig = go.Figure()
        
        # Add pie chart trace
        if data:
            import pandas as pd
            df = pd.DataFrame(data)
            fig.add_trace(
                go.Pie(
                    labels=df['Category'],
                    values=df['Count'],
                    textinfo='percent',
                    textfont=dict(size=14*SCALE, weight='bold'),
                    marker=dict(colors=color_palette[:len(df)]),
                    hovertemplate='%{label}<br>%{value} (%{percent})<extra></extra>',
                    hole=0
                )
            )
        else:
            fig.add_annotation(
                text=f"No {'false positives' if is_fp else 'false negatives'} found",
                xref="paper", yref="paper",
                x=0.5, y=0.5,
                showarrow=False,
                font=dict(size=14*SCALE)
            )
        
        # Update layout - NO TITLE
        fig.update_layout(
            height=500,
            width=500,
            font=dict(family="Arial", size=12*SCALE),
            legend=dict(
                orientation="h",          # Horizontal legend
                yanchor="bottom",         # Anchor point
                y=1.02,                   # Position above chart
                xanchor="center",         # Center horizontally
                x=0.5,                    # Center position
                font=dict(
                    size=14*SCALE,        # Larger font size
                    family="Arial",
                    weight="bold"         # Bold text
                )
            ),
            margin=dict(t=30, b=30, l=30, r=30)  # Reduced top margin since no title
        )
        
        return fig
    else:
        import matplotlib.pyplot as plt
        
        # Prepare data
        labels = []
        sizes = []
        
        if error_counts['direct'] > 0:
            labels.append('Direct Phenotypes')
            sizes.append(error_counts['direct'])
            
        if error_counts['implied'] > 0:
            labels.append('Implied Phenotypes')
            sizes.append(error_counts['implied'])
            
        if error_counts['unknown'] > 0:
            labels.append('Unknown Type')
            sizes.append(error_counts['unknown'])
        
        # Create figure
        fig, ax = plt.subplots(figsize=(8, 8))
        
        # Plot pie chart
        if sizes:
            ax.pie(sizes, labels=labels, autopct='%1.1f%%', startangle=90, 
                  colors=['rgb(252,141,98)', 'rgb(141,160,203)', 'rgb(150,150,150)'],
                  textprops={'fontsize': 12*SCALE, 'fontweight': 'bold'})
            ax.axis('equal')  # Equal aspect ratio ensures that pie is drawn as a circle
        else:
            ax.text(0.5, 0.5, f"No {'false positives' if is_fp else 'false negatives'} found", 
                   ha='center', va='center', fontsize=12*SCALE)
        
        plt.tight_layout()
        return fig

def plot_phenotype_distribution_pie(phenotype_dist: Dict, use_plotly=True):
    """
    Create a pie chart for the distribution of direct vs implied phenotypes.
    
    Args:
        phenotype_dist: Dictionary with distribution data
        use_plotly: Whether to use Plotly (True) or Matplotlib (False)
        
    Returns:
        Figure object (either Plotly or Matplotlib)
    """
    # Set colors
    color_palette = [
        'rgb(252,141,98)',   # Orange (for Direct Phenotypes)
        'rgb(141,160,203)'   # Grayish-blue (for Implied Phenotypes)
    ]
    
    if use_plotly:
        import plotly.express as px
        import pandas as pd
        
        # Create DataFrame for the chart
        data = {
            'Category': ['Direct Phenotypes', 'Implied Phenotypes'],
            'Count': [
                phenotype_dist['direct']['count'], 
                phenotype_dist['implied']['count']
            ],
            'Percentage': [
                phenotype_dist['direct']['percentage'], 
                phenotype_dist['implied']['percentage']
            ]
        }
        
        # Create a DataFrame
        df = pd.DataFrame(data)
        
        # Create the pie chart - NO TITLE
        fig = px.pie(
            df, 
            values='Percentage', 
            names='Category',
            color_discrete_sequence=color_palette,
            title=None,  # No title
            hole=0,  # Regular pie chart (no hole)
        )
        
        # Add only percentages inside the pie
        fig.update_traces(
            textposition='inside', 
            textinfo='percent',  # Only percentage, no label
            textfont=dict(size=24*SCALE, weight="bold"),  # Scaled size and bold
            hovertemplate='%{label}<br>%{value:.1f}%<br>Count: %{customdata[0]}<extra></extra>',
            customdata=df[['Count']].values
        )
        
        # Update layout
        fig.update_layout(
            width=500*SCALE,
            height=500*SCALE,
            legend=dict(
                orientation="h",          # Horizontal legend
                yanchor="bottom",         # Anchor point
                y=1.02,                   # Position above chart
                xanchor="center",         # Center horizontally
                x=0.5,                    # Center position
                font=dict(
                    size=24*SCALE,        # Larger font size
                    family="Arial",
                    color="black",
                    weight="bold"         # Bold text
                )
            ),
            margin=dict(t=30*SCALE, b=20*SCALE, l=20*SCALE, r=20*SCALE),  # Reduced top margin
        )
        
        return fig
    else:
        import matplotlib.pyplot as plt
        
        # Create a figure
        fig, ax = plt.subplots(figsize=(10, 10))
        
        # Data
        labels = ['Direct Phenotypes', 'Implied Phenotypes']
        sizes = [
            phenotype_dist['direct']['percentage'],
            phenotype_dist['implied']['percentage']
        ]
        
        # Plot pie chart
        wedges, texts, autotexts = ax.pie(
            sizes, 
            labels=labels, 
            autopct='%1.1f%%',
            colors=color_palette,
            textprops={'fontsize': 14*SCALE, 'fontweight': 'bold'}
        )
        
        # Increase font size for labels and percentages
        plt.setp(autotexts, size=16*SCALE, weight='bold')
        plt.setp(texts, size=14*SCALE, weight='bold')
        
        # Equal aspect ratio ensures pie is circular
        ax.axis('equal')
        
        return fig

def run_improved_analysis():
    """Run the improved HPO error analysis with focus on error types."""
    # Define file paths
    predictions_file = 'data/results/agents/hpo/step3/verified_lab_test_matched.json'
    ground_truth_file = 'data/dataset/mine_hpo.json'
    annotations_file = 'data/dataset/implied_phenotypes.json'
    
    # Create figures directory if it doesn't exist
    os.makedirs('data/figures', exist_ok=True)
    
    # Whether to use Plotly for visualizations
    use_plotly = True
    # Scale for font and visualization sizes
    global SCALE
    SCALE = 1.5
    
    # Load data
    print(f"Loading data from files...")
    predictions_data, ground_truth_data, annotations_data = load_data(
        predictions_file, ground_truth_file, annotations_file
    )
    
    # Extract predictions and ground truth
    print(f"Extracting HPO codes from predictions and ground truth...")
    predictions_dict = extract_predictions(predictions_data)
    ground_truth_dict = extract_ground_truth(ground_truth_data)
    
    # Extract phenotype types from annotations and prediction data
    print(f"Extracting phenotype types from annotations...")
    annotation_phenotype_types = extract_phenotype_types_from_annotations(annotations_data)
    prediction_phenotype_types = extract_prediction_types(predictions_data)
    
    # Combine type information from both sources, prioritizing annotations
    phenotype_types = {**prediction_phenotype_types, **annotation_phenotype_types}
    
    # Extract phenotype names
    print(f"Extracting phenotype names...")
    pred_code_to_name = extract_phenotype_names(predictions_data, is_ground_truth=False)
    gt_code_to_name = extract_phenotype_names(ground_truth_data, is_ground_truth=True)
    annotation_code_to_name = extract_additional_names_from_annotations(annotations_data)
    
    # Combine name mappings, prioritize annotations over predictions and ground truth
    code_to_name = {**gt_code_to_name, **pred_code_to_name, **annotation_code_to_name}
    
    # Find errors
    print(f"Finding false positives and false negatives...")
    false_positive_counts, false_negative_counts, case_errors = find_errors(predictions_dict, ground_truth_dict)
    
    # Analyze errors with improved approach
    print(f"Analyzing errors by type...")
    fp_error_counts, fn_error_counts = analyze_errors_with_text_matching(
        false_positive_counts, 
        false_negative_counts, 
        ground_truth_data,
        code_to_name,
        phenotype_types
    )
    
    # 1. Create and save individual false positive bar chart
    print(f"Creating false positive bar chart...")
    if use_plotly:
        fp_bar_fig = plot_error_distribution_individual(
            false_positive_counts, false_negative_counts, code_to_name, n=5, error_type='fp', use_plotly=True
        )
        fp_bar_fig.write_image('data/figures/top5_false_positives_bar.png', scale=SCALE)
        # Also save as HTML for interactive viewing
        fp_bar_fig.write_html('data/figures/top5_false_positives_bar.html')
        print("Top 5 false positives bar chart saved to 'data/figures/top5_false_positives_bar.png' and HTML")
    else:
        fp_bar_fig = plot_error_distribution_individual(
            false_positive_counts, false_negative_counts, code_to_name, n=5, error_type='fp', use_plotly=False
        )
        fp_bar_fig.savefig('data/figures/top5_false_positives_bar.png', dpi=300*SCALE, bbox_inches='tight')
        plt.close(fp_bar_fig)
        print("Top 5 false positives bar chart saved to 'data/figures/top5_false_positives_bar.png'")
    
    # 2. Create and save individual false negative bar chart
    print(f"Creating false negative bar chart...")
    if use_plotly:
        fn_bar_fig = plot_error_distribution_individual(
            false_positive_counts, false_negative_counts, code_to_name, n=5, error_type='fn', use_plotly=True
        )
        fn_bar_fig.write_image('data/figures/top5_false_negatives_bar.png', scale=SCALE)
        # Also save as HTML for interactive viewing
        fn_bar_fig.write_html('data/figures/top5_false_negatives_bar.html')
        print("Top 5 false negatives bar chart saved to 'data/figures/top5_false_negatives_bar.png' and HTML")
    else:
        fn_bar_fig = plot_error_distribution_individual(
            false_positive_counts, false_negative_counts, code_to_name, n=5, error_type='fn', use_plotly=False
        )
        fn_bar_fig.savefig('data/figures/top5_false_negatives_bar.png', dpi=300*SCALE, bbox_inches='tight')
        plt.close(fn_bar_fig)
        print("Top 5 false negatives bar chart saved to 'data/figures/top5_false_negatives_bar.png'")
    
    # 3. Create and save individual false positive pie chart
    print(f"Creating false positive pie chart...")
    if use_plotly:
        fp_pie_fig = plot_error_type_pie_individual(fp_error_counts, error_type='fp', use_plotly=True)
        fp_pie_fig.write_image('data/figures/false_positives_pie.png', scale=SCALE)
        # Also save as HTML for interactive viewing
        fp_pie_fig.write_html('data/figures/false_positives_pie.html')
        print("False positive error type pie chart saved to 'data/figures/false_positives_pie.png' and HTML")
    else:
        fp_pie_fig = plot_error_type_pie_individual(fp_error_counts, error_type='fp', use_plotly=False)
        fp_pie_fig.savefig('data/figures/false_positives_pie.png', dpi=300*SCALE, bbox_inches='tight')
        plt.close(fp_pie_fig)
        print("False positive error type pie chart saved to 'data/figures/false_positives_pie.png'")
    
    # 4. Create and save individual false negative pie chart
    print(f"Creating false negative pie chart...")
    if use_plotly:
        fn_pie_fig = plot_error_type_pie_individual(fn_error_counts, error_type='fn', use_plotly=True)
        fn_pie_fig.write_image('data/figures/false_negatives_pie.png', scale=SCALE)
        # Also save as HTML for interactive viewing
        fn_pie_fig.write_html('data/figures/false_negatives_pie.html')
        print("False negative error type pie chart saved to 'data/figures/false_negatives_pie.png' and HTML")
    else:
        fn_pie_fig = plot_error_type_pie_individual(fn_error_counts, error_type='fn', use_plotly=False)
        fn_pie_fig.savefig('data/figures/false_negatives_pie.png', dpi=300*SCALE, bbox_inches='tight')
        plt.close(fn_pie_fig)
        print("False negative error type pie chart saved to 'data/figures/false_negatives_pie.png'")
    
    # 5. Create and save distribution pie chart for all phenotypes from annotations
    print("\nCreating distribution of direct vs implied phenotypes from annotations...")
    if annotations_data:
        phenotype_dist = analyze_phenotype_distribution(annotations_data)
        if use_plotly:
            dist_fig = plot_phenotype_distribution_pie(phenotype_dist, use_plotly=True)
            dist_fig.write_image('data/figures/phenotype_presence_chart.png', scale=SCALE)
            dist_fig.write_html('data/figures/phenotype_presence_chart.html')
            print("Phenotype distribution chart saved to 'data/figures/phenotype_presence_chart.png' and HTML")
    
    # Also export data for further analysis
    error_analysis_data = {
        "false_positive_counts": dict(false_positive_counts),
        "false_negative_counts": dict(false_negative_counts),
        "fp_error_types": fp_error_counts,
        "fn_error_types": fn_error_counts,
        "case_errors": case_errors
    }
    
    with open('data/figures/error_analysis_data.json', 'w') as f:
        json.dump(error_analysis_data, f, indent=2)
    
    print("\nError analysis data saved to 'data/figures/error_analysis_data.json'")
    print("\nAnalysis complete! Individual visualizations have been saved to the data/figures directory.")
    
import json
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from collections import Counter
from typing import Dict, List, Tuple, Any
from fuzzywuzzy import fuzz
import os
import textwrap
import plotly.express as px
import plotly.graph_objects as go
from plotly.subplots import make_subplots

# Set style for better visualization
plt.style.use('ggplot')
# Scale parameter for font sizes and figure quality
SCALE = 1.5  # Default scale factor

def normalize_text(text: str) -> str:
    """Normalize text by converting to lowercase and removing extra whitespace."""
    if not isinstance(text, str):
        return ""
    return ' '.join(text.lower().split())

def load_data(predictions_file: str, ground_truth_file: str, annotations_file: str = None):
    """Load prediction and ground truth data, and annotations if provided."""
    with open(predictions_file, 'r') as f:
        predictions = json.load(f)
    
    with open(ground_truth_file, 'r') as f:
        ground_truth = json.load(f)
    
    # Handle nested structure with metadata and results keys
    if isinstance(predictions, dict) and "results" in predictions:
        predictions_data = predictions.get("results", {})
    else:
        predictions_data = predictions
    
    # Load annotations if provided
    annotations_data = None
    if annotations_file and os.path.exists(annotations_file):
        try:
            with open(annotations_file, 'r') as f:
                annotations_data = json.load(f)
            print(f"Successfully loaded annotations from {annotations_file}")
        except Exception as e:
            print(f"Error loading annotations file: {e}")
    
    return predictions_data, ground_truth, annotations_data

def extract_predictions(predictions_data: Dict) -> Dict[str, List[str]]:
    """Extract HPO codes from the predictions data structure."""
    result = {}
    
    # Handle different possible formats
    if isinstance(predictions_data, dict):
        for case_id, case_data in predictions_data.items():
            hpo_codes = []
            
            # Check for different possible field names for phenotype lists
            phenotype_fields = ["matched_phenotypes", "verified_phenotypes"]
            for field in phenotype_fields:
                if field in case_data:
                    phenotype_list = case_data[field]
                    if isinstance(phenotype_list, list):
                        for item in phenotype_list:
                            if not isinstance(item, dict):
                                continue
                                
                            # Check for different possible field names for HPO codes
                            for code_field in ["HPO_Term", "hpo_term", "hpo_id", "hp_id"]:
                                if code_field in item:
                                    code = item[code_field]
                                    # Ensure code starts with HP:
                                    if code and isinstance(code, str) and code.startswith("HP:"):
                                        hpo_codes.append(code)
                                        break
            
            # Add non-empty lists to result
            if hpo_codes:
                result[str(case_id)] = hpo_codes
    
    return result

def extract_prediction_types(predictions_data: Dict) -> Dict[str, str]:
    """Extract the phenotype types (direct/implied) from predictions data."""
    code_to_type = {}
    
    if isinstance(predictions_data, dict):
        for case_id, case_data in predictions_data.items():
            for field in ["matched_phenotypes", "verified_phenotypes"]:
                if field in case_data and isinstance(case_data[field], list):
                    for item in case_data[field]:
                        if not isinstance(item, dict):
                            continue
                        
                        # Get the HPO code
                        hpo_code = None
                        for code_field in ["HPO_Term", "hpo_term", "hpo_id", "hp_id"]:
                            if code_field in item and item[code_field]:
                                hpo_code = item[code_field]
                                break
                        
                        # Get the phenotype type
                        if hpo_code and "status" in item:
                            code_to_type[hpo_code] = item["status"]
    
    return code_to_type

def extract_ground_truth(ground_truth_data: Dict) -> Dict[str, List[str]]:
    """Extract ground truth HPO codes from the ground truth data structure."""
    result = {}
    
    # Handle different possible formats
    if isinstance(ground_truth_data, dict):
        for case_id, case_data in ground_truth_data.items():
            hpo_codes = []
            
            # Try different possible field names
            for field in ["phenotypes", "ground_truth", "hpo_terms"]:
                if field in case_data and isinstance(case_data[field], list):
                    for item in case_data[field]:
                        # Check for different possible field names in nested dictionaries
                        if isinstance(item, dict):
                            for code_field in ["hpo_id", "HPO_Term", "hpo_code", "hp_id"]:
                                if code_field in item:
                                    code = item[code_field]
                                    # Ensure code starts with HP:
                                    if code and isinstance(code, str) and code.startswith("HP:"):
                                        hpo_codes.append(code)
                                        break
                        elif isinstance(item, str) and item.startswith("HP:"):
                            # Direct HPO code string
                            hpo_codes.append(item)
            
            # Add non-empty lists to result
            if hpo_codes:
                result[str(case_id)] = hpo_codes
    
    return result

def extract_phenotype_types_from_annotations(annotations_data: Dict) -> Dict[str, str]:
    """
    Extract the phenotype types (direct/implied) from annotations file.
    
    Returns:
        Dictionary mapping HPO codes to their type ('direct_phenotype' or 'implied_phenotype')
    """
    code_to_type = {}
    
    if not annotations_data:
        return code_to_type
    
    for case_id, case_data in annotations_data.items():
        # Process direct phenotypes
        if "direct_phenotypes" in case_data and isinstance(case_data["direct_phenotypes"], list):
            for item in case_data["direct_phenotypes"]:
                if isinstance(item, dict) and "hpo_code" in item:
                    code_to_type[item["hpo_code"]] = "direct_phenotype"
        
        # Process implied phenotypes
        if "implied_phenotypes_with_context" in case_data and isinstance(case_data["implied_phenotypes_with_context"], list):
            for item in case_data["implied_phenotypes_with_context"]:
                if isinstance(item, dict) and "hpo_code" in item:
                    code_to_type[item["hpo_code"]] = "implied_phenotype"
                    
        # Process all_phenotypes as a fallback if no type is specified
        if "all_phenotypes" in case_data and isinstance(case_data["all_phenotypes"], list):
            for item in case_data["all_phenotypes"]:
                if isinstance(item, dict) and "hpo_code" in item:
                    # Only add if not already categorized
                    if item["hpo_code"] not in code_to_type:
                        code_to_type[item["hpo_code"]] = "unknown"
    
    print(f"Extracted {len(code_to_type)} phenotype types from annotations")
    return code_to_type

def extract_phenotype_names(data: Dict, is_ground_truth: bool = False) -> Dict[str, str]:
    """Extract HPO code to phenotype name mappings."""
    code_to_name = {}
    
    if isinstance(data, dict):
        for case_id, case_data in data.items():
            # Determine which fields to check based on whether it's ground truth or predictions
            if is_ground_truth:
                fields_to_check = ["phenotypes", "ground_truth", "hpo_terms"]
            else:
                fields_to_check = ["matched_phenotypes", "verified_phenotypes"]
                
            for field in fields_to_check:
                if field in case_data:
                    phenotype_list = case_data[field]
                    if isinstance(phenotype_list, list):
                        for item in phenotype_list:
                            if not isinstance(item, dict):
                                continue
                            
                            # Get phenotype name and code
                            phenotype_name = None
                            hpo_code = None
                            
                            # Check various possible field names for the phenotype name
                            for name_field in ["phenotype", "entity", "name", "phenotype_name", "text"]:
                                if name_field in item and item[name_field]:
                                    phenotype_name = item[name_field]
                                    break
                            
                            # Check various possible field names for the HPO code
                            for code_field in ["HPO_Term", "hpo_term", "hpo_id", "hp_id", "HPO_ID"]:
                                if code_field in item and item[code_field]:
                                    hpo_code = item[code_field]
                                    break
                            
                            if hpo_code and isinstance(hpo_code, str) and hpo_code.startswith("HP:") and phenotype_name:
                                code_to_name[hpo_code] = phenotype_name
    
    return code_to_name

def extract_additional_names_from_annotations(annotations_data: Dict) -> Dict[str, str]:
    """Extract additional phenotype names from annotations."""
    code_to_name = {}
    
    if not annotations_data:
        return code_to_name
    
    for case_id, case_data in annotations_data.items():
        # Process all phenotypes sections for names
        for section in ["all_phenotypes", "direct_phenotypes"]:
            if section in case_data and isinstance(case_data[section], list):
                for item in case_data[section]:
                    if isinstance(item, dict) and "phenotype" in item and "hpo_code" in item:
                        code_to_name[item["hpo_code"]] = item["phenotype"]
        
        # Process implied phenotypes
        if "implied_phenotypes_with_context" in case_data and isinstance(case_data["implied_phenotypes_with_context"], list):
            for item in case_data["implied_phenotypes_with_context"]:
                if isinstance(item, dict) and "implied_phenotype" in item and "hpo_code" in item:
                    code_to_name[item["hpo_code"]] = item["implied_phenotype"]
    
    return code_to_name

def find_errors(predictions_dict: Dict[str, List[str]], ground_truth_dict: Dict[str, List[str]]):
    """Find false positives and false negatives across all cases."""
    all_false_positives = []
    all_false_negatives = []
    case_errors = {}
    
    # Process each case
    for case_id in sorted(set(list(predictions_dict.keys()) + list(ground_truth_dict.keys()))):
        # Get HPO codes for this case
        predictions = set(predictions_dict.get(case_id, []))
        ground_truth = set(ground_truth_dict.get(case_id, []))
        
        # Find false positives (codes in predictions but not in ground truth)
        false_positives = predictions - ground_truth
        
        # Find false negatives (codes in ground truth but not in predictions)
        false_negatives = ground_truth - predictions
        
        # Store case-specific errors
        if false_positives or false_negatives:
            case_errors[case_id] = {
                "false_positives": list(false_positives),
                "false_negatives": list(false_negatives)
            }
        
        # Add to overall counts
        all_false_positives.extend(false_positives)
        all_false_negatives.extend(false_negatives)
    
    # Count frequency of each error
    false_positive_counts = Counter(all_false_positives)
    false_negative_counts = Counter(all_false_negatives)
    
    return false_positive_counts, false_negative_counts, case_errors

def check_if_phenotype_in_text(hpo_code: str, clinical_text: str, code_to_name: Dict[str, str]) -> bool:
    """
    Check if the phenotype name appears directly in the clinical text.
    This helps determine if it's a direct or implied phenotype.
    """
    if not clinical_text or not hpo_code in code_to_name:
        return False
    
    phenotype_name = code_to_name[hpo_code]
    clinical_text_lower = clinical_text.lower()
    phenotype_name_lower = phenotype_name.lower()
    
    # Check if the exact phenotype name appears in the text
    if phenotype_name_lower in clinical_text_lower:
        return True
    
    # Check for partial matches (more than 3 words in phenotype name)
    words = phenotype_name_lower.split()
    if len(words) > 3:
        # Check if any 3-word combinations appear in the text
        for i in range(len(words) - 2):
            three_word_phrase = " ".join(words[i:i+3])
            if three_word_phrase in clinical_text_lower and len(three_word_phrase) > 10:
                return True
    
    return False

def analyze_errors_with_text_matching(
    false_positive_counts: Counter, 
    false_negative_counts: Counter, 
    ground_truth_data: Dict,
    code_to_name: Dict[str, str], 
    annotation_types: Dict[str, str] = None
) -> Tuple[Dict[str, int], Dict[str, int]]:
    """
    Analyze errors by checking if phenotype names appear in clinical text.
    
    Args:
        false_positive_counts: Counter of false positive codes
        false_negative_counts: Counter of false negative codes
        ground_truth_data: Dictionary with clinical text for each case
        code_to_name: Dictionary mapping HPO codes to phenotype names
        annotation_types: Dictionary mapping HPO codes to phenotype types (from annotations)
        
    Returns:
        Tuple of (fp_error_counts, fn_error_counts) dictionaries
    """
    fp_error_counts = {'direct': 0, 'implied': 0, 'unknown': 0}
    fn_error_counts = {'direct': 0, 'implied': 0, 'unknown': 0}
    
    # Create a map of all clinical texts
    case_to_text = {}
    for case_id, case_data in ground_truth_data.items():
        if "clinical_text" in case_data:
            case_to_text[case_id] = case_data["clinical_text"]
    
    # Analyze false negatives first
    for hpo_code, count in false_negative_counts.items():
        # Check if we have type information from annotations
        if annotation_types and hpo_code in annotation_types:
            # Use annotation type information
            phenotype_type = annotation_types[hpo_code]
            if phenotype_type == "direct_phenotype":
                fn_error_counts['direct'] += count
            elif phenotype_type == "implied_phenotype":
                fn_error_counts['implied'] += count
            else:
                fn_error_counts['unknown'] += count
        else:
            # If no annotation type, try text-based classification
            found_in_text = False
            for case_id, text in case_to_text.items():
                if check_if_phenotype_in_text(hpo_code, text, code_to_name):
                    found_in_text = True
                    break
            
            if found_in_text:
                fn_error_counts['direct'] += count
            else:
                # If not found in text, it's likely implied
                fn_error_counts['implied'] += count
    
    # Analyze false positives similarly
    for hpo_code, count in false_positive_counts.items():
        # Check if we have type information from annotations
        if annotation_types and hpo_code in annotation_types:
            # Use annotation type information
            phenotype_type = annotation_types[hpo_code]
            if phenotype_type == "direct_phenotype":
                fp_error_counts['direct'] += count
            elif phenotype_type == "implied_phenotype":
                fp_error_counts['implied'] += count
            else:
                fp_error_counts['unknown'] += count
        else:
            # If no annotation type, try text-based classification
            found_in_text = False
            for case_id, text in case_to_text.items():
                if check_if_phenotype_in_text(hpo_code, text, code_to_name):
                    found_in_text = True
                    break
            
            if found_in_text:
                fp_error_counts['direct'] += count
            else:
                # If not found in text, it's likely implied
                fp_error_counts['implied'] += count
    
    return fp_error_counts, fn_error_counts

def plot_error_distribution(false_positive_counts, false_negative_counts, code_to_name, n=5, use_plotly=True):
    """
    Plot the distribution of the most common false positives and false negatives with wrapped labels.
    
    Args:
        false_positive_counts: Counter with false positive codes
        false_negative_counts: Counter with false negative codes
        code_to_name: Dictionary mapping HPO codes to phenotype names
        n: Number of top errors to display
        use_plotly: Whether to use Plotly (True) or Matplotlib (False)
        
    Returns:
        Figure object (either Plotly or Matplotlib)
    """
    if use_plotly:
        # Create two separate dataframes for false positives and negatives
        fp_data = []
        for code, count in false_positive_counts.most_common(n):
            name = code_to_name.get(code, "Unknown")
            fp_data.append({
                'HPO Code': code,
                'Phenotype': name,
                'Label': f"{code} ({name})",
                'Count': count,
                'Error Type': 'False Positive'
            })
            
        fn_data = []
        for code, count in false_negative_counts.most_common(n):
            name = code_to_name.get(code, "Unknown")
            fn_data.append({
                'HPO Code': code,
                'Phenotype': name,
                'Label': f"{code} ({name})",
                'Count': count,
                'Error Type': 'False Negative'
            })
        
        # Create subplots: 1 row, 2 columns
        fig = make_subplots(rows=1, cols=2, 
                            subplot_titles=('Top 5 False Positives (Incorrect Predictions)', 
                                          'Top 5 False Negatives (Missed Phenotypes)'),
                            horizontal_spacing=0.15)
        
        # Add trace for false positives
        if fp_data:
            fp_df = pd.DataFrame(fp_data)
            fig.add_trace(
                go.Bar(
                    y=fp_df['Label'],
                    x=fp_df['Count'],
                    orientation='h',
                    marker=dict(color='rgb(75,75,75)'),  # Dark Gray
                    name='False Positives'
                ),
                row=1, col=1
            )
        else:
            fig.add_annotation(
                text="No false positives found",
                xref="x1", yref="y1",
                x=0.5, y=0.5,
                showarrow=False,
                font=dict(size=14*SCALE),
                row=1, col=1
            )
            
        # Add trace for false negatives
        if fn_data:
            fn_df = pd.DataFrame(fn_data)
            fig.add_trace(
                go.Bar(
                    y=fn_df['Label'],
                    x=fn_df['Count'],
                    orientation='h',
                    marker=dict(color='rgb(141,160,203)'),  # Grayish-blue
                    name='False Negatives'
                ),
                row=1, col=2
            )
        else:
            fig.add_annotation(
                text="No false negatives found",
                xref="x2", yref="y2",
                x=0.5, y=0.5,
                showarrow=False,
                font=dict(size=14*SCALE),
                row=1, col=2
            )
        
        # Update layout
        fig.update_layout(
            height=500,
            width=1000,
            showlegend=False,
            title_text="HPO Code Error Analysis",
            title_font=dict(size=16*SCALE),
            font=dict(family="Arial", size=12*SCALE),
            margin=dict(l=20, r=20, t=50, b=20),
        )
        
        # Update axes
        fig.update_xaxes(title_text="Count", title_font=dict(size=14*SCALE))
        fig.update_yaxes(autorange="reversed")  # Reverse y-axis to match matplotlib
        
        return fig
        
    else:
        # Create a figure with two subplots using matplotlib
        fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(20, 8))
        
        # Create a wrapper function for text wrapping
        def wrap_labels(text, width=25):
            """Wrap text to fit within specified width."""
            return '\n'.join(textwrap.wrap(text, width=width))
        
        # Plot most common false positives - REDUCED TO TOP 5
        top_fps = false_positive_counts.most_common(n)
        if top_fps:
            fp_codes, fp_counts = zip(*top_fps)
            fp_names = [f"{code} ({code_to_name.get(code, 'Unknown')})" for code in fp_codes]
            fp_names_wrapped = [wrap_labels(name) for name in fp_names]
            
            # Create horizontal bar chart
            y_pos = np.arange(len(fp_names_wrapped))
            ax1.barh(y_pos, fp_counts, align='center', color='rgb(75,75,75)')  # Dark Gray
            ax1.set_yticks(y_pos)
            ax1.set_yticklabels(fp_names_wrapped, fontsize=10*SCALE)
            ax1.invert_yaxis()  # Labels read top-to-bottom
            ax1.set_xlabel('Count', fontsize=12*SCALE)
            ax1.set_title('Top 5 False Positives (Incorrect Predictions)', fontsize=14*SCALE)
            ax1.tick_params(axis='both', which='major', labelsize=10*SCALE)
        else:
            ax1.text(0.5, 0.5, 'No false positives found', ha='center', va='center', fontsize=12*SCALE)
        
        # Plot most common false negatives - REDUCED TO TOP 5
        top_fns = false_negative_counts.most_common(n)
        if top_fns:
            fn_codes, fn_counts = zip(*top_fns)
            fn_names = [f"{code} ({code_to_name.get(code, 'Unknown')})" for code in fn_codes]
            fn_names_wrapped = [wrap_labels(name) for name in fn_names]
            
            # Create horizontal bar chart
            y_pos = np.arange(len(fn_names_wrapped))
            ax2.barh(y_pos, fn_counts, align='center', color='rgb(141,160,203)')  # Grayish-blue
            ax2.set_yticks(y_pos)
            ax2.set_yticklabels(fn_names_wrapped, fontsize=10*SCALE)
            ax2.invert_yaxis()  # Labels read top-to-bottom
            ax2.set_xlabel('Count', fontsize=12*SCALE)
            ax2.set_title('Top 5 False Negatives (Missed Phenotypes)', fontsize=14*SCALE)
            ax2.tick_params(axis='both', which='major', labelsize=10*SCALE)
        else:
            ax2.text(0.5, 0.5, 'No false negatives found', ha='center', va='center', fontsize=12*SCALE)
        
        # Adjust layout to make room for wrapped labels
        plt.tight_layout()
        plt.subplots_adjust(left=0.2, right=0.95)
        
        return fig

def plot_error_type_pie(fp_error_counts, fn_error_counts, use_plotly=True):
    """
    Create pie charts showing the distribution of direct vs implied phenotype errors.
    
    Args:
        fp_error_counts: Dictionary with false positive counts by type
        fn_error_counts: Dictionary with false negative counts by type
        use_plotly: Whether to use Plotly (True) or Matplotlib (False)
        
    Returns:
        Figure object (either Plotly or Matplotlib)
    """
    # Define color palette to match the requested colors
    color_palette = [
        'rgb(252,141,98)',   # Orange (for Direct Phenotypes)
        'rgb(141,160,203)',  # Grayish-blue (for Implied Phenotypes)
        'rgb(150,150,150)'   # Gray (for Unknown Type)
    ]
    
    if use_plotly:
        # Create data for both pie charts
        fp_data = []
        if fp_error_counts['direct'] > 0:
            fp_data.append({'Category': 'Direct Phenotypes', 'Count': fp_error_counts['direct']})
        if fp_error_counts['implied'] > 0:
            fp_data.append({'Category': 'Implied Phenotypes', 'Count': fp_error_counts['implied']})
        if fp_error_counts['unknown'] > 0:
            fp_data.append({'Category': 'Unknown Type', 'Count': fp_error_counts['unknown']})
            
        fn_data = []
        if fn_error_counts['direct'] > 0:
            fn_data.append({'Category': 'Direct Phenotypes', 'Count': fn_error_counts['direct']})
        if fn_error_counts['implied'] > 0:
            fn_data.append({'Category': 'Implied Phenotypes', 'Count': fn_error_counts['implied']})
        if fn_error_counts['unknown'] > 0:
            fn_data.append({'Category': 'Unknown Type', 'Count': fn_error_counts['unknown']})
        
        # Create subplots
        fig = make_subplots(rows=1, cols=2, specs=[[{'type': 'domain'}, {'type': 'domain'}]],
                           subplot_titles=('False Positives by Phenotype Type', 'False Negatives by Phenotype Type'))
        
        # Add traces
        if fp_data:
            fp_df = pd.DataFrame(fp_data)
            fig.add_trace(
                go.Pie(
                    labels=fp_df['Category'],
                    values=fp_df['Count'],
                    textinfo='percent',
                    textfont=dict(size=14*SCALE, weight='bold'),
                    marker=dict(colors=color_palette[:len(fp_df)]),
                    hovertemplate='%{label}<br>%{value} (%{percent})<extra></extra>',
                    hole=0
                ),
                row=1, col=1
            )
        else:
            fig.add_annotation(
                text="No false positives found",
                xref="x1", yref="y1",
                x=0.5, y=0.5,
                showarrow=False,
                font=dict(size=14*SCALE),
                row=1, col=1
            )
            
        if fn_data:
            fn_df = pd.DataFrame(fn_data)
            fig.add_trace(
                go.Pie(
                    labels=fn_df['Category'],
                    values=fn_df['Count'],
                    textinfo='percent',
                    textfont=dict(size=14*SCALE, weight='bold'),
                    marker=dict(colors=color_palette[:len(fn_df)]),
                    hovertemplate='%{label}<br>%{value} (%{percent})<extra></extra>',
                    hole=0
                ),
                row=1, col=2
            )
        else:
            fig.add_annotation(
                text="No false negatives found",
                xref="x2", yref="y2",
                x=0.5, y=0.5,
                showarrow=False,
                font=dict(size=14*SCALE),
                row=1, col=2
            )
        
        # Update layout
        fig.update_layout(
            title_text="Error Analysis by Phenotype Type",
            height=500,
            width=1000,
            font=dict(family="Arial", size=12*SCALE),
            legend=dict(
                orientation="h",          # Horizontal legend
                yanchor="bottom",         # Anchor point
                y=1.15,                   # Position above chart
                xanchor="center",         # Center horizontally
                x=0.5,                    # Center position
                font=dict(
                    size=14*SCALE,        # Larger font size
                    family="Arial",
                    weight="bold"         # Bold text
                )
            ),
            margin=dict(t=80, b=30, l=30, r=30)
        )
        
        return fig
        
    else:
        # Create figure with two subplots using matplotlib
        fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(16, 8))
        
        # Prepare data for false positives
        fp_labels = []
        fp_sizes = []
        
        if fp_error_counts['direct'] > 0:
            fp_labels.append('Direct Phenotypes')
            fp_sizes.append(fp_error_counts['direct'])
            
        if fp_error_counts['implied'] > 0:
            fp_labels.append('Implied Phenotypes')
            fp_sizes.append(fp_error_counts['implied'])
            
        if fp_error_counts['unknown'] > 0:
            fp_labels.append('Unknown Type')
            fp_sizes.append(fp_error_counts['unknown'])
        
        # Plot false positives pie chart
        if fp_sizes:
            ax1.pie(fp_sizes, labels=fp_labels, autopct='%1.1f%%', startangle=90, 
                    colors=['rgb(252,141,98)', 'rgb(141,160,203)', 'rgb(150,150,150)'],
                    textprops={'fontsize': 12*SCALE, 'fontweight': 'bold'})
            ax1.set_title('False Positives by Phenotype Type', fontsize=14*SCALE, fontweight='bold')
        else:
            ax1.text(0.5, 0.5, 'No false positives found', ha='center', va='center', fontsize=12*SCALE)
        
        # Prepare data for false negatives
        fn_labels = []
        fn_sizes = []
        
        if fn_error_counts['direct'] > 0:
            fn_labels.append('Direct Phenotypes')
            fn_sizes.append(fn_error_counts['direct'])
            
        if fn_error_counts['implied'] > 0:
            fn_labels.append('Implied Phenotypes')
            fn_sizes.append(fn_error_counts['implied'])
            
        if fn_error_counts['unknown'] > 0:
            fn_labels.append('Unknown Type')
            fn_sizes.append(fn_error_counts['unknown'])
        
        # Plot false negatives pie chart
        if fn_sizes:
            ax2.pie(fn_sizes, labels=fn_labels, autopct='%1.1f%%', startangle=90,
                    colors=['rgb(252,141,98)', 'rgb(141,160,203)', 'rgb(150,150,150)'],
                    textprops={'fontsize': 12*SCALE, 'fontweight': 'bold'})
            ax2.set_title('False Negatives by Phenotype Type', fontsize=14*SCALE, fontweight='bold')
        else:
            ax2.text(0.5, 0.5, 'No false negatives found', ha='center', va='center', fontsize=12*SCALE)
        
        plt.tight_layout()
        return fig

def run_improved_analysis():
    """Run the improved HPO error analysis with focus on error types."""
    # Define file paths
    predictions_file = 'data/results/agents/hpo/step3/verified_lab_test_matched.json'
    ground_truth_file = 'data/dataset/mine_hpo.json'
    annotations_file = 'data/dataset/implied_phenotypes.json'
    
    # Create figures directory if it doesn't exist
    os.makedirs('data/figures', exist_ok=True)
    
    # Whether to use Plotly for visualizations
    use_plotly = True
    # Scale for font and visualization sizes
    global SCALE
    SCALE = 1
    
    # Load data
    print(f"Loading data from files...")
    predictions_data, ground_truth_data, annotations_data = load_data(
        predictions_file, ground_truth_file, annotations_file
    )
    
    # Extract predictions and ground truth
    print(f"Extracting HPO codes from predictions and ground truth...")
    predictions_dict = extract_predictions(predictions_data)
    ground_truth_dict = extract_ground_truth(ground_truth_data)
    
    # Extract phenotype types from annotations and prediction data
    print(f"Extracting phenotype types from annotations...")
    annotation_phenotype_types = extract_phenotype_types_from_annotations(annotations_data)
    prediction_phenotype_types = extract_prediction_types(predictions_data)
    
    # Combine type information from both sources, prioritizing annotations
    phenotype_types = {**prediction_phenotype_types, **annotation_phenotype_types}
    
    # Extract phenotype names
    print(f"Extracting phenotype names...")
    pred_code_to_name = extract_phenotype_names(predictions_data, is_ground_truth=False)
    gt_code_to_name = extract_phenotype_names(ground_truth_data, is_ground_truth=True)
    annotation_code_to_name = extract_additional_names_from_annotations(annotations_data)
    
    # Combine name mappings, prioritize annotations over predictions and ground truth
    code_to_name = {**gt_code_to_name, **pred_code_to_name, **annotation_code_to_name}
    
    # Find errors
    print(f"Finding false positives and false negatives...")
    false_positive_counts, false_negative_counts, case_errors = find_errors(predictions_dict, ground_truth_dict)
    
    # Analyze errors with improved approach
    print(f"Analyzing errors by type...")
    fp_error_counts, fn_error_counts = analyze_errors_with_text_matching(
        false_positive_counts, 
        false_negative_counts, 
        ground_truth_data,
        code_to_name,
        phenotype_types
    )
    
    # 1. Plot top 5 error distribution with wrapped labels
    print(f"Creating error distribution visualization...")
    if use_plotly:
        error_dist_fig = plot_error_distribution(false_positive_counts, false_negative_counts, code_to_name, n=5, use_plotly=True)
        error_dist_fig.write_image('data/figures/top5_error_distribution.png', scale=SCALE)
        # Also save as HTML for interactive viewing
        error_dist_fig.write_html('data/figures/top5_error_distribution.html')
        print("Top 5 error distribution saved to 'data/figures/top5_error_distribution.png' and HTML")
    else:
        error_dist_fig = plot_error_distribution(false_positive_counts, false_negative_counts, code_to_name, n=5, use_plotly=False)
        error_dist_fig.savefig('data/figures/top5_error_distribution.png', dpi=300*SCALE, bbox_inches='tight')
        plt.close(error_dist_fig)
        print("Top 5 error distribution saved to 'data/figures/top5_error_distribution.png'")
    
    # 2. Plot pie charts showing the distribution of error types
    print(f"Creating error type pie charts...")
    if use_plotly:
        error_type_fig = plot_error_type_pie(fp_error_counts, fn_error_counts, use_plotly=True)
        error_type_fig.write_image('data/figures/error_types_pie.png', scale=SCALE)
        # Also save as HTML for interactive viewing
        error_type_fig.write_html('data/figures/error_types_pie.html')
        print("Error type distribution saved to 'data/figures/error_types_pie.png' and HTML")
    else:
        error_type_fig = plot_error_type_pie(fp_error_counts, fn_error_counts, use_plotly=False)
        error_type_fig.savefig('data/figures/error_types_pie.png', dpi=300*SCALE, bbox_inches='tight')
        plt.close(error_type_fig)
        print("Error type distribution saved to 'data/figures/error_types_pie.png'")
    
    # Print summary statistics
    print("\nError Type Analysis:")
    print("\nFalse Positives:")
    fp_total = sum(fp_error_counts.values())
    if fp_total > 0:
        print(f"  Direct Phenotypes: {fp_error_counts['direct']} ({fp_error_counts['direct']/fp_total*100:.1f}%)")
        print(f"  Implied Phenotypes: {fp_error_counts['implied']} ({fp_error_counts['implied']/fp_total*100:.1f}%)")
        print(f"  Unknown Type: {fp_error_counts['unknown']} ({fp_error_counts['unknown']/fp_total*100:.1f}%)")
    else:
        print("  No false positives found")
    
    print("\nFalse Negatives:")
    fn_total = sum(fn_error_counts.values())
    if fn_total > 0:
        print(f"  Direct Phenotypes: {fn_error_counts['direct']} ({fn_error_counts['direct']/fn_total*100:.1f}%)")
        print(f"  Implied Phenotypes: {fn_error_counts['implied']} ({fn_error_counts['implied']/fn_total*100:.1f}%)")
        print(f"  Unknown Type: {fn_error_counts['unknown']} ({fn_error_counts['unknown']/fn_total*100:.1f}%)")
    else:
        print("  No false negatives found")
    
    # Create and save distribution pie chart for all phenotypes from annotations
    print("\nCreating distribution of direct vs implied phenotypes from annotations...")
    if annotations_data:
        phenotype_dist = analyze_phenotype_distribution(annotations_data)
        if use_plotly:
            dist_fig = plot_phenotype_distribution_pie(phenotype_dist, use_plotly=True)
            dist_fig.write_image('data/figures/phenotype_presence_chart.png', scale=SCALE)
            dist_fig.write_html('data/figures/phenotype_presence_chart.html')
            print("Phenotype distribution chart saved to 'data/figures/phenotype_presence_chart.png' and HTML")
    
    # Also export data for further analysis
    error_analysis_data = {
        "false_positive_counts": dict(false_positive_counts),
        "false_negative_counts": dict(false_negative_counts),
        "fp_error_types": fp_error_counts,
        "fn_error_types": fn_error_counts,
        "case_errors": case_errors
    }
    
    with open('data/figures/error_analysis_data.json', 'w') as f:
        json.dump(error_analysis_data, f, indent=2)
    
    print("\nError analysis data saved to 'data/figures/error_analysis_data.json'")
    print("\nAnalysis complete! Visualizations have been saved to the data/figures directory.")

if __name__ == "__main__":
    run_improved_analysis()

Loading data from files...
Successfully loaded annotations from data/dataset/implied_phenotypes.json
Extracting HPO codes from predictions and ground truth...
Extracting phenotype types from annotations...
Extracted 1234 phenotype types from annotations
Extracting phenotype names...
Finding false positives and false negatives...
Analyzing errors by type...
Creating error distribution visualization...
Top 5 error distribution saved to 'data/figures/top5_error_distribution.png' and HTML
Creating error type pie charts...
Error type distribution saved to 'data/figures/error_types_pie.png' and HTML

Error Type Analysis:

False Positives:
  Direct Phenotypes: 618 (84.1%)
  Implied Phenotypes: 117 (15.9%)
  Unknown Type: 0 (0.0%)

False Negatives:
  Direct Phenotypes: 335 (57.4%)
  Implied Phenotypes: 249 (42.6%)
  Unknown Type: 0 (0.0%)

Creating distribution of direct vs implied phenotypes from annotations...
Total phenotypes: 1813
Direct phenotypes: 1320 (72.81%)
Implied phenotypes: 493 (2

In [None]:
import json
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from fuzzywuzzy import fuzz
import re
from collections import defaultdict
from IPython.display import display, HTML

def load_data(file_path):
    """Load data from a JSON file."""
    with open(file_path, 'r') as f:
        data = json.load(f)
    
    # Handle structured format with metadata/results
    if isinstance(data, dict) and "results" in data:
        results = data["results"]
    else:
        results = data
    
    # Ensure all keys are strings
    if isinstance(results, dict):
        results = {str(k): v for k, v in results.items()}
    
    return results

def extract_predictions(data):
    """Extract HPO codes from the predictions data structure, based on evaluate_step3.py."""
    result = {}
    
    # Handle nested structure with metadata and results keys
    if isinstance(data, dict) and "results" in data:
        data = data["results"]
    
    # Process each case in the data
    for case_id, case_data in data.items():
        hpo_codes = []
        
        # Check for different possible field names for phenotype lists
        phenotype_fields = ["matched_phenotypes", "verified_phenotypes"]
        for field in phenotype_fields:
            if field in case_data:
                phenotype_list = case_data[field]
                if isinstance(phenotype_list, list):
                    for item in phenotype_list:
                        if not isinstance(item, dict):
                            continue
                            
                        # Extract the HPO code and phenotype information
                        phenotype = None
                        hpo_code = None
                        status = item.get("status", "unknown")
                        context = item.get("context", "")
                        
                        # Get phenotype name
                        if "phenotype" in item:
                            phenotype = item["phenotype"]
                        
                        # Check for different possible field names for HPO codes
                        for code_field in ["HPO_Term", "hpo_term", "hpo_id", "hp_id"]:
                            if code_field in item:
                                code = item[code_field]
                                # Ensure code starts with HP:
                                if code and isinstance(code, str) and code.startswith("HP:"):
                                    hpo_code = code
                                    break
                        
                        if phenotype:
                            hpo_codes.append({
                                "phenotype": phenotype,
                                "hpo_code": hpo_code,
                                "status": status,
                                "context": context
                            })
        
        # Add non-empty lists to result
        if hpo_codes:
            result[str(case_id)] = hpo_codes
    
    return result

def extract_ground_truth(data):
    """Extract ground truth HPO codes, based on evaluate_step3.py."""
    result = {}
    
    # Process each case in the data
    for case_id, case_data in data.items():
        hpo_codes = []
        phenotypes = []
        
        # Try different possible field names
        for field in ["phenotypes", "ground_truth", "hpo_terms"]:
            if field in case_data:
                field_data = case_data[field]
                if isinstance(field_data, list):
                    for item in field_data:
                        # Check for different possible field formats
                        if isinstance(item, dict):
                            # Try different field names for the phenotype
                            phenotype = None
                            for phenotype_field in ["phenotype", "phenotype_name", "name"]:
                                if phenotype_field in item:
                                    phenotype = item[phenotype_field]
                                    break
                            
                            # Try different field names for the HPO code
                            hpo_code = None
                            for code_field in ["hpo_id", "HPO_Term", "hpo_code", "hp_id"]:
                                if code_field in item:
                                    code = item[code_field]
                                    # Ensure code starts with HP:
                                    if code and isinstance(code, str) and code.startswith("HP:"):
                                        hpo_code = code
                                        break
                            
                            if phenotype:
                                phenotypes.append(phenotype)
                            if hpo_code:
                                hpo_codes.append(hpo_code)
                                
                        elif isinstance(item, str):
                            # If it's an HPO code, add to hpo_codes
                            if item.startswith("HP:"):
                                hpo_codes.append(item)
                            else:
                                phenotypes.append(item)
                    break
        
        # Add non-empty lists to result
        if hpo_codes:
            result[str(case_id)] = hpo_codes
        elif phenotypes:
            # Fallback to phenotypes if no HPO codes found
            result[str(case_id)] = phenotypes
        
        # Debug info for this case
        if result.get(str(case_id)):
            print(f"Case {case_id}: Found {len(result[str(case_id)])} ground truth entries")
    
    return result

def extract_clinical_text(data):
    """Extract clinical text from prediction data."""
    case_texts = {}
    
    for case_id, case_data in data.items():
        if "clinical_text" in case_data:
            case_texts[case_id] = case_data["clinical_text"]
        elif "original_text" in case_data:
            case_texts[case_id] = case_data["original_text"]
    
    return case_texts

def normalize_text(text):
    """Normalize text for comparison."""
    return ' '.join(text.lower().split())

def compare_predictions(pred_phenotypes, gt_phenotypes, similarity_threshold=80):
    """
    Compare predicted phenotypes to ground truth, based on evaluate_step3.py.
    
    Args:
        pred_phenotypes: List of predicted phenotype dictionaries with hpo_code
        gt_phenotypes: List of ground truth HPO codes or phenotype strings
        similarity_threshold: Threshold for fuzzy matching (for phenotype text comparison)
        
    Returns:
        Tuple of (true_positives, false_positives, false_negatives)
    """
    # Initialize result containers
    true_positives = []
    false_positives = []
    false_negatives = gt_phenotypes.copy()
    
    # First try exact matching of HPO codes if available
    for pred in pred_phenotypes:
        pred_hpo = pred.get("hpo_code")
        
        # If we have an HPO code, try exact matching first
        if pred_hpo and pred_hpo in gt_phenotypes:
            # Found an exact HPO match
            matched = True
            true_positives.append(pred)
            false_negatives.remove(pred_hpo)
        else:
            # Extract phenotype string for fuzzy matching
            pred_phenotype = pred.get("phenotype", "")
            
            # Normalize the phenotype string
            norm_pred = normalize_text(pred_phenotype)
            
            # Flag to track if we found a match
            matched = False
            
            # Try fuzzy matching against ground truth
            for gt in list(false_negatives):  # Use a copy to safely modify during iteration
                # Normalize the ground truth string for comparison
                norm_gt = normalize_text(gt)
                
                # Check for match using fuzzy string matching
                similarity = fuzz.ratio(norm_pred, norm_gt)
                if similarity >= similarity_threshold:
                    # Found a match
                    matched = True
                    # Add to true positives
                    true_positives.append(pred)
                    # Remove from false negatives
                    false_negatives.remove(gt)
                    break
            
            # If no match found, add to false positives
            if not matched:
                false_positives.append(pred)
    
    return true_positives, false_positives, false_negatives

def compare_performance(pred_a, pred_b, ground_truth):
    """
    Compare performance of prediction files A and B against ground truth.
    
    Returns:
        Tuple of (cases_a_better, cases_b_better) where each is a dict
        mapping case_id to comparison metrics
    """
    cases_a_better = {}
    cases_b_better = {}
    
    # Get common case IDs across all three files
    common_cases = set(pred_a.keys()) & set(pred_b.keys()) & set(ground_truth.keys())
    
    for case_id in common_cases:
        # Get phenotypes for this case
        phenotypes_a = pred_a[case_id]
        phenotypes_b = pred_b[case_id]
        gt_phenotypes = ground_truth[case_id]
        
        # Compare predictions to ground truth
        tp_a, fp_a, fn_a = compare_predictions(phenotypes_a, gt_phenotypes)
        tp_b, fp_b, fn_b = compare_predictions(phenotypes_b, gt_phenotypes)
        
        # Calculate F1 scores
        precision_a = len(tp_a) / (len(tp_a) + len(fp_a)) if (len(tp_a) + len(fp_a)) > 0 else 0
        recall_a = len(tp_a) / (len(tp_a) + len(fn_a)) if (len(tp_a) + len(fn_a)) > 0 else 0
        f1_a = 2 * (precision_a * recall_a) / (precision_a + recall_a) if (precision_a + recall_a) > 0 else 0
        
        precision_b = len(tp_b) / (len(tp_b) + len(fp_b)) if (len(tp_b) + len(fp_b)) > 0 else 0
        recall_b = len(tp_b) / (len(tp_b) + len(fn_b)) if (len(tp_b) + len(fn_b)) > 0 else 0
        f1_b = 2 * (precision_b * recall_b) / (precision_b + recall_b) if (precision_b + recall_b) > 0 else 0
        
        # Compare F1 scores to determine which is better
        if f1_a > f1_b:
            # A performed better
            cases_a_better[case_id] = {
                "tp_a": tp_a,
                "fp_a": fp_a,
                "fn_a": fn_a,
                "tp_b": tp_b,
                "fp_b": fp_b,
                "fn_b": fn_b,
                "f1_a": f1_a,
                "f1_b": f1_b,
                "advantage": f1_a - f1_b
            }
        elif f1_b > f1_a:
            # B performed better
            cases_b_better[case_id] = {
                "tp_a": tp_a,
                "fp_a": fp_a,
                "fn_a": fn_a,
                "tp_b": tp_b,
                "fp_b": fp_b,
                "fn_b": fn_b,
                "f1_a": f1_a,
                "f1_b": f1_b,
                "advantage": f1_b - f1_a
            }
    
    return cases_a_better, cases_b_better

def highlight_entity_in_context(entity, context, highlight_color="#FFFF00"):
    """
    Find and highlight an entity in its context.
    
    Args:
        entity: The entity/phenotype to highlight
        context: The context text
        highlight_color: HTML color for highlighting
        
    Returns:
        HTML-formatted text with entity highlighted
    """
    if not context:
        return f"<span style='background-color: {highlight_color};'>{entity}</span> <i>(No context available)</i>"
    
    # If entity is directly in context, highlight it
    if entity in context:
        highlighted = context.replace(
            entity, 
            f"<span style='background-color: {highlight_color};'>{entity}</span>"
        )
        return highlighted
    
    # Try case-insensitive match
    pattern = re.compile(re.escape(entity), re.IGNORECASE)
    match = pattern.search(context)
    if match:
        start, end = match.span()
        highlighted = (
            context[:start] + 
            f"<span style='background-color: {highlight_color};'>{context[start:end]}</span>" + 
            context[end:]
        )
        return highlighted
    
    # If entity not found in context, show both
    return (
        f"<span style='background-color: {highlight_color};'>{entity}</span> <br>"
        f"Context: {context}"
    )

def get_examples(cases_dict, pred_a_texts, pred_b_texts, max_examples=5, example_type="tp"):
    """
    Get example entities with context.
    
    Args:
        cases_dict: Dictionary of cases where one prediction file performed better
        pred_a_texts: Dictionary of clinical texts from prediction file A
        pred_b_texts: Dictionary of clinical texts from prediction file B
        max_examples: Maximum number of examples to return
        example_type: Type of examples to get ("tp", "fp", or "fn")
        
    Returns:
        List of examples with case_id, entity, and context
    """
    examples = []
    
    # Determine which fields to look at based on example_type
    if example_type.startswith("tp"):
        field_name = f"tp_{example_type[-1]}"  # tp_a or tp_b
    elif example_type.startswith("fp"):
        field_name = f"fp_{example_type[-1]}"  # fp_a or fp_b
    elif example_type.startswith("fn"):
        field_name = f"fn_{example_type[-1]}"  # fn_a or fn_b
    else:
        return []
    
    # Get examples from cases
    for case_id, metrics in cases_dict.items():
        entities = metrics.get(field_name, [])
        
        if not entities:
            continue
            
        # Get clinical text for context
        text = pred_a_texts.get(case_id, "") if example_type[-1] == "a" else pred_b_texts.get(case_id, "")
        
        for entity_info in entities:
            # Format depends on whether we're dealing with true/false positives or false negatives
            if example_type.startswith("fn"):
                # False negatives are just strings
                entity = entity_info
                context = text  # Use full text as context
                
                # Try to extract a better context from the text
                if text:
                    sentences = re.split(r'(?<=[.!?])\s+', text)
                    for sentence in sentences:
                        if entity.lower() in sentence.lower():
                            context = sentence
                            break
            else:
                # True/false positives are dictionaries
                entity = entity_info["phenotype"]
                context = entity_info.get("context", "")
                
                # If no context in entity_info, try to find it in the clinical text
                if not context and text:
                    # Simple context extraction - find the sentence containing the entity
                    sentences = re.split(r'(?<=[.!?])\s+', text)
                    for sentence in sentences:
                        if entity.lower() in sentence.lower():
                            context = sentence
                            break
            
            examples.append({
                "case_id": case_id,
                "entity": entity,
                "context": context,
                "f1_diff": metrics.get("advantage", 0)
            })
    
    # Sort examples by F1 difference (advantage)
    examples.sort(key=lambda x: x["f1_diff"], reverse=True)
    
    return examples[:max_examples]

def display_examples(examples, title, highlight_color="#FFFF00"):
    """Display examples with highlighted entities."""
    if not examples:
        display(HTML(f"<h3>{title}</h3><p>No examples found.</p>"))
        return
        
    html_output = f"<h3>{title}</h3>"
    html_output += "<table style='width:100%; border-collapse: collapse;'>"
    html_output += "<tr><th style='text-align:left; padding:8px; border:1px solid #ddd; background-color:#f2f2f2;'>Case ID</th>"
    html_output += "<th style='text-align:left; padding:8px; border:1px solid #ddd; background-color:#f2f2f2;'>Entity in Context</th>"
    html_output += "<th style='text-align:left; padding:8px; border:1px solid #ddd; background-color:#f2f2f2;'>F1 Difference</th></tr>"
    
    for i, example in enumerate(examples):
        # Alternate row colors
        row_style = "background-color:#f9f9f9;" if i % 2 == 0 else ""
        
        # Format entity in context
        highlighted = highlight_entity_in_context(
            example["entity"], 
            example["context"],
            highlight_color
        )
        
        html_output += f"<tr style='{row_style}'>"
        html_output += f"<td style='padding:8px; border:1px solid #ddd;'>{example['case_id']}</td>"
        html_output += f"<td style='padding:8px; border:1px solid #ddd;'>{highlighted}</td>"
        html_output += f"<td style='padding:8px; border:1px solid #ddd;'>{example['f1_diff']:.4f}</td>"
        html_output += "</tr>"
    
    html_output += "</table>"
    
    display(HTML(html_output))

def get_overall_stats(cases_a_better, cases_b_better):
    """Calculate overall statistics for the comparison."""
    # Count by status (direct vs implied phenotypes)
    status_counts_a = defaultdict(int)
    status_counts_b = defaultdict(int)
    
    # Average F1 score differences
    f1_diffs_a = []
    f1_diffs_b = []
    
    # Process cases where A performed better
    for case_id, metrics in cases_a_better.items():
        f1_diffs_a.append(metrics["advantage"])
        
        # Count by status
        for tp in metrics.get("tp_a", []):
            if isinstance(tp, dict) and "status" in tp:
                status_counts_a[tp["status"]] += 1
    
    # Process cases where B performed better
    for case_id, metrics in cases_b_better.items():
        f1_diffs_b.append(metrics["advantage"])
        
        # Count by status
        for tp in metrics.get("tp_b", []):
            if isinstance(tp, dict) and "status" in tp:
                status_counts_b[tp["status"]] += 1
    
    # Calculate averages
    avg_advantage_a = np.mean(f1_diffs_a) if f1_diffs_a else 0
    avg_advantage_b = np.mean(f1_diffs_b) if f1_diffs_b else 0
    
    return {
        "cases_a_better": len(cases_a_better),
        "cases_b_better": len(cases_b_better),
        "avg_advantage_a": avg_advantage_a,
        "avg_advantage_b": avg_advantage_b,
        "status_counts_a": dict(status_counts_a),
        "status_counts_b": dict(status_counts_b)
    }

def visualize_stats(stats):
    """Visualize comparison statistics."""
    # Display case counts
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 5))
    
    # Cases bar chart
    labels = ['Prediction A Better', 'Prediction B Better']
    values = [stats["cases_a_better"], stats["cases_b_better"]]
    
    ax1.bar(labels, values, color=['blue', 'orange'])
    ax1.set_title('Number of Cases Where Each Prediction Performed Better')
    ax1.set_ylabel('Number of Cases')
    
    for i, v in enumerate(values):
        ax1.text(i, v + 1, str(v), ha='center')
    
    # Average F1 advantage bar chart
    avg_labels = ['Avg. F1 Advantage (A)', 'Avg. F1 Advantage (B)']
    avg_values = [stats["avg_advantage_a"], stats["avg_advantage_b"]]
    
    ax2.bar(avg_labels, avg_values, color=['lightblue', 'lightsalmon'])
    ax2.set_title('Average F1 Score Advantage')
    ax2.set_ylabel('Average F1 Difference')
    
    for i, v in enumerate(avg_values):
        ax2.text(i, v + 0.01, f"{v:.4f}", ha='center')
    
    plt.tight_layout()
    plt.show()
    
    # Display phenotype type breakdowns
    if stats["status_counts_a"] or stats["status_counts_b"]:
        fig, (ax3, ax4) = plt.subplots(1, 2, figsize=(14, 5))
        
        if stats["status_counts_a"]:
            # Status counts for A
            status_labels_a = list(stats["status_counts_a"].keys())
            status_values_a = list(stats["status_counts_a"].values())
            
            ax3.pie(status_values_a, labels=status_labels_a, autopct='%1.1f%%')
            ax3.set_title('Phenotype Types in A Better Cases')
        else:
            ax3.text(0.5, 0.5, "No phenotype type data", 
                    horizontalalignment='center', verticalalignment='center')
            ax3.set_title('Phenotype Types in A Better Cases (No Data)')
        
        if stats["status_counts_b"]:
            # Status counts for B
            status_labels_b = list(stats["status_counts_b"].keys())
            status_values_b = list(stats["status_counts_b"].values())
            
            ax4.pie(status_values_b, labels=status_labels_b, autopct='%1.1f%%')
            ax4.set_title('Phenotype Types in B Better Cases')
        else:
            ax4.text(0.5, 0.5, "No phenotype type data", 
                    horizontalalignment='center', verticalalignment='center')
            ax4.set_title('Phenotype Types in B Better Cases (No Data)')
        
        plt.tight_layout()
        plt.show()

def compare_prediction_files(pred_file_a, pred_file_b, ground_truth_file, similarity_threshold=80):
    """
    Compare two prediction files against a ground truth file.
    
    Args:
        pred_file_a: Path to prediction file A
        pred_file_b: Path to prediction file B
        ground_truth_file: Path to ground truth file
        similarity_threshold: Threshold for fuzzy matching (0-100)
    """
    print(f"Loading prediction file A: {pred_file_a}")
    pred_a_data = load_data(pred_file_a)
    
    print(f"Loading prediction file B: {pred_file_b}")
    pred_b_data = load_data(pred_file_b)
    
    print(f"Loading ground truth file: {ground_truth_file}")
    gt_data = load_data(ground_truth_file)
    
    # Extract HPO codes and phenotypes
    print("Extracting phenotypes and HPO codes from files...")
    pred_a_phenotypes = extract_predictions(pred_a_data)
    pred_b_phenotypes = extract_predictions(pred_b_data)
    gt_phenotypes = extract_ground_truth(gt_data)
    
    print(f"Found predictions for {len(pred_a_phenotypes)} cases in file A")
    print(f"Found predictions for {len(pred_b_phenotypes)} cases in file B")
    print(f"Found ground truth for {len(gt_phenotypes)} cases")
    
    # Extract clinical texts
    pred_a_texts = extract_clinical_text(pred_a_data)
    pred_b_texts = extract_clinical_text(pred_b_data)
    
    # Compare performance
    print("Comparing prediction performances...")
    cases_a_better, cases_b_better = compare_performance(
        pred_a_phenotypes, 
        pred_b_phenotypes, 
        gt_phenotypes
    )
    
    # Calculate and visualize statistics
    stats = get_overall_stats(cases_a_better, cases_b_better)
    visualize_stats(stats)
    
    # Print summary stats
    print("\n=== Summary Statistics ===")
    print(f"Cases where prediction A performed better: {stats['cases_a_better']}")
    print(f"Cases where prediction B performed better: {stats['cases_b_better']}")
    print(f"Average F1 advantage when A is better: {stats['avg_advantage_a']:.4f}")
    print(f"Average F1 advantage when B is better: {stats['avg_advantage_b']:.4f}")
    
    # Get examples where A performed better
    print("\n=== Examples where Prediction A performed better ===")
    a_better_tp = get_examples(cases_a_better, pred_a_texts, pred_b_texts, example_type="tp_a")
    display_examples(a_better_tp, "True Positives in A (but not in B)", "#AAFFAA")
    
    a_better_fp_b = get_examples(cases_a_better, pred_a_texts, pred_b_texts, example_type="fp_b")
    display_examples(a_better_fp_b, "False Positives in B (but correct in A)", "#FFAAAA")
    
    a_better_fn_b = get_examples(cases_a_better, pred_a_texts, pred_b_texts, example_type="fn_b")
    display_examples(a_better_fn_b, "False Negatives in B (but found in A)", "#AAAAFF")
    
    # Get examples where B performed better
    print("\n=== Examples where Prediction B performed better ===")
    b_better_tp = get_examples(cases_b_better, pred_a_texts, pred_b_texts, example_type="tp_b")
    display_examples(b_better_tp, "True Positives in B (but not in A)", "#AAFFAA")
    
    b_better_fp_a = get_examples(cases_b_better, pred_a_texts, pred_b_texts, example_type="fp_a")
    display_examples(b_better_fp_a, "False Positives in A (but correct in B)", "#FFAAAA")
    
    b_better_fn_a = get_examples(cases_b_better, pred_a_texts, pred_b_texts, example_type="fn_a")
    display_examples(b_better_fn_a, "False Negatives in A (but found in B)", "#AAAAFF")
    
    print("\nComparison complete!")

# Example of how to use the tool with variables instead of the widget interface
def run_hpo_comparison(pred_file_a, pred_file_b, ground_truth_file, similarity_threshold=80):
    """
    Run HPO prediction comparison with file paths provided as variables.
    
    Args:
        pred_file_a: Path to prediction file A
        pred_file_b: Path to prediction file B
        ground_truth_file: Path to ground truth file
        similarity_threshold: Threshold for fuzzy matching (0-100)
    """
    print("Starting HPO Prediction Comparison")
    print("==================================")
    
    compare_prediction_files(
        pred_file_a,
        pred_file_b,
        ground_truth_file,
        similarity_threshold
    )

# Example usage:
if __name__ == "__main__":
    # Set your file paths here
    PRED_FILE_A = "data/results/agents/hpo/step3/verified_lab_test_matched.json"
    PRED_FILE_B = "data/results/hporag/llama70b_updated.json"
    GROUND_TRUTH_FILE = "data/dataset/mine_hpo.json"
    SIMILARITY_THRESHOLD = 100
    
    # Run the comparison
    run_hpo_comparison(
        PRED_FILE_A,
        PRED_FILE_B,
        GROUND_TRUTH_FILE,
        SIMILARITY_THRESHOLD
    )

Starting HPO Prediction Comparison
Loading prediction file A: data/results/agents/hpo/step3/verified_lab_test_matched.json
Loading prediction file B: data/results/hporag/llama70b_updated.json
Loading ground truth file: data/dataset/mine_hpo.json
Extracting phenotypes and HPO codes from files...
Case 1: Found 8 ground truth entries
Case 2: Found 21 ground truth entries
Case 3: Found 6 ground truth entries
Case 4: Found 13 ground truth entries
Case 5: Found 11 ground truth entries
Case 6: Found 19 ground truth entries
Case 7: Found 29 ground truth entries
Case 8: Found 18 ground truth entries
Case 9: Found 26 ground truth entries
Case 10: Found 20 ground truth entries
Case 11: Found 26 ground truth entries
Case 12: Found 10 ground truth entries
Case 13: Found 11 ground truth entries
Case 14: Found 8 ground truth entries
Case 15: Found 9 ground truth entries
Case 16: Found 10 ground truth entries
Case 17: Found 4 ground truth entries
Case 18: Found 9 ground truth entries
Case 19: Found 12

ValueError: list.remove(x): x not in list