In [1]:

import os
import sys
current_dir = os.getcwd()
parent_dir = os.path.dirname(current_dir)
sys.path.insert(0, parent_dir)
# Set the parent directory as the current directory
os.chdir(parent_dir)

In [2]:
# load fully corrected datasets
from rdma.utils.data import read_json_file, print_json_structure


# all human labels
human_corrections_full = read_json_file("data/dataset/rare_disease_corrections_john.json")
print("-------- Human Corrections File ------")
print_json_structure(human_corrections_full)
# human + supervisor labels
human_rdma_corrections = read_json_file("data/dataset/rare_disease_annotations_rdma_john_comprehensive.json")
# supervisor labels
print("------- Supervisor Corrections File -------")
rdma_corrections = read_json_file("data/results/supervisor/multistage_no_min.json")
print_json_structure(rdma_corrections)

-------- Human Corrections File ------
Dictionary:
  metadata (dict): 
  Dictionary:
    timestamp (str): 
    total_entities_in_file (int): 
    reviewed_entities (int): 
  corrected_annotations (list): 
  List: (333 items)
    Item 0 (dict): 
    Dictionary:
      entity (str): 
      document_id (str): 
      orpha_code (str): 
      category (str): 
      is_rare_disease (bool): 
      ... and 2 more items
    Item 1 (dict): 
    Dictionary:
      entity (str): 
      document_id (str): 
      orpha_code (str): 
      category (str): 
      is_rare_disease (bool): 
      ... and 2 more items
    Item 2 (dict): 
    Dictionary:
      entity (str): 
      document_id (str): 
      orpha_code (str): 
      category (str): 
      is_rare_disease (bool): 
      ... and 2 more items
    Item 3 (dict): 
    Dictionary:
      entity (str): 
      document_id (str): 
      orpha_code (str): 
      category (str): 
      is_rare_disease (bool): 
      ... and 2 more items
    Item 4 (dict): 

# Comparing RDMA (only) and my annotations, because they go through the entire set of possible annotations

In [3]:
"""
Improved Rare Disease Annotator Agreement Analysis with ORPHA Code Priority

This script computes the inter-annotator agreement between human corrections
and RDMA supervisor corrections for rare disease entity recognition,
using ORPHA codes as the primary matching criterion and falling back to
hierarchical clustering for entity variants without ORPHA codes.
"""

import json
import numpy as np
from typing import Dict, List, Set, Tuple, Any, Optional
from sklearn.metrics import cohen_kappa_score
import scipy.stats as stats
from collections import defaultdict, Counter
from fuzzywuzzy import fuzz
import re
import networkx as nx
from tqdm import tqdm

def read_json_file(filename: str) -> dict:
    """Read a JSON file and return its contents."""
    with open(filename, 'r') as f:
        return json.load(f)

def normalize_entity(entity: str, abbreviations: Dict[str, str] = None) -> str:
    """
    Normalize entity text by lowercasing, removing extra spaces, 
    and expanding known abbreviations.
    
    Args:
        entity: The entity text to normalize
        abbreviations: Dictionary of abbreviations to expand
        
    Returns:
        Normalized entity text
    """
    if not entity:
        return ""
        
    # Convert to lowercase and strip
    normalized = entity.lower().strip()
    
    # Remove multiple spaces and replace with single space
    normalized = re.sub(r'\s+', ' ', normalized)
    
    # Remove hyphens between words (convert "heparin-induced" to "heparin induced")
    normalized = re.sub(r'(\w)-(\w)', r'\1 \2', normalized)
    
    # Expand abbreviation if it exists in the dictionary
    if abbreviations and normalized in abbreviations:
        return abbreviations[normalized]
        
    return normalized

def normalize_orpha_code(orpha_code: str) -> str:
    """
    Normalize ORPHA code to a standard format.
    
    Args:
        orpha_code: Raw ORPHA code string
        
    Returns:
        Normalized ORPHA code (e.g., "ORPHA:12345")
    """
    if not orpha_code or not isinstance(orpha_code, str):
        return ""
    
    # Remove whitespace and convert to uppercase
    code = orpha_code.strip().upper()
    
    # Extract just the number part if it exists
    match = re.search(r'(\d+)', code)
    if match:
        number = match.group(1)
        return f"ORPHA:{number}"
    
    return ""

def build_entity_similarity_graph(entities: List[str], threshold: int = 90) -> nx.Graph:
    """
    Build a graph where nodes are entities and edges exist if similarity exceeds threshold.
    Improved to better handle substring relationships.
    
    Args:
        entities: List of entity strings
        threshold: Minimum similarity score to create an edge
        
    Returns:
        NetworkX graph with entities as nodes and similarities as edge weights
    """
    G = nx.Graph()
    
    # Add all entities as nodes
    for entity in entities:
        G.add_node(entity)
    
    # Add edges for similar entities
    n = len(entities)
    print(f"Building similarity graph for {n} entities without ORPHA codes...")
    
    # Use tqdm for progress tracking in the nested loop
    with tqdm(total=n*(n-1)//2) as pbar:
        for i in range(n):
            for j in range(i+1, n):
                # Update progress
                pbar.update(1)
                
                # Skip comparison if entities are identical
                entity1 = entities[i]
                entity2 = entities[j]
                if entity1 == entity2:
                    continue
                
                # Use improved substring checking
                should_connect, similarity_score, match_type = improved_substring_check(entity1, entity2, threshold)
                
                if should_connect:
                    G.add_edge(entity1, entity2, weight=similarity_score)
    
    print(f"Graph built with {len(G.nodes)} nodes and {len(G.edges)} edges")
    return G

def improved_substring_check(entity1: str, entity2: str, threshold: int = 90) -> Tuple[bool, int, str]:
    """
    Improved substring checking that should catch more cases.
    
    Returns:
        Tuple of (should_connect, similarity_score, match_type)
    """
    # Normalize entities
    norm1 = ' '.join(entity1.lower().split())
    norm2 = ' '.join(entity2.lower().split())
    
    # 1. Exact match
    if norm1 == norm2:
        return True, 100, "exact"
    
    # 2. Direct substring (case insensitive)
    if norm1 in norm2 or norm2 in norm1:
        return True, 95, "substring"
    
    # 3. Word-level subset check
    words1 = set(norm1.split())
    words2 = set(norm2.split())
    
    if len(words1) != len(words2):  # Only check if different word counts
        if len(words1) < len(words2) and words1.issubset(words2):
            return True, 92, "word_subset"
        elif len(words2) < len(words1) and words2.issubset(words1):
            return True, 92, "word_subset"
    
    # 4. Check for partial word matches (e.g., "sarcoid" vs "sarcoidosis")
    # This handles cases where one entity is a shortened form of another
    if len(words1) == 1 and len(words2) == 1:
        word1, word2 = list(words1)[0], list(words2)[0]
        if len(word1) >= 4 and len(word2) >= 4:  # Only for words of reasonable length
            if word1 in word2 or word2 in word1:
                return True, 90, "partial_word"
    
    # 5. Fuzzy matching for typos and slight variations
    fuzzy_score = fuzz.token_sort_ratio(entity1, entity2)
    if fuzzy_score >= threshold:
        return True, fuzzy_score, "fuzzy"
    
    return False, 0, "no_match"

def cluster_entities(entities: List[str], threshold: int = 90) -> Dict[str, str]:
    """
    Cluster entities using connected components in a similarity graph.
    Improved to better handle canonical form selection.
    
    Args:
        entities: List of entity strings
        threshold: Minimum similarity score to consider entities similar
        
    Returns:
        Dictionary mapping each entity to its canonical form
    """
    if not entities:
        return {}
        
    # Build similarity graph
    G = build_entity_similarity_graph(entities, threshold)
    
    # Find connected components (clusters)
    clusters = list(nx.connected_components(G))
    print(f"Found {len(clusters)} entity clusters for entities without ORPHA codes")
    
    # Create entity mapping
    entity_mapping = {}
    
    # Process each cluster
    for cluster in clusters:
        # Choose canonical form more intelligently
        canonical = choose_canonical_entity(list(cluster))
        
        # Map all entities in this cluster to the canonical form
        for entity in cluster:
            entity_mapping[entity] = canonical
    
    # Add identity mappings for entities not in any cluster
    for entity in entities:
        if entity not in entity_mapping:
            entity_mapping[entity] = entity
    
    return entity_mapping

def choose_canonical_entity(cluster_entities: List[str]) -> str:
    """
    Choose the best canonical entity from a cluster of similar entities.
    Prioritizes longer, more complete forms while avoiding overly complex ones.
    
    Args:
        cluster_entities: List of entities in the same cluster
        
    Returns:
        The canonical entity string
    """
    if len(cluster_entities) == 1:
        return cluster_entities[0]
    
    # Normalize entities for comparison
    normalized_entities = [(entity, ' '.join(entity.lower().split())) for entity in cluster_entities]
    
    # Sort by various criteria to find the best canonical form
    def entity_score(entity_tuple):
        entity, normalized = entity_tuple
        
        # Prefer entities that are not abbreviations (length > 3 or contain spaces)
        is_likely_abbreviation = len(normalized) <= 3 and ' ' not in normalized
        abbreviation_penalty = 50 if is_likely_abbreviation else 0
        
        # Prefer longer entities (more complete names)
        length_score = len(normalized)
        
        # Prefer entities with more words (more descriptive)
        word_count_score = len(normalized.split()) * 10
        
        # Penalize entities that are too long (might be overly specific)
        length_penalty = max(0, len(normalized) - 100) * 0.5
        
        return length_score + word_count_score - abbreviation_penalty - length_penalty
    
    # Sort by score and return the best one
    scored_entities = [(entity, entity_score((entity, normalized))) 
                      for entity, normalized in normalized_entities]
    scored_entities.sort(key=lambda x: x[1], reverse=True)
    
    canonical = scored_entities[0][0]
    
    # Debug output for clusters with multiple entities
    if len(cluster_entities) > 1:
        print(f"  Cluster: {cluster_entities}")
        print(f"  Chosen canonical: '{canonical}'")
    
    return canonical

def create_entity_canonical_mapping(
    all_entities_with_orpha: List[Tuple[str, str]], 
    entities_without_orpha: List[str],
    similarity_threshold: int = 90
) -> Dict[str, str]:
    """
    Create a mapping from entities to their canonical forms, prioritizing ORPHA codes.
    Improved to handle substring relationships better.
    
    Args:
        all_entities_with_orpha: List of (entity, orpha_code) tuples
        entities_without_orpha: List of entity strings without ORPHA codes
        similarity_threshold: Minimum similarity for entity clustering
        
    Returns:
        Dictionary mapping each entity to its canonical identifier
    """
    entity_mapping = {}
    
    # Step 1: Group entities by ORPHA code
    orpha_to_entities = defaultdict(list)
    for entity, orpha_code in all_entities_with_orpha:
        normalized_orpha = normalize_orpha_code(orpha_code)
        if normalized_orpha:
            orpha_to_entities[normalized_orpha].append(entity)
    
    print(f"Found {len(orpha_to_entities)} unique ORPHA codes")
    
    # Step 2: For each ORPHA code group, choose a canonical entity
    orpha_canonical_entities = {}
    for orpha_code, entities in orpha_to_entities.items():
        # Use the improved canonical selection for ORPHA groups too
        canonical_entity = choose_canonical_entity(entities)
        orpha_canonical_entities[orpha_code] = canonical_entity
        
        # Map all entities with this ORPHA code to the canonical entity
        for entity in entities:
            entity_mapping[entity] = canonical_entity
    
    print(f"Mapped {sum(len(entities) for entities in orpha_to_entities.values())} entities using ORPHA codes")
    
    # Step 3: Cluster entities without ORPHA codes using improved clustering
    if entities_without_orpha:
        print(f"Clustering {len(entities_without_orpha)} entities without ORPHA codes")
        string_clustering = cluster_entities(entities_without_orpha, similarity_threshold)
        entity_mapping.update(string_clustering)
    
    return entity_mapping

def create_entity_canonical_mapping_fixed(
    all_entities_with_orpha: List[Tuple[str, str]], 
    entities_without_orpha: List[str],
    similarity_threshold: int = 85  # Lower threshold to catch more similarities
) -> Dict[str, str]:
    """
    Create entity mapping with more aggressive clustering to catch cases like 'budd chiari'.
    """
    entity_mapping = {}
    
    # Step 1: Group entities by ORPHA code
    orpha_to_entities = defaultdict(list)
    for entity, orpha_code in all_entities_with_orpha:
        normalized_orpha = normalize_orpha_code(orpha_code)
        if normalized_orpha:
            orpha_to_entities[normalized_orpha].append(entity)
    
    print(f"Found {len(orpha_to_entities)} unique ORPHA codes")
    
    # Step 2: For each ORPHA code group, choose a canonical entity
    for orpha_code, entities in orpha_to_entities.items():
        canonical_entity = choose_canonical_entity(entities)
        
        # Map all entities with this ORPHA code to the canonical entity
        for entity in entities:
            entity_mapping[entity] = canonical_entity
    # Step 2.5: Cross-reference ORPHA entities with non-ORPHA entities
    for entity_without_orpha in entities_without_orpha:
        for orpha_code, orpha_entities in orpha_to_entities.items():
            for orpha_entity in orpha_entities:
                should_connect, _, _ = improved_substring_check_aggressive(
                    entity_without_orpha, orpha_entity, similarity_threshold
                )
                print("DEBUG:", orpha_entity, entity_without_orpha, should_connect)
                if should_connect:
                    # Map the non-ORPHA entity to the ORPHA canonical form
                    entity_mapping[entity_without_orpha] = choose_canonical_entity(orpha_entities)
                    break
    
    print(f"Mapped {sum(len(entities) for entities in orpha_to_entities.values())} entities using ORPHA codes")
    
    # Step 3: Cluster entities without ORPHA codes with more aggressive matching
    if entities_without_orpha:
        print(f"Clustering {len(entities_without_orpha)} entities without ORPHA codes")
        
        # Use more aggressive clustering
        string_clustering = cluster_entities_aggressive(entities_without_orpha, similarity_threshold)
        entity_mapping.update(string_clustering)
    
    return entity_mapping

def cluster_entities_aggressive(entities: List[str], threshold: int = 85) -> Dict[str, str]:
    """
    More aggressive clustering that should definitely catch 'budd chiari' variants.
    """
    if not entities:
        return {}
    
    print(f"Using aggressive clustering with threshold {threshold}")
    
    # Build similarity graph with more aggressive matching
    G = build_entity_similarity_graph_aggressive(entities, threshold)
    
    # Find connected components
    clusters = list(nx.connected_components(G))
    print(f"Found {len(clusters)} entity clusters for entities without ORPHA codes")
    
    # Create entity mapping
    entity_mapping = {}
    
    for cluster in clusters:
        canonical = choose_canonical_entity(list(cluster))
        
        # Debug output for interesting clusters
        if len(cluster) > 1:
            print(f"  Cluster: {list(cluster)} -> '{canonical}'")
        
        for entity in cluster:
            entity_mapping[entity] = canonical
    
    # Add identity mappings for entities not in any cluster
    for entity in entities:
        if entity not in entity_mapping:
            entity_mapping[entity] = entity
    
    return entity_mapping

def build_entity_similarity_graph_aggressive(entities: List[str], threshold: int = 85) -> nx.Graph:
    """
    More aggressive similarity graph building that should catch more substring relationships.
    """
    G = nx.Graph()
    
    # Add all entities as nodes
    for entity in entities:
        G.add_node(entity)
    
    # Track edges for debugging
    edges_added = []
    
    n = len(entities)
    print(f"Building aggressive similarity graph for {n} entities...")
    
    for i in range(n):
        for j in range(i+1, n):
            entity1 = entities[i]
            entity2 = entities[j]
            if "budd" in entity1.lower() or "budd" in entity2.lower():
                print(f"  Checking '{entity1}' <-> '{entity2}'")
            if entity1 == entity2:
                continue
            
            # Use# Use improved substring check with lower threshold
            should_connect, similarity_score, match_type = improved_substring_check_aggressive(
                entity1, entity2, threshold
            )
            
            if should_connect:
                G.add_edge(entity1, entity2, weight=similarity_score)
                edges_added.append((entity1, entity2, similarity_score, match_type))
                
                # Debug output for interesting connections
                if 'budd' in entity1.lower() or 'budd' in entity2.lower():
                    print(f"  BUDD EDGE: '{entity1}' <-> '{entity2}' (score={similarity_score}, type={match_type})")
    
    print(f"Graph built with {len(G.nodes)} nodes and {len(G.edges)} edges")
    
    return G

def improved_substring_check_aggressive(entity1: str, entity2: str, threshold: int = 85) -> Tuple[bool, int, str]:
    """
    More aggressive substring checking that should definitely catch 'budd chiari' cases.
    """
    # Normalize entities
    norm1 = ' '.join(entity1.lower().split())
    norm2 = ' '.join(entity2.lower().split())
    
    # 1. Exact match
    if norm1 == norm2:
        return True, 100, "exact"
    
    # 2. Direct substring (case insensitive) - this should catch "budd chiari" in "budd chiari syndrome"
    if norm1 in norm2 or norm2 in norm1:
        return True, 95, "substring"
    
    # 3. Word-level subset check - more aggressive
    words1 = set(norm1.split())
    words2 = set(norm2.split())
    
    # Check if one is a subset of the other (regardless of length difference)
    if words1.issubset(words2) and words1 != words2:
        return True, 92, "word_subset"
    elif words2.issubset(words1) and words1 != words2:
        return True, 92, "word_subset"
    
    # 4. Check for partial word matches with lower requirements
    if len(words1) == 1 and len(words2) == 1:
        word1, word2 = list(words1)[0], list(words2)[0]
        if len(word1) >= 3 and len(word2) >= 3:  # Lower requirement
            if word1 in word2 or word2 in word1:
                return True, 88, "partial_word"
    
    # 5. More aggressive fuzzy matching
    fuzzy_score = fuzz.token_sort_ratio(entity1, entity2)
    if fuzzy_score >= threshold:
        return True, fuzzy_score, "fuzzy"
    
    # 6. Additional check: ratio similarity for very close matches
    ratio_score = fuzz.ratio(norm1, norm2)
    if ratio_score >= threshold:
        return True, ratio_score, "ratio"
    
    return False, 0, "no_match"

def apply_or_operation_to_similar_entities(doc_entities: Dict[str, bool], 
                                         entity_mapping: Dict[str, str]) -> Dict[str, bool]:
    """
    Apply OR operation to entities that map to the same canonical form.
    
    Args:
        doc_entities: Dictionary of {entity: is_rare_disease} for a single document
        entity_mapping: Mapping from entities to their canonical forms
        
    Returns:
        Dictionary with unified entities using OR operation
    """
    # Group entities by their canonical form
    canonical_groups = defaultdict(list)
    
    for entity, is_rare in doc_entities.items():
        canonical = entity_mapping.get(entity, entity)  # Use entity itself if not in mapping
        canonical_groups[canonical].append((entity, is_rare))
    
    # Apply OR operation for each canonical group
    unified_entities = {}
    
    for canonical, entity_list in canonical_groups.items():
        # OR operation: True if ANY entity in the group is True
        unified_is_rare = any(is_rare for _, is_rare in entity_list)
        unified_entities[canonical] = unified_is_rare
        
        # Debug output for groups with multiple entities
        if len(entity_list) > 1:
            print(f"  Unified '{canonical}': {entity_list} -> {unified_is_rare}")
    
    return unified_entities

def fix_document_entities_with_or_operation(
    human_doc_entities: Dict[str, Dict[str, bool]], 
    supervisor_doc_entities: Dict[str, Dict[str, bool]],
    entity_mapping: Dict[str, str]
) -> Tuple[Dict[str, Dict[str, bool]], Dict[str, Dict[str, bool]]]:
    """
    Apply OR operation to fix documents where similar entities are separate.
    
    Returns:
        Tuple of (fixed_human_doc_entities, fixed_supervisor_doc_entities)
    """
    print("\n=== APPLYING OR OPERATION TO SIMILAR ENTITIES ===")
    
    fixed_human = {}
    fixed_supervisor = {}
    
    # Process human documents
    for doc_id, entities in human_doc_entities.items():
        unified = apply_or_operation_to_similar_entities(entities, entity_mapping)
        fixed_human[doc_id] = unified
        
        if unified != entities:
            print(f"  Document {doc_id} - Human entities unified")
    
    # Process supervisor documents
    for doc_id, entities in supervisor_doc_entities.items():
        unified = apply_or_operation_to_similar_entities(entities, entity_mapping)
        fixed_supervisor[doc_id] = unified
        
        if unified != entities:
            print(f"  Document {doc_id} - Supervisor entities unified")
    
    return fixed_human, fixed_supervisor

def create_entity_canonical_mapping_with_cross_orpha(
    all_entities_with_orpha: List[Tuple[str, str]], 
    entities_without_orpha: List[str],
    similarity_threshold: int = 85
) -> Dict[str, str]:
    """
    Create entity mapping with cross-ORPHA matching to unify entities like 'budd chiari' -> 'budd chiari syndrome'.
    """
    entity_mapping = {}
    
    # Step 1: Group entities by ORPHA code
    orpha_to_entities = defaultdict(list)
    for entity, orpha_code in all_entities_with_orpha:
        normalized_orpha = normalize_orpha_code(orpha_code)
        if normalized_orpha:
            orpha_to_entities[normalized_orpha].append(entity)
    
    print(f"Found {len(orpha_to_entities)} unique ORPHA codes")
    
    # Step 2: For each ORPHA code group, choose a canonical entity
    orpha_canonical_entities = {}
    for orpha_code, entities in orpha_to_entities.items():
        canonical_entity = choose_canonical_entity(entities)
        orpha_canonical_entities[orpha_code] = canonical_entity
        
        # Map all entities with this ORPHA code to the canonical entity
        for entity in entities:
            entity_mapping[entity] = canonical_entity
    
    print(f"Mapped {sum(len(entities) for entities in orpha_to_entities.values())} entities using ORPHA codes")
    
    # Step 3: Cross-reference non-ORPHA entities with ORPHA entities
    remaining_entities_without_orpha = []
    
    for entity_without_orpha in entities_without_orpha:
        matched_to_orpha = False
        
        # Check against all ORPHA canonical entities
        for orpha_code, canonical_entity in orpha_canonical_entities.items():
            # Check against the canonical entity
            should_connect, score, match_type = improved_substring_check_aggressive(
                entity_without_orpha, canonical_entity, similarity_threshold
            )
            
            if should_connect:
                entity_mapping[entity_without_orpha] = canonical_entity
                matched_to_orpha = True
                print(f"  Cross-ORPHA match: '{entity_without_orpha}' -> '{canonical_entity}' (ORPHA entity, {match_type}, score={score})")
                break
            
            # Also check against all entities in this ORPHA group
            if not matched_to_orpha:
                for orpha_entity in orpha_to_entities[orpha_code]:
                    should_connect, score, match_type = improved_substring_check_aggressive(
                        entity_without_orpha, orpha_entity, similarity_threshold
                    )
                    
                    if should_connect:
                        entity_mapping[entity_without_orpha] = canonical_entity
                        matched_to_orpha = True
                        print(f"  Cross-ORPHA match: '{entity_without_orpha}' -> '{canonical_entity}' (via '{orpha_entity}', {match_type}, score={score})")
                        break
            
            if matched_to_orpha:
                break
        
        # If not matched to any ORPHA entity, keep for string clustering
        if not matched_to_orpha:
            remaining_entities_without_orpha.append(entity_without_orpha)
    
    print(f"Cross-ORPHA matching: {len(entities_without_orpha) - len(remaining_entities_without_orpha)} entities matched to ORPHA entities")
    print(f"Remaining for string clustering: {len(remaining_entities_without_orpha)} entities")
    
    # Step 4: Cluster remaining entities without ORPHA codes
    if remaining_entities_without_orpha:
        print(f"Clustering {len(remaining_entities_without_orpha)} remaining entities without ORPHA codes")
        string_clustering = cluster_entities_aggressive(remaining_entities_without_orpha, similarity_threshold)
        entity_mapping.update(string_clustering)
    
    return entity_mapping

def extract_document_entity_sets_with_orpha_priority_fixed(
    human_corrections: dict, 
    supervisor_corrections: dict, 
    similarity_threshold: int = 85
) -> Tuple[Dict[str, Dict[str, bool]], Dict[str, Dict[str, bool]], Dict[str, str], Dict[str, Dict]]:
    """
    Extract document-level entity sets with proper OR logic applied during processing.
    """
    # Dictionary for abbreviation expansion
    abbreviations = {
        "hit": "heparin induced thrombocytopenia",
        "pah": "pulmonary arterial hypertension",
        "pfo": "patent foramen ovale",
        "pcd": "primary ciliary dyskinesia",
        "hids": "hyper-igd syndrome",
        "ald": "adrenoleukodystrophy",
    }
    
    def is_valid_annotation(entity, context):
        """Check if an annotation is valid."""
        excluded_terms = ["high altitude pulmonary edema"]
        if any(term.lower() in entity.lower() for term in excluded_terms):
            return False
        return True
    
    # Step 1: Collect all entities and create mapping (same as before)
    all_entities_with_orpha = []
    entities_without_orpha = []
    raw_entity_to_original = {}
    
    def process_annotations_for_mapping(annotations, source_name):
        """First pass: collect entities for mapping creation."""
        for annotation in annotations:
            if 'entity' in annotation and 'document_id' in annotation and 'is_rare_disease' in annotation:
                entity = annotation['entity']
                context = annotation.get('context', '')
                
                if not is_valid_annotation(entity, context):
                    continue
                
                normalized = normalize_entity(entity, abbreviations)
                raw_entity_to_original[normalized] = entity
                
                orpha_code = annotation.get('orpha_code', '') or annotation.get('orpha_id', '')
                normalized_orpha = normalize_orpha_code(orpha_code)
                
                if normalized_orpha:
                    all_entities_with_orpha.append((normalized, normalized_orpha))
                else:
                    entities_without_orpha.append(normalized)
    
    # Process both annotation sets to get all entities
    if human_corrections and 'corrected_annotations' in human_corrections:
        process_annotations_for_mapping(human_corrections['corrected_annotations'], "human")
    
    if supervisor_corrections and 'results' in supervisor_corrections:
        all_supervisor_annotations = []
        for category in ['true_positives', 'false_negatives']:
            if category in supervisor_corrections['results']:
                all_supervisor_annotations.extend(supervisor_corrections['results'][category])
        process_annotations_for_mapping(all_supervisor_annotations, "supervisor")
    
    # Create entity mapping
    unique_entities_with_orpha = list(set(all_entities_with_orpha))
    unique_entities_without_orpha = list(set(entities_without_orpha))
    
    print(f"Found {len(unique_entities_with_orpha)} unique entities with ORPHA codes")
    print(f"Found {len(unique_entities_without_orpha)} unique entities without ORPHA codes")
    
    entity_mapping = create_entity_canonical_mapping_with_cross_orpha(
        unique_entities_with_orpha, 
        unique_entities_without_orpha,
        similarity_threshold
    )
    
    # Step 2: Process annotations with OR logic applied immediately
    human_doc_entities = defaultdict(dict)
    supervisor_doc_entities = defaultdict(dict)
    context_mapping = defaultdict(dict)
    
    def process_annotations_with_or_logic(annotations, target_doc_entities, source_name):
        """Process annotations and apply OR logic immediately for same canonical entities."""
        for annotation in annotations:
            if 'entity' in annotation and 'document_id' in annotation and 'is_rare_disease' in annotation:
                entity = annotation['entity']
                context = annotation.get('context', '')
                doc_id = annotation['document_id']
                is_rare = annotation['is_rare_disease']
                orpha_code = annotation.get('orpha_code', '') or annotation.get('orpha_id', '')
                
                if not is_valid_annotation(entity, context):
                    continue
                
                # Get canonical form
                normalized = normalize_entity(entity, abbreviations)
                canonical_entity = entity_mapping.get(normalized, normalized)
                
                # Apply OR logic immediately - if canonical entity already exists, OR with existing value
                if canonical_entity in target_doc_entities[doc_id]:
                    # OR logic: True if current OR existing is True
                    target_doc_entities[doc_id][canonical_entity] = (
                        target_doc_entities[doc_id][canonical_entity] or is_rare
                    )
                    print(f"  OR operation: '{canonical_entity}' in doc {doc_id} - {source_name}: "
                          f"existing={target_doc_entities[doc_id][canonical_entity]} OR new={is_rare} = "
                          f"{target_doc_entities[doc_id][canonical_entity] or is_rare}")
                else:
                    target_doc_entities[doc_id][canonical_entity] = is_rare
                
                # Store context information (keep the most recent or merge as needed)
                context_key = (doc_id, canonical_entity)
                if context_key not in context_mapping:
                    context_mapping[context_key] = {}
                context_mapping[context_key][source_name] = {
                    'context': context,
                    'original_entity': entity,
                    'orpha_code': orpha_code,
                    'is_rare': is_rare,
                    'canonical_entity': canonical_entity
                }
    
    # Process human annotations
    if human_corrections and 'corrected_annotations' in human_corrections:
        process_annotations_with_or_logic(
            human_corrections['corrected_annotations'], 
            human_doc_entities, 
            'human'
        )
    
    # Process supervisor annotations
    if supervisor_corrections and 'results' in supervisor_corrections:
        all_supervisor_annotations = []
        for category in ['true_positives', 'false_negatives']:
            if category in supervisor_corrections['results']:
                all_supervisor_annotations.extend(supervisor_corrections['results'][category])
        
        process_annotations_with_or_logic(
            all_supervisor_annotations, 
            supervisor_doc_entities, 
            'supervisor'
        )
    
    # Ensure all documents are represented in both dictionaries
    all_docs = set(human_doc_entities.keys()) | set(supervisor_doc_entities.keys())
    for doc_id in all_docs:
        if doc_id not in human_doc_entities:
            human_doc_entities[doc_id] = {}
        if doc_id not in supervisor_doc_entities:
            supervisor_doc_entities[doc_id] = {}
    
    print(f"\nProcessed {len(all_docs)} documents with OR logic applied during processing")
    
    return dict(human_doc_entities), dict(supervisor_doc_entities), entity_mapping, dict(context_mapping)


def compute_agreement_metrics(human_doc_entities: Dict[str, Dict[str, bool]], 
                             supervisor_doc_entities: Dict[str, Dict[str, bool]],
                             excluded_docs: Optional[List[str]] = None) -> Dict[str, Any]:
    """
    Compute agreement metrics between human and supervisor annotations without filtering by rarity.
    
    Args:
        human_doc_entities: Dictionary mapping document_id to a dict of {entity: is_rare_disease}
        supervisor_doc_entities: Dictionary mapping document_id to a dict of {entity: is_rare_disease}
        excluded_docs: Optional list of document IDs to exclude from evaluation
        
    Returns:
        Dictionary with agreement metrics
    """
    # Initialize counters
    all_judgments = []  # List of (human, supervisor) tuples for each entity judgment
    
    # Get all document IDs
    all_doc_ids = set(human_doc_entities.keys()) | set(supervisor_doc_entities.keys())
    
    # Remove excluded documents if specified
    if excluded_docs:
        all_doc_ids = all_doc_ids - set(excluded_docs)
    
    # Collect all unique entities across all documents
    all_unique_entities_set = set()
    all_doc_entity_pairs = []  # List of (doc_id, entity) pairs
    
    for doc_id in all_doc_ids:
        # Combine both sets of entities for this document
        human_entities = human_doc_entities[doc_id]
        supervisor_entities = supervisor_doc_entities[doc_id]
        all_entities = set(human_entities.keys()) | set(supervisor_entities.keys())
        
        for entity in all_entities:
            # Track unique entities
            all_unique_entities_set.add(entity)
            # Track all document-entity pairs
            all_doc_entity_pairs.append((doc_id, entity))
    
    # Lists for Cohen's Kappa and other metrics
    human_judgments = []
    supervisor_judgments = []
    
    # Count metrics for confusion matrix
    tp = 0  # Both say it's rare
    tn = 0  # Both say it's not rare
    fp = 0  # Supervisor says rare, Human says not rare
    fn = 0  # Human says rare, Supervisor says not rare
    
    # Process each document-entity pair
    for doc_id, entity in all_doc_entity_pairs:
        # Get human and supervisor judgments (default to False if entity is not present)
        human_judgment = human_doc_entities[doc_id].get(entity, False)
        supervisor_judgment = supervisor_doc_entities[doc_id].get(entity, False)
        
        # Add to lists for metrics calculation
        human_judgments.append(1 if human_judgment else 0)
        supervisor_judgments.append(1 if supervisor_judgment else 0)
        
        # Update confusion matrix
        if human_judgment and supervisor_judgment:
            tp += 1
        elif not human_judgment and not supervisor_judgment:
            tn += 1
        elif supervisor_judgment and not human_judgment:
            fp += 1
        elif human_judgment and not supervisor_judgment:
            fn += 1
        
        # Save the pair of judgments
        all_judgments.append((human_judgment, supervisor_judgment))
    
    # Calculate metrics
    total_judgments = len(all_judgments)
    agreements = tp + tn
    disagreements = fp + fn
    
    # Agreement rates
    if total_judgments > 0:
        percent_agreement = agreements / total_judgments
    else:
        percent_agreement = 0
    
    # Precision, recall, F1
    precision = tp / (tp + fp) if (tp + fp) > 0 else 0
    recall = tp / (tp + fn) if (tp + fn) > 0 else 0
    f1 = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0
    
    # Calculate Cohen's Kappa
    if human_judgments and supervisor_judgments:
        kappa = cohen_kappa_score(human_judgments, supervisor_judgments)
    else:
        kappa = 0
    
    # Calculate Pearson correlation
    if human_judgments and supervisor_judgments and len(human_judgments) > 1:
        pearson_corr, p_value = stats.pearsonr(human_judgments, supervisor_judgments)
    else:
        pearson_corr = 0
        p_value = 1
    
    # Compile entity statistics (with rarity classifications)
    human_rare_entities = sum(1 for doc_id in all_doc_ids for entity, is_rare in human_doc_entities[doc_id].items() if is_rare)
    supervisor_rare_entities = sum(1 for doc_id in all_doc_ids for entity, is_rare in supervisor_doc_entities[doc_id].items() if is_rare)
    
    human_nonrare_entities = sum(1 for doc_id in all_doc_ids for entity, is_rare in human_doc_entities[doc_id].items() if not is_rare)
    supervisor_nonrare_entities = sum(1 for doc_id in all_doc_ids for entity, is_rare in supervisor_doc_entities[doc_id].items() if not is_rare)
    
    # Count unique entities
    unique_human_entities = set()
    unique_supervisor_entities = set()
    unique_human_rare = set()
    unique_supervisor_rare = set()
    
    for doc_id in all_doc_ids:
        for entity, is_rare in human_doc_entities[doc_id].items():
            unique_human_entities.add(entity)
            if is_rare:
                unique_human_rare.add(entity)
                
        for entity, is_rare in supervisor_doc_entities[doc_id].items():
            unique_supervisor_entities.add(entity)
            if is_rare:
                unique_supervisor_rare.add(entity)
    
    return {
        'total_documents': len(all_doc_ids),
        'total_entity_judgments': total_judgments,
        'total_agreements': agreements,
        'total_disagreements': disagreements,
        
        # Entity counts with rarity classification
        'human_rare_entities': human_rare_entities,
        'human_nonrare_entities': human_nonrare_entities,
        'supervisor_rare_entities': supervisor_rare_entities,
        'supervisor_nonrare_entities': supervisor_nonrare_entities,
        
        # Unique entity counts
        'unique_entities_total': len(all_unique_entities_set),
        'unique_human_entities': len(unique_human_entities),
        'unique_supervisor_entities': len(unique_supervisor_entities),
        'unique_human_rare': len(unique_human_rare),
        'unique_supervisor_rare': len(unique_supervisor_rare),
        
        # Confusion matrix
        'true_positives': tp,  # Both say it's rare
        'true_negatives': tn,  # Both say it's not rare
        'false_positives': fp,  # Supervisor says rare, Human says not rare  
        'false_negatives': fn,  # Human says rare, Supervisor says not rare
        
        # Agreement metrics
        'percent_agreement': percent_agreement,
        'cohen_kappa': kappa,
        'pearson_correlation': pearson_corr,
        'pearson_p_value': p_value,
        
        # Classification metrics
        'precision': precision,
        'recall': recall,
        'f1': f1
    }

def analyze_disagreements(human_doc_entities: Dict[str, Dict[str, bool]], 
                         supervisor_doc_entities: Dict[str, Dict[str, bool]],
                         context_mapping: Dict[Tuple[str, str], Dict] = None) -> Dict[str, Any]:
    """
    Analyze disagreements in entity classification with context information.
    
    Args:
        human_doc_entities: Dictionary mapping document_id to entity judgments
        supervisor_doc_entities: Dictionary mapping document_id to entity judgments  
        context_mapping: Optional dictionary with context information for each entity
    """
    # Get all document IDs
    all_doc_ids = set(human_doc_entities.keys()) | set(supervisor_doc_entities.keys())
    
    # Initialize disagreement lists
    human_rare_supervisor_not = []  # Human says rare, Supervisor says not rare
    supervisor_rare_human_not = []  # Supervisor says rare, Human says not rare
    
    # To track contradictory classifications
    entity_classifications = defaultdict(lambda: {"human_rare": 0, "human_not_rare": 0, 
                                                 "supervisor_rare": 0, "supervisor_not_rare": 0,
                                                 "documents": set()})
    
    # Process each document
    for doc_id in all_doc_ids:
        human_entities = human_doc_entities.get(doc_id, {})
        supervisor_entities = supervisor_doc_entities.get(doc_id, {})
        
        # Get all entities from both sets
        all_entities = set(human_entities.keys()) | set(supervisor_entities.keys())
        
        for entity in all_entities:
            # Get judgments (default to False if not present)
            human_judgment = human_entities.get(entity, False)
            supervisor_judgment = supervisor_entities.get(entity, False)
            
            # Track all classifications for this entity
            entity_classifications[entity]["documents"].add(doc_id)
            if human_judgment:
                entity_classifications[entity]["human_rare"] += 1
            else:
                entity_classifications[entity]["human_not_rare"] += 1
            
            if supervisor_judgment:
                entity_classifications[entity]["supervisor_rare"] += 1
            else:
                entity_classifications[entity]["supervisor_not_rare"] += 1
            
            # Check for disagreements
            if human_judgment and not supervisor_judgment:
                # Human says rare, Supervisor says not rare
                disagreement_entry = {
                    'entity': entity,
                    'document_id': doc_id,
                    'human_judgment': True,
                    'supervisor_judgment': False
                }
                
                # Add context information if available
                if context_mapping:
                    context_key = (doc_id, entity)
                    context_info = context_mapping.get(context_key, {})
                    
                    human_info = context_info.get('human', {})
                    supervisor_info = context_info.get('supervisor', {})
                    
                    disagreement_entry.update({
                        'human_context': human_info.get('context', 'No context available'),
                        'human_original_entity': human_info.get('original_entity', entity),
                        'human_orpha_code': human_info.get('orpha_code', ''),
                        'supervisor_context': supervisor_info.get('context', 'No context available'),
                        'supervisor_original_entity': supervisor_info.get('original_entity', 'Entity not found in supervisor'),
                        'supervisor_orpha_code': supervisor_info.get('orpha_code', '')
                    })
                
                human_rare_supervisor_not.append(disagreement_entry)
                
            elif supervisor_judgment and not human_judgment:
                # Supervisor says rare, Human says not rare
                disagreement_entry = {
                    'entity': entity,
                    'document_id': doc_id,
                    'human_judgment': False,
                    'supervisor_judgment': True
                }
                
                # Add context information if available
                if context_mapping:
                    context_key = (doc_id, entity)
                    context_info = context_mapping.get(context_key, {})
                    
                    human_info = context_info.get('human', {})
                    supervisor_info = context_info.get('supervisor', {})
                    
                    disagreement_entry.update({
                        'human_context': human_info.get('context', 'No context available'),
                        'human_original_entity': human_info.get('original_entity', 'Entity not found in human'),
                        'human_orpha_code': human_info.get('orpha_code', ''),
                        'supervisor_context': supervisor_info.get('context', 'No context available'),
                        'supervisor_original_entity': supervisor_info.get('original_entity', entity),
                        'supervisor_orpha_code': supervisor_info.get('orpha_code', '')
                    })
                
                supervisor_rare_human_not.append(disagreement_entry)
    
    # Count frequencies
    human_rare_freq = {}
    for item in human_rare_supervisor_not:
        entity = item['entity']
        if entity not in human_rare_freq:
            human_rare_freq[entity] = 0
        human_rare_freq[entity] += 1
    
    supervisor_rare_freq = {}
    for item in supervisor_rare_human_not:
        entity = item['entity']
        if entity not in supervisor_rare_freq:
            supervisor_rare_freq[entity] = 0
        supervisor_rare_freq[entity] += 1
    
    # Sort by frequency
    human_rare_sorted = sorted(human_rare_freq.items(), key=lambda x: x[1], reverse=True)
    supervisor_rare_sorted = sorted(supervisor_rare_freq.items(), key=lambda x: x[1], reverse=True)
    
    # Identify problematic entities with mixed classifications
    contradictory_entities = {}
    for entity, stats in entity_classifications.items():
        if stats["human_rare"] > 0 and stats["human_not_rare"] > 0:
            contradictory_entities[entity] = {
                "human_contradictory": True,
                "human_rare_count": stats["human_rare"],
                "human_not_rare_count": stats["human_not_rare"],
                "documents": list(stats["documents"])
            }
        if stats["supervisor_rare"] > 0 and stats["supervisor_not_rare"] > 0:
            if entity not in contradictory_entities:
                contradictory_entities[entity] = {
                    "human_contradictory": False,
                    "documents": list(stats["documents"])
                }
            contradictory_entities[entity]["supervisor_contradictory"] = True
            contradictory_entities[entity]["supervisor_rare_count"] = stats["supervisor_rare"]
            contradictory_entities[entity]["supervisor_not_rare_count"] = stats["supervisor_not_rare"]
    
    # For specific entities of interest
    entities_of_interest = ["heparin induced thrombocytopenia", "portal vein thrombosis", 
                          "tracheobronchomalacia", "rheumatic fever", "sarcoid"]
    entities_detail = {}
    
    # Find the closest match for each entity of interest
    for target_entity in entities_of_interest:
        best_match = None
        best_score = 0
        
        # Find the best matching entity in our classification data
        for entity in entity_classifications:
            # Try exact match first
            if entity.lower() == target_entity.lower():
                best_match = entity
                break
                
            # Otherwise use fuzzy matching
            score = fuzz.token_sort_ratio(entity.lower(), target_entity.lower())
            if score > best_score and score >= 85:  # At least 85% similarity
                best_score = score
                best_match = entity
        
        # If we found a match, get its details
        if best_match:
            stats = entity_classifications[best_match]
            entities_detail[target_entity] = {
                "matched_entity": best_match,
                "match_score": best_score if best_match.lower() != target_entity.lower() else 100,
                "stats": stats,
                "human_rare_documents": [doc_id for doc_id in stats["documents"] 
                                        if human_doc_entities.get(doc_id, {}).get(best_match, False)],
                "supervisor_rare_documents": [doc_id for doc_id in stats["documents"] 
                                            if supervisor_doc_entities.get(doc_id, {}).get(best_match, False)]
            }
    
    return {
        'human_rare_supervisor_not': human_rare_supervisor_not,
        'supervisor_rare_human_not': supervisor_rare_human_not,
        'human_rare_freq': human_rare_sorted,
        'supervisor_rare_freq': supervisor_rare_sorted,
        'total_human_rare_disagreements': len(human_rare_supervisor_not),
        'total_supervisor_rare_disagreements': len(supervisor_rare_human_not),
        'unique_human_rare_disagreements': len(human_rare_freq),
        'unique_supervisor_rare_disagreements': len(supervisor_rare_freq),
        'contradictory_entities': contradictory_entities,
        'entities_detail': entities_detail
    }


In [4]:


def validate_or_logic_application(human_doc_entities, supervisor_doc_entities, entity_mapping):
    """
    Validate that OR logic has been properly applied by checking for any potential missed groupings.
    """
    print("\n=== VALIDATING OR LOGIC APPLICATION ===")
    
    # Check if there are any entities in the same document that should be grouped but aren't
    issues_found = 0
    
    for doc_id in set(human_doc_entities.keys()) | set(supervisor_doc_entities.keys()):
        human_entities = human_doc_entities.get(doc_id, {})
        supervisor_entities = supervisor_doc_entities.get(doc_id, {})
        
        # Check if any entities in this document map to the same canonical form
        # This shouldn't happen if OR logic was applied correctly
        human_canonical_counts = defaultdict(int)
        supervisor_canonical_counts = defaultdict(int)
        
        for entity in human_entities:
            canonical = entity_mapping.get(entity, entity)
            human_canonical_counts[canonical] += 1
        
        for entity in supervisor_entities:
            canonical = entity_mapping.get(entity, entity)
            supervisor_canonical_counts[canonical] += 1
        
        # Check for multiple entities mapping to same canonical (should not happen)
        for canonical, count in human_canonical_counts.items():
            if count > 1:
                print(f"  WARNING: Doc {doc_id} has {count} human entities mapping to '{canonical}'")
                issues_found += 1
        
        for canonical, count in supervisor_canonical_counts.items():
            if count > 1:
                print(f"  WARNING: Doc {doc_id} has {count} supervisor entities mapping to '{canonical}'")
                issues_found += 1
    
    if issues_found == 0:
        print("✓ OR logic validation passed - no duplicate canonical entities found")
    else:
        print(f"✗ OR logic validation failed - found {issues_found} potential issues")
    
    return issues_found == 0


def analyze_disagreements_with_clustering_info(
    human_doc_entities: Dict[str, Dict[str, bool]], 
    supervisor_doc_entities: Dict[str, Dict[str, bool]],
    entity_mapping: Dict[str, str],
    context_mapping: Dict[Tuple[str, str], Dict] = None
) -> Dict[str, Any]:
    """
    Enhanced disagreement analysis that includes information about entity clustering and OR operations.
    """
    print("\n=== ANALYZING DISAGREEMENTS WITH CLUSTERING INFO ===")
    
    # Standard disagreement analysis
    disagreements = analyze_disagreements(human_doc_entities, supervisor_doc_entities, context_mapping)
    
    # Additional analysis: identify disagreements involving clustered entities
    clustered_disagreements = {
        'human_rare_supervisor_not_clustered': [],
        'supervisor_rare_human_not_clustered': []
    }
    
    # Find reverse mapping: canonical -> list of original entities
    canonical_to_entities = defaultdict(list)
    for entity, canonical in entity_mapping.items():
        if entity != canonical:  # Only include actual mappings
            canonical_to_entities[canonical].append(entity)
    
    # Analyze disagreements involving clustered entities
    for disagreement in disagreements['human_rare_supervisor_not']:
        entity = disagreement['entity']
        if entity in canonical_to_entities and len(canonical_to_entities[entity]) > 0:
            disagreement['is_clustered_entity'] = True
            disagreement['clustered_variants'] = canonical_to_entities[entity]
            clustered_disagreements['human_rare_supervisor_not_clustered'].append(disagreement)
        else:
            disagreement['is_clustered_entity'] = False
    
    for disagreement in disagreements['supervisor_rare_human_not']:
        entity = disagreement['entity']
        if entity in canonical_to_entities and len(canonical_to_entities[entity]) > 0:
            disagreement['is_clustered_entity'] = True
            disagreement['clustered_variants'] = canonical_to_entities[entity]
            clustered_disagreements['supervisor_rare_human_not_clustered'].append(disagreement)
        else:
            disagreement['is_clustered_entity'] = False
    
    # Add clustering info to the disagreements
    disagreements.update(clustered_disagreements)
    
    # Summary statistics
    total_human_rare_disagree = len(disagreements['human_rare_supervisor_not'])
    clustered_human_rare_disagree = len(clustered_disagreements['human_rare_supervisor_not_clustered'])
    
    total_supervisor_rare_disagree = len(disagreements['supervisor_rare_human_not'])
    clustered_supervisor_rare_disagree = len(clustered_disagreements['supervisor_rare_human_not_clustered'])
    
    print(f"Disagreements involving clustered entities:")
    print(f"  Human rare, supervisor not: {clustered_human_rare_disagree}/{total_human_rare_disagree}")
    print(f"  Supervisor rare, human not: {clustered_supervisor_rare_disagree}/{total_supervisor_rare_disagree}")
    
    return disagreements


def print_clustering_summary(entity_mapping: Dict[str, str]):
    """Print a summary of the entity clustering results."""
    print("\n=== ENTITY CLUSTERING SUMMARY ===")
    
    # Count mappings
    identity_mappings = sum(1 for entity, canonical in entity_mapping.items() if entity == canonical)
    actual_mappings = len(entity_mapping) - identity_mappings
    
    # Group by canonical entity
    canonical_groups = defaultdict(list)
    for entity, canonical in entity_mapping.items():
        canonical_groups[canonical].append(entity)
    
    # Count cluster sizes
    cluster_sizes = defaultdict(int)
    for canonical, entities in canonical_groups.items():
        cluster_size = len(entities)
        cluster_sizes[cluster_size] += 1
    
    print(f"Total entities: {len(entity_mapping)}")
    print(f"Identity mappings (no clustering): {identity_mappings}")
    print(f"Actual mappings (clustered): {actual_mappings}")
    print(f"Unique canonical entities: {len(canonical_groups)}")
    
    print("\nCluster size distribution:")
    for size in sorted(cluster_sizes.keys()):
        count = cluster_sizes[size]
        print(f"  Size {size}: {count} clusters")
    
    # Show largest clusters
    largest_clusters = sorted(canonical_groups.items(), key=lambda x: len(x[1]), reverse=True)[:10]
    print(f"\nLargest clusters:")
    for canonical, entities in largest_clusters:
        if len(entities) > 1:  # Only show actual clusters
            print(f"  '{canonical}': {len(entities)} entities")
            for entity in entities[:5]:  # Show first 5
                print(f"    - '{entity}'")
            if len(entities) > 5:
                print(f"    - ... and {len(entities)-5} more")


# Updated main analysis function
def run_complete_disagreement_analysis(human_corrections, supervisor_corrections, similarity_threshold=85):
    """
    Run the complete disagreement analysis with proper OR logic and validation.
    """
    print("=== RUNNING COMPLETE DISAGREEMENT ANALYSIS ===")
    
    # Step 1: Extract entities with proper OR logic
    human_doc_entities, supervisor_doc_entities, entity_mapping, context_mapping = (
        extract_document_entity_sets_with_orpha_priority_fixed(
            human_corrections, supervisor_corrections, similarity_threshold
        )
    )
    
    # Step 2: Print clustering summary
    print_clustering_summary(entity_mapping)
    
    # Step 3: Validate OR logic application
    validation_passed = validate_or_logic_application(
        human_doc_entities, supervisor_doc_entities, entity_mapping
    )
    
    # Step 4: Compute agreement metrics
    metrics = compute_agreement_metrics(human_doc_entities, supervisor_doc_entities)
    
    # Step 5: Analyze disagreements with clustering info
    disagreements = analyze_disagreements_with_clustering_info(
        human_doc_entities, supervisor_doc_entities, entity_mapping, context_mapping
    )
    
    # Step 6: Print comprehensive report
    print_report(metrics, disagreements)
    
    return {
        'metrics': metrics,
        'disagreements': disagreements,
        'entity_mapping': entity_mapping,
        'validation_passed': validation_passed,
        'human_doc_entities': human_doc_entities,
        'supervisor_doc_entities': supervisor_doc_entities
    }

In [5]:
def is_bracketed_context(context: str) -> bool:
    """
    Check if a context string is bracketed (starts with '[' and ends with ']').
    
    Args:
        context: The context string to check
        
    Returns:
        True if the context is bracketed, False otherwise
    """
    if not context or not isinstance(context, str):
        return False
    
    stripped_context = context.strip()
    return stripped_context.startswith('[') and stripped_context.endswith(']')


def filter_bracketed_disagreements(disagreements_list: List[Dict]) -> Tuple[List[Dict], int]:
    """
    Filter out disagreements where either human or supervisor context is bracketed.
    
    Args:
        disagreements_list: List of disagreement dictionaries
        
    Returns:
        Tuple of (filtered_disagreements, excluded_count)
    """
    filtered_disagreements = []
    excluded_count = 0
    
    for disagreement in disagreements_list:
        human_context = disagreement.get('human_context', '')
        supervisor_context = disagreement.get('supervisor_context', '')
        
        # Check if either context is bracketed
        human_bracketed = is_bracketed_context(human_context)
        supervisor_bracketed = is_bracketed_context(supervisor_context)
        
        if human_bracketed or supervisor_bracketed:
            excluded_count += 1
            # Optional: Add exclusion reason to the disagreement for debugging
            disagreement['excluded_reason'] = f"Bracketed context - Human: {human_bracketed}, Supervisor: {supervisor_bracketed}"
        else:
            filtered_disagreements.append(disagreement)
    
    return filtered_disagreements, excluded_count


def print_detailed_disagreements(disagreements: Dict[str, Any], max_examples: int = 5) -> None:
    """
    Print concise disagreement analysis with contexts, excluding bracketed contexts.
    
    Args:
        disagreements: Dictionary containing disagreement analysis results
        max_examples: Maximum number of examples to show for each type
    """
    print("\n=== DETAILED DISAGREEMENT ANALYSIS (EXCLUDING BRACKETED CONTEXTS) ===")
    
    # Human says rare, supervisor says not rare
    if disagreements['human_rare_supervisor_not']:
        # Filter out bracketed contexts
        filtered_human_rare, human_excluded = filter_bracketed_disagreements(
            disagreements['human_rare_supervisor_not']
        )
        
        print(f"\n--- HUMAN RARE, SUPERVISOR NOT RARE ---")
        print(f"Total disagreements: {len(disagreements['human_rare_supervisor_not'])}")
        print(f"Excluded (bracketed contexts): {human_excluded}")
        print(f"Valid disagreements: {len(filtered_human_rare)}")
        
        if filtered_human_rare:
            # Group by entity and show top examples
            human_rare_freq_with_context = {}
            for item in filtered_human_rare:
                entity = item['entity']
                if entity not in human_rare_freq_with_context:
                    human_rare_freq_with_context[entity] = []
                human_rare_freq_with_context[entity].append(item)
            
            sorted_entities = sorted(human_rare_freq_with_context.items(), 
                                   key=lambda x: len(x[1]), reverse=True)
            
            for i, (entity, examples) in enumerate(sorted_entities[:max_examples]):
                print(f"\n{i+1}. '{entity}' ({len(examples)}x)")
                example = examples[0]  # Show just the first example
                print(f"   Doc {example['document_id']} | H: '{example.get('human_original_entity', 'N/A')}' | S: '{example.get('supervisor_original_entity', 'N/A')}'")
                print(f"   H-Context: \"{example.get('human_context', 'N/A')[:100]}{'...' if len(example.get('human_context', '')) > 100 else ''}\"")
                print(f"   S-Context: \"{example.get('supervisor_context', 'N/A')[:100]}{'...' if len(example.get('supervisor_context', '')) > 100 else ''}\"")
        else:
            print("   No valid disagreements after filtering bracketed contexts.")
    
    # Supervisor says rare, human says not rare
    if disagreements['supervisor_rare_human_not']:
        # Filter out bracketed contexts
        filtered_supervisor_rare, supervisor_excluded = filter_bracketed_disagreements(
            disagreements['supervisor_rare_human_not']
        )
        
        print(f"\n--- SUPERVISOR RARE, HUMAN NOT RARE ---")
        print(f"Total disagreements: {len(disagreements['supervisor_rare_human_not'])}")
        print(f"Excluded (bracketed contexts): {supervisor_excluded}")
        print(f"Valid disagreements: {len(filtered_supervisor_rare)}")
        
        if filtered_supervisor_rare:
            supervisor_rare_freq_with_context = {}
            for item in filtered_supervisor_rare:
                entity = item['entity']
                if entity not in supervisor_rare_freq_with_context:
                    supervisor_rare_freq_with_context[entity] = []
                supervisor_rare_freq_with_context[entity].append(item)
            
            sorted_entities = sorted(supervisor_rare_freq_with_context.items(), 
                                   key=lambda x: len(x[1]), reverse=True)
            
            for i, (entity, examples) in enumerate(sorted_entities[:max_examples]):
                print(f"\n{i+1}. '{entity}' ({len(examples)}x)")
                example = examples[0]  # Show just the first example
                print(f"   Doc {example['document_id']} | H: '{example.get('human_original_entity', 'N/A')}' | S: '{example.get('supervisor_original_entity', 'N/A')}'")
                print(f"   H-Context: \"{example.get('human_context', 'N/A')[:100]}{'...' if len(example.get('human_context', '')) > 100 else ''}\"")
                print(f"   S-Context: \"{example.get('supervisor_context', 'N/A')[:100]}{'...' if len(example.get('supervisor_context', '')) > 100 else ''}\"")
        else:
            print("   No valid disagreements after filtering bracketed contexts.")


def compute_filtered_disagreement_stats(disagreements: Dict[str, Any]) -> Dict[str, Any]:
    """
    Compute disagreement statistics after filtering out bracketed contexts.
    
    Args:
        disagreements: Dictionary containing disagreement analysis results
        
    Returns:
        Dictionary with filtered statistics
    """
    # Filter human rare disagreements
    filtered_human_rare, human_excluded = filter_bracketed_disagreements(
        disagreements.get('human_rare_supervisor_not', [])
    )
    
    # Filter supervisor rare disagreements
    filtered_supervisor_rare, supervisor_excluded = filter_bracketed_disagreements(
        disagreements.get('supervisor_rare_human_not', [])
    )
    
    # Compute frequencies for filtered disagreements
    human_rare_freq = {}
    for item in filtered_human_rare:
        entity = item['entity']
        human_rare_freq[entity] = human_rare_freq.get(entity, 0) + 1
    
    supervisor_rare_freq = {}
    for item in filtered_supervisor_rare:
        entity = item['entity']
        supervisor_rare_freq[entity] = supervisor_rare_freq.get(entity, 0) + 1
    
    # Sort by frequency
    human_rare_sorted = sorted(human_rare_freq.items(), key=lambda x: x[1], reverse=True)
    supervisor_rare_sorted = sorted(supervisor_rare_freq.items(), key=lambda x: x[1], reverse=True)
    
    return {
        'filtered_human_rare_supervisor_not': filtered_human_rare,
        'filtered_supervisor_rare_human_not': filtered_supervisor_rare,
        'human_excluded_count': human_excluded,
        'supervisor_excluded_count': supervisor_excluded,
        'filtered_human_rare_freq': human_rare_sorted,
        'filtered_supervisor_rare_freq': supervisor_rare_sorted,
        'filtered_total_human_rare_disagreements': len(filtered_human_rare),
        'filtered_total_supervisor_rare_disagreements': len(filtered_supervisor_rare),
        'filtered_unique_human_rare_disagreements': len(human_rare_freq),
        'filtered_unique_supervisor_rare_disagreements': len(supervisor_rare_freq),
    }


def print_report(metrics: Dict[str, Any], disagreements: Dict[str, Any]) -> None:
    """
    Print a report with the agreement metrics, excluding bracketed contexts from disagreement analysis.
    
    Args:
        metrics: Dictionary containing agreement metrics
        disagreements: Dictionary containing disagreement analysis results
    """
    print("\n=== RARE DISEASE ANNOTATOR AGREEMENT REPORT WITH ORPHA CODE PRIORITY ===")
    
    print("\n--- ENTITY STATISTICS ---")
    print(f"Documents analyzed: {metrics['total_documents']}")
    print(f"Total entity judgments: {metrics['total_entity_judgments']}")
    
    print(f"\nHuman annotations:")
    print(f"  Total entities: {metrics['human_rare_entities'] + metrics['human_nonrare_entities']}")
    print(f"  Rare disease entities: {metrics['human_rare_entities']}")
    print(f"  Non-rare entities: {metrics['human_nonrare_entities']}")
    
    print(f"\nSupervisor annotations:")
    print(f"  Total entities: {metrics['supervisor_rare_entities'] + metrics['supervisor_nonrare_entities']}")
    print(f"  Rare disease entities: {metrics['supervisor_rare_entities']}")
    print(f"  Non-rare entities: {metrics['supervisor_nonrare_entities']}")
    
    print(f"\nUnique entities after ORPHA-priority mapping:")
    print(f"  Total unique entities: {metrics['unique_entities_total']}")
    print(f"  Unique in human annotations: {metrics['unique_human_entities']}")
    print(f"  Unique in supervisor annotations: {metrics['unique_supervisor_entities']}")
    print(f"  Unique rare in human: {metrics['unique_human_rare']}")
    print(f"  Unique rare in supervisor: {metrics['unique_supervisor_rare']}")
    
    print("\n--- AGREEMENT METRICS ---")
    print(f"Agreements: {metrics['total_agreements']} of {metrics['total_entity_judgments']} judgments")
    print(f"Percentage Agreement (Accuracy): {metrics['percent_agreement']:.4f}")
    print(f"Cohen's Kappa: {metrics['cohen_kappa']:.4f}")
    
    print("\n--- CONFUSION MATRIX ---")
    print(f"True Positives (both rare): {metrics['true_positives']}")
    print(f"True Negatives (both non-rare): {metrics['true_negatives']}")
    print(f"False Positives (supervisor rare, human non-rare): {metrics['false_positives']}")
    print(f"False Negatives (human rare, supervisor non-rare): {metrics['false_negatives']}")
    
    print("\n--- CLASSIFICATION METRICS ---")
    print(f"Precision: {metrics['precision']:.4f}")
    print(f"Recall: {metrics['recall']:.4f}")
    print(f"F1 Score: {metrics['f1']:.4f}")
    
    # Compute filtered disagreement statistics
    filtered_stats = compute_filtered_disagreement_stats(disagreements)
    
    print("\n--- DISAGREEMENT SUMMARY (EXCLUDING BRACKETED CONTEXTS) ---")
    print(f"Original disagreements:")
    print(f"  Human says rare, supervisor says not: {disagreements['total_human_rare_disagreements']} total")
    print(f"  Supervisor says rare, human says not: {disagreements['total_supervisor_rare_disagreements']} total")
    
    print(f"\nFiltered disagreements (excluding bracketed contexts):")
    print(f"  Human says rare, supervisor says not: {filtered_stats['filtered_total_human_rare_disagreements']} valid ({filtered_stats['human_excluded_count']} excluded)")
    print(f"  Supervisor says rare, human says not: {filtered_stats['filtered_total_supervisor_rare_disagreements']} valid ({filtered_stats['supervisor_excluded_count']} excluded)")
    print(f"  Unique entities in human rare disagreements: {filtered_stats['filtered_unique_human_rare_disagreements']}")
    print(f"  Unique entities in supervisor rare disagreements: {filtered_stats['filtered_unique_supervisor_rare_disagreements']}")
    
    if filtered_stats['filtered_human_rare_freq']:
        print("\nTop disagreements (human says rare, supervisor says not) - after filtering:")
        for entity, count in filtered_stats['filtered_human_rare_freq'][:5]:
            print(f"  {entity}: {count} occurrences")
    
    if filtered_stats['filtered_supervisor_rare_freq']:
        print("\nTop disagreements (supervisor says rare, human says not) - after filtering:")
        for entity, count in filtered_stats['filtered_supervisor_rare_freq'][:5]:
            print(f"  {entity}: {count} occurrences")
    
    # Print contradictory entities (unchanged)
    if disagreements.get('contradictory_entities'):
        print("\nEntities with contradictory classifications:")
        for entity, details in sorted(disagreements['contradictory_entities'].items(), 
                                    key=lambda x: x[1].get('human_rare_count', 0) + 
                                                 x[1].get('supervisor_rare_count', 0), 
                                    reverse=True)[:10]:
            print(f"\n  {entity}:")
            if details.get('human_contradictory', False):
                print(f"    Human classified as rare: {details['human_rare_count']} times")
                print(f"    Human classified as not rare: {details['human_not_rare_count']} times")
            if details.get('supervisor_contradictory', False):
                print(f"    Supervisor classified as rare: {details['supervisor_rare_count']} times")
                print(f"    Supervisor classified as not rare: {details['supervisor_not_rare_count']} times")
            print(f"    Appears in {len(details['documents'])} documents")
    
    # Print details for entities of interest (unchanged)
    if disagreements.get('entities_detail'):
        print("\nSpecific entities of interest:")
        for target_entity, details in disagreements['entities_detail'].items():
            matched_entity = details.get("matched_entity", "")
            match_info = ""
            if matched_entity != target_entity and details.get("match_score", 0) < 100:
                match_info = f" (matched to '{matched_entity}' with {details['match_score']}% similarity)"
                
            stats = details['stats']
            print(f"\n  {target_entity}{match_info}:")
            print(f"    Human classified as rare: {stats['human_rare']} times")
            print(f"    Human classified as not rare: {stats['human_not_rare']} times")
            print(f"    Supervisor classified as rare: {stats['supervisor_rare']} times")
            print(f"    Supervisor classified as not rare: {stats['supervisor_not_rare']} times")
            print(f"    Human rare documents: {details['human_rare_documents']}")
            print(f"    Supervisor rare documents: {details['supervisor_rare_documents']}")
    
    # Validation assertions (unchanged)
    print("\n--- VALIDATION ---")
    print("Assertion checks:")
    print("human rare entities:", metrics['human_rare_entities'])
    print("supervisor rare entities:", metrics['supervisor_rare_entities'])
    print("true positives:", metrics['true_positives'])
    print("false negatives:", metrics['false_negatives'])
    print("false positives:", metrics['false_positives'])

    try:
        assert metrics['human_rare_entities'] == metrics['true_positives'] + metrics['false_negatives']
        print("✓ Human rare entities = TP + FN")
    except AssertionError:
        print("✗ Human rare entities ≠ TP + FN")
    
    try:
        assert metrics['supervisor_rare_entities'] == metrics['true_positives'] + metrics['false_positives']
        print("✓ Supervisor rare entities = TP + FP")
    except AssertionError:
        print("✗ Supervisor rare entities ≠ TP + FP")
    
    # Print detailed disagreement analysis with filtering
    print_detailed_disagreements(disagreements)


# Example usage and test cases
def test_bracket_filtering():
    """Test the bracket filtering functionality."""
    test_contexts = [
        "[Entity 'retinopathy of prematurity' occurrence #2 (index 1) not found by string search]",
        "Patient has diabetes mellitus",
        "[This is a bracketed context]",
        "Normal clinical context without brackets",
        "   [  Bracketed with spaces  ]   ",
        "",
        None
    ]
    
    print("Testing bracket filtering:")
    for context in test_contexts:
        is_bracketed = is_bracketed_context(context)
        print(f"'{context}' -> Bracketed: {is_bracketed}")
    
    # Test disagreement filtering
    test_disagreements = [
        {
            'entity': 'diabetes',
            'human_context': 'Patient has diabetes mellitus',
            'supervisor_context': 'Diabetes noted in history'
        },
        {
            'entity': 'retinopathy',
            'human_context': "[Entity 'retinopathy of prematurity' occurrence #2 not found]",
            'supervisor_context': 'Retinopathy of prematurity diagnosed'
        },
        {
            'entity': 'hypertension',
            'human_context': 'Blood pressure elevated',
            'supervisor_context': '[Hypertension context bracketed]'
        }
    ]
    
    filtered, excluded_count = filter_bracketed_disagreements(test_disagreements)
    print(f"\nFiltering test: {len(test_disagreements)} original, {len(filtered)} filtered, {excluded_count} excluded")
    print("Remaining disagreements:", [d['entity'] for d in filtered])


if __name__ == "__main__":
    test_bracket_filtering()

Testing bracket filtering:
'[Entity 'retinopathy of prematurity' occurrence #2 (index 1) not found by string search]' -> Bracketed: True
'Patient has diabetes mellitus' -> Bracketed: False
'[This is a bracketed context]' -> Bracketed: True
'Normal clinical context without brackets' -> Bracketed: False
'   [  Bracketed with spaces  ]   ' -> Bracketed: True
'' -> Bracketed: False
'None' -> Bracketed: False

Filtering test: 3 original, 1 filtered, 2 excluded
Remaining disagreements: ['diabetes']


In [6]:

# Load human and supervisor corrections
human_corrections = read_json_file("data/dataset/rare_disease_corrections_john.json")
rdma_corrections = read_json_file("data/results/supervisor/multistage_no_min.json")

# Extract document-level entity sets with ORPHA code priority
human_doc_entities, supervisor_doc_entities, entity_mapping, context_mapping = extract_document_entity_sets_with_orpha_priority_fixed(
    human_corrections, 
    rdma_corrections,
    similarity_threshold=90  # Using 90% as the clustering threshold for entities without ORPHA codes
)

# Compute agreement metrics
metrics = compute_agreement_metrics(
    human_doc_entities, 
    supervisor_doc_entities
)

# Analyze disagreements (now with context information)
disagreements = analyze_disagreements(
    human_doc_entities, 
    supervisor_doc_entities,
    context_mapping
)

# Print report
print_report(metrics, disagreements)

Found 98 unique entities with ORPHA codes
Found 31 unique entities without ORPHA codes
Found 84 unique ORPHA codes
  Cluster: ['postpolio syndrome', 'post polio syndrome']
  Chosen canonical: 'post polio syndrome'
  Cluster: ['lyme disease', "lyme's disease"]
  Chosen canonical: 'lyme's disease'
  Cluster: ['hemochromatosis', 'iron storage disease']
  Chosen canonical: 'iron storage disease'
  Cluster: ["alport's syndrome", 'alport syndrome']
  Chosen canonical: 'alport's syndrome'
  Cluster: ['als', 'amyotrophic lateral sclerosis']
  Chosen canonical: 'amyotrophic lateral sclerosis'
  Cluster: ['fulminant liver failure', 'acute hepatic failure', 'fulminant hepatic failure']
  Chosen canonical: 'fulminant hepatic failure'
  Cluster: ['essential thrombocythemia', 'essential thrombocytosis']
  Chosen canonical: 'essential thrombocythemia'
  Cluster: ['epileptic seizures', 'epileptic', 'seizure disorder', 'epilepsy', 'epileptic seizure']
  Chosen canonical: 'epileptic seizures'
  Cluster:

In [7]:
# print(sorted(supervisor_doc_entities.keys()))
for doc_id, entities in human_doc_entities.items():
    
    if doc_id in supervisor_doc_entities:
        if len(supervisor_doc_entities[doc_id]) != len(human_doc_entities[doc_id]):
            print(f"Document ID: {doc_id}")
            print(f"Entities: {entities}")
            print("Supervisor Entities:")
            print(supervisor_doc_entities[doc_id])
    else:
        print("No supervisor entities found for this document.")
        print(human_doc_entities[doc_id])
    

In [8]:
import re
import json
import os
from pathlib import Path
from typing import List, Dict, Any, Tuple
from circuitsvis.tokens import colored_tokens

class SimpleDisagreementVisualizer:
    """Simple visualizer for rare disease entity extraction disagreements."""
    
    def __init__(self):
        pass
    
    def is_bracketed_context(self, context: str) -> bool:
        """Check if context is bracketed (starts with '[' and ends with ']')."""
        if not context or not isinstance(context, str):
            return False
        stripped = context.strip()
        return stripped.startswith('[') and stripped.endswith(']')
    
    def tokenize_text(self, text: str) -> List[str]:
        """Tokenization that preserves spaces and handles punctuation."""
        if not text or self.is_bracketed_context(text):
            return []
        
        # Split text into tokens including words, punctuation, and spaces
        # This handles cases like "thrombosis/portal" by keeping "/" as separate token
        tokens = []
        current_pos = 0
        
        # Find all word boundaries, punctuation, and spaces
        import re
        for match in re.finditer(r'\S+', text):
            start, end = match.span()
            
            # Add any spaces before this token
            if start > current_pos:
                space_text = text[current_pos:start]
                if space_text:
                    tokens.append(space_text)
            
            # Get the word/punctuation chunk
            chunk = text[start:end]
            
            # Split chunk further if it contains punctuation mixed with words
            # This handles "thrombosis/portal" -> ["thrombosis", "/", "portal"]
            sub_tokens = re.findall(r'\w+|[^\w\s]', chunk)
            
            for i, sub_token in enumerate(sub_tokens):
                tokens.append(sub_token)
                # Don't add space after the last sub-token of this chunk
                if i < len(sub_tokens) - 1:
                    # Add minimal space between sub-tokens within same chunk
                    pass  # Let the natural spacing handle this
            
            current_pos = end
        
        # Add any trailing spaces
        if current_pos < len(text):
            trailing_space = text[current_pos:]
            if trailing_space:
                tokens.append(trailing_space)
        
        return tokens
    
    def find_entity_positions(self, tokens: List[str], entity: str) -> List[Tuple[int, int]]:
        """Find positions of entity in tokenized text, including partial matches within chunks."""
        if not tokens or not entity:
            return []
        
        entity_words = entity.lower().split()
        if not entity_words:
            return []
        
        positions = []
        
        # Method 1: Find complete word sequence matches
        word_tokens = []
        token_indices = []
        
        for i, token in enumerate(tokens):
            if token.strip() and re.match(r'\w', token):  # Only actual words, not spaces or pure punctuation
                word_tokens.append(token.lower())
                token_indices.append(i)
        
        # Find complete entity matches in word sequence
        for i in range(len(word_tokens) - len(entity_words) + 1):
            if word_tokens[i:i + len(entity_words)] == entity_words:
                start_idx = token_indices[i]
                if i + len(entity_words) - 1 < len(token_indices):
                    # Include tokens up to and including the last entity word
                    end_idx = token_indices[i + len(entity_words) - 1] + 1
                else:
                    end_idx = len(tokens)
                positions.append((start_idx, end_idx))
        
        # Method 2: Find partial matches within individual tokens
        # This handles cases like "thrombosis/portal vein thrombosis" where entity spans across punctuation
        entity_text = entity.lower()
        
        for i, token in enumerate(tokens):
            if token.strip() and re.match(r'\w', token):  # Only check word tokens
                token_lower = token.lower()
                
                # Check if this token contains any part of our entity
                if any(word in token_lower for word in entity_words):
                    # Check if we can find the full entity starting from this position
                    remaining_text = ' '.join([t for t in tokens[i:] if t.strip() and re.match(r'\w', t)])
                    
                    if entity_text in remaining_text.lower():
                        # Find how many tokens we need to include to get the full entity
                        current_text = ""
                        end_idx = i
                        
                        for j in range(i, len(tokens)):
                            if tokens[j].strip() and re.match(r'\w', tokens[j]):
                                if current_text:
                                    current_text += " "
                                current_text += tokens[j].lower()
                                
                                if entity_text in current_text:
                                    end_idx = j + 1
                                    # Avoid duplicates by checking if we already found this position
                                    if (i, end_idx) not in positions:
                                        positions.append((i, end_idx))
                                    break
        
        # Remove duplicate positions and sort
        positions = list(set(positions))
        positions.sort()
        
        return positions
    
    def create_highlight_scores(self, tokens: List[str], entity: str) -> List[float]:
        """Create highlight scores for tokens (1.0 for entity, 0.0 for others)."""
        if not tokens:
            return []
        
        scores = [0.0] * len(tokens)
        positions = self.find_entity_positions(tokens, entity)
        
        for start_idx, end_idx in positions:
            for i in range(start_idx, end_idx):
                scores[i] = 1.0
                
        return scores
    
    def visualize_single_disagreement(self, disagreement: Dict, title: str = ""):
        """Visualize a single disagreement with both contexts highlighted."""
        entity = disagreement.get('entity', '')
        human_context = disagreement.get('human_context', '')
        supervisor_context = disagreement.get('supervisor_context', '')
        doc_id = disagreement.get('document_id', 'Unknown')
        
        print(f"\n{'-'*60}")
        if title:
            print(f"📊 {title}")
        print(f"Entity: '{entity}'")
        print(f"Document: {doc_id}")
        
        # Skip if either context is bracketed
        if self.is_bracketed_context(human_context) or self.is_bracketed_context(supervisor_context):
            print("⚠️  SKIPPED: Bracketed context detected")
            return
        
        # Show human context
        if human_context:
            print(f"\n👤 Context:")
            tokens = self.tokenize_text(human_context)
            if tokens:
                scores = self.create_highlight_scores(tokens, entity)
                colored_tokens(tokens, scores, positive_color="#ff6b6b")
    
    def filter_valid_disagreements(self, disagreements_list: List[Dict]) -> List[Dict]:
        """Filter out disagreements with bracketed contexts."""
        return [d for d in disagreements_list 
                if not (self.is_bracketed_context(d.get('human_context', '')) or 
                       self.is_bracketed_context(d.get('supervisor_context', '')))]
    
    def visualize_category(self, disagreements: Dict, category: str, max_examples: int = 5):
        """Visualize disagreements for a specific category."""
        if category not in disagreements:
            print(f"Category '{category}' not found")
            return
        
        category_data = disagreements[category]
        valid_disagreements = self.filter_valid_disagreements(category_data)
        
        # Category titles
        titles = {
            'human_rare_supervisor_not': '👤 SAYS RARE, 🤖 SAYS NOT RARE',
            'supervisor_rare_human_not': '🤖 SAYS RARE, 👤 SAYS NOT RARE'
        }
        
        title = titles.get(category, category.upper())
        
        print(f"\n{'='*20} {title} {'='*20}")
        print(f"Total: {len(category_data)} | Valid: {len(valid_disagreements)} | Showing: {min(max_examples, len(valid_disagreements))}")
        
        if not valid_disagreements:
            print("No valid disagreements to show.")
            return
        
        # Show examples
        for i, disagreement in enumerate(valid_disagreements[:max_examples]):
            self.visualize_single_disagreement(
                disagreement, 
                f"Example {i+1}/{min(max_examples, len(valid_disagreements))}"
            )
    
    def visualize_all(self, disagreements: Dict, max_examples_per_category: int = 3):
        """Visualize both disagreement categories."""
        print("\n" + "="*80)
        print("🔍 DISAGREEMENT VISUALIZATION")
        print("="*80)
        
        # Human rare, supervisor not rare
        self.visualize_category(disagreements, 'human_rare_supervisor_not', max_examples_per_category)
        
        # Supervisor rare, human not rare  
        self.visualize_category(disagreements, 'supervisor_rare_human_not', max_examples_per_category)
    
    def save_to_html(self, disagreements: Dict, max_examples: int = 10, output_dir: str = "figs", 
                     show_human: bool = True):
        """Save disagreement visualizations to HTML files using CircuitsVis."""
        # Create output directory
        Path(output_dir).mkdir(parents=True, exist_ok=True)
        
        # Save each category to separate HTML files
        categories = {
            'human_rare_supervisor_not': {
                'title': '👤 Says Rare, 🤖 Says Not Rare',
                'filename_human': 'human_rare_supervisor_not_human_context.html',
                'filename_supervisor': 'human_rare_supervisor_not_supervisor_context.html'
            },
            'supervisor_rare_human_not': {
                'title': '🤖 Says Rare, 👤 Says Not Rare', 
                'filename_human': 'supervisor_rare_human_not_human_context.html',
                'filename_supervisor': 'supervisor_rare_human_not_supervisor_context.html'
            }
        }
        
        for category, info in categories.items():
            if category in disagreements:
                # Save human context version
                html_file_human = os.path.join(output_dir, info['filename_human'])
                self._create_html_file(
                    disagreements[category], 
                    info['title'] + " (👤 Context)",
                    html_file_human,
                    max_examples,
                    show_context='human'
                )
                print(f"Saved {category} (human context) to: {html_file_human}")
                
                # Save supervisor context version
                html_file_supervisor = os.path.join(output_dir, info['filename_supervisor'])
                self._create_html_file(
                    disagreements[category], 
                    info['title'] + " (🤖 Context)",
                    html_file_supervisor,
                    max_examples,
                    show_context='supervisor'
                )
                print(f"Saved {category} (supervisor context) to: {html_file_supervisor}")
    
    def _create_html_file(self, category_data: List[Dict], title: str, 
                         output_file: str, max_examples: int, show_context: str = 'human'):
        """Create simple HTML file showing just highlighted tokens using table layout."""
        # Filter valid disagreements
        valid_disagreements = self.filter_valid_disagreements(category_data)
        
        if not valid_disagreements:
            html_content = f"""<!DOCTYPE html>
<html><head><title>{title}</title></head>
<body><h1>{title}</h1><p>No valid disagreements found.</p></body></html>"""
            with open(output_file, 'w', encoding='utf-8') as f:
                f.write(html_content)
            return
        
        # Group by entity and take top examples
        entity_groups = {}
        for item in valid_disagreements:
            entity = item['entity']
            if entity not in entity_groups:
                entity_groups[entity] = []
            entity_groups[entity].append(item)
        
        sorted_entities = sorted(entity_groups.items(), key=lambda x: len(x[1]), reverse=True)
        
        # Create HTML with table-based layout to prevent CSS cascading
        html_content = f"""<!DOCTYPE html>
<html>
<head>
    <meta charset="UTF-8">
    <title>{title}</title>
    <style>
        body {{ 
            font-family: Arial, sans-serif; 
            margin: 15px; 
            padding: 0;
        }}
        
        .disagreement-table {{
            width: 100%;
            border-collapse: separate;
            border-spacing: 0;
            margin: 8px 0;  /* Reduced from 20px */
        }}
        
        .disagreement-row {{
            border: 1px solid #ddd;  /* Thinner border */
        }}
        
        .entity-cell {{
            background-color: #f5f5f5;
            padding: 8px 12px;  /* Reduced from 15px */
            border-bottom: 1px solid #ddd;
            font-weight: bold;
            font-size: 15px;  /* Slightly smaller */
        }}
        
        .context-cell {{
            padding: 10px 12px;  /* Reduced from 15px */
            background-color: white;
            vertical-align: top;
        }}
        
        .context-label {{
            font-weight: bold;
            margin-bottom: 5px;  /* Reduced from 8px */
            color: #555;
            font-size: 14px;
        }}
        
        .tokens-wrapper {{
            background-color: #f9f9f9;
            padding: 6px 8px;  /* Reduced from 10px */
            border-radius: 3px;
            border: 1px solid #e0e0e0;
            margin: 2px 0;  /* Reduced from 5px */
            line-height: 1.3;  /* Tighter line spacing */
        }}
    </style>
</head>
<body>
    <h1 style="margin-bottom: 15px;">{title}</h1>
"""
        
        # Generate visualizations and create table rows
        for i, (entity, examples) in enumerate(sorted_entities[:max_examples]):
            example = examples[0]
            human_context = example.get('human_context', '')
            supervisor_context = example.get('supervisor_context', '')
            human_original = example.get('human_original_entity', entity)
            supervisor_original = example.get('supervisor_original_entity', entity)
            doc_id = example.get('document_id', 'Unknown')
            
            # Generate appropriate context visualization based on show_context parameter
            context_viz_html = ""
            context_icon = ""
            original_entity = entity
            
            if show_context == 'human':
                context_icon = "👤"
                original_entity = human_original
                if human_context and not self.is_bracketed_context(human_context):
                    tokens = self.tokenize_text(human_context)
                    if tokens:
                        scores = self.create_highlight_scores(tokens, human_original)
                        context_viz = colored_tokens(
                            tokens, scores, 
                            min_value=0.0, max_value=1.0,
                            positive_color="#ff6b6b"
                        )
                        context_viz_html = context_viz._repr_html_()
            else:  # supervisor
                context_icon = "🤖"
                original_entity = supervisor_original
                if supervisor_context and not self.is_bracketed_context(supervisor_context):
                    tokens = self.tokenize_text(supervisor_context)
                    if tokens:
                        scores = self.create_highlight_scores(tokens, supervisor_original)
                        context_viz = colored_tokens(
                            tokens, scores, 
                            min_value=0.0, max_value=1.0,
                            positive_color="#ff6b6b"
                        )
                        context_viz_html = context_viz._repr_html_()
            
            # Create table row for this disagreement
            html_content += f"""
<table class="disagreement-table">
    <tr class="disagreement-row">
        <td class="entity-cell">
            #{i+1}: "{entity}" ({len(examples)} occurrences) - Document: {doc_id}
        </td>
    </tr>
"""
            
            if context_viz_html:
                html_content += f"""
    <tr class="disagreement-row">
        <td class="context-cell">
            <div class="context-label">{context_icon} Context:</div>
            <div class="tokens-wrapper">{context_viz_html}</div>
        </td>
    </tr>
"""
            else:
                # Show message if no valid context
                html_content += f"""
    <tr class="disagreement-row">
        <td class="context-cell">
            <div class="context-label">{context_icon} Context:</div>
            <div class="tokens-wrapper" style="color: #999; font-style: italic;">No valid context available</div>
        </td>
    </tr>
"""
            
            html_content += "</table>\n\n"
        
        html_content += "</body></html>"
        
        # Write to file
        with open(output_file, 'w', encoding='utf-8') as f:
            f.write(html_content)

# Simple usage functions
def visualize_disagreements(disagreements_dict: Dict, max_examples: int = 5, save_html: bool = False):
    """
    Main function to visualize disagreements.
    
    Args:
        disagreements_dict: Dictionary with 'human_rare_supervisor_not' and 'supervisor_rare_human_not' keys
        max_examples: Maximum examples to show per category
        save_html: Whether to save HTML files to figs/ directory
    """
    visualizer = SimpleDisagreementVisualizer()
    
    # Console visualization
    visualizer.visualize_all(disagreements_dict, max_examples)
    
    # Save HTML files if requested
    if save_html:
        print(f"\n💾 Saving HTML files...")
        visualizer.save_to_html(disagreements_dict, max_examples * 2, "figs")
        print(f"HTML files saved to figs/ directory")

def save_disagreements_html(disagreements_dict: Dict, max_examples: int = 10, output_dir: str = "figs"):
    """
    Save disagreement visualizations to HTML files.
    
    Args:
        disagreements_dict: Dictionary with disagreement data
        max_examples: Maximum examples per category
        output_dir: Output directory (default: "figs")
    """
    visualizer = SimpleDisagreementVisualizer()
    visualizer.save_to_html(disagreements_dict, max_examples, output_dir)
    
    print(f"✅ HTML files created:")
    print(f"   👤 Human context versions:")
    print(f"     - {output_dir}/human_rare_supervisor_not_human_context.html")
    print(f"     - {output_dir}/supervisor_rare_human_not_human_context.html")
    print(f"   🤖 Supervisor context versions:")
    print(f"     - {output_dir}/human_rare_supervisor_not_supervisor_context.html")
    print(f"     - {output_dir}/supervisor_rare_human_not_supervisor_context.html")

# Test with sample data
def test_visualization():
    """Test the visualizer with sample data."""
    sample_data = {
        'human_rare_supervisor_not': [
            {
                'entity': 'diabetes',
                'human_context': 'Patient diagnosed with diabetes mellitus type 1 last year',
                'supervisor_context': 'History of diabetes noted in chart',
                'document_id': 'DOC001'
            }
        ],
        'supervisor_rare_human_not': [
            {
                'entity': 'hypertension', 
                'human_context': 'Blood pressure consistently elevated above normal range',
                'supervisor_context': 'Patient has essential hypertension requiring medication',
                'document_id': 'DOC002'
            }
        ]
    }
    
    visualize_disagreements(sample_data, max_examples=2)

# if __name__ == "__main__":
#     test_visualization()

In [9]:
visualize_disagreements(disagreements, max_examples=2, save_html=True)



🔍 DISAGREEMENT VISUALIZATION

Total: 43 | Valid: 41 | Showing: 2

------------------------------------------------------------
📊 Example 1/2
Entity: 'methemoglobinemia'
Document: 14936

👤 Context:

------------------------------------------------------------
📊 Example 2/2
Entity: 'hypogammaglobulinemia'
Document: 33893

👤 Context:

Total: 6 | Valid: 2 | Showing: 2

------------------------------------------------------------
📊 Example 1/2
Entity: 'heparin induced thrombocytopenia'
Document: 10406

👤 Context:

------------------------------------------------------------
📊 Example 2/2
Entity: 'heparin induced thrombocytopenia'
Document: 52528

👤 Context:

💾 Saving HTML files...


Saved human_rare_supervisor_not (human context) to: figs/human_rare_supervisor_not_human_context.html
Saved human_rare_supervisor_not (supervisor context) to: figs/human_rare_supervisor_not_supervisor_context.html
Saved supervisor_rare_human_not (human context) to: figs/supervisor_rare_human_not_human_context.html
Saved supervisor_rare_human_not (supervisor context) to: figs/supervisor_rare_human_not_supervisor_context.html
HTML files saved to figs/ directory


# All Human vs. RDMA + Human

# 