*Framework for evaluating AI-generated SOAP notes via NER validation and LLM-based Detect missing information, hallucinations, and clinical accuracy issues.*

In [2]:
!pip install groq

Collecting groq
  Downloading groq-0.33.0-py3-none-any.whl.metadata (16 kB)
Downloading groq-0.33.0-py3-none-any.whl (135 kB)
[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/135.8 kB[0m [31m?[0m eta [36m-:--:--[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m135.8/135.8 kB[0m [31m13.7 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: groq
Successfully installed groq-0.33.0


In [8]:
# from datasets import load_dataset
# import pandas as pd

# # Load the dataset
# dataset = load_dataset("adesouza1/soap_notes")
# dataset2 = load_dataset("omi-health/medical-dialogue-to-soap-summary")

In [9]:
# from google.colab import userdata

# # SeepScribe_OPENAI = userdata.get('SeepScribe_OPENAI')
# Groq_DeepScribe = userdata.get('Groq_DeepScribe')
# HugFace_DeepScribe = userdata.get('HugFace_DeepScribe')

In [13]:
dataset['train']

Dataset({
    features: ['age', 'patient_name', 'doctor_data', 'gender', 'dob', 'phone', 'person_data', 'health_problem', 'patient_convo', 'soap_notes', 'doctor_name', 'address', 'full_patient_data'],
    num_rows: 558
})

### Both Combined

In [14]:
### new with better task 2 per note metrics

import pandas as pd
import numpy as np
import torch
import logging
import time
import json
import csv
from concurrent.futures import ThreadPoolExecutor, as_completed
from datasets import load_dataset
from transformers import AutoTokenizer, AutoModel, pipeline
from sklearn.metrics.pairwise import cosine_similarity
from typing import Dict, List, Tuple, Any
from openai import OpenAI
from groq import Groq
import warnings
from google.colab import userdata
warnings.filterwarnings('ignore')

Groq_DeepScribe = userdata.get('Groq_DeepScribe')
HugFace_DeepScribe = userdata.get('HugFace_DeepScribe')
# ================================================================================
# CENTRALIZED CONFIGURATION - All settings for Task 1 & Task 2
# ================================================================================

# Dataset & Processing
NUM_SAMPLES = 10  # Number of transcript-SOAP note pairs to evaluate

# Task 1: NER Entity Validation
NER_MODEL = "Helios9/BioMed_NER"
EMBEDDING_MODEL = "emilyalsentzer/Bio_ClinicalBERT"
CONFIDENCE_THRESHOLD = 0.5
SIMILARITY_THRESHOLD = 0.7
MAX_TEXT_LENGTH = 2000

# Task 1 Outputs
NER_SUMMARY_OUTPUT = "ner_evaluation_summary.csv"
NER_DETAILS_OUTPUT = "ner_entity_matches.csv"
TASK1_LOG_FILE = "task1_ner_evaluation.log"

# Task 2: LYNX Hallucination Detection
LYNX_MODEL = "PatronusAI/Llama-3-Patronus-Lynx-8B-Instruct:featherless-ai"
GROQ_MODEL = "openai/gpt-oss-20b"
MAX_CONCURRENT = 20
RETRY_MAX = 3
RETRY_BACKOFF_BASE = 1.0

# Task 2 Outputs
HALLUCINATION_RESULTS_CSV = "lynx_hallucination_results.csv"
TASK2_LOG_FILE = "task2_lynx_evaluation.log"

# API Keys - SET THESE BEFORE RUNNING
HugFace_DeepScribe = HugFace_DeepScribe
Groq_DeepScribe = Groq_DeepScribe

# Entity type weights for criticality scoring (Task 1)
ENTITY_WEIGHTS = {
    'MEDICATION': 1.0, 'DRUG': 1.0,
    'DIAGNOSIS': 0.9, 'DISEASE': 0.9, 'DISORDER': 0.9,
    'PROCEDURE': 0.7, 'TEST': 0.7, 'TREATMENT': 0.7,
    'SYMPTOM': 0.5, 'SIGN': 0.5,
    'ANATOMY': 0.3, 'OTHER': 0.3
}

# ================================================================================
# LOGGING SETUP
# ================================================================================

def setup_logging(log_file: str, task_name: str):
    """Configure logging for a specific task"""
    logger = logging.getLogger(task_name)
    logger.setLevel(logging.DEBUG)
    logger.handlers = []  # Clear existing handlers

    file_handler = logging.FileHandler(log_file, mode='w', encoding='utf-8')
    file_handler.setLevel(logging.DEBUG)
    file_formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')
    file_handler.setFormatter(file_formatter)
    logger.addHandler(file_handler)

    logger.info("="*80)
    logger.info(f"{task_name} - Logging Started")
    logger.info("="*80)

    return logger

# ================================================================================
# TASK 1: NER ENTITY VALIDATION
# ================================================================================

class MedicalEntityEvaluator:
    """NER-based evaluator for detecting missing entities in AI-generated SOAP notes"""

    def __init__(self, logger):
        self.logger = logger
        self.logger.info("Initializing Medical Entity Evaluator...")

        # Load NER Model
        self.logger.info(f"Loading NER model: {NER_MODEL}")
        self.ner_pipeline = pipeline(
            "token-classification",
            model=NER_MODEL,
            aggregation_strategy="simple"
        )
        self.logger.info("NER model loaded successfully")

        # Load Embedding Model
        self.logger.info(f"Loading embedding model: {EMBEDDING_MODEL}")
        self.tokenizer = AutoTokenizer.from_pretrained(EMBEDDING_MODEL)
        self.embedding_model = AutoModel.from_pretrained(EMBEDDING_MODEL)

        if torch.cuda.is_available():
            self.embedding_model = self.embedding_model.cuda()
            self.logger.info("Using GPU acceleration")
        else:
            self.logger.info("Using CPU")

        self.entity_weights = ENTITY_WEIGHTS
        self.logger.info("Evaluator initialized successfully")

    def extract_entities(self, text: str, source_type: str = "text") -> List[Dict]:
        """Extract medical entities using NER model"""
        self.logger.debug(f"Extracting entities from {source_type} (length: {len(text)} chars)")

        original_length = len(text)
        if original_length > MAX_TEXT_LENGTH:
            text = text[:MAX_TEXT_LENGTH]
            self.logger.debug(f"Text truncated from {original_length} to {MAX_TEXT_LENGTH} chars")

        entities = []
        try:
            ner_results = self.ner_pipeline(text)

            for entity in ner_results:
                confidence = entity.get('score', 0.0)
                if confidence >= CONFIDENCE_THRESHOLD:
                    entity_dict = {
                        'text': entity.get('word', '').strip(),
                        'type': entity.get('entity_group', 'UNKNOWN').upper(),
                        'confidence': round(confidence, 3),
                        'start': entity.get('start', 0),
                        'end': entity.get('end', 0)
                    }
                    entity_dict['weight'] = self.entity_weights.get(entity_dict['type'], 0.3)
                    entities.append(entity_dict)

            self.logger.debug(f"Found {len(entities)} entities above confidence threshold")
        except Exception as e:
            self.logger.error(f"Error extracting entities: {e}")

        # Remove duplicates
        seen = set()
        unique_entities = []
        for ent in entities:
            key = (ent['text'].lower(), ent['type'])
            if key not in seen and len(ent['text']) > 1:
                seen.add(key)
                unique_entities.append(ent)

        if len(unique_entities) != len(entities):
            self.logger.debug(f"Deduplicated to {len(unique_entities)} unique entities")

        return unique_entities

    def get_embedding(self, text: str) -> np.ndarray:
        """Get BioClinicalBERT embedding for text"""
        inputs = self.tokenizer(text, return_tensors='pt', truncation=True, max_length=128, padding=True)
        if torch.cuda.is_available():
            inputs = {k: v.cuda() for k, v in inputs.items()}
        with torch.no_grad():
            outputs = self.embedding_model(**inputs)
            embedding = outputs.last_hidden_state.mean(dim=1).cpu().numpy()
        return embedding

    def match_entities_semantically(self, transcript_entities: List[Dict],
                                   soap_entities: List[Dict]) -> Tuple[Dict, List[Dict], pd.DataFrame]:
        """Match entities using semantic similarity"""
        self.logger.debug(f"Matching {len(transcript_entities)} transcript with {len(soap_entities)} SOAP entities")

        matches = {}
        unmatched = []
        match_details = []

        if not transcript_entities or not soap_entities:
            return matches, transcript_entities if transcript_entities else [], pd.DataFrame()

        # Compute embeddings
        transcript_embeddings = np.vstack([self.get_embedding(ent['text']) for ent in transcript_entities])
        soap_embeddings = np.vstack([self.get_embedding(ent['text']) for ent in soap_entities])

        similarity_matrix = cosine_similarity(soap_embeddings, transcript_embeddings)
        matched_transcript_indices = set()

        for soap_idx, soap_ent in enumerate(soap_entities):
            similarities = similarity_matrix[soap_idx]
            best_match_idx = np.argmax(similarities)
            best_similarity = similarities[best_match_idx]

            if best_similarity >= SIMILARITY_THRESHOLD and best_match_idx not in matched_transcript_indices:
                transcript_ent = transcript_entities[best_match_idx]
                matched_transcript_indices.add(best_match_idx)

                matches[soap_ent['text']] = {
                    'matched_entity': transcript_ent['text'],
                    'similarity': round(best_similarity, 3),
                    'soap_type': soap_ent['type'],
                    'transcript_type': transcript_ent['type']
                }

                match_details.append({
                    'soap_entity': soap_ent['text'],
                    'soap_type': soap_ent['type'],
                    'transcript_entity': transcript_ent['text'],
                    'transcript_type': transcript_ent['type'],
                    'similarity': round(best_similarity, 3),
                    'status': 'Matched'
                })

        for idx, ent in enumerate(transcript_entities):
            if idx not in matched_transcript_indices:
                unmatched.append(ent)
                match_details.append({
                    'soap_entity': '', 'soap_type': '',
                    'transcript_entity': ent['text'],
                    'transcript_type': ent['type'],
                    'similarity': 0.0,
                    'status': 'Missing in SOAP'
                })

        self.logger.debug(f"Matched: {len(matches)}, Unmatched: {len(unmatched)}")
        match_df = pd.DataFrame(match_details) if match_details else pd.DataFrame()
        return matches, unmatched, match_df

    def compute_metrics(self, transcript_entities: List[Dict], soap_entities: List[Dict],
                       matches: Dict, unmatched: List[Dict]) -> Dict:
        """Compute evaluation metrics"""
        total_transcript = len(transcript_entities)
        total_matched = len(matches)

        coverage_score = (total_matched / total_transcript * 100) if total_transcript > 0 else 0.0

        if transcript_entities:
            total_weight = sum(e['weight'] for e in transcript_entities)
            matched_weight = sum(e['weight'] for e in transcript_entities
                               if e['text'] in [m['matched_entity'] for m in matches.values()])
            criticality_score = (matched_weight / total_weight * 100) if total_weight > 0 else 0.0
        else:
            criticality_score = 0.0

        missing_breakdown = {
            'critical': sum(1 for e in unmatched if e['weight'] >= 0.9),
            'moderate': sum(1 for e in unmatched if 0.5 <= e['weight'] < 0.9),
            'low': sum(1 for e in unmatched if e['weight'] < 0.5)
        }

        all_confidences = [e['confidence'] for e in transcript_entities + soap_entities]
        extraction_confidence = np.mean(all_confidences) if all_confidences else 0.0

        return {
            'coverage_score': round(coverage_score, 1),
            'criticality_score': round(criticality_score, 1),
            'extraction_confidence': round(extraction_confidence, 3),
            'total_transcript_entities': total_transcript,
            'total_soap_entities': len(soap_entities),
            'matched_entities': total_matched,
            'unmatched_entities': len(unmatched),
            'missing_breakdown': missing_breakdown
        }

    def evaluate_single_pair(self, transcript: str, soap_note: str) -> Dict:
        """Evaluate a single transcript-SOAP pair"""
        start_time = time.time()

        transcript_entities = self.extract_entities(transcript, "transcript")
        soap_entities = self.extract_entities(soap_note, "SOAP note")
        matches, unmatched, match_details = self.match_entities_semantically(transcript_entities, soap_entities)
        metrics = self.compute_metrics(transcript_entities, soap_entities, matches, unmatched)

        return {
            'transcript_entities': transcript_entities,
            'soap_entities': soap_entities,
            'matches': matches,
            'unmatched': unmatched,
            'match_details': match_details,
            'metrics': metrics,
            'processing_time': round(time.time() - start_time, 2)
        }

# ================================================================================
# TASK 2: LYNX HALLUCINATION DETECTION
# ================================================================================

# Prompt templates
SYSTEM_PROMPT = (
    "You are a clinical-reasoning assistant. Your job is:\n"
    "1. Read a SOAP note (S, O, A, P sections).\n"
    "2. Identify atomic clinical claims from that SOAP note; each claim corresponds to one fact (for example: a medication with dose/frequency, a vital with units, an allergy, a symptom present/absent, laterality of a finding, temporal timing, a procedure, social or family history, a plan item).\n"
    "3. For each claim, generate hallucination-oriented probes to check whether the claim is supported by a transcript. For each claim, create:\n"
    "   - q_pos: a yes/no QUESTION that captures the COMPLETE claim with ALL details.\n"
    "      - If the claim mentions multiple symptoms, body parts, medications, or attributes, include ALL of them in the question.\n"
    "      - Example: For claim 'Patient reports pain in right knee and left ankle'\n"
    "        - question should be 'Does patient report pain in right knee AND left ankle?' not just 'Does patient report pain in right knee?'\n"
    "   - a_pos: a short declarative ANSWER text that restates the claim (for example: 'Patient denies cough.').\n"
    "   - q_neg: the logically negated yes/no QUESTION (for example: 'Does the patient report cough?').\n"
    "   - a_neg: the short declarative ANSWER for the negation (for example: 'Patient reports cough.').\n"
    "4. Only generate probes for categories that actually appear in the SOAP note. For example, if the SOAP note has a medication with dose/frequency, generate the claims + probes for medication. If the SOAP note has no family history, skip family history.\n"
    "5. Generate a total of 10 claims, each with one hallucination probe pair (i.e., 5 q_pos/a_pos and 5 q_neg/a_neg).\n"
    "6. Output must be in strict JSON format with the structure:\n"
    "{\n"
    "  \"claims\": [\n"
    "    {\n"
    "      \"claim\": <string>,\n"
    "      \"section\": <\"S\"|\"O\"|\"A\"|\"P\">,\n"
    "      \"category\": <one of medications, laterality, negation, allergies, diagnoses, vitals, temporal, symptoms, procedures, social_history, family_history, plans>,\n"
    "      \"hallucination_probes\": [\n"
    "        {\"q_pos\": <string>, \"a_pos\": <string>, \"q_neg\": <string>, \"a_neg\": <string>}\n"
    "      ]\n"
    "    }\n"
    "  ]\n"
    "}\n"
)

USER_PROMPT_TEMPLATE = "SOAP NOTE:\n{soap}\n\nGenerate claims and their associated hallucination-oriented probes as described above."

def build_prompt(question: str, answer: str, document: str) -> str:
    """Build LYNX evaluation prompt"""
    return f"""Given the following QUESTION, DOCUMENT and ANSWER you must analyze the provided answer and determine whether it is faithful to the contents of the DOCUMENT.
--
QUESTION (THIS DOES NOT COUNT AS BACKGROUND INFORMATION):
{question}

--
DOCUMENT:
{document}

--
ANSWER:
{answer}

--
Your output must be EXACTLY in this JSON format:
{{"REASONING": ["point 1", "point 2"], "SCORE": "PASS"}}
OR
{{"REASONING": ["point 1", "point 2"], "SCORE": "FAIL"}}

CRITICAL: SCORE value MUST be in quotes: "PASS" or "FAIL".
Respond with JSON only:
"""

def clean_json_with_llm(messy_content: str, groq_client: Groq) -> Dict[str, Any]:
    """Use Groq to clean messy JSON from LYNX"""
    prompt = f"""Convert this messy JSON-like text into valid JSON.
Extract only the REASONING (as array of strings) and SCORE (as string "PASS" or "FAIL").

Messy input:
{messy_content}

Return ONLY valid JSON in this exact format:
{{"REASONING": ["point 1", "point 2"], "SCORE": "PASS"}}

Valid JSON:"""

    resp = groq_client.chat.completions.create(
        model=GROQ_MODEL,
        messages=[{"role": "user", "content": prompt}],
        temperature=0,
        max_completion_tokens=10000,
        response_format={"type": "json_object"}
    )

    clean_json = resp.choices[0].message.content.strip()
    return json.loads(clean_json)

def call_lynx(question: str, answer: str, document: str,
              lynx_client: OpenAI, groq_client: Groq, logger) -> Dict[str, Any]:
    """Call LYNX model to evaluate a claim"""
    prompt = build_prompt(question, answer, document)
    attempt = 0

    while attempt < RETRY_MAX:
        try:
            resp = lynx_client.chat.completions.create(
                model=LYNX_MODEL,
                messages=[{"role": "user", "content": prompt}],
                temperature=0.0,
                max_tokens=4000
            )

            messy_content = resp.choices[0].message.content.strip()
            result = clean_json_with_llm(messy_content, groq_client)

            if "SCORE" not in result:
                raise ValueError("Missing SCORE")

            return {"question": question, "answer": answer, "result": result}

        except Exception as e:
            logger.warning(f"LYNX call error (attempt {attempt+1}/{RETRY_MAX}): {str(e)}")
            attempt += 1
            if attempt >= RETRY_MAX:
                logger.error(f"LYNX call failed after {RETRY_MAX} attempts")
                return {
                    "question": question,
                    "answer": answer,
                    "result": {"REASONING": [f"ERROR: {str(e)}"], "SCORE": "FAIL"}
                }
            time.sleep(RETRY_BACKOFF_BASE * (2 ** (attempt-1)))

def evaluate_probes_for_document(probes_json: Dict[str, Any], document: str,
                                lynx_client: OpenAI, groq_client: Groq, logger) -> List[Dict[str, Any]]:
    """Evaluate all probes for a document using LYNX"""
    results: List[Dict[str, Any]] = []
    tasks = []

    logger.info(f"Evaluating {len(probes_json['claims'])} claims with LYNX")

    with ThreadPoolExecutor(max_workers=MAX_CONCURRENT) as executor:
        for claim_obj in probes_json["claims"]:
            for probe in claim_obj["hallucination_probes"]:
                tasks.append((claim_obj, "pos", probe["q_pos"], probe["a_pos"]))
                tasks.append((claim_obj, "neg", probe["q_neg"], probe["a_neg"]))

        futures = {executor.submit(call_lynx, q, a, document, lynx_client, groq_client, logger):
                  (claim_obj, tag, q, a) for (claim_obj, tag, q, a) in tasks}

        for fut in as_completed(futures):
            claim_obj, tag, q, a = futures[fut]
            resp = fut.result()
            results.append({
                "claim": claim_obj["claim"],
                "section": claim_obj["section"],
                "category": claim_obj["category"],
                "tag": tag,
                "question": q,
                "answer": a,
                "score": resp["result"].get("SCORE", "FAIL"),
                "reasoning": resp["result"].get("REASONING", [])
            })

    logger.info(f"Completed {len(results)} evaluations")
    return results

def bucket_claims(evals: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
    """Bucket claim evaluations into final statuses"""
    buckets: List[Dict[str, Any]] = []
    temp: Dict[Any, Any] = {}

    for rec in evals:
        key = (rec["claim"], rec["section"], rec["category"])
        if key not in temp:
            temp[key] = {"pos": None, "neg": None, "claim": rec["claim"],
                        "section": rec["section"], "category": rec["category"]}
        temp[key][rec["tag"]] = rec

    for key, pair in temp.items():
        p_pos = pair.get("pos", {"score": "FAIL"})["score"]
        p_neg = pair.get("neg", {"score": "FAIL"})["score"]

        if p_pos == "PASS" and p_neg == "FAIL":
            status = "Supported"
        elif p_pos == "FAIL" and p_neg == "PASS":
            status = "Hallucination"
        elif p_pos == "PASS" and p_neg == "PASS":
            status = "Ambiguous"
        else:
            status = "Unclear"

        buckets.append({
            "claim": pair["claim"],
            "section": pair["section"],
            "category": pair["category"],
            "score_pos": p_pos,
            "score_neg": p_neg,
            "reasoning_pos": "\n".join(pair.get("pos",{}).get("reasoning", [])),
            "reasoning_neg": "\n".join(pair.get("neg",{}).get("reasoning", [])),
            "status": status
        })

    return buckets

def compute_hallucination_metrics(bucketed: List[Dict[str, Any]]) -> Dict[str, Any]:
    """Compute hallucination detection metrics"""
    T = len(bucketed)
    S = sum(1 for b in bucketed if b["status"] == "Supported")
    H = sum(1 for b in bucketed if b["status"] == "Hallucination")
    A = sum(1 for b in bucketed if b["status"] == "Ambiguous")
    U = sum(1 for b in bucketed if b["status"] == "Unclear")

    return {
        "total_claims": T,
        "supported": S,
        "hallucination": H,
        "ambiguous": A,
        "unclear": U,
        "hallucination_rate": (H / T * 100) if T else 0.0,
        "accuracy_rate": (S / T * 100) if T else 0.0,
        "ambiguity_rate": ((A + U) / T * 100) if T else 0.0,
        "overall_clarity": ((S / T * 100) if T else 0.0) - (((A + U) / T * 100) if T else 0.0) * 0.5
    }

def extract_probes(soap_note: str, groq_client: Groq, logger) -> Dict[str, Any]:
    """Extract hallucination probe claims from SOAP note"""
    prompt_user = USER_PROMPT_TEMPLATE.format(soap=soap_note)

    resp = groq_client.chat.completions.create(
        model=GROQ_MODEL,
        messages=[
            {"role": "system", "content": SYSTEM_PROMPT},
            {"role": "user", "content": prompt_user}
        ],
        temperature=0.2,
        top_p=1.0,
        max_completion_tokens=10000,
        stream=False,
        response_format={"type": "json_object"}
    )

    raw = resp.choices[0].message.content
    try:
        payload = json.loads(raw)
    except json.JSONDecodeError as e:
        logger.error(f"Failed to parse JSON from Groq: {e}")
        logger.error(f"Raw response: {raw[:500]}...")
        raise ValueError(f"Unable to parse JSON:\n{raw}\nError: {e}")

    if "claims" not in payload or not isinstance(payload["claims"], list):
        logger.error(f"Invalid JSON structure. Got keys: {list(payload.keys())}")
        raise ValueError(f"Output JSON missing top-level 'claims' list. Got: {list(payload.keys())}")

    # Validate each claim has required fields
    for idx, claim_obj in enumerate(payload["claims"]):
        for key in ["claim", "section", "category", "hallucination_probes"]:
            if key not in claim_obj:
                logger.error(f"Claim {idx} missing key '{key}'. Claim keys: {list(claim_obj.keys())}")
                logger.error(f"Full claim object: {json.dumps(claim_obj, indent=2)}")
                raise ValueError(f"Claim object {idx} missing required key: {key}")

        # Validate hallucination_probes structure
        if not isinstance(claim_obj["hallucination_probes"], list) or len(claim_obj["hallucination_probes"]) == 0:
            logger.error(f"Claim {idx} has invalid hallucination_probes structure")
            raise ValueError(f"Claim {idx} hallucination_probes must be a non-empty list")

        for probe_idx, probe in enumerate(claim_obj["hallucination_probes"]):
            for probe_key in ["q_pos", "a_pos", "q_neg", "a_neg"]:
                if probe_key not in probe:
                    logger.error(f"Claim {idx}, probe {probe_idx} missing key '{probe_key}'")
                    raise ValueError(f"Claim {idx}, probe {probe_idx} missing required key: {probe_key}")

    logger.info(f"Generated {len(payload['claims'])} claims from SOAP note")
    return payload

def write_hallucination_csv(bucketed: List[Dict[str, Any]], path: str):
    """Write hallucination results to CSV"""
    keys = ["claim", "section", "category", "score_pos", "score_neg", "status", "reasoning_pos", "reasoning_neg"]
    with open(path, "w", newline="", encoding="utf-8") as f:
        writer = csv.DictWriter(f, fieldnames=keys)
        writer.writeheader()
        for b in bucketed:
            writer.writerow({k: b[k] for k in keys})

# ================================================================================
# MAIN PIPELINE EXECUTION
# ================================================================================

def main():
    """Integrated pipeline: Task 1 (NER) → Task 2 (LYNX Hallucination Detection)"""

    print("="*60)
    print("MEDICAL AI EVALUATION PIPELINE")
    print("="*60)
    print("Task 1: NER Entity Validation")
    print("Task 2: LYNX Hallucination Detection")
    print("="*60)

    # Setup logging for both tasks
    task1_logger = setup_logging(TASK1_LOG_FILE, "Task 1: NER")
    task2_logger = setup_logging(TASK2_LOG_FILE, "Task 2: LYNX")

    # Load dataset ONCE
    print(f"\n📚 Loading dataset...")
    dataset = load_dataset("adesouza1/soap_notes")
    print(f"✓ Loaded {len(dataset['train'])} documents")

    # Initialize API clients for Task 2
    lynx_client = OpenAI(api_key=HugFace_DeepScribe, base_url="https://router.huggingface.co/v1")
    groq_client = Groq(api_key=Groq_DeepScribe)

    # ========== TASK 1: NER ENTITY VALIDATION ==========
    print(f"\n{'='*60}")
    print("TASK 1: NER ENTITY VALIDATION")
    print("="*60)
    print(f"Initializing NER models...")
    evaluator = MedicalEntityEvaluator(task1_logger)
    print(f"✓ Models loaded")

    print(f"\nProcessing {NUM_SAMPLES} samples (Task 1)...")
    ner_results = []
    all_match_details = []

    for i in range(NUM_SAMPLES):
        task1_logger.info(f"\n{'='*80}\nProcessing sample {i+1}/{NUM_SAMPLES}\n{'='*80}")

        transcript = dataset['train']['patient_convo'][i]
        soap_note = dataset['train']['soap_notes'][i]

        result = evaluator.evaluate_single_pair(transcript, soap_note)

        flat_result = {
            'document_id': i,
            'coverage_score': f"{result['metrics']['coverage_score']}%",
            'criticality_score': f"{result['metrics']['criticality_score']}%",
            'extraction_confidence': result['metrics']['extraction_confidence'],
            'entities_transcript': result['metrics']['total_transcript_entities'],
            'entities_soap': result['metrics']['total_soap_entities'],
            'entities_matched': result['metrics']['matched_entities'],
            'missing_critical': result['metrics']['missing_breakdown']['critical'],
            'missing_moderate': result['metrics']['missing_breakdown']['moderate'],
            'missing_low': result['metrics']['missing_breakdown']['low'],
            'processing_time': result['processing_time']
        }

        ner_results.append(flat_result)

        if not result['match_details'].empty:
            match_detail_df = result['match_details'].copy()
            match_detail_df['document_id'] = i
            all_match_details.append(match_detail_df)

        print(f"  Sample {i+1}/{NUM_SAMPLES}... ✓")

    # Save Task 1 results
    ner_results_df = pd.DataFrame(ner_results)
    ner_results_df.to_csv(NER_SUMMARY_OUTPUT, index=False)

    if all_match_details:
        matches_df = pd.concat(all_match_details, ignore_index=True)
        matches_df.to_csv(NER_DETAILS_OUTPUT, index=False)

    task1_logger.info("Task 1 complete")

    # ========== TASK 2: LYNX HALLUCINATION DETECTION ==========
    print(f"\n{'='*60}")
    print("TASK 2: LYNX HALLUCINATION DETECTION")
    print("="*60)

    print(f"\nProcessing {NUM_SAMPLES} samples (Task 2)...")
    total_bucketed = []
    per_note_metrics = []

    for i in range(NUM_SAMPLES):
        task2_logger.info(f"\n{'='*80}\nProcessing sample {i+1}/{NUM_SAMPLES}\n{'='*80}")

        transcript = dataset['train']['patient_convo'][i]
        soap_note = dataset['train']['soap_notes'][i]

        try:
            probes = extract_probes(soap_note, groq_client, task2_logger)
            evals = evaluate_probes_for_document(probes, transcript, lynx_client, groq_client, task2_logger)
            bucketed = bucket_claims(evals)

            # Compute metrics for THIS note
            note_metrics = compute_hallucination_metrics(bucketed)
            note_metrics['document_id'] = i
            per_note_metrics.append(note_metrics)

            task2_logger.info(f"Note {i+1} metrics: Supported={note_metrics['supported']}/{note_metrics['total_claims']}, "
                            f"Hallucinations={note_metrics['hallucination']}, "
                            f"Rate={note_metrics['hallucination_rate']:.1f}%")

            # Add to overall total
            total_bucketed.extend(bucketed)

            print(f"  Sample {i+1}/{NUM_SAMPLES}... ✓")
        except KeyError as e:
            task2_logger.error(f"KeyError processing sample {i+1}: {str(e)}")
            task2_logger.error(f"This likely means the Groq model didn't return the expected JSON structure")
            print(f"  Sample {i+1}/{NUM_SAMPLES}... ✗ (KeyError: {str(e)})")
        except ValueError as e:
            task2_logger.error(f"ValueError processing sample {i+1}: {str(e)}")
            print(f"  Sample {i+1}/{NUM_SAMPLES}... ✗ (ValueError: {str(e)[:100]}...)")
        except Exception as e:
            task2_logger.error(f"Error processing sample {i+1}: {str(e)}", exc_info=True)
            print(f"  Sample {i+1}/{NUM_SAMPLES}... ✗ (error: {type(e).__name__})")

    # Save Task 2 results
    if total_bucketed:
        # Save detailed claims
        write_hallucination_csv(total_bucketed, HALLUCINATION_RESULTS_CSV)

        # Save per-note metrics summary
        per_note_df = pd.DataFrame(per_note_metrics)
        per_note_summary_file = "lynx_per_note_summary.csv"
        per_note_df.to_csv(per_note_summary_file, index=False)

        task2_logger.info("Task 2 complete")
        task2_logger.info(f"Per-note summary saved to {per_note_summary_file}")
    else:
        task2_logger.warning("No claims were successfully processed in Task 2")
        print("  ⚠️  No claims processed - check logs for errors")

    # ========== DISPLAY FINAL METRICS ==========
    print("\n" + "="*60)
    print("TASK 1: NER ENTITY VALIDATION - FINAL METRICS")
    print("="*60)

    coverage_values = ner_results_df['coverage_score'].str.rstrip('%').astype(float)
    criticality_values = ner_results_df['criticality_score'].str.rstrip('%').astype(float)

    print(f"Total Documents:       {len(ner_results_df)}")
    print(f"Avg Coverage Score:    {coverage_values.mean():.1f}% ± {coverage_values.std():.1f}%")
    print(f"Avg Criticality Score: {criticality_values.mean():.1f}% ± {criticality_values.std():.1f}%")
    print(f"Missing Critical:      {ner_results_df['missing_critical'].sum()} total ({ner_results_df['missing_critical'].mean():.1f} per doc)")
    print(f"\n✓ Results: {NER_SUMMARY_OUTPUT}, {NER_DETAILS_OUTPUT}")

    print("\n" + "="*60)
    print("TASK 2: LYNX HALLUCINATION DETECTION - FINAL METRICS")
    print("="*60)

    if total_bucketed and per_note_metrics:
        # Overall metrics (aggregated across all notes)
        overall_metrics = compute_hallucination_metrics(total_bucketed)

        print(f"\n📊 OVERALL METRICS (All {len(per_note_metrics)} documents):")
        print(f"Total Claims:          {overall_metrics['total_claims']}")
        print(f"Supported:             {overall_metrics['supported']} ({overall_metrics['accuracy_rate']:.1f}%)")
        print(f"Hallucinations:        {overall_metrics['hallucination']} ({overall_metrics['hallucination_rate']:.1f}%)")
        print(f"Ambiguous/Unclear:     {overall_metrics['ambiguous'] + overall_metrics['unclear']} ({overall_metrics['ambiguity_rate']:.1f}%)")
        print(f"Overall Clarity:       {overall_metrics['overall_clarity']:.1f}%")

        # Per-note averages
        per_note_df = pd.DataFrame(per_note_metrics)
        print(f"\n📋 PER-NOTE AVERAGES:")
        print(f"Avg Claims per Note:   {per_note_df['total_claims'].mean():.1f} ± {per_note_df['total_claims'].std():.1f}")
        print(f"Avg Hallucination Rate: {per_note_df['hallucination_rate'].mean():.1f}% ± {per_note_df['hallucination_rate'].std():.1f}%")
        print(f"Avg Accuracy Rate:     {per_note_df['accuracy_rate'].mean():.1f}% ± {per_note_df['accuracy_rate'].std():.1f}%")
        print(f"Avg Clarity Score:     {per_note_df['overall_clarity'].mean():.1f}% ± {per_note_df['overall_clarity'].std():.1f}%")

        print(f"\n✓ Results:")
        print(f"  - Detailed claims: {HALLUCINATION_RESULTS_CSV}")
        print(f"  - Per-note summary: lynx_per_note_summary.csv")
    else:
        print("⚠️  No claims were successfully processed")
        print("Check the Task 2 log file for error details:")
        print(f"   {TASK2_LOG_FILE}")

    print("\n" + "="*60)
    print("✅ PIPELINE COMPLETE")
    print("="*60)
    print(f"Logs: {TASK1_LOG_FILE}, {TASK2_LOG_FILE}")
    print("="*60)

if __name__ == "__main__":
    main()

INFO:Task 1: NER:Task 1: NER - Logging Started
INFO:Task 2: LYNX:Task 2: LYNX - Logging Started


MEDICAL AI EVALUATION PIPELINE
Task 1: NER Entity Validation
Task 2: LYNX Hallucination Detection

📚 Loading dataset...
✓ Loaded 558 documents


INFO:Task 1: NER:Initializing Medical Entity Evaluator...
INFO:Task 1: NER:Loading NER model: Helios9/BioMed_NER



TASK 1: NER ENTITY VALIDATION
Initializing NER models...


Device set to use cuda:0
INFO:Task 1: NER:NER model loaded successfully
INFO:Task 1: NER:Loading embedding model: emilyalsentzer/Bio_ClinicalBERT
INFO:Task 1: NER:Using GPU acceleration
INFO:Task 1: NER:Evaluator initialized successfully
INFO:Task 1: NER:
Processing sample 1/10
DEBUG:Task 1: NER:Extracting entities from transcript (length: 2632 chars)
DEBUG:Task 1: NER:Text truncated from 2632 to 2000 chars
Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no truncation.
DEBUG:Task 1: NER:Found 20 entities above confidence threshold
DEBUG:Task 1: NER:Deduplicated to 17 unique entities
DEBUG:Task 1: NER:Extracting entities from SOAP note (length: 1694 chars)


✓ Models loaded

Processing 10 samples (Task 1)...


DEBUG:Task 1: NER:Found 37 entities above confidence threshold
DEBUG:Task 1: NER:Deduplicated to 25 unique entities
DEBUG:Task 1: NER:Matching 17 transcript with 25 SOAP entities
DEBUG:Task 1: NER:Matched: 13, Unmatched: 4
INFO:Task 1: NER:
Processing sample 2/10
DEBUG:Task 1: NER:Extracting entities from transcript (length: 2514 chars)
DEBUG:Task 1: NER:Text truncated from 2514 to 2000 chars
DEBUG:Task 1: NER:Found 25 entities above confidence threshold
DEBUG:Task 1: NER:Deduplicated to 22 unique entities
DEBUG:Task 1: NER:Extracting entities from SOAP note (length: 1127 chars)


  Sample 1/10... ✓


DEBUG:Task 1: NER:Found 34 entities above confidence threshold
DEBUG:Task 1: NER:Deduplicated to 25 unique entities
DEBUG:Task 1: NER:Matching 22 transcript with 25 SOAP entities
DEBUG:Task 1: NER:Matched: 17, Unmatched: 5
INFO:Task 1: NER:
Processing sample 3/10
DEBUG:Task 1: NER:Extracting entities from transcript (length: 2584 chars)
DEBUG:Task 1: NER:Text truncated from 2584 to 2000 chars


  Sample 2/10... ✓


DEBUG:Task 1: NER:Found 40 entities above confidence threshold
DEBUG:Task 1: NER:Deduplicated to 29 unique entities
DEBUG:Task 1: NER:Extracting entities from SOAP note (length: 1467 chars)
DEBUG:Task 1: NER:Found 66 entities above confidence threshold
DEBUG:Task 1: NER:Deduplicated to 58 unique entities
DEBUG:Task 1: NER:Matching 29 transcript with 58 SOAP entities
DEBUG:Task 1: NER:Matched: 24, Unmatched: 5
INFO:Task 1: NER:
Processing sample 4/10
DEBUG:Task 1: NER:Extracting entities from transcript (length: 2290 chars)
DEBUG:Task 1: NER:Text truncated from 2290 to 2000 chars
DEBUG:Task 1: NER:Found 43 entities above confidence threshold
DEBUG:Task 1: NER:Deduplicated to 29 unique entities
DEBUG:Task 1: NER:Extracting entities from SOAP note (length: 1333 chars)


  Sample 3/10... ✓


DEBUG:Task 1: NER:Found 44 entities above confidence threshold
DEBUG:Task 1: NER:Deduplicated to 32 unique entities
DEBUG:Task 1: NER:Matching 29 transcript with 32 SOAP entities
DEBUG:Task 1: NER:Matched: 22, Unmatched: 7
INFO:Task 1: NER:
Processing sample 5/10
DEBUG:Task 1: NER:Extracting entities from transcript (length: 2811 chars)
DEBUG:Task 1: NER:Text truncated from 2811 to 2000 chars
DEBUG:Task 1: NER:Found 13 entities above confidence threshold
DEBUG:Task 1: NER:Extracting entities from SOAP note (length: 2024 chars)
DEBUG:Task 1: NER:Text truncated from 2024 to 2000 chars


  Sample 4/10... ✓


DEBUG:Task 1: NER:Found 33 entities above confidence threshold
DEBUG:Task 1: NER:Deduplicated to 31 unique entities
DEBUG:Task 1: NER:Matching 13 transcript with 31 SOAP entities
DEBUG:Task 1: NER:Matched: 11, Unmatched: 2
INFO:Task 1: NER:
Processing sample 6/10
DEBUG:Task 1: NER:Extracting entities from transcript (length: 3023 chars)
DEBUG:Task 1: NER:Text truncated from 3023 to 2000 chars
DEBUG:Task 1: NER:Found 13 entities above confidence threshold
DEBUG:Task 1: NER:Deduplicated to 12 unique entities
DEBUG:Task 1: NER:Extracting entities from SOAP note (length: 1782 chars)


  Sample 5/10... ✓


DEBUG:Task 1: NER:Found 38 entities above confidence threshold
DEBUG:Task 1: NER:Deduplicated to 32 unique entities
DEBUG:Task 1: NER:Matching 12 transcript with 32 SOAP entities
DEBUG:Task 1: NER:Matched: 11, Unmatched: 1
INFO:Task 1: NER:
Processing sample 7/10
DEBUG:Task 1: NER:Extracting entities from transcript (length: 3058 chars)
DEBUG:Task 1: NER:Text truncated from 3058 to 2000 chars
DEBUG:Task 1: NER:Found 33 entities above confidence threshold
DEBUG:Task 1: NER:Deduplicated to 21 unique entities
DEBUG:Task 1: NER:Extracting entities from SOAP note (length: 1368 chars)


  Sample 6/10... ✓


DEBUG:Task 1: NER:Found 62 entities above confidence threshold
DEBUG:Task 1: NER:Deduplicated to 35 unique entities
DEBUG:Task 1: NER:Matching 21 transcript with 35 SOAP entities
DEBUG:Task 1: NER:Matched: 18, Unmatched: 3
INFO:Task 1: NER:
Processing sample 8/10
DEBUG:Task 1: NER:Extracting entities from transcript (length: 2892 chars)
DEBUG:Task 1: NER:Text truncated from 2892 to 2000 chars


  Sample 7/10... ✓


DEBUG:Task 1: NER:Found 34 entities above confidence threshold
DEBUG:Task 1: NER:Deduplicated to 31 unique entities
DEBUG:Task 1: NER:Extracting entities from SOAP note (length: 1310 chars)
DEBUG:Task 1: NER:Found 51 entities above confidence threshold
DEBUG:Task 1: NER:Deduplicated to 44 unique entities
DEBUG:Task 1: NER:Matching 31 transcript with 44 SOAP entities
DEBUG:Task 1: NER:Matched: 21, Unmatched: 10
INFO:Task 1: NER:
Processing sample 9/10
DEBUG:Task 1: NER:Extracting entities from transcript (length: 2115 chars)
DEBUG:Task 1: NER:Text truncated from 2115 to 2000 chars
DEBUG:Task 1: NER:Found 31 entities above confidence threshold
DEBUG:Task 1: NER:Deduplicated to 23 unique entities
DEBUG:Task 1: NER:Extracting entities from SOAP note (length: 1265 chars)


  Sample 8/10... ✓


DEBUG:Task 1: NER:Found 38 entities above confidence threshold
DEBUG:Task 1: NER:Deduplicated to 22 unique entities
DEBUG:Task 1: NER:Matching 23 transcript with 22 SOAP entities
DEBUG:Task 1: NER:Matched: 17, Unmatched: 6
INFO:Task 1: NER:
Processing sample 10/10
DEBUG:Task 1: NER:Extracting entities from transcript (length: 2548 chars)
DEBUG:Task 1: NER:Text truncated from 2548 to 2000 chars
DEBUG:Task 1: NER:Found 41 entities above confidence threshold
DEBUG:Task 1: NER:Deduplicated to 29 unique entities
DEBUG:Task 1: NER:Extracting entities from SOAP note (length: 1572 chars)


  Sample 9/10... ✓


DEBUG:Task 1: NER:Found 75 entities above confidence threshold
DEBUG:Task 1: NER:Deduplicated to 50 unique entities
DEBUG:Task 1: NER:Matching 29 transcript with 50 SOAP entities
DEBUG:Task 1: NER:Matched: 21, Unmatched: 8
INFO:Task 1: NER:Task 1 complete
INFO:Task 2: LYNX:
Processing sample 1/10


  Sample 10/10... ✓

TASK 2: LYNX HALLUCINATION DETECTION

Processing 10 samples (Task 2)...


INFO:Task 2: LYNX:Generated 10 claims from SOAP note
INFO:Task 2: LYNX:Evaluating 10 claims with LYNX
INFO:Task 2: LYNX:Completed 20 evaluations
INFO:Task 2: LYNX:Note 1 metrics: Supported=7/10, Hallucinations=1, Rate=10.0%
INFO:Task 2: LYNX:
Processing sample 2/10


  Sample 1/10... ✓


INFO:Task 2: LYNX:Generated 10 claims from SOAP note
INFO:Task 2: LYNX:Evaluating 10 claims with LYNX
INFO:Task 2: LYNX:Completed 20 evaluations
INFO:Task 2: LYNX:Note 2 metrics: Supported=10/10, Hallucinations=0, Rate=0.0%
INFO:Task 2: LYNX:
Processing sample 3/10


  Sample 2/10... ✓


INFO:Task 2: LYNX:Generated 10 claims from SOAP note
INFO:Task 2: LYNX:Evaluating 10 claims with LYNX
INFO:Task 2: LYNX:Completed 20 evaluations
INFO:Task 2: LYNX:Note 3 metrics: Supported=7/10, Hallucinations=2, Rate=20.0%
INFO:Task 2: LYNX:
Processing sample 4/10


  Sample 3/10... ✓


INFO:Task 2: LYNX:Generated 10 claims from SOAP note
INFO:Task 2: LYNX:Evaluating 10 claims with LYNX
INFO:Task 2: LYNX:Completed 20 evaluations
INFO:Task 2: LYNX:Note 4 metrics: Supported=8/10, Hallucinations=0, Rate=0.0%
INFO:Task 2: LYNX:
Processing sample 5/10


  Sample 4/10... ✓


INFO:Task 2: LYNX:Generated 10 claims from SOAP note
INFO:Task 2: LYNX:Evaluating 10 claims with LYNX
INFO:Task 2: LYNX:Completed 20 evaluations
INFO:Task 2: LYNX:Note 5 metrics: Supported=4/10, Hallucinations=0, Rate=0.0%
INFO:Task 2: LYNX:
Processing sample 6/10


  Sample 5/10... ✓


INFO:Task 2: LYNX:Generated 10 claims from SOAP note
INFO:Task 2: LYNX:Evaluating 10 claims with LYNX
INFO:Task 2: LYNX:Completed 20 evaluations
INFO:Task 2: LYNX:Note 6 metrics: Supported=8/10, Hallucinations=0, Rate=0.0%
INFO:Task 2: LYNX:
Processing sample 7/10


  Sample 6/10... ✓


INFO:Task 2: LYNX:Generated 10 claims from SOAP note
INFO:Task 2: LYNX:Evaluating 10 claims with LYNX
INFO:Task 2: LYNX:Completed 20 evaluations
INFO:Task 2: LYNX:Note 7 metrics: Supported=7/10, Hallucinations=0, Rate=0.0%
INFO:Task 2: LYNX:
Processing sample 8/10


  Sample 7/10... ✓


INFO:Task 2: LYNX:Generated 10 claims from SOAP note
INFO:Task 2: LYNX:Evaluating 10 claims with LYNX
INFO:Task 2: LYNX:Completed 20 evaluations
INFO:Task 2: LYNX:Note 8 metrics: Supported=9/10, Hallucinations=0, Rate=0.0%
INFO:Task 2: LYNX:
Processing sample 9/10


  Sample 8/10... ✓


INFO:Task 2: LYNX:Generated 10 claims from SOAP note
INFO:Task 2: LYNX:Evaluating 10 claims with LYNX
INFO:Task 2: LYNX:Completed 20 evaluations
INFO:Task 2: LYNX:Note 9 metrics: Supported=10/10, Hallucinations=0, Rate=0.0%
INFO:Task 2: LYNX:
Processing sample 10/10


  Sample 9/10... ✓


INFO:Task 2: LYNX:Generated 10 claims from SOAP note
INFO:Task 2: LYNX:Evaluating 10 claims with LYNX
INFO:Task 2: LYNX:Completed 20 evaluations
INFO:Task 2: LYNX:Note 10 metrics: Supported=8/10, Hallucinations=0, Rate=0.0%
INFO:Task 2: LYNX:Task 2 complete
INFO:Task 2: LYNX:Per-note summary saved to lynx_per_note_summary.csv


  Sample 10/10... ✓

TASK 1: NER ENTITY VALIDATION - FINAL METRICS
Total Documents:       10
Avg Coverage Score:    78.8% ± 7.2%
Avg Criticality Score: 81.0% ± 6.2%
Missing Critical:      0 total (0.0 per doc)

✓ Results: ner_evaluation_summary.csv, ner_entity_matches.csv

TASK 2: LYNX HALLUCINATION DETECTION - FINAL METRICS

📊 OVERALL METRICS (All 10 documents):
Total Claims:          100
Supported:             78 (78.0%)
Hallucinations:        3 (3.0%)
Ambiguous/Unclear:     19 (19.0%)
Overall Clarity:       68.5%

📋 PER-NOTE AVERAGES:
Avg Claims per Note:   10.0 ± 0.0
Avg Hallucination Rate: 3.0% ± 6.7%
Avg Accuracy Rate:     78.0% ± 17.5%
Avg Clarity Score:     68.5% ± 25.7%

✓ Results:
  - Detailed claims: lynx_hallucination_results.csv
  - Per-note summary: lynx_per_note_summary.csv

✅ PIPELINE COMPLETE
Logs: task1_ner_evaluation.log, task2_lynx_evaluation.log


#### In parallel

In [4]:
### PARALLEL VERSION - Task 1 & Task 2 run simultaneously
### Clean console output - debug only in log files

import pandas as pd
import numpy as np
import torch
import logging
import time
import json
import csv
from concurrent.futures import ThreadPoolExecutor, as_completed
from datasets import load_dataset
from transformers import AutoTokenizer, AutoModel, pipeline
from sklearn.metrics.pairwise import cosine_similarity
from typing import Dict, List, Tuple, Any
from openai import OpenAI
from groq import Groq
import warnings
from google.colab import userdata
warnings.filterwarnings('ignore')

Groq_DeepScribe = userdata.get('Groq_DeepScribe')
HugFace_DeepScribe = userdata.get('HugFace_DeepScribe')

# ================================================================================
# CENTRALIZED CONFIGURATION
# ================================================================================

# Dataset & Processing
NUM_SAMPLES = 10

# Task 1: NER Entity Validation
NER_MODEL = "Helios9/BioMed_NER"
EMBEDDING_MODEL = "emilyalsentzer/Bio_ClinicalBERT"
CONFIDENCE_THRESHOLD = 0.5
SIMILARITY_THRESHOLD = 0.7
MAX_TEXT_LENGTH = 2000

# Task 1 Outputs
NER_SUMMARY_OUTPUT = "ner_evaluation_summary.csv"
NER_DETAILS_OUTPUT = "ner_entity_matches.csv"
TASK1_LOG_FILE = "task1_ner_evaluation.log"

# Task 2: LYNX Hallucination Detection
LYNX_MODEL = "PatronusAI/Llama-3-Patronus-Lynx-8B-Instruct:featherless-ai"
GROQ_MODEL = "openai/gpt-oss-20b"
MAX_CONCURRENT = 20
RETRY_MAX = 3
RETRY_BACKOFF_BASE = 1.0

# Task 2 Outputs
HALLUCINATION_RESULTS_CSV = "lynx_hallucination_results.csv"
TASK2_LOG_FILE = "task2_lynx_evaluation.log"

# API Keys
HugFace_DeepScribe = HugFace_DeepScribe
Groq_DeepScribe = Groq_DeepScribe

# Entity type weights for criticality scoring
ENTITY_WEIGHTS = {
    'MEDICATION': 1.0, 'DRUG': 1.0,
    'DIAGNOSIS': 0.9, 'DISEASE': 0.9, 'DISORDER': 0.9,
    'PROCEDURE': 0.7, 'TEST': 0.7, 'TREATMENT': 0.7,
    'SYMPTOM': 0.5, 'SIGN': 0.5,
    'ANATOMY': 0.3, 'OTHER': 0.3
}

# ================================================================================
# LOGGING SETUP - FILE ONLY, NO CONSOLE OUTPUT
# ================================================================================

def setup_logging(log_file: str, task_name: str):
    """Configure logging for a specific task - FILE ONLY, no console output"""
    logger = logging.getLogger(task_name)
    logger.setLevel(logging.DEBUG)
    logger.handlers = []  # Clear existing handlers
    logger.propagate = False  # CRITICAL: Prevent propagation to root logger

    file_handler = logging.FileHandler(log_file, mode='w', encoding='utf-8')
    file_handler.setLevel(logging.DEBUG)
    file_formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')
    file_handler.setFormatter(file_formatter)
    logger.addHandler(file_handler)

    logger.info("="*80)
    logger.info(f"{task_name} - Logging Started")
    logger.info("="*80)

    return logger

# ================================================================================
# TASK 1: NER ENTITY VALIDATION (same as before, just copied here)
# ================================================================================

class MedicalEntityEvaluator:
    """NER-based evaluator for detecting missing entities"""

    def __init__(self, logger):
        self.logger = logger
        self.logger.info("Initializing Medical Entity Evaluator...")

        self.logger.info(f"Loading NER model: {NER_MODEL}")
        self.ner_pipeline = pipeline(
            "token-classification",
            model=NER_MODEL,
            aggregation_strategy="simple"
        )
        self.logger.info("NER model loaded successfully")

        self.logger.info(f"Loading embedding model: {EMBEDDING_MODEL}")
        self.tokenizer = AutoTokenizer.from_pretrained(EMBEDDING_MODEL)
        self.embedding_model = AutoModel.from_pretrained(EMBEDDING_MODEL)

        if torch.cuda.is_available():
            self.embedding_model = self.embedding_model.cuda()
            self.logger.info("Using GPU acceleration")
        else:
            self.logger.info("Using CPU")

        self.entity_weights = ENTITY_WEIGHTS
        self.logger.info("Evaluator initialized successfully")

    def extract_entities(self, text: str, source_type: str = "text") -> List[Dict]:
        """Extract medical entities using NER model"""
        self.logger.debug(f"Extracting entities from {source_type} (length: {len(text)} chars)")

        original_length = len(text)
        if original_length > MAX_TEXT_LENGTH:
            text = text[:MAX_TEXT_LENGTH]
            self.logger.debug(f"Text truncated from {original_length} to {MAX_TEXT_LENGTH} chars")

        entities = []
        try:
            ner_results = self.ner_pipeline(text)

            for entity in ner_results:
                confidence = entity.get('score', 0.0)
                if confidence >= CONFIDENCE_THRESHOLD:
                    entity_dict = {
                        'text': entity.get('word', '').strip(),
                        'type': entity.get('entity_group', 'UNKNOWN').upper(),
                        'confidence': round(confidence, 3),
                        'start': entity.get('start', 0),
                        'end': entity.get('end', 0)
                    }
                    entity_dict['weight'] = self.entity_weights.get(entity_dict['type'], 0.3)
                    entities.append(entity_dict)

            self.logger.debug(f"Found {len(entities)} entities above confidence threshold")
        except Exception as e:
            self.logger.error(f"Error extracting entities: {e}")

        # Remove duplicates
        seen = set()
        unique_entities = []
        for ent in entities:
            key = (ent['text'].lower(), ent['type'])
            if key not in seen and len(ent['text']) > 1:
                seen.add(key)
                unique_entities.append(ent)

        if len(unique_entities) != len(entities):
            self.logger.debug(f"Deduplicated to {len(unique_entities)} unique entities")

        return unique_entities

    def get_embedding(self, text: str) -> np.ndarray:
        """Get BioClinicalBERT embedding for text"""
        inputs = self.tokenizer(text, return_tensors='pt', truncation=True, max_length=128, padding=True)
        if torch.cuda.is_available():
            inputs = {k: v.cuda() for k, v in inputs.items()}
        with torch.no_grad():
            outputs = self.embedding_model(**inputs)
            embedding = outputs.last_hidden_state.mean(dim=1).cpu().numpy()
        return embedding

    def match_entities_semantically(self, transcript_entities: List[Dict],
                                   soap_entities: List[Dict]) -> Tuple[Dict, List[Dict], pd.DataFrame]:
        """Match entities using semantic similarity"""
        self.logger.debug(f"Matching {len(transcript_entities)} transcript with {len(soap_entities)} SOAP entities")

        matches = {}
        unmatched = []
        match_details = []

        if not transcript_entities or not soap_entities:
            self.logger.warning("One or both entity lists empty - no matching performed")
            return matches, transcript_entities.copy() if transcript_entities else [], pd.DataFrame()

        for trans_ent in transcript_entities:
            best_match = None
            best_similarity = 0.0

            trans_emb = self.get_embedding(trans_ent['text'])

            for soap_ent in soap_entities:
                soap_emb = self.get_embedding(soap_ent['text'])
                similarity = cosine_similarity(trans_emb, soap_emb)[0][0]

                if similarity > best_similarity and similarity >= SIMILARITY_THRESHOLD:
                    best_similarity = similarity
                    best_match = soap_ent

            if best_match:
                matches[trans_ent['text']] = {
                    'transcript_entity': trans_ent,
                    'soap_entity': best_match,
                    'similarity': round(float(best_similarity), 3)
                }
                match_details.append({
                    'transcript_entity': trans_ent['text'],
                    'transcript_type': trans_ent['type'],
                    'soap_entity': best_match['text'],
                    'soap_type': best_match['type'],
                    'similarity': round(float(best_similarity), 3),
                    'status': 'Matched'
                })
            else:
                unmatched.append(trans_ent)
                match_details.append({
                    'transcript_entity': trans_ent['text'],
                    'transcript_type': trans_ent['type'],
                    'soap_entity': '',
                    'soap_type': '',
                    'similarity': 0.0,
                    'status': 'Missing in SOAP'
                })

        matches_df = pd.DataFrame(match_details)
        self.logger.debug(f"Matched: {len(matches)}, Unmatched: {len(unmatched)}")

        return matches, unmatched, matches_df

    def evaluate_single_pair(self, transcript: str, soap_note: str) -> Dict:
        """Evaluate a single transcript-SOAP pair"""
        start_time = time.time()

        # Extract entities
        transcript_entities = self.extract_entities(transcript, "transcript")
        soap_entities = self.extract_entities(soap_note, "SOAP note")

        # Match entities
        matches, unmatched, match_details_df = self.match_entities_semantically(
            transcript_entities, soap_entities
        )

        # Calculate coverage
        total_transcript = len(transcript_entities)
        matched = len(matches)
        coverage_score = (matched / total_transcript * 100) if total_transcript > 0 else 0.0

        # Calculate criticality score (weighted)
        if total_transcript > 0:
            total_weight = sum(ent['weight'] for ent in transcript_entities)
            matched_weight = sum(matches[key]['transcript_entity']['weight'] for key in matches.keys())
            criticality_score = (matched_weight / total_weight * 100) if total_weight > 0 else 0.0
        else:
            criticality_score = 0.0

        # Missing breakdown
        missing_breakdown = {'critical': 0, 'moderate': 0, 'low': 0}
        for ent in unmatched:
            weight = ent['weight']
            if weight >= 0.9:
                missing_breakdown['critical'] += 1
            elif weight >= 0.5:
                missing_breakdown['moderate'] += 1
            else:
                missing_breakdown['low'] += 1

        processing_time = round(time.time() - start_time, 2)

        return {
            'metrics': {
                'coverage_score': round(coverage_score, 1),
                'criticality_score': round(criticality_score, 1),
                'extraction_confidence': round(np.mean([e['confidence'] for e in transcript_entities]), 3) if transcript_entities else 0.0,
                'total_transcript_entities': total_transcript,
                'total_soap_entities': len(soap_entities),
                'matched_entities': matched,
                'missing_breakdown': missing_breakdown
            },
            'match_details': match_details_df,
            'processing_time': processing_time
        }

# ================================================================================
# TASK 2: LYNX HALLUCINATION DETECTION (keeping original functions)
# ================================================================================

def clean_json_with_llm(raw_text: str, groq_client, logger) -> dict:
    """Use Groq to clean messy JSON"""
    prompt = f"""Extract the JSON object from this text. Return ONLY valid JSON with REASONING and SCORE fields.

Text: {raw_text}

Return format:
{{"REASONING": ["reason1", "reason2"], "SCORE": "PASS"}}
or
{{"REASONING": ["reason1", "reason2"], "SCORE": "FAIL"}}"""

    try:
        resp = groq_client.chat.completions.create(
            model=GROQ_MODEL,
            messages=[{"role": "user", "content": prompt}],
            temperature=0,
            max_completion_tokens=10000,
            response_format={"type": "json_object"}
        )
        cleaned = json.loads(resp.choices[0].message.content)
        return cleaned
    except Exception as e:
        logger.error(f"JSON cleaning failed: {e}")
        return {"REASONING": ["Failed to parse"], "SCORE": "FAIL"}

def call_lynx(document: str, question: str, answer: str, lynx_client, groq_client, logger,
              retry_count=0) -> dict:
    """Call LYNX model to evaluate if answer is supported"""
    prompt = f"""[INST] <<SYS>> You are a helpful assistant. Given a QUESTION, ANSWER, and DOCUMENT, your job is to determine whether the ANSWER is faithful to the DOCUMENT. Your answer should be a valid JSON object with two fields: REASONING and SCORE. REASONING should be a list of strings. SCORE should be either "PASS" or "FAIL". <</SYS>>

QUESTION: {question}
ANSWER: {answer}
DOCUMENT: {document} [/INST]"""

    try:
        resp = lynx_client.chat.completions.create(
            model=LYNX_MODEL,
            messages=[{"role": "user", "content": prompt}],
            temperature=0.0,
            max_tokens=4000
        )

        raw_content = resp.choices[0].message.content.strip()

        try:
            result = json.loads(raw_content)
            if "REASONING" in result and "SCORE" in result:
                return result
            else:
                logger.warning("Missing required fields, attempting to clean")
                return clean_json_with_llm(raw_content, groq_client, logger)
        except json.JSONDecodeError:
            logger.warning(f"JSON decode error, cleaning response")
            return clean_json_with_llm(raw_content, groq_client, logger)

    except Exception as e:
        logger.error(f"LYNX call failed: {e}")
        if retry_count < RETRY_MAX:
            wait_time = RETRY_BACKOFF_BASE * (2 ** retry_count)
            time.sleep(wait_time)
            return call_lynx(document, question, answer, lynx_client, groq_client, logger, retry_count + 1)
        return {"REASONING": ["Error occurred"], "SCORE": "FAIL"}

def extract_probes(soap_note: str, groq_client, logger) -> List[Dict]:
    """Extract claims and probes from SOAP note"""
    prompt = f"""From this SOAP note, extract exactly 10 atomic clinical claims. For each claim, create:
1. A positive probe (q_pos, a_pos) that would be answered YES if claim is true
2. A negative probe (q_neg, a_neg) that would be answered YES if claim is false

SOAP Note:
{soap_note}

Return JSON:
{{
  "claims": [
    {{
      "claim": "...",
      "section": "S/O/A/P",
      "category": "medications/diagnoses/symptoms/vitals/procedures/family_history/social_history/allergies/imaging/lab_results/physical_exam/other",
      "q_pos": "Does ...",
      "a_pos": "Yes, ...",
      "q_neg": "Does patient deny/not have ...",
      "a_neg": "Yes, patient denies/does not have ..."
    }}
  ]
}}"""

    resp = groq_client.chat.completions.create(
        model=GROQ_MODEL,
        messages=[{"role": "user", "content": prompt}],
        temperature=0.2,
        top_p=1.0,
        max_completion_tokens=10000,
        stream=False,
        response_format={"type": "json_object"}
    )

    result = json.loads(resp.choices[0].message.content)
    return result.get("claims", [])

def evaluate_probes_for_document(probes: List[Dict], transcript: str, lynx_client,
                                groq_client, logger) -> List[Dict]:
    """Evaluate all probes for a document using LYNX"""
    results = []

    with ThreadPoolExecutor(max_workers=MAX_CONCURRENT) as executor:
        futures = {}

        for probe in probes:
            # Submit positive probe
            future_pos = executor.submit(
                call_lynx, transcript, probe['q_pos'], probe['a_pos'],
                lynx_client, groq_client, logger
            )
            # Submit negative probe
            future_neg = executor.submit(
                call_lynx, transcript, probe['q_neg'], probe['a_neg'],
                lynx_client, groq_client, logger
            )

            futures[future_pos] = ('pos', probe)
            futures[future_neg] = ('neg', probe)

        # Collect results
        probe_results = {}
        for future in as_completed(futures):
            probe_type, probe = futures[future]
            result = future.result()

            claim_text = probe['claim']
            if claim_text not in probe_results:
                probe_results[claim_text] = {'probe': probe, 'pos': None, 'neg': None}

            probe_results[claim_text][probe_type] = result

        # Combine pos and neg results
        for claim_text, data in probe_results.items():
            results.append({
                'claim': claim_text,
                'section': data['probe']['section'],
                'category': data['probe']['category'],
                'q_pos': data['probe']['q_pos'],
                'a_pos': data['probe']['a_pos'],
                'q_neg': data['probe']['q_neg'],
                'a_neg': data['probe']['a_neg'],
                'result_pos': data['pos'],
                'result_neg': data['neg']
            })

    return results

def bucket_claims(evals: List[Dict]) -> List[Dict]:
    """Bucket claims into Supported/Hallucination/Ambiguous/Unclear"""
    bucketed = []

    for ev in evals:
        pos_score = ev['result_pos'].get('SCORE', 'FAIL')
        neg_score = ev['result_neg'].get('SCORE', 'FAIL')

        if pos_score == "PASS" and neg_score == "FAIL":
            status = "Supported"
        elif pos_score == "FAIL" and neg_score == "PASS":
            status = "Hallucination"
        elif pos_score == "PASS" and neg_score == "PASS":
            status = "Ambiguous"
        else:
            status = "Unclear"

        bucketed.append({
            'claim': ev['claim'],
            'section': ev['section'],
            'category': ev['category'],
            'score_pos': pos_score,
            'score_neg': neg_score,
            'status': status,
            'reasoning_pos': ', '.join(ev['result_pos'].get('REASONING', [])),
            'reasoning_neg': ', '.join(ev['result_neg'].get('REASONING', []))
        })

    return bucketed

def compute_hallucination_metrics(bucketed_claims: List[Dict]) -> Dict:
    """Compute hallucination metrics"""
    total = len(bucketed_claims)
    if total == 0:
        return {
            'total_claims': 0,
            'supported': 0,
            'hallucination': 0,
            'ambiguous': 0,
            'unclear': 0,
            'hallucination_rate': 0.0,
            'accuracy_rate': 0.0,
            'ambiguity_rate': 0.0,
            'overall_clarity': 0.0
        }

    supported = sum(1 for c in bucketed_claims if c['status'] == 'Supported')
    hallucination = sum(1 for c in bucketed_claims if c['status'] == 'Hallucination')
    ambiguous = sum(1 for c in bucketed_claims if c['status'] == 'Ambiguous')
    unclear = sum(1 for c in bucketed_claims if c['status'] == 'Unclear')

    hallucination_rate = (hallucination / total) * 100
    accuracy_rate = (supported / total) * 100
    ambiguity_rate = ((ambiguous + unclear) / total) * 100
    overall_clarity = accuracy_rate - (ambiguity_rate * 0.5)

    return {
        'total_claims': total,
        'supported': supported,
        'hallucination': hallucination,
        'ambiguous': ambiguous,
        'unclear': unclear,
        'hallucination_rate': round(hallucination_rate, 1),
        'accuracy_rate': round(accuracy_rate, 1),
        'ambiguity_rate': round(ambiguity_rate, 1),
        'overall_clarity': round(overall_clarity, 1)
    }

def write_hallucination_csv(bucketed: List[Dict], filename: str):
    """Write hallucination results to CSV"""
    if not bucketed:
        return

    with open(filename, 'w', newline='', encoding='utf-8') as f:
        writer = csv.DictWriter(f, fieldnames=bucketed[0].keys())
        writer.writeheader()
        writer.writerows(bucketed)

# ================================================================================
# PARALLEL EXECUTION FUNCTIONS
# ================================================================================

def run_task1(dataset, task1_logger):
    """Run Task 1 on all samples"""
    print(f"[Task 1] Initializing NER models...")
    evaluator = MedicalEntityEvaluator(task1_logger)
    print(f"[Task 1] ✓ Models loaded, starting processing...")

    ner_results = []
    all_match_details = []

    for i in range(NUM_SAMPLES):
        task1_logger.info(f"\n{'='*80}\nProcessing sample {i+1}/{NUM_SAMPLES}\n{'='*80}")

        transcript = dataset['train']['patient_convo'][i]
        soap_note = dataset['train']['soap_notes'][i]

        result = evaluator.evaluate_single_pair(transcript, soap_note)

        flat_result = {
            'document_id': i,
            'coverage_score': f"{result['metrics']['coverage_score']}%",
            'criticality_score': f"{result['metrics']['criticality_score']}%",
            'extraction_confidence': result['metrics']['extraction_confidence'],
            'entities_transcript': result['metrics']['total_transcript_entities'],
            'entities_soap': result['metrics']['total_soap_entities'],
            'entities_matched': result['metrics']['matched_entities'],
            'missing_critical': result['metrics']['missing_breakdown']['critical'],
            'missing_moderate': result['metrics']['missing_breakdown']['moderate'],
            'missing_low': result['metrics']['missing_breakdown']['low'],
            'processing_time': result['processing_time']
        }

        ner_results.append(flat_result)

        if not result['match_details'].empty:
            match_detail_df = result['match_details'].copy()
            match_detail_df['document_id'] = i
            all_match_details.append(match_detail_df)

        print(f"[Task 1] Sample {i+1}/{NUM_SAMPLES}... ✓")

    # Save results
    ner_results_df = pd.DataFrame(ner_results)
    ner_results_df.to_csv(NER_SUMMARY_OUTPUT, index=False)

    if all_match_details:
        matches_df = pd.concat(all_match_details, ignore_index=True)
        matches_df.to_csv(NER_DETAILS_OUTPUT, index=False)

    task1_logger.info("Task 1 complete")
    print(f"[Task 1] ✅ Complete!")

    return ner_results_df

def run_task2(dataset, lynx_client, groq_client, task2_logger):
    """Run Task 2 on all samples"""
    print(f"[Task 2] Starting LYNX hallucination detection...")

    total_bucketed = []
    per_note_metrics = []

    for i in range(NUM_SAMPLES):
        task2_logger.info(f"\n{'='*80}\nProcessing sample {i+1}/{NUM_SAMPLES}\n{'='*80}")

        transcript = dataset['train']['patient_convo'][i]
        soap_note = dataset['train']['soap_notes'][i]

        try:
            probes = extract_probes(soap_note, groq_client, task2_logger)
            evals = evaluate_probes_for_document(probes, transcript, lynx_client, groq_client, task2_logger)
            bucketed = bucket_claims(evals)

            # Compute metrics for THIS note
            note_metrics = compute_hallucination_metrics(bucketed)
            note_metrics['document_id'] = i
            per_note_metrics.append(note_metrics)

            task2_logger.info(f"Note {i+1} metrics: Supported={note_metrics['supported']}/{note_metrics['total_claims']}, "
                            f"Hallucinations={note_metrics['hallucination']}, "
                            f"Rate={note_metrics['hallucination_rate']:.1f}%")

            # Add to overall total
            total_bucketed.extend(bucketed)

            print(f"[Task 2] Sample {i+1}/{NUM_SAMPLES}... ✓")
        except KeyError as e:
            task2_logger.error(f"KeyError processing sample {i+1}: {str(e)}")
            print(f"[Task 2] Sample {i+1}/{NUM_SAMPLES}... ✗ (KeyError)")
        except ValueError as e:
            task2_logger.error(f"ValueError processing sample {i+1}: {str(e)}")
            print(f"[Task 2] Sample {i+1}/{NUM_SAMPLES}... ✗ (ValueError)")
        except Exception as e:
            task2_logger.error(f"Error processing sample {i+1}: {str(e)}", exc_info=True)
            print(f"[Task 2] Sample {i+1}/{NUM_SAMPLES}... ✗ (error)")

    # Save results
    if total_bucketed:
        write_hallucination_csv(total_bucketed, HALLUCINATION_RESULTS_CSV)

        per_note_df = pd.DataFrame(per_note_metrics)
        per_note_summary_file = "lynx_per_note_summary.csv"
        per_note_df.to_csv(per_note_summary_file, index=False)

        task2_logger.info("Task 2 complete")
        task2_logger.info(f"Per-note summary saved to {per_note_summary_file}")
        print(f"[Task 2] ✅ Complete!")

        return per_note_df, total_bucketed
    else:
        task2_logger.warning("No claims were successfully processed")
        print(f"[Task 2] ⚠️  No claims processed")
        return None, []

# ================================================================================
# MAIN - PARALLEL EXECUTION
# ================================================================================

def main():
    print("="*60)
    print("MEDICAL AI EVALUATION PIPELINE - PARALLEL EXECUTION")
    print("="*60)
    print("Task 1: NER Entity Validation")
    print("Task 2: LYNX Hallucination Detection")
    print("="*60)

    # Setup logging
    task1_logger = setup_logging(TASK1_LOG_FILE, "Task 1: NER")
    task2_logger = setup_logging(TASK2_LOG_FILE, "Task 2: LYNX")

    # Load dataset ONCE
    print(f"\n📚 Loading dataset...")
    dataset = load_dataset("adesouza1/soap_notes")
    print(f"✓ Loaded {len(dataset['train'])} documents")

    # Initialize API clients for Task 2
    lynx_client = OpenAI(api_key=HugFace_DeepScribe, base_url="https://router.huggingface.co/v1")
    groq_client = Groq(api_key=Groq_DeepScribe)

    # ========== RUN BOTH TASKS IN PARALLEL ==========
    print(f"\n{'='*60}")
    print("STARTING PARALLEL EXECUTION")
    print("="*60)
    print(f"Processing {NUM_SAMPLES} samples with both tasks running simultaneously...")
    print()

    start_time = time.time()

    with ThreadPoolExecutor(max_workers=2) as executor:
        # Submit both tasks
        task1_future = executor.submit(run_task1, dataset, task1_logger)
        task2_future = executor.submit(run_task2, dataset, lynx_client, groq_client, task2_logger)

        # Wait for both to complete
        ner_results_df = task1_future.result()
        per_note_df, total_bucketed = task2_future.result()

    total_time = time.time() - start_time

    print(f"\n{'='*60}")
    print("BOTH TASKS COMPLETE")
    print(f"Total parallel execution time: {total_time:.1f} seconds")
    print("="*60)

    # ========== DISPLAY FINAL METRICS ==========
    print("\n" + "="*60)
    print("TASK 1: NER ENTITY VALIDATION - FINAL METRICS")
    print("="*60)

    coverage_values = ner_results_df['coverage_score'].str.rstrip('%').astype(float)
    criticality_values = ner_results_df['criticality_score'].str.rstrip('%').astype(float)

    print(f"Total Documents:       {len(ner_results_df)}")
    print(f"Avg Coverage Score:    {coverage_values.mean():.1f}% ± {coverage_values.std():.1f}%")
    print(f"Avg Criticality Score: {criticality_values.mean():.1f}% ± {criticality_values.std():.1f}%")
    print(f"Missing Critical:      {ner_results_df['missing_critical'].sum()} total ({ner_results_df['missing_critical'].mean():.1f} per doc)")
    print(f"\n✓ Results: {NER_SUMMARY_OUTPUT}, {NER_DETAILS_OUTPUT}")

    print("\n" + "="*60)
    print("TASK 2: LYNX HALLUCINATION DETECTION - FINAL METRICS")
    print("="*60)

    if total_bucketed and per_note_df is not None:
        # Overall metrics
        overall_metrics = compute_hallucination_metrics(total_bucketed)

        print(f"\n📊 OVERALL METRICS (All {len(per_note_df)} documents):")
        print(f"Total Claims:          {overall_metrics['total_claims']}")
        print(f"Supported:             {overall_metrics['supported']} ({overall_metrics['accuracy_rate']:.1f}%)")
        print(f"Hallucinations:        {overall_metrics['hallucination']} ({overall_metrics['hallucination_rate']:.1f}%)")
        print(f"Ambiguous/Unclear:     {overall_metrics['ambiguous'] + overall_metrics['unclear']} ({overall_metrics['ambiguity_rate']:.1f}%)")
        print(f"Overall Clarity:       {overall_metrics['overall_clarity']:.1f}%")

        # Per-note averages
        print(f"\n📋 PER-NOTE AVERAGES:")
        print(f"Avg Claims per Note:    {per_note_df['total_claims'].mean():.1f} ± {per_note_df['total_claims'].std():.1f}")
        print(f"Avg Hallucination Rate: {per_note_df['hallucination_rate'].mean():.1f}% ± {per_note_df['hallucination_rate'].std():.1f}%")
        print(f"Avg Accuracy Rate:      {per_note_df['accuracy_rate'].mean():.1f}% ± {per_note_df['accuracy_rate'].std():.1f}%")
        print(f"Avg Clarity Score:      {per_note_df['overall_clarity'].mean():.1f}% ± {per_note_df['overall_clarity'].std():.1f}%")

        print(f"\n✓ Results:")
        print(f"  - Detailed claims: {HALLUCINATION_RESULTS_CSV}")
        print(f"  - Per-note summary: lynx_per_note_summary.csv")
    else:
        print("⚠️  No claims were successfully processed")
        print(f"Check log: {TASK2_LOG_FILE}")

    print("\n" + "="*60)
    print("✅ PIPELINE COMPLETE")
    print("="*60)
    print(f"Total Time: {total_time:.1f} seconds (parallel execution)")
    print(f"Logs: {TASK1_LOG_FILE}, {TASK2_LOG_FILE}")
    print("="*60)

if __name__ == "__main__":
    main()

MEDICAL AI EVALUATION PIPELINE - PARALLEL EXECUTION
Task 1: NER Entity Validation
Task 2: LYNX Hallucination Detection

📚 Loading dataset...
✓ Loaded 558 documents

STARTING PARALLEL EXECUTION
Processing 10 samples with both tasks running simultaneously...

[Task 1] Initializing NER models...
[Task 2] Starting LYNX hallucination detection...


Device set to use cuda:0
Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no truncation.


[Task 1] ✓ Models loaded, starting processing...
[Task 1] Sample 1/10... ✓
[Task 2] Sample 1/10... ✓
[Task 1] Sample 2/10... ✓
[Task 2] Sample 2/10... ✓
[Task 1] Sample 3/10... ✓
[Task 2] Sample 3/10... ✓
[Task 1] Sample 4/10... ✓
[Task 1] Sample 5/10... ✓
[Task 1] Sample 6/10... ✓
[Task 2] Sample 4/10... ✓
[Task 1] Sample 7/10... ✓
[Task 1] Sample 8/10... ✓
[Task 2] Sample 5/10... ✓
[Task 1] Sample 9/10... ✓
[Task 2] Sample 6/10... ✓
[Task 1] Sample 10/10... ✓
[Task 1] ✅ Complete!
[Task 2] Sample 7/10... ✓
[Task 2] Sample 8/10... ✓
[Task 2] Sample 9/10... ✓
[Task 2] Sample 10/10... ✓
[Task 2] ✅ Complete!

BOTH TASKS COMPLETE
Total parallel execution time: 134.7 seconds

TASK 1: NER ENTITY VALIDATION - FINAL METRICS
Total Documents:       10
Avg Coverage Score:    99.6% ± 1.4%
Avg Criticality Score: 99.6% ± 1.2%
Missing Critical:      0 total (0.0 per doc)

✓ Results: ner_evaluation_summary.csv, ner_entity_matches.csv

TASK 2: LYNX HALLUCINATION DETECTION - FINAL METRICS

📊 OVERALL MET