In [None]:

import os
import sys
current_dir = os.getcwd()
parent_dir = os.path.dirname(current_dir)
sys.path.insert(0, parent_dir)
# Set the parent directory as the current directory
os.chdir(parent_dir)

In [None]:
import json
import pandas as pd
import argparse
from pathlib import Path

def extract_flagged_entities(json_file):
    """Extract flagged entities with context from supervisor output JSON file."""
    
    # Load the JSON file
    with open(json_file, 'r') as f:
        data = json.load(f)
    
    # Get the flagged entities from the summary
    flagged_entities = data.get('summary', {}).get('flagged_entities', [])
    print(f"Found {len(flagged_entities)} flagged entities")
    
    # Prepare detailed data for review
    detailed_entities = []
    
    # Process each flagged entity
    for entity in flagged_entities:
        entity_info = {
            'entity': entity.get('entity', ''),
            'document_id': entity.get('document_id', ''),
            'orpha_code': entity.get('orpha_code', ''),
            'category': entity.get('category', ''),
            'explanation': entity.get('explanation', '')
        }
        
        # Get detailed information from the results section
        category = entity.get('category', '')
        doc_id = entity.get('document_id', '')
        
        if doc_id and category and category in data.get('results', {}):
            # Find the detailed entity data
            for result in data['results'][category]:
                if (result.get('document_id') == doc_id and 
                    result.get('entity') == entity_info['entity']):
                    # Add detailed info
                    entity_info['context'] = result.get('context', '')
                    entity_info['is_rare_disease'] = result.get('is_rare_disease', False)
                    
                    # Add top candidate
                    candidates = result.get('orpha_candidates', [])
                    if candidates:
                        top_candidate = candidates[0]
                        entity_info['top_candidate_name'] = top_candidate.get('name', '')
                        entity_info['top_candidate_id'] = top_candidate.get('id', '')
                        entity_info['top_candidate_similarity'] = top_candidate.get('similarity', 0.0)
                    
                    break
        
        detailed_entities.append(entity_info)
    
    return detailed_entities

def main():
    # parser = argparse.ArgumentParser(description="Extract flagged entities for review")
    # parser.add_argument("json_file", help="Path to supervisor output JSON file")
    # parser.add_argument("--output", help="Output CSV file for review (optional)")
    # parser.add_argument("--category", choices=["false_positives", "false_negatives", "true_positives"],
    #                   help="Filter by category (optional)")
    
    # args = parser.parse_args()
    
    # Extract entities
    json_file = "data/results/supervisor/multistage_no_min.json"
    category = "false_positives"
    entities = extract_flagged_entities(json_file)
    print(entities[0])
    # Filter by category if requested
    if category:
        entities = [e for e in entities if e.get('category') == category]
        print(f"Filtered to {len(entities)} {category}")
    
    # Create dataframe
    df = pd.DataFrame(entities)
    
    
if __name__ == "__main__":
    main()

# Create a new set of annotations frome existing ones and human ones.

In [None]:
import json
import copy
from typing import Dict, List, Any, Set, Tuple
from datetime import datetime
import os
import logging

# Configure logging
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(levelname)s - %(message)s',
    datefmt='%Y-%m-%d %H:%M:%S'
)
logger = logging.getLogger(__name__)


class AnnotationCorrector:
    """
    A class to handle the correction and enhancement of clinical annotations,
    including removing inappropriate annotations and adding valid ones.
    """

    def __init__(self, 
                 existing_annotations: Dict[str, Any], 
                 corrections: Dict[str, Any],
                 debug: bool = False):
        """
        Initialize the AnnotationCorrector with existing annotations and corrections.
        
        Args:
            existing_annotations: Original annotations dictionary
            corrections: Corrections dictionary from verification
            debug: Enable debug output
        """
        self.existing_annotations = copy.deepcopy(existing_annotations)
        self.corrections = corrections
        self.debug = debug
        
        # Initialize statistics tracking
        self.stats = {
            'total_corrections': 0,
            'added_annotations': 0,
            'updated_annotations': 0,
            'removed_annotations': 0,
            'skipped_corrections': 0,
            'documents_modified': set()
        }
        
        # Track entities to be removed
        self.entities_to_remove = {}  # {document_id: [entity1, entity2, ...]}

    def _identify_removals(self) -> None:
        """
        Identify annotations that should be removed based on correction data.
        These are entities marked as false negatives but verified not to be rare diseases.
        """
        # Process all corrections to identify entries to remove
        for correction in self.corrections.get('corrected_annotations', []):
            document_id = str(correction.get('document_id', ''))
            entity = correction.get('entity', '').lower() if correction.get('entity') else None
            category = correction.get('category', '')
            is_rare_disease = correction.get('is_rare_disease', False)
            
            # Skip if missing critical information
            if not document_id or not entity:
                continue
                
            # If it's a false negative but NOT a rare disease, it should be removed
            if category == 'false_negatives' and not is_rare_disease:
                if document_id not in self.entities_to_remove:
                    self.entities_to_remove[document_id] = []
                
                self.entities_to_remove[document_id].append(entity)
                
                if self.debug:
                    logger.info(f"Marked for removal: Entity '{entity}' in document {document_id}")

    def _process_updates_and_additions(self) -> None:
        """
        Process corrections to update existing annotations or add new ones.
        """
        # Process each correction
        for correction in self.corrections.get('corrected_annotations', []):
            # Increment total corrections
            self.stats['total_corrections'] += 1
            
            # Extract key information
            document_id = str(correction.get('document_id', ''))
            entity = correction.get('entity', '')
            entity_lower = entity.lower() if entity else ''
            orpha_code = correction.get('orpha_code', '')
            is_rare_disease = correction.get('is_rare_disease', False)
            category = correction.get('category', '')
            
            if self.debug:
                logger.info(f"\nProcessing Correction:")
                logger.info(f"  Document ID: {document_id}")
                logger.info(f"  Entity: {entity}")
                logger.info(f"  ORPHA Code: {orpha_code}")
                logger.info(f"  Category: {category}")
                logger.info(f"  Is Rare Disease: {is_rare_disease}")
            
            # Skip if critical information is missing
            if not document_id or not entity or not orpha_code:
                if self.debug:
                    logger.info("  Skipping: Missing critical information")
                self.stats['skipped_corrections'] += 1
                continue
            
            # Skip if not a confirmed rare disease
            if not is_rare_disease:
                if self.debug:
                    logger.info("  Skipping: Not a confirmed rare disease")
                self.stats['skipped_corrections'] += 1
                continue
            
            # Check if the document exists in existing annotations
            if document_id not in self.existing_annotations:
                if self.debug:
                    logger.info(f"  Document {document_id} not found in annotations")
                self.stats['skipped_corrections'] += 1
                continue
            
            # Ensure annotations list exists
            if 'annotations' not in self.existing_annotations[document_id]:
                self.existing_annotations[document_id]['annotations'] = []
            
            # Try to find and update existing annotation
            annotations = self.existing_annotations[document_id]['annotations']
            found_match = False
            
            for annotation in annotations:
                # Match by mention, case-insensitive
                if annotation.get('mention', '').lower() == entity_lower:
                    # Update the annotation with ORPHA code
                    annotation['ordo_with_desc'] = f"{orpha_code} {entity}"
                    found_match = True
                    self.stats['updated_annotations'] += 1
                    
                    if self.debug:
                        logger.info(f"  Updated existing annotation for {entity}")
                    
                    # Track document modification
                    self.stats['documents_modified'].add(document_id)
                    break
            
            # If no match found, add new annotation
            if not found_match:
                new_annotation = {
                    'mention': entity,
                    'umls_with_desc': '',
                    'ordo_with_desc': f"{orpha_code} {entity}",
                    'gold_text_to_umls_label': 0,
                    'gold_text_to_ordo_label': 1,
                    'document_structure': '',
                    'semehr_label': 0,
                    'correction_source': 'iterative_verification'
                }
                
                annotations.append(new_annotation)
                self.stats['added_annotations'] += 1
                
                if self.debug:
                    logger.info(f"  Added new annotation for {entity}")
                
                # Track document modification
                self.stats['documents_modified'].add(document_id)

    def _remove_false_annotations(self) -> None:
        """
        Remove annotations that were identified as not being true rare diseases.
        """
        for document_id, entities in self.entities_to_remove.items():
            if document_id not in self.existing_annotations:
                continue
                
            if 'annotations' not in self.existing_annotations[document_id]:
                continue
                
            # Get the annotations list
            annotations = self.existing_annotations[document_id]['annotations']
            
            # Create a new list excluding the entities to remove
            new_annotations = []
            for annotation in annotations:
                mention = annotation.get('mention', '').lower()
                
                if mention in entities:
                    # This annotation should be removed
                    self.stats['removed_annotations'] += 1
                    
                    if self.debug:
                        logger.info(f"  Removed annotation for '{mention}' in document {document_id}")
                else:
                    # Keep this annotation
                    new_annotations.append(annotation)
            
            # Update the annotations list
            if len(annotations) != len(new_annotations):
                self.existing_annotations[document_id]['annotations'] = new_annotations
                self.stats['documents_modified'].add(document_id)

    def process_corrections(self) -> Dict[str, Any]:
        """
        Process the corrections to update the annotations.
        
        Returns:
            Dict: Updated annotations with statistics
        """
        # Step 1: Identify annotations to be removed
        self._identify_removals()
        
        # Step 2: Process updates and additions
        self._process_updates_and_additions()
        
        # Step 3: Remove false annotations
        self._remove_false_annotations()
        
        # Convert document set to list for JSON serialization
        self.stats['documents_modified'] = list(self.stats['documents_modified'])
        
        # Prepare output metadata
        output_metadata = {
            'update_timestamp': datetime.now().strftime('%Y-%m-%d %H:%M:%S'),
            'stats': self.stats
        }
        
        # Create final output dictionary
        final_output = {
            'metadata': output_metadata,
            'documents': self.existing_annotations
        }
        
        # Print summary
        if self.debug:
            logger.info("\n=== Correction Processing Summary ===")
            logger.info(f"Total corrections processed: {self.stats['total_corrections']}")
            logger.info(f"Annotations added: {self.stats['added_annotations']}")
            logger.info(f"Annotations updated: {self.stats['updated_annotations']}")
            logger.info(f"Annotations removed: {self.stats['removed_annotations']}")
            logger.info(f"Corrections skipped: {self.stats['skipped_corrections']}")
            logger.info(f"Documents modified: {len(self.stats['documents_modified'])}")
        
        return final_output


def detailed_correction_diagnostic(
    existing_annotations: Dict[str, Any], 
    corrections: Dict[str, Any]
) -> None:
    """
    Provide detailed diagnostic information about the correction process.
    
    Args:
        existing_annotations: Original annotations dictionary
        corrections: Corrections dictionary
    """
    logger.info("\n===== CORRECTION DIAGNOSTIC =====")
    
    # Print basic information about corrections
    correction_list = corrections.get('corrected_annotations', [])
    logger.info(f"Total corrections found: {len(correction_list)}")
    
    # Categorize corrections
    categories = {
        'false_negatives': [],
        'false_positives': [],
        'true_positives': []
    }
    
    for correction in correction_list:
        category = correction.get('category', 'unknown')
        if category in categories:
            categories[category].append(correction)
        
    logger.info(f"False negatives to review: {len(categories['false_negatives'])}")
    logger.info(f"False positives to review: {len(categories['false_positives'])}")
    logger.info(f"True positives to review: {len(categories['true_positives'])}")
    
    # Count entities to be removed
    removals = []
    for correction in correction_list:
        category = correction.get('category', '')
        is_rare_disease = correction.get('is_rare_disease', False)
        
        if category == 'false_negatives' and not is_rare_disease:
            removals.append(correction)
    
    logger.info(f"Entities identified for removal: {len(removals)}")
    
    # Print a sample of entities to be removed if any
    if removals:
        logger.info("\nSample entities to be removed:")
        for i, removal in enumerate(removals[:5]):
            logger.info(f"  {i+1}. Entity: '{removal.get('entity', '')}' - Document: {removal.get('document_id', '')}")
            logger.info(f"     Explanation: {removal.get('explanation', 'No explanation')}")
    
    # Print detailed document statistics
    doc_counts = {}
    for correction in correction_list:
        doc_id = correction.get('document_id', '')
        if doc_id not in doc_counts:
            doc_counts[doc_id] = 0
        doc_counts[doc_id] += 1
    
    logger.info(f"\nDocuments with corrections: {len(doc_counts)}")
    most_corrections = sorted(doc_counts.items(), key=lambda x: x[1], reverse=True)[:5]
    
    if most_corrections:
        logger.info("Documents with most corrections:")
        for doc_id, count in most_corrections:
            logger.info(f"  Document {doc_id}: {count} corrections")


def update_annotations_with_corrections(
    existing_annotations: Dict[str, Any], 
    corrections: Dict[str, Any], 
    output_file: str = None,
    debug: bool = False
) -> Dict[str, Any]:
    """
    Update existing annotations with corrections from the review process,
    including removing false annotations that aren't actually rare diseases.
    
    Args:
        existing_annotations: Original annotations dictionary
        corrections: Corrections from the review process
        output_file: Path to save the updated annotations
        debug: Enable debug output
    
    Returns:
        Dict: Updated annotations dictionary with enriched rare disease annotations
    """
    # Print diagnostic information if debug is on
    if debug:
        detailed_correction_diagnostic(existing_annotations, corrections)
    
    # Validate input
    if not isinstance(existing_annotations, dict):
        raise ValueError("Existing annotations must be a dictionary")
    
    if not isinstance(corrections, dict) or 'corrected_annotations' not in corrections:
        raise ValueError("Corrections must be a dictionary with 'corrected_annotations' key")
    
    # Process corrections using the AnnotationCorrector
    corrector = AnnotationCorrector(existing_annotations, corrections, debug)
    result = corrector.process_corrections()
    
    # Save to output file if specified
    if output_file:
        output_dir = os.path.dirname(os.path.abspath(output_file))
        os.makedirs(output_dir, exist_ok=True)
        
        with open(output_file, 'w') as f:
            json.dump(result, f, indent=2)
        
        if debug:
            logger.info(f"Updated annotations saved to: {output_file}")
    
    return result


def process_annotations_correction(
    existing_annotations_path: str, 
    corrections_path: str, 
    output_path: str = None,
    debug: bool = True
) -> Dict[str, Any]:
    """
    Convenience function to load and process annotations corrections.
    
    Args:
        existing_annotations_path: Path to existing annotations JSON file
        corrections_path: Path to corrections JSON file
        output_path: Path to save updated annotations
        debug: Enable debug output
    
    Returns:
        Dict: Updated annotations dictionary
    """
    # Load existing annotations
    try:
        with open(existing_annotations_path, 'r') as f:
            existing_annotations = json.load(f)
            
        if debug:
            logger.info(f"Loaded existing annotations from {existing_annotations_path}")
            
            # Check if we need to extract the documents key
            if 'documents' in existing_annotations and isinstance(existing_annotations['documents'], dict):
                existing_annotations = existing_annotations['documents']
                logger.info("Extracted 'documents' key from annotations")
    except Exception as e:
        logger.error(f"Error loading existing annotations: {e}")
        raise
    
    # Load corrections
    try:
        with open(corrections_path, 'r') as f:
            corrections = json.load(f)
            
        if debug:
            logger.info(f"Loaded corrections from {corrections_path}")
            
            # Check if we need to extract the results key
            if 'results' in corrections:
                # Use the flattened corrected_annotations list if available
                if 'corrected_annotations' not in corrections:
                    # Extract from all categories and build a unified list
                    all_corrections = []
                    for category in ['false_negatives', 'false_positives', 'true_positives']:
                        if category in corrections['results']:
                            all_corrections.extend(corrections['results'][category])
                    
                    corrections = {'corrected_annotations': all_corrections}
                    logger.info(f"Extracted and merged corrections from categories: {len(all_corrections)} total")
            elif 'summary' in corrections and 'flagged_entities' in corrections['summary']:
                # Handle supervisor output format
                corrections = {'corrected_annotations': corrections['summary']['flagged_entities']}
                logger.info(f"Extracted flagged entities as corrections: {len(corrections['corrected_annotations'])} total")
    except Exception as e:
        logger.error(f"Error loading corrections: {e}")
        raise
    
    # Update and save annotations
    return update_annotations_with_corrections(
        existing_annotations, 
        corrections, 
        output_path, 
        debug=debug
    )


def print_annotations_differences(
    original_path: str, 
    corrected_path: str
) -> None:
    """
    Print differences between original and corrected annotations.
    
    Args:
        original_path: Path to original annotations file
        corrected_path: Path to corrected annotations file
    """
    # Load original and corrected annotations
    with open(original_path, 'r') as f:
        original_data = json.load(f)
        # Extract documents if needed
        original_annos = original_data.get('documents', original_data)
    
    with open(corrected_path, 'r') as f:
        corrected_data = json.load(f)
        # Extract documents if needed
        corrected_annos = corrected_data.get('documents', corrected_data)
    
    logger.info("\n===== ANNOTATION DIFFERENCES =====")
    
    # Track overall changes
    modifications = {
        'added': 0,
        'updated': 0,
        'removed': 0,
        'unchanged': 0
    }
    
    # Compare document IDs with annotations
    for doc_id, doc_data in corrected_annos.items():
        # Check if document exists in original
        orig_doc = original_annos.get(doc_id, {})
        
        # Get annotations from both
        corr_anns = doc_data.get('annotations', [])
        orig_anns = orig_doc.get('annotations', [])
        
        if not corr_anns and not orig_anns:
            continue
            
        # Track mentions for easy comparison
        orig_mentions = {ann.get('mention', '').lower(): ann for ann in orig_anns}
        corr_mentions = {ann.get('mention', '').lower(): ann for ann in corr_anns}
        
        # Find new, changed, and removed entries
        all_mentions = set(orig_mentions.keys()) | set(corr_mentions.keys())
        document_changes = {
            'added': [],
            'updated': [],
            'removed': [],
            'unchanged': 0
        }
        
        for mention in all_mentions:
            orig_ann = orig_mentions.get(mention)
            corr_ann = corr_mentions.get(mention)
            
            if not orig_ann and corr_ann:
                # Added annotation
                document_changes['added'].append(corr_ann)
                modifications['added'] += 1
            elif orig_ann and not corr_ann:
                # Removed annotation
                document_changes['removed'].append(orig_ann)
                modifications['removed'] += 1
            elif orig_ann and corr_ann:
                # Check for changes
                if orig_ann.get('ordo_with_desc') != corr_ann.get('ordo_with_desc'):
                    # Updated annotation
                    document_changes['updated'].append((orig_ann, corr_ann))
                    modifications['updated'] += 1
                else:
                    # Unchanged annotation
                    document_changes['unchanged'] += 1
                    modifications['unchanged'] += 1
        
        # Print document summary if changes exist
        changes_count = len(document_changes['added']) + len(document_changes['updated']) + len(document_changes['removed'])
        if changes_count > 0:
            logger.info(f"\nDocument {doc_id} changes:")
            logger.info(f"  Added: {len(document_changes['added'])}")
            logger.info(f"  Updated: {len(document_changes['updated'])}")
            logger.info(f"  Removed: {len(document_changes['removed'])}")
            logger.info(f"  Unchanged: {document_changes['unchanged']}")
            
            # Print details of changed annotations
            if document_changes['added']:
                logger.info("\n  Added Annotations:")
                for i, ann in enumerate(document_changes['added'][:3], 1):  # Limit to first 3
                    logger.info(f"    {i}. '{ann.get('mention', '')}' - {ann.get('ordo_with_desc', '')}")
                if len(document_changes['added']) > 3:
                    logger.info(f"    ... and {len(document_changes['added']) - 3} more")
                    
            if document_changes['removed']:
                logger.info("\n  Removed Annotations:")
                for i, ann in enumerate(document_changes['removed'][:3], 1):  # Limit to first 3
                    logger.info(f"    {i}. '{ann.get('mention', '')}' - {ann.get('ordo_with_desc', '')}")
                if len(document_changes['removed']) > 3:
                    logger.info(f"    ... and {len(document_changes['removed']) - 3} more")
                    
            if document_changes['updated']:
                logger.info("\n  Updated Annotations:")
                for i, (orig, corr) in enumerate(document_changes['updated'][:3], 1):  # Limit to first 3
                    logger.info(f"    {i}. '{orig.get('mention', '')}'")
                    logger.info(f"       Old: {orig.get('ordo_with_desc', '')}")
                    logger.info(f"       New: {corr.get('ordo_with_desc', '')}")
                if len(document_changes['updated']) > 3:
                    logger.info(f"    ... and {len(document_changes['updated']) - 3} more")
    
    # Print overall summary
    logger.info("\n=== Overall Changes Summary ===")
    logger.info(f"Total annotations added: {modifications['added']}")
    logger.info(f"Total annotations updated: {modifications['updated']}")
    logger.info(f"Total annotations removed: {modifications['removed']}")
    logger.info(f"Total annotations unchanged: {modifications['unchanged']}")
    logger.info(f"Total annotations affected: {modifications['added'] + modifications['updated'] + modifications['removed']}")


# Example usage
if __name__ == "__main__":
    # Example paths - update with your actual paths
    EXISTING_ANNOTATIONS_PATH = "data/dataset/filtered_rd_annos_updated_adam.json"
    CORRECTIONS_PATH = "data/dataset/adam_corrections_v2.json"
    OUTPUT_PATH = "data/dataset/rd_annos_adam_corrected_v1.json" # technically its v1, but let's just leave it at that.
    
    # Process the corrections
    updated_annotations = process_annotations_correction(
        EXISTING_ANNOTATIONS_PATH,
        CORRECTIONS_PATH,
        OUTPUT_PATH,
        debug=True
    )
    
    # Print differences between original and updated annotations
    print_annotations_differences(EXISTING_ANNOTATIONS_PATH, OUTPUT_PATH)
# Example usage (commented out)
# updated_annotations = process_annotations_correction(
#     "data/dataset/filtered_rd_annos_updated_adam.json", 
#     "data/dataset/adam_corrections_v2.json", 
#     "data/dataset/rd_annos_adam_corrected_v1.json",
#     debug=True
# )

# print_annotations_differences(EXISTING_ANNOTATIONS_PATH, OUTPUT_PATH)