In [None]:
import torch
from transformers import BartForConditionalGeneration, BartTokenizer
from transformers import RobertaForQuestionAnswering, RobertaTokenizer
from typing import List, Tuple, Dict
import numpy as np
from collections import defaultdict

class FactualConsistencyEvaluator:
    """
    QA-based factual consistency evaluation pipeline.
    Uses a fine-tuned BART-large for question generation and
    RoBERTa-base for question answering to verify if information
    in generated notes is grounded in the original clinical dialogue.
    """

    def __init__(self,
                 qg_model_name: str = "facebook/bart-large",
                 qa_model_name: str = "deepset/roberta-base-squad2",
                 device: str = None):
        """
        Initialize the factual consistency evaluator.

        Args:
            qg_model_name: Name/path of the BART model for question generation
            qa_model_name: Name/path of the RoBERTa model for QA
            device: Device to run models on ('cuda', 'cpu', or None for auto)
        """
        self.device = device if device else ('cuda' if torch.cuda.is_available() else 'cpu')

        # Initialize Question Generation model (BART-large fine-tuned)
        print(f"Loading QG model: {qg_model_name}")
        self.qg_tokenizer = BartTokenizer.from_pretrained(qg_model_name)
        self.qg_model = BartForConditionalGeneration.from_pretrained(qg_model_name)
        self.qg_model.to(self.device)
        self.qg_model.eval()

        # Initialize Question Answering model (RoBERTa-base)
        print(f"Loading QA model: {qa_model_name}")
        self.qa_tokenizer = RobertaTokenizer.from_pretrained(qa_model_name)
        self.qa_model = RobertaForQuestionAnswering.from_pretrained(qa_model_name)
        self.qa_model.to(self.device)
        self.qa_model.eval()

    def generate_questions(self, generated_note: str, num_questions: int = 5) -> List[str]:
        """
        Generate factual questions from the generated SOAP note.

        Args:
            generated_note: The generated clinical note
            num_questions: Number of questions to generate

        Returns:
            List of generated questions
        """
        # Prepare prompt for question generation
        prompt = f"Generate {num_questions} factual questions based on this clinical note: {generated_note}"

        # Tokenize and generate questions
        inputs = self.qg_tokenizer(prompt,
                                  max_length=512,
                                  truncation=True,
                                  return_tensors="pt")
        inputs = {k: v.to(self.device) for k, v in inputs.items()}

        # Generate questions
        with torch.no_grad():
            outputs = self.qg_model.generate(
                **inputs,
                max_length=200,
                num_beams=4,
                num_return_sequences=num_questions,
                early_stopping=True
            )

        # Decode generated questions
        questions = []
        for output in outputs:
            question = self.qg_tokenizer.decode(output, skip_special_tokens=True)
            # Extract just the question part (remove prompt if present)
            if "question:" in question.lower():
                question = question.split("question:")[-1].strip()
            questions.append(question.strip())

        return list(set(questions))[:num_questions]  # Remove duplicates

    def extract_answer(self,
                      context: str,
                      question: str,
                      model_type: str = 'qa') -> Tuple[str, float]:
        """
        Extract answer from context using QA model.

        Args:
            context: Text to search for answer (dialogue or generated note)
            question: Question to answer
            model_type: Which model to use ('qa' for QA model)

        Returns:
            Tuple of (answer_text, confidence_score)
        """
        if model_type == 'qa':
            tokenizer = self.qa_tokenizer
            model = self.qa_model
        else:
            raise ValueError(f"Unknown model type: {model_type}")

        # Tokenize inputs
        inputs = tokenizer(question,
                          context,
                          max_length=512,
                          truncation=True,
                          return_tensors="pt")
        inputs = {k: v.to(self.device) for k, v in inputs.items()}

        # Get answer
        with torch.no_grad():
            outputs = model(**inputs)

        # Get start and end positions
        answer_start = torch.argmax(outputs.start_logits)
        answer_end = torch.argmax(outputs.end_logits) + 1

        # Calculate confidence (average of start and end logits)
        confidence = (outputs.start_logits[0][answer_start].item() +
                     outputs.end_logits[0][answer_end-1].item()) / 2

        # Convert token indices to text
        answer_tokens = inputs['input_ids'][0][answer_start:answer_end]
        answer = tokenizer.decode(answer_tokens, skip_special_tokens=True)

        return answer.strip(), confidence

    def calculate_f1(self, pred_answer: str, true_answer: str) -> float:
        """
        Calculate F1 score between predicted and true answers.

        Args:
            pred_answer: Predicted answer
            true_answer: True answer

        Returns:
            F1 score (0-1)
        """
        # Tokenize into words (simple whitespace split)
        pred_tokens = set(pred_answer.lower().split())
        true_tokens = set(true_answer.lower().split())

        if not pred_tokens or not true_tokens:
            return 0.0

        # Calculate precision, recall, F1
        common_tokens = pred_tokens.intersection(true_tokens)
        precision = len(common_tokens) / len(pred_tokens) if pred_tokens else 0
        recall = len(common_tokens) / len(true_tokens) if true_tokens else 0

        if precision + recall == 0:
            return 0.0

        f1 = 2 * (precision * recall) / (precision + recall)
        return f1

    def evaluate_consistency(self,
                           clinical_dialogue: str,
                           generated_note: str,
                           num_questions: int = 5,
                           f1_threshold: float = 0.5) -> Dict:
        """
        Main evaluation method for factual consistency.

        Args:
            clinical_dialogue: Original clinical dialogue (source)
            generated_note: Generated SOAP note to evaluate
            num_questions: Number of questions to generate
            f1_threshold: F1 threshold for considering answers as matching

        Returns:
            Dictionary with evaluation results
        """
        # Step 1: Generate questions from generated note
        print("Generating factual questions...")
        questions = self.generate_questions(generated_note, num_questions)

        results = {
            'questions': questions,
            'answers_from_dialogue': [],
            'answers_from_note': [],
            'f1_scores': [],
            'is_consistent': [],
            'confidence_scores_dialogue': [],
            'confidence_scores_note': []
        }

        # Step 2: For each question, get answers from both sources
        print(f"Answering {len(questions)} questions...")
        for i, question in enumerate(questions):
            # Get answer from original dialogue (ground truth source)
            answer_dialogue, conf_dialogue = self.extract_answer(
                clinical_dialogue, question, 'qa'
            )

            # Get answer from generated note
            answer_note, conf_note = self.extract_answer(
                generated_note, question, 'qa'
            )

            # Calculate F1 overlap between answers
            f1_score = self.calculate_f1(answer_note, answer_dialogue)

            # Determine if answers are consistent (above threshold)
            is_consistent = f1_score >= f1_threshold

            # Store results
            results['answers_from_dialogue'].append(answer_dialogue)
            results['answers_from_note'].append(answer_note)
            results['f1_scores'].append(f1_score)
            results['is_consistent'].append(is_consistent)
            results['confidence_scores_dialogue'].append(conf_dialogue)
            results['confidence_scores_note'].append(conf_note)

            print(f"  Q{i+1}: {question[:50]}...")
            print(f"    Dialogue Answer: {answer_dialogue[:50]}... (conf: {conf_dialogue:.3f})")
            print(f"    Note Answer: {answer_note[:50]}... (conf: {conf_note:.3f})")
            print(f"    F1: {f1_score:.3f}, Consistent: {is_consistent}")

        # Step 3: Calculate overall consistency score
        consistency_score = sum(results['is_consistent']) / len(questions)

        # Compile final results
        final_results = {
            'consistency_score': consistency_score,
            'num_questions': len(questions),
            'num_consistent': sum(results['is_consistent']),
            'avg_f1_score': np.mean(results['f1_scores']) if results['f1_scores'] else 0,
            'avg_confidence_dialogue': np.mean(results['confidence_scores_dialogue']) if results['confidence_scores_dialogue'] else 0,
            'avg_confidence_note': np.mean(results['confidence_scores_note']) if results['confidence_scores_note'] else 0,
            'detailed_results': results
        }

        print(f"\nOverall Consistency Score: {consistency_score:.3f} "
              f"({sum(results['is_consistent'])}/{len(questions)} questions)")

        return final_results

    def batch_evaluate(self,
                      dialogues: List[str],
                      generated_notes: List[str],
                      num_questions: int = 5) -> Dict:
        """
        Batch evaluation for multiple dialogue-note pairs.

        Args:
            dialogues: List of clinical dialogues
            generated_notes: List of corresponding generated notes
            num_questions: Number of questions per note

        Returns:
            Dictionary with aggregated batch results
        """
        if len(dialogues) != len(generated_notes):
            raise ValueError("Number of dialogues must match number of generated notes")

        batch_results = []
        all_scores = []

        print(f"Starting batch evaluation of {len(dialogues)} pairs...")
        for i, (dialogue, note) in enumerate(zip(dialogues, generated_notes)):
            print(f"\nEvaluating pair {i+1}/{len(dialogues)}...")

            try:
                result = self.evaluate_consistency(dialogue, note, num_questions)
                batch_results.append(result)
                all_scores.append(result['consistency_score'])
            except Exception as e:
                print(f"Error evaluating pair {i+1}: {e}")
                # Add default result for failed evaluation
                batch_results.append({
                    'consistency_score': 0.0,
                    'num_questions': num_questions,
                    'num_consistent': 0,
                    'avg_f1_score': 0.0,
                    'error': str(e)
                })
                all_scores.append(0.0)

        # Aggregate results
        aggregated = {
            'mean_consistency': np.mean(all_scores),
            'std_consistency': np.std(all_scores),
            'min_consistency': np.min(all_scores),
            'max_consistency': np.max(all_scores),
            'median_consistency': np.median(all_scores),
            'num_samples': len(batch_results),
            'individual_results': batch_results
        }

        print(f"\n{'='*50}")
        print(f"BATCH EVALUATION SUMMARY")
        print(f"{'='*50}")
        print(f"Mean Consistency: {aggregated['mean_consistency']:.3f} ± {aggregated['std_consistency']:.3f}")
        print(f"Range: [{aggregated['min_consistency']:.3f}, {aggregated['max_consistency']:.3f}]")
        print(f"Median: {aggregated['median_consistency']:.3f}")
        print(f"Samples: {aggregated['num_samples']}")

        return aggregated


evaluator = FactualConsistencyEvaluator(
    qg_model_name="./models/fine-tuned-bart-qg",
    qa_model_name="./models/fine-tuned-roberta-qa"
)

# Read data from files
import json

# Read dialogues (assuming JSON file with 'dialogues' key)
with open('clinical_dialogues.json', 'r') as f:
    dialogue_data = json.load(f)
    clinical_dialogues = dialogue_data['dialogues']

# Read generated notes (assuming JSON file with 'notes' key)
with open('generated_notes.json', 'r') as f:
    notes_data = json.load(f)
    generated_notes = notes_data['notes']

# Or read from CSV/text files
# import pandas as pd
# df = pd.read_csv('data.csv')
# clinical_dialogues = df['dialogue'].tolist()
# generated_notes = df['generated_note'].tolist()

# Validate data length
if len(clinical_dialogues) != len(generated_notes):
    print(f"Warning: Mismatched data - {len(clinical_dialogues)} dialogues vs {len(generated_notes)} notes")
    # Truncate to minimum length
    min_len = min(len(clinical_dialogues), len(generated_notes))
    clinical_dialogues = clinical_dialogues[:min_len]
    generated_notes = generated_notes[:min_len]

# Batch evaluation
print(f"Running evaluation on {len(clinical_dialogues)} samples...")
batch_results = evaluator.batch_evaluate(
    dialogues=clinical_dialogues,
    generated_notes=generated_notes,
    num_questions=3  # Adjust as needed
)

# Save results to file
output_file = 'factual_consistency_results.json'
with open(output_file, 'w') as f:
    json.dump(batch_results, f, indent=2)

print(f"Evaluation complete. Results saved to {output_file}")
print(f"Mean factual consistency: {batch_results['mean_consistency']:.3f} ± {batch_results['std_consistency']:.3f}")