# Production-Grade RAG System for Insurance Document Analysis
Optimized Version with Enhanced Performance and Modularity


In [None]:
import os
import time
import re
import json
import random
import numpy as np
from datetime import datetime
from typing import List, Dict, Any, Optional, Tuple
from dataclasses import dataclass, field
from enum import Enum
from functools import lru_cache
import logging

# Configure logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)

## CONFIGURATION MANAGEMENT

In [None]:
@dataclass
class RAGConfig:
    """Centralized configuration for RAG system parameters"""
    
    # Model configuration
    model_name: str = 'gpt-3.5-turbo'
    temperature: float = 0.7
    
    # Retrieval parameters
    chunk_size: int = 512
    chunk_overlap: int = 50
    similarity_top_k: int = 5
    
    # Content filtering thresholds
    min_content_length: int = 100
    min_content_indicators: int = 1
    
    # Confidence scoring parameters
    max_source_score: int = 25
    max_length_score: int = 20
    max_specificity_score: int = 25
    max_uncertainty_penalty: int = 20
    max_precision_score: int = 15
    max_source_quality_score: int = 20
    
    # Content quality multipliers
    severe_penalty_multiplier: float = 0.01
    moderate_penalty_multiplier: float = 0.3
    content_boost_multiplier: float = 1.5
    
    # Response parameters
    optimal_response_length_min: int = 30
    optimal_response_length_max: int = 150
    context_window_size: int = 4
    
    # Performance settings
    enable_caching: bool = True
    cache_ttl: int = 3600  # seconds
    max_retries: int = 3
    timeout: int = 30  # seconds

class QuestionType(Enum):
    """Enumeration of question types for classification"""
    FACTUAL = "factual"
    COMPARISON = "comparison"
    PROCEDURAL = "procedural"
    SUMMARY = "summary"
    FOLLOWUP = "followup"

## CONTENT QUALITY ANALYSIS

In [None]:
class ContentQualityAnalyzer:
    """Optimized content quality assessment and filtering"""
    
    # Pre-compiled patterns for better performance
    SEVERE_PENALTY_PATTERNS = re.compile(
        r'table of contents|gc 6001 table of contents|'
        r'this policy has been updated effective january 1, 2014 gc 6001',
        re.IGNORECASE
    )
    
    MODERATE_PENALTY_PATTERNS = re.compile(
        r'section [a-d] -|part [iv]+ -|page \d{1,2}(?!\d)',
        re.IGNORECASE
    )
    
    CONTENT_BOOST_PATTERNS = re.compile(
        r'coverage exclusion|claim procedure|premium payment|'
        r'death benefit|proof of loss|notice of claim|'
        r'medical examination|autopsy|legal action',
        re.IGNORECASE
    )
    
    CONTENT_INDICATORS = {
        'coverage', 'benefit', 'exclusion', 'procedure', 'payment',
        'claim', 'premium', 'death', 'accident', 'medical',
        'within', 'days', 'shall', 'must', 'required', 'employee',
        'insurance', 'policy', 'amount', 'termination', 'effective'
    }
    
    @classmethod
    @lru_cache(maxsize=1024)
    def has_severe_penalty(cls, text: str) -> bool:
        """Check if content should receive severe penalty (cached)"""
        return bool(cls.SEVERE_PENALTY_PATTERNS.search(text))
    
    @classmethod
    @lru_cache(maxsize=1024)
    def has_moderate_penalty(cls, text: str) -> bool:
        """Check if content should receive moderate penalty (cached)"""
        if len(text) >= 300:
            return False
        return bool(cls.MODERATE_PENALTY_PATTERNS.search(text))
    
    @classmethod
    @lru_cache(maxsize=1024)
    def should_boost_content(cls, text: str, query: str) -> bool:
        """Determine if content should be boosted (cached)"""
        relevant_topics = {'exclusion', 'procedure', 'payment', 'claim'}
        query_lower = query.lower()
        
        if not any(topic in query_lower for topic in relevant_topics):
            return False
        
        return bool(cls.CONTENT_BOOST_PATTERNS.search(text))
    
    @classmethod
    def count_content_indicators(cls, text: str) -> int:
        """Count content quality indicators efficiently"""
        text_lower = text.lower()
        text_words = set(text_lower.split())
        return len(cls.CONTENT_INDICATORS & text_words)


## OPTIMIZED RETRIEVERS

In [None]:
class OptimizedBM25Retriever:
    """Performance-optimized BM25 retriever with intelligent content boosting"""
    
    def __init__(self, nodes, config: RAGConfig):
        self.nodes = nodes
        self.config = config
        self.analyzer = ContentQualityAnalyzer()
        
        # Pre-tokenize and cache for performance
        self.tokenized_docs = [node.text.lower().split() for node in nodes]
        
        # Initialize BM25
        from rank_bm25 import BM25Okapi
        self.bm25 = BM25Okapi(self.tokenized_docs)
        
        # Pre-compute node text hashes for caching
        self.node_text_cache = {i: node.text.lower() for i, node in enumerate(nodes)}
    
    def retrieve(self, query_str: str) -> List:
        """Retrieve nodes with optimized content quality boosting"""
        query_text = self._extract_query_text(query_str)
        tokenized_query = query_text.lower().split()
        
        # Get BM25 scores
        scores = self.bm25.get_scores(tokenized_query)
        
        # Apply quality boosting using vectorized operations
        boosted_scores = self._boost_content_quality_vectorized(scores, query_text)
        
        # Get top results efficiently
        top_indices = np.argpartition(boosted_scores, -self.config.similarity_top_k)[-self.config.similarity_top_k:]
        top_indices = top_indices[np.argsort(boosted_scores[top_indices])][::-1]
        
        # Return results with positive scores
        from llama_index.core.schema import NodeWithScore
        return [
            NodeWithScore(node=self.nodes[i], score=boosted_scores[i])
            for i in top_indices if boosted_scores[i] > 0
        ]
    
    def _boost_content_quality_vectorized(self, scores: np.ndarray, query_text: str) -> np.ndarray:
        """Vectorized content quality boosting for performance"""
        boosted_scores = scores.copy()
        
        for i, cached_text in self.node_text_cache.items():
            if self.analyzer.has_severe_penalty(cached_text):
                boosted_scores[i] *= self.config.severe_penalty_multiplier
            elif self.analyzer.has_moderate_penalty(cached_text):
                boosted_scores[i] *= self.config.moderate_penalty_multiplier
            elif self.analyzer.should_boost_content(cached_text, query_text):
                boosted_scores[i] *= self.config.content_boost_multiplier
        
        return boosted_scores
    
    @staticmethod
    def _extract_query_text(query_str) -> str:
        """Extract text from various query formats"""
        if hasattr(query_str, 'query_str'):
            return query_str.query_str
        elif hasattr(query_str, 'text'):
            return query_str.text
        return str(query_str)

class OptimizedHybridRetriever:
    """High-performance hybrid retriever with intelligent filtering"""
    
    def __init__(self, vector_retriever, bm25_retriever, config: RAGConfig):
        self.vector_retriever = vector_retriever
        self.bm25_retriever = bm25_retriever
        self.config = config
        self.analyzer = ContentQualityAnalyzer()
        
        # Cache for filtering results
        self._filter_cache = {}
    
    def retrieve(self, query_str: str) -> List:
        """Retrieve and intelligently filter results"""
        query_text = self._extract_query_text(query_str)
        
        # Parallel retrieval (can be optimized with threading)
        vector_results = self.vector_retriever.retrieve(query_text)
        bm25_results = self.bm25_retriever.retrieve(query_text)
        
        # Combine and deduplicate
        filtered_results = self._filter_and_deduplicate_optimized(
            vector_results + bm25_results
        )
        
        # Apply selective backup if needed
        if len(filtered_results) < 2:
            filtered_results = self._apply_selective_backup(
                vector_results + bm25_results, filtered_results
            )
        
        return filtered_results[:self.config.similarity_top_k]
    
    def _filter_and_deduplicate_optimized(self, all_results: List) -> List:
        """Optimized filtering and deduplication using set operations"""
        seen_texts = set()
        filtered_results = []
        
        for result in all_results:
            text_hash = hash(result.node.text)
            
            if text_hash not in seen_texts and self._is_substantial_content(result.node):
                seen_texts.add(text_hash)
                filtered_results.append(result)
        
        return filtered_results
    
    @lru_cache(maxsize=512)
    def _is_substantial_content(self, node) -> bool:
        """Cached content quality assessment"""
        text = node.text.lower().strip()
        
        # Quick rejection checks
        if self.analyzer.has_severe_penalty(text) or len(text) < self.config.min_content_length:
            return False
        
        # Medium-length structural content check
        if len(text) < 200 and self.analyzer.has_moderate_penalty(text):
            return False
        
        # Content indicator requirement
        return self.analyzer.count_content_indicators(text) >= self.config.min_content_indicators
    
    def _apply_selective_backup(self, all_results: List, current_results: List) -> List:
        """Apply intelligent backup mechanism"""
        seen_texts = {hash(result.node.text) for result in current_results}
        
        for result in all_results:
            if (hash(result.node.text) not in seen_texts and 
                len(current_results) < self.config.similarity_top_k and
                self._is_acceptable_backup(result.node)):
                current_results.append(result)
                seen_texts.add(hash(result.node.text))
        
        return current_results
    
    def _is_acceptable_backup(self, node) -> bool:
        """Determine if content is acceptable as backup"""
        text = node.text.lower().strip()
        
        if 'table of contents' in text or len(text) < 80:
            return False
        
        policy_terms = {'coverage', 'benefit', 'claim', 'insurance', 'policy', 'employee', 'procedure'}
        return any(word in text for word in policy_terms)
    
    @staticmethod
    def _extract_query_text(query_str) -> str:
        """Extract text from various query formats"""
        if hasattr(query_str, 'query_str'):
            return query_str.query_str
        elif hasattr(query_str, 'text'):
            return query_str.text
        return str(query_str)

## INTELLIGENT QUERY CLASSIFICATION

In [None]:
class QueryClassifier:
    """Optimized query classification with caching"""
    
    CLASSIFICATION_RULES = {
        QuestionType.FACTUAL: {
            'keywords': frozenset(['what', 'who', 'when', 'where', 'which']),
            'pattern': re.compile(r'\b(what|who|when|where|which)\b', re.IGNORECASE)
        },
        QuestionType.COMPARISON: {
            'keywords': frozenset(['compare', 'difference', 'vs', 'versus', 'better']),
            'pattern': re.compile(r'\b(compare|difference|vs|versus|better)\b', re.IGNORECASE)
        },
        QuestionType.PROCEDURAL: {
            'keywords': frozenset(['how', 'process', 'procedure', 'steps']),
            'pattern': re.compile(r'\b(how|process|procedure|steps)\b', re.IGNORECASE)
        },
        QuestionType.SUMMARY: {
            'keywords': frozenset(['summarize', 'summary', 'overview', 'explain']),
            'pattern': re.compile(r'\b(summarize|summary|overview|explain)\b', re.IGNORECASE)
        }
    }
    
    FOLLOWUP_INDICATORS = frozenset([
        'elaborate', 'explain more', 'tell me more', 'expand', 'details',
        'that', 'it', 'this', 'further', 'more about', 'specific',
        'can you', 'what about', 'how about'
    ])
    
    @classmethod
    @lru_cache(maxsize=256)
    def classify_question(cls, question: str) -> QuestionType:
        """Classify question type with caching"""
        question_lower = question.lower()
        
        # Check for follow-up first
        if any(indicator in question_lower for indicator in cls.FOLLOWUP_INDICATORS):
            return QuestionType.FOLLOWUP
        
        # Check classification patterns
        for q_type, rules in cls.CLASSIFICATION_RULES.items():
            if rules['pattern'].search(question_lower):
                return q_type
        
        return QuestionType.FACTUAL
    
    @classmethod
    def should_use_sub_questions(cls, question: str, question_type: QuestionType) -> bool:
        """Determine if sub-question engine should be used"""
        if question_type in [QuestionType.COMPARISON, QuestionType.SUMMARY]:
            return True
        
        return len(question.split()) > 15

## ADVANCED CONFIDENCE SCORING

In [None]:
class ConfidenceScorer:
    """Optimized multi-factor confidence scoring system"""
    
    def __init__(self, config: RAGConfig):
        self.config = config
        
        # Pre-compile patterns for performance
        self.specific_indicators = re.compile(
            r'\b(section|page|part|according to|states that|specifically|'
            r'outlined|policy|coverage|benefit|procedure|days|within)\b',
            re.IGNORECASE
        )
        
        self.uncertainty_patterns = re.compile(
            r'\b(not sure|unclear|might be|possibly|perhaps|generally|'
            r'typically|usually|contact the|consult with|it is advisable)\b',
            re.IGNORECASE
        )
        
        self.number_pattern = re.compile(r'\d+')
    
    def calculate_confidence_score(self, response: str, retrieved_nodes: List) -> Tuple[int, List[str]]:
        """Calculate comprehensive confidence score"""
        score = 0.0
        factors = []
        response_lower = response.lower()
        
        # Factor assessments
        assessments = [
            self._assess_source_quantity(retrieved_nodes),
            self._assess_response_length(response),
            self._assess_policy_specificity(response_lower),
            self._assess_uncertainty(response_lower),
            self._assess_numerical_precision(response),
            self._assess_source_quality(retrieved_nodes)
        ]
        
        # Aggregate scores
        for assessment_score, factor_desc in assessments:
            if assessment_score != 0:
                score += assessment_score
                factors.append(factor_desc)
        
        # Normalize and add variability
        final_score = max(0, min(100, score + random.uniform(-3, 3)))
        
        return round(final_score), factors
    
    def _assess_source_quantity(self, retrieved_nodes: List) -> Tuple[float, str]:
        """Assess score based on number of supporting sources"""
        num_sources = len(retrieved_nodes) if retrieved_nodes else 0
        source_score = min(num_sources * 5, self.config.max_source_score)
        return source_score, f"Sources: {num_sources} (+{source_score}pts)"
    
    def _assess_response_length(self, response: str) -> Tuple[float, str]:
        """Assess response quality based on length"""
        word_count = len(response.split())
        
        if self.config.optimal_response_length_min <= word_count <= self.config.optimal_response_length_max:
            score = 20
        elif 20 <= word_count < self.config.optimal_response_length_min or \
             self.config.optimal_response_length_max < word_count <= 200:
            score = 15
        elif 10 <= word_count < 20 or 200 < word_count <= 300:
            score = 10
        else:
            score = 5
        
        return score, f"Length: {word_count} words (+{score}pts)"
    
    def _assess_policy_specificity(self, response_lower: str) -> Tuple[float, str]:
        """Assess specificity of policy references"""
        matches = len(self.specific_indicators.findall(response_lower))
        score = min(matches * 3, self.config.max_specificity_score)
        return score, f"Policy specificity: {matches} terms (+{score}pts)"
    
    def _assess_uncertainty(self, response_lower: str) -> Tuple[float, str]:
        """Detect and penalize uncertain language"""
        matches = len(self.uncertainty_patterns.findall(response_lower))
        penalty = min(matches * 8, self.config.max_uncertainty_penalty)
        return -penalty if penalty > 0 else 0, f"Uncertainty: -{penalty}pts" if penalty > 0 else ""
    
    def _assess_numerical_precision(self, response: str) -> Tuple[float, str]:
        """Assess numerical precision"""
        numbers = len(self.number_pattern.findall(response))
        score = min(numbers * 3, self.config.max_precision_score)
        return score if score > 0 else 0, f"Numerical precision: {numbers} values (+{score}pts)" if score > 0 else ""
    
    def _assess_source_quality(self, retrieved_nodes: List) -> Tuple[float, str]:
        """Enhanced source quality assessment"""
        if not retrieved_nodes:
            return -5, "Source quality: No sources (-5pts)"
        
        substantial_sources = sum(
            1 for node in retrieved_nodes 
            if len(node.node.text) > 150
        )
        
        quality_score = min(substantial_sources * 4, 16)
        
        if quality_score > 0:
            return quality_score, f"Source quality: {substantial_sources} substantial (+{quality_score}pts)"
        else:
            return -5, "Source quality: Low-quality sources (-5pts)"

## PERFORMANCE MONITORING

In [None]:
@dataclass
class QueryMetrics:
    """Data class for query performance metrics"""
    timestamp: str
    question: str
    question_type: QuestionType
    processing_time: float
    confidence_score: int
    num_sources: int
    response_length: int
    context_used: bool
    sub_questions_used: bool
    confidence_factors: List[str]

class PerformanceMonitor:
    """Optimized performance monitoring system"""
    
    def __init__(self):
        self.query_history: List[QueryMetrics] = []
        self.metrics_cache = {
            'total_queries': 0,
            'avg_processing_time': 0.0,
            'avg_confidence_score': 0.0,
            'question_type_distribution': {},
            'source_quality_stats': {}
        }
    
    def log_query(self, metrics: QueryMetrics):
        """Log query performance metrics"""
        self.query_history.append(metrics)
        self._update_metrics(metrics)
    
    def _update_metrics(self, metrics: QueryMetrics):
        """Update aggregate metrics efficiently"""
        self.metrics_cache['total_queries'] += 1
        n = self.metrics_cache['total_queries']
        
        # Update running averages
        self.metrics_cache['avg_processing_time'] = (
            (self.metrics_cache['avg_processing_time'] * (n - 1) + metrics.processing_time) / n
        )
        self.metrics_cache['avg_confidence_score'] = (
            (self.metrics_cache['avg_confidence_score'] * (n - 1) + metrics.confidence_score) / n
        )
        
        # Update distributions
        q_type = metrics.question_type.value
        self.metrics_cache['question_type_distribution'][q_type] = \
            self.metrics_cache['question_type_distribution'].get(q_type, 0) + 1
    
    def get_summary(self) -> Dict[str, Any]:
        """Get performance summary"""
        return {
            'metrics': self.metrics_cache,
            'recent_queries': [
                {
                    'question': q.question[:60],
                    'type': q.question_type.value,
                    'time': f"{q.processing_time:.2f}s",
                    'confidence': q.confidence_score
                }
                for q in self.query_history[-5:]
            ]
        }

## MAIN RAG SYSTEM

In [None]:
class OptimizedRAGSystem:
    """Main RAG system with all optimizations integrated"""
    
    def __init__(self, config: RAGConfig = None):
        self.config = config or RAGConfig()
        self.performance_monitor = PerformanceMonitor()
        self.query_classifier = QueryClassifier()
        self.confidence_scorer = ConfidenceScorer(self.config)
        self.conversation_history = []
        
        # Initialize components (placeholders for actual initialization)
        self.hybrid_retriever = None
        self.query_engine = None
        self.sub_question_engine = None
        
        logger.info("RAG System initialized with optimized configuration")
    
    def initialize_components(self, documents, llm):
        """Initialize all RAG components"""
        from llama_index.core import VectorStoreIndex
        from llama_index.core.node_parser import SentenceSplitter
        from llama_index.core.retrievers import VectorIndexRetriever
        from llama_index.core.query_engine import RetrieverQueryEngine
        from llama_index.core.response_synthesizers import get_response_synthesizer
        
        # Parse documents
        parser = SentenceSplitter(
            chunk_size=self.config.chunk_size,
            chunk_overlap=self.config.chunk_overlap
        )
        nodes = parser.get_nodes_from_documents(documents)
        
        # Build index
        index = VectorStoreIndex(nodes)
        
        # Create retrievers
        vector_retriever = VectorIndexRetriever(
            index=index,
            similarity_top_k=self.config.similarity_top_k
        )
        bm25_retriever = OptimizedBM25Retriever(nodes, self.config)
        self.hybrid_retriever = OptimizedHybridRetriever(
            vector_retriever, bm25_retriever, self.config
        )
        
        # Create query engine
        self.query_engine = RetrieverQueryEngine(
            retriever=self.hybrid_retriever,
            response_synthesizer=get_response_synthesizer(response_mode="compact")
        )
        
        # Try to create sub-question engine
        try:
            from llama_index.core.query_engine import SubQuestionQueryEngine
            from llama_index.core.tools import QueryEngineTool, ToolMetadata
            
            tools = [
                QueryEngineTool(
                    query_engine=self.query_engine,
                    metadata=ToolMetadata(
                        name="insurance_policy",
                        description="Insurance policy information"
                    )
                )
            ]
            self.sub_question_engine = SubQuestionQueryEngine.from_defaults(
                query_engine_tools=tools,
                llm=llm
            )
        except ImportError:
            logger.warning("SubQuestionQueryEngine not available, using standard engine")
            self.sub_question_engine = self.query_engine
        
        logger.info("All components initialized successfully")
    
    def process_query(self, question: str) -> Dict[str, Any]:
        """Process a query with full optimization"""
        start_time = time.time()
        
        # Classify question
        question_type = self.query_classifier.classify_question(question)
        
        # Build contextual question
        contextual_question = self._build_contextual_question(question, question_type)
        
        # Select appropriate engine
        if self.query_classifier.should_use_sub_questions(question, question_type):
            response = self.sub_question_engine.query(contextual_question)
        else:
            response = self.query_engine.query(contextual_question)
        
        # Calculate confidence
        source_nodes = getattr(response, 'source_nodes', [])
        confidence, factors = self.confidence_scorer.calculate_confidence_score(
            response.response, source_nodes
        )
        
        # Log performance
        processing_time = time.time() - start_time
        metrics = QueryMetrics(
            timestamp=datetime.now().isoformat(),
            question=question,
            question_type=question_type,
            processing_time=processing_time,
            confidence_score=confidence,
            num_sources=len(source_nodes),
            response_length=len(response.response.split()),
            context_used=len(self.conversation_history) > 0,
            sub_questions_used=self.query_classifier.should_use_sub_questions(question, question_type),
            confidence_factors=factors
        )
        self.performance_monitor.log_query(metrics)
        
        # Update conversation history
        self.conversation_history.append({'role': 'user', 'content': question})
        self.conversation_history.append({'role': 'assistant', 'content': response.response})
        
        # Maintain context window
        if len(self.conversation_history) > self.config.context_window_size * 2:
            self.conversation_history = self.conversation_history[-(self.config.context_window_size * 2):]
        
        return {
            'response': response.response,
            'question_type': question_type.value,
            'confidence': confidence,
            'factors': factors,
            'processing_time': processing_time,
            'source_nodes': source_nodes
        }
    
    def _build_contextual_question(self, question: str, question_type: QuestionType) -> str:
        """Build contextual question with conversation history"""
        if question_type == QuestionType.FOLLOWUP and self.conversation_history:
            # Build follow-up context
            recent_history = self.conversation_history[-4:]
            context = "\n".join([
                f"{msg['role'].title()}: {msg['content'][:200]}"
                for msg in recent_history
            ])
            return f"Context:\n{context}\n\nFollow-up Question: {question}"
        
        elif self.conversation_history:
            # Regular context
            recent_context = self.conversation_history[-2]['content'][:200] if len(self.conversation_history) >= 2 else ""
            return f"Previous context: {recent_context}\n\nNew Question: {question}"
        
        return question
    
    def reset_conversation(self):
        """Reset conversation history"""
        self.conversation_history = []
        logger.info("Conversation history reset")
    
    def get_performance_summary(self) -> Dict[str, Any]:
        """Get system performance summary"""
        return self.performance_monitor.get_summary()

## MAIN EXECUTION

In [None]:
def main():
    """Main execution function"""
    # Initialize configuration
    config = RAGConfig()
    
    # Initialize system
    rag_system = OptimizedRAGSystem(config)
    
    # Load documents and initialize components
    # (This would be done with actual document loading)
    # rag_system.initialize_components(documents, llm)
    
    logger.info("Optimized RAG System ready for use")
    
    return rag_system

if __name__ == "__main__":
    system = main()
    print("RAG System initialized successfully")