In [None]:
"""
PublicGuard: Audi Alteram Partem (AAP) Framework Implementation
Real-time hallucination detection for critical public services
"""

import re
import json
import logging
from typing import Dict, List, Tuple, Optional
from datetime import datetime
from LLMAPI import LLMAPI

# Configure logging for audit trail
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
    handlers=[
        logging.FileHandler('publicguard_audit.log'),
        logging.StreamHandler()
    ]
)
logger = logging.getLogger('PublicGuard')

class PublicGuard:
    """
    Main class for the Audi Alteram Partem hallucination detection framework.
    Aggregates predictions from multiple LLMs using Bayesian inference.
    """
    
    def __init__(self, config_path: str = "config.json"):
        """
        Initialize PublicGuard with model configurations.
        
        Args:
            config_path: Path to configuration file containing model priors
        """
        # Default model priors computed from MCC scores: P = (MCC + 1) / 2
        self.DEFAULT_PRIORS = {
            "phi4": 0.734,
            "qwen2.5": 0.724,
            "mistral": 0.642,
            "gemma3": 0.633,
            "llama3.2": 0.573,
            "deepseek-r1": 0.710,
            "moonshot-v1-8k": 0.689,
        }
        
        # Load custom configuration if available
        self.model_priors = self._load_config(config_path)
        
        # Standard prompt template for consistent model responses
        self.PROMPT_TEMPLATE = """Determine if the following statement is TRUE or FALSE.
Only answer with "TRUE" or "FALSE".

Statement: "{statement}"

Answer:"""
        
        # Audit trail storage
        self.audit_trail = []
        
    def _load_config(self, config_path: str) -> Dict[str, float]:
        """
        Load model priors from configuration file.
        
        Args:
            config_path: Path to JSON configuration file
            
        Returns:
            Dictionary of model priors
        """
        try:
            with open(config_path, 'r') as f:
                config = json.load(f)
                logger.info(f"Loaded configuration from {config_path}")
                return config.get('model_priors', self.DEFAULT_PRIORS)
        except FileNotFoundError:
            logger.warning(f"Config file {config_path} not found, using defaults")
            return self.DEFAULT_PRIORS
    
    def _query_model(self, model_name: str, statement: str) -> Optional[str]:
        """
        Query a single model for hallucination detection.
        
        Args:
            model_name: Name of the model to query
            statement: Statement to evaluate
            
        Returns:
            "TRUE" or "FALSE" prediction, or None if error
        """
        try:
            # Format the prompt
            prompt = self.PROMPT_TEMPLATE.format(statement=statement)
            
            # Call the model API
            response = LLMAPI(prompt, model_name)
            
            # Extract TRUE/FALSE from response using pattern matching
            if re.search(r"\bTRUE\b", response, re.IGNORECASE):
                return "TRUE"
            elif re.search(r"\bFALSE\b", response, re.IGNORECASE):
                return "FALSE"
            else:
                # Default to FALSE for ambiguous responses
                logger.warning(f"Ambiguous response from {model_name}: {response}")
                return "FALSE"
                
        except Exception as e:
            logger.error(f"Error querying {model_name}: {str(e)}")
            return None
    
    def _compute_posterior(self, predictions: Dict[str, str]) -> float:
        """
        Compute posterior probability using Bayesian aggregation.
        
        Args:
            predictions: Dictionary of {model_name: prediction}
            
        Returns:
            Posterior probability that the statement is TRUE
        """
        # Initialize likelihood ratios
        likelihood_true = 1.0
        likelihood_false = 1.0
        
        # Update likelihoods based on each model's prediction and prior
        for model, prediction in predictions.items():
            prior = self.model_priors.get(model, 0.5)  # Default to 0.5 if unknown
            
            if prediction == "TRUE":
                # Model says TRUE
                likelihood_true *= prior
                likelihood_false *= (1 - prior)
            else:
                # Model says FALSE
                likelihood_true *= (1 - prior)
                likelihood_false *= prior
        
        # Compute posterior probability using Bayes' rule
        total_likelihood = likelihood_true + likelihood_false
        
        if total_likelihood > 0:
            posterior = likelihood_true / total_likelihood
        else:
            # Handle edge case of zero likelihood
            posterior = 0.5
            
        return posterior
    
    def evaluate(self, 
                 statement: str, 
                 num_models: int = 5,
                 confidence_threshold: float = 0.5,
                 context: str = "general") -> Dict:
        """
        Evaluate a statement for potential hallucination using multiple models.
        
        Args:
            statement: The statement to evaluate
            num_models: Number of models to use (top performers)
            confidence_threshold: Threshold for TRUE/FALSE decision
            context: Application context (e.g., "medical", "legal", "emergency")
            
        Returns:
            Dictionary containing:
                - is_truthful: Boolean indicating if statement is truthful
                - confidence: Float confidence score (0-1)
                - prediction: "TRUE" or "FALSE"
                - model_votes: Individual model predictions
                - audit_id: Unique identifier for audit trail
        """
        start_time = datetime.now()
        audit_id = f"AAP-{start_time.strftime('%Y%m%d-%H%M%S-%f')}"
        
        logger.info(f"Starting evaluation {audit_id} for context: {context}")
        
        # Select top models based on prior probabilities
        sorted_models = sorted(self.model_priors.items(), 
                             key=lambda x: x[1], 
                             reverse=True)
        selected_models = [model for model, _ in sorted_models[:num_models]]
        
        # Query each selected model
        predictions = {}
        for model in selected_models:
            prediction = self._query_model(model, statement)
            if prediction is not None:
                predictions[model] = prediction
                logger.info(f"Model {model} predicted: {prediction}")
        
        # Ensure we have enough predictions
        if len(predictions) < 2:
            logger.error("Insufficient model responses for reliable evaluation")
            return {
                "error": "Insufficient model responses",
                "audit_id": audit_id
            }
        
        # Compute posterior probability
        posterior = self._compute_posterior(predictions)
        
        # Make final decision
        final_prediction = "TRUE" if posterior >= confidence_threshold else "FALSE"
        
        # Adjust confidence score for interpretability
        # If predicting FALSE, confidence is 1 - posterior
        if final_prediction == "FALSE":
            confidence = 1 - posterior
        else:
            confidence = posterior
        
        # Prepare result
        result = {
            "is_truthful": final_prediction == "TRUE",
            "confidence": round(confidence, 3),
            "prediction": final_prediction,
            "raw_posterior": round(posterior, 3),
            "model_votes": predictions,
            "num_models_used": len(predictions),
            "audit_id": audit_id,
            "context": context,
            "evaluation_time_ms": int((datetime.now() - start_time).total_seconds() * 1000)
        }
        
        # Log to audit trail
        audit_entry = {
            **result,
            "statement": statement[:200] + "..." if len(statement) > 200 else statement,
            "timestamp": start_time.isoformat()
        }
        self.audit_trail.append(audit_entry)
        
        # Log summary
        logger.info(f"Evaluation complete: {final_prediction} "
                   f"(confidence: {confidence:.3f}, time: {result['evaluation_time_ms']}ms)")
        
        return result
    
    def batch_evaluate(self, statements: List[str], **kwargs) -> List[Dict]:
        """
        Evaluate multiple statements in batch.
        
        Args:
            statements: List of statements to evaluate
            **kwargs: Additional arguments passed to evaluate()
            
        Returns:
            List of evaluation results
        """
        results = []
        for i, statement in enumerate(statements):
            logger.info(f"Processing batch item {i+1}/{len(statements)}")
            result = self.evaluate(statement, **kwargs)
            results.append(result)
        
        return results
    
    def get_audit_trail(self, 
                       start_date: Optional[str] = None,
                       end_date: Optional[str] = None,
                       context: Optional[str] = None) -> List[Dict]:
        """
        Retrieve audit trail with optional filtering.
        
        Args:
            start_date: Filter by start date (ISO format)
            end_date: Filter by end date (ISO format)
            context: Filter by application context
            
        Returns:
            Filtered audit trail entries
        """
        filtered_trail = self.audit_trail
        
        # Apply filters if provided
        if start_date:
            filtered_trail = [e for e in filtered_trail 
                            if e['timestamp'] >= start_date]
        
        if end_date:
            filtered_trail = [e for e in filtered_trail 
                            if e['timestamp'] <= end_date]
        
        if context:
            filtered_trail = [e for e in filtered_trail 
                            if e.get('context') == context]
        
        return filtered_trail
    
    def calibrate_model(self, 
                       model_name: str,
                       test_statements: List[Tuple[str, bool]]) -> float:
        """
        Calibrate a new model by computing its prior probability.
        
        Args:
            model_name: Name of the model to calibrate
            test_statements: List of (statement, is_true) tuples
            
        Returns:
            Computed prior probability for the model
        """
        correct_predictions = 0
        total_predictions = 0
        
        for statement, ground_truth in test_statements:
            prediction = self._query_model(model_name, statement)
            if prediction is not None:
                total_predictions += 1
                is_correct = (prediction == "TRUE") == ground_truth
                if is_correct:
                    correct_predictions += 1
        
        if total_predictions > 0:
            # Compute MCC-based prior
            accuracy = correct_predictions / total_predictions
            # Simplified MCC approximation
            mcc_approx = 2 * accuracy - 1
            prior = (mcc_approx + 1) / 2
            
            logger.info(f"Calibrated {model_name}: accuracy={accuracy:.3f}, prior={prior:.3f}")
            return prior
        else:
            logger.error(f"Failed to calibrate {model_name}: no valid predictions")
            return 0.5


# Example usage and testing
if __name__ == "__main__":
    # Initialize PublicGuard
    guard = PublicGuard()
    
    # Example 1: Medical context
    medical_statement = "Aspirin is commonly used to reduce fever and relieve mild to moderate pain."
    result = guard.evaluate(
        medical_statement,
        num_models=5,
        confidence_threshold=0.5,
        context="medical"
    )
    
    print("\n=== PublicGuard Evaluation Result ===")
    print(f"Statement: {medical_statement[:100]}...")
    print(f"Prediction: {result['prediction']}")
    print(f"Confidence: {result['confidence']:.1%}")
    print(f"Is Truthful: {result['is_truthful']}")
    print(f"Audit ID: {result['audit_id']}")
    print(f"Model Consensus: {result['model_votes']}")
    
    # Example 2: Batch evaluation for legal context
    legal_statements = [
        "A contract requires offer, acceptance, and consideration to be valid.",
        "In the US, jury verdicts in civil cases must always be unanimous.",
        "Attorney-client privilege protects all communications between lawyer and client."
    ]
    
    print("\n=== Batch Evaluation for Legal Context ===")
    batch_results = guard.batch_evaluate(
        legal_statements,
        num_models=4,
        context="legal"
    )
    
    for i, result in enumerate(batch_results):
        print(f"\nStatement {i+1}: {result['prediction']} "
              f"(confidence: {result['confidence']:.1%})")
    
    # Example 3: Retrieve audit trail
    print("\n=== Audit Trail Summary ===")
    audit_entries = guard.get_audit_trail(context="medical")
    print(f"Total medical evaluations: {len(audit_entries)}")
    
    if audit_entries:
        avg_confidence = sum(e['confidence'] for e in audit_entries) / len(audit_entries)
        print(f"Average confidence: {avg_confidence:.1%}")
        
    # Example 4: High-stakes emergency context with higher threshold
    emergency_statement = "In case of cardiac arrest, begin CPR with 30 chest compressions."
    emergency_result = guard.evaluate(
        emergency_statement,
        num_models=7,  # Use more models for critical decisions
        confidence_threshold=0.7,  # Higher threshold for safety
        context="emergency"
    )
    
    print(f"\n=== Emergency Response Evaluation ===")
    print(f"Critical Statement Evaluation:")
    print(f"Confidence Level: {emergency_result['confidence']:.1%}")
    print(f"Recommendation: {'ACCEPT' if emergency_result['confidence'] > 0.9 else 'REQUIRE HUMAN REVIEW'}")