# Efficient RAG System for Insurance Documents

## Overview
This notebook implements an optimized Retrieval-Augmented Generation (RAG) system for insurance document querying using:
- **OpenAI Embeddings** for semantic understanding
- **ChromaDB** for efficient vector storage and retrieval
- **Cross-encoder re-ranking** for improved result relevance
- **GPT-3.5** for natural language response generation

### Key Improvements Over Original Implementation:
1. **Modular Architecture**: Separate classes for each component
2. **Error Handling**: Comprehensive error handling and logging
3. **Performance Optimization**: Batch processing and caching strategies
4. **Configuration Management**: Centralized configuration system
5. **Resource Management**: Efficient memory and API usage

## 1. Environment Setup and Dependency Installation

In [None]:
# Install dependencies with version pinning for reproducibility
!pip install -q --upgrade pip
!pip install -q pdfplumber==0.10.3 \
                tiktoken==0.5.2 \
                openai==1.3.8 \
                chromadb==0.4.18 \
                sentence-transformers==2.2.2 \
                pandas==2.1.4 \
                numpy==1.24.3 \
                tqdm==4.66.1

In [None]:
# Standard library imports
import os
import json
import logging
import warnings
from pathlib import Path
from typing import List, Dict, Any, Optional, Tuple
from dataclasses import dataclass, field
from contextlib import contextmanager

# Third-party imports
import pandas as pd
import numpy as np
from tqdm.auto import tqdm
from operator import itemgetter

# PDF processing
import pdfplumber

# NLP and ML libraries
import tiktoken
import openai
import chromadb
from chromadb.utils.embedding_functions import OpenAIEmbeddingFunction
from sentence_transformers import CrossEncoder

# Configure logging
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)

# Suppress warnings for cleaner output
warnings.filterwarnings('ignore', category=FutureWarning)

print("✅ All dependencies imported successfully!")

In [None]:
@dataclass
class RAGConfig:
    """Configuration class for the RAG system"""
    
    # API Configuration
    openai_api_key: str = ""
    embedding_model: str = "text-embedding-ada-002"
    chat_model: str = "gpt-3.5-turbo"
    
    # File paths
    pdf_directory: str = "./documents"
    chroma_persist_directory: str = "./chroma_db"
    
    # Processing parameters
    min_text_length: int = 10
    chunk_size: int = 1000
    chunk_overlap: int = 200
    
    # Search parameters
    similarity_threshold: float = 0.2
    max_search_results: int = 10
    top_k_rerank: int = 3
    
    # Cross-encoder model
    cross_encoder_model: str = "cross-encoder/ms-marco-MiniLM-L-6-v2"
    
    # Collection names
    main_collection_name: str = "insurance_documents"
    cache_collection_name: str = "query_cache"
    
    def __post_init__(self):
        """Validate configuration after initialization"""
        if not self.openai_api_key:
            raise ValueError("OpenAI API key is required")
        
        # Set OpenAI API key
        openai.api_key = self.openai_api_key

# Initialize configuration
try:
    # Try to load API key from environment or file
    api_key = os.getenv('OPENAI_API_KEY')
    if not api_key:
        # For local development, you can create a file with your API key
        try:
            with open('openai_key.txt', 'r') as f:
                api_key = f.read().strip()
        except FileNotFoundError:
            api_key = input("Please enter your OpenAI API key: ")
    
    config = RAGConfig(openai_api_key=api_key)
    print("✅ Configuration initialized successfully!")
    
except Exception as e:
    logger.error(f"Configuration initialization failed: {e}")
    raise

## 2. Data Processing Pipeline for PDF Documents

In [None]:
class PDFProcessor:
    """Optimized PDF processing with batch capabilities and error handling"""
    
    def __init__(self, config: RAGConfig):
        self.config = config
        
    @staticmethod
    def check_bboxes(word: Dict, table_bbox: Tuple) -> bool:
        """Check if word is inside a table bounding box"""
        word_bbox = (word['x0'], word['top'], word['x1'], word['bottom'])
        return (word_bbox[0] > table_bbox[0] and word_bbox[1] > table_bbox[1] and 
                word_bbox[2] < table_bbox[2] and word_bbox[3] < table_bbox[3])
    
    def extract_text_from_pdf(self, pdf_path: Path) -> List[Tuple[str, str]]:
        """Extract text and tables from PDF with error handling"""
        pages_data = []
        
        try:
            with pdfplumber.open(pdf_path) as pdf:
                for page_num, page in enumerate(pdf.pages, 1):
                    try:
                        # Extract tables and their bounding boxes
                        tables = page.find_tables()
                        table_bboxes = [table.bbox for table in tables]
                        table_data = [{'table': table.extract(), 'top': table.bbox[1]} 
                                    for table in tables]
                        
                        # Extract non-table words
                        words = page.extract_words()
                        non_table_words = [
                            word for word in words 
                            if not any(self.check_bboxes(word, bbox) for bbox in table_bboxes)
                        ]
                        
                        # Combine text and tables in reading order
                        all_elements = non_table_words + table_data
                        if all_elements:
                            clusters = pdfplumber.utils.cluster_objects(
                                all_elements, itemgetter('top'), tolerance=5
                            )
                            
                            lines = []
                            for cluster in clusters:
                                if cluster and 'text' in cluster[0]:
                                    try:
                                        line_text = ' '.join(item['text'] for item in cluster)
                                        lines.append(line_text)
                                    except (KeyError, TypeError):
                                        continue
                                elif cluster and 'table' in cluster[0]:
                                    table_json = json.dumps(cluster[0]['table'])
                                    lines.append(table_json)
                            
                            page_text = ' '.join(lines)
                            if len(page_text.split()) >= self.config.min_text_length:
                                pages_data.append((f"Page {page_num}", page_text))
                                
                    except Exception as e:
                        logger.warning(f"Error processing page {page_num} of {pdf_path.name}: {e}")
                        continue
                        
        except Exception as e:
            logger.error(f"Error opening PDF {pdf_path.name}: {e}")
            return []
            
        return pages_data
    
    def process_directory(self, directory_path: str) -> pd.DataFrame:
        """Process all PDFs in directory with progress tracking"""
        pdf_dir = Path(directory_path)
        if not pdf_dir.exists():
            raise FileNotFoundError(f"Directory not found: {directory_path}")
            
        pdf_files = list(pdf_dir.glob("*.pdf"))
        if not pdf_files:
            logger.warning(f"No PDF files found in {directory_path}")
            return pd.DataFrame()
        
        all_data = []
        
        with tqdm(pdf_files, desc="Processing PDFs") as pbar:
            for pdf_path in pbar:
                pbar.set_postfix({"Current": pdf_path.name})
                
                try:
                    pages_data = self.extract_text_from_pdf(pdf_path)
                    
                    for page_no, page_text in pages_data:
                        all_data.append({
                            'Page_No': page_no,
                            'Page_Text': page_text,
                            'Document_Name': pdf_path.name,
                            'Text_Length': len(page_text.split()),
                            'Policy_Name': pdf_path.stem,
                            'Source_Path': str(pdf_path)
                        })
                        
                except Exception as e:
                    logger.error(f"Failed to process {pdf_path.name}: {e}")
                    continue
        
        if not all_data:
            logger.warning("No data extracted from PDFs")
            return pd.DataFrame()
            
        df = pd.DataFrame(all_data)
        
        # Filter out pages with insufficient content
        initial_count = len(df)
        df = df[df['Text_Length'] >= self.config.min_text_length]
        filtered_count = len(df)
        
        logger.info(f"Processed {len(pdf_files)} PDFs, extracted {initial_count} pages, "
                   f"kept {filtered_count} pages after filtering")
        
        return df

# Initialize PDF processor
pdf_processor = PDFProcessor(config)
print("✅ PDF processor initialized!")

## 3. Vector Database Setup and Configuration

In [None]:
class VectorDatabaseManager:
    """Manages ChromaDB collections with connection pooling and error handling"""
    
    def __init__(self, config: RAGConfig):
        self.config = config
        self.client = None
        self.embedding_function = None
        self.main_collection = None
        self.cache_collection = None
        
    def initialize_client(self):
        """Initialize ChromaDB client with persistence"""
        try:
            # Create persist directory if it doesn't exist
            persist_dir = Path(self.config.chroma_persist_directory)
            persist_dir.mkdir(parents=True, exist_ok=True)
            
            self.client = chromadb.PersistentClient(path=str(persist_dir))
            
            # Initialize embedding function
            self.embedding_function = OpenAIEmbeddingFunction(
                api_key=self.config.openai_api_key,
                model_name=self.config.embedding_model
            )
            
            logger.info("ChromaDB client initialized successfully")
            
        except Exception as e:
            logger.error(f"Failed to initialize ChromaDB client: {e}")
            raise
    
    def get_or_create_collections(self):
        """Create or retrieve existing collections"""
        if not self.client:
            self.initialize_client()
            
        try:
            # Main documents collection
            self.main_collection = self.client.get_or_create_collection(
                name=self.config.main_collection_name,
                embedding_function=self.embedding_function
            )
            
            # Cache collection for query optimization
            self.cache_collection = self.client.get_or_create_collection(
                name=self.config.cache_collection_name,
                embedding_function=self.embedding_function
            )
            
            logger.info(f"Collections initialized: "
                       f"Main({self.main_collection.count()} docs), "
                       f"Cache({self.cache_collection.count()} queries)")
            
        except Exception as e:
            logger.error(f"Failed to create collections: {e}")
            raise
    
    def add_documents_batch(self, documents_df: pd.DataFrame, batch_size: int = 100):
        """Add documents to collection in batches for efficiency"""
        if self.main_collection is None:
            self.get_or_create_collections()
        
        documents = documents_df['Page_Text'].tolist()
        metadatas = documents_df[['Policy_Name', 'Page_No', 'Document_Name', 'Text_Length']].to_dict('records')
        ids = [f"{row['Policy_Name']}_{row['Page_No']}" for _, row in documents_df.iterrows()]
        
        # Process in batches to avoid memory issues
        total_batches = len(documents) // batch_size + (1 if len(documents) % batch_size else 0)
        
        with tqdm(total=total_batches, desc="Adding documents to vector DB") as pbar:
            for i in range(0, len(documents), batch_size):
                batch_docs = documents[i:i + batch_size]
                batch_metadata = metadatas[i:i + batch_size]
                batch_ids = ids[i:i + batch_size]
                
                try:
                    self.main_collection.add(
                        documents=batch_docs,
                        metadatas=batch_metadata,
                        ids=batch_ids
                    )
                    pbar.update(1)
                    
                except Exception as e:
                    logger.error(f"Error adding batch {i//batch_size + 1}: {e}")
                    continue
        
        final_count = self.main_collection.count()
        logger.info(f"Successfully added documents. Total in collection: {final_count}")
        
        return final_count
    
    def check_collection_status(self):
        """Check the status of collections"""
        if not self.client:
            return "Client not initialized"
            
        try:
            collections = self.client.list_collections()
            status = {
                "total_collections": len(collections),
                "collection_names": [col.name for col in collections]
            }
            
            if self.main_collection:
                status["main_collection_count"] = self.main_collection.count()
            if self.cache_collection:
                status["cache_collection_count"] = self.cache_collection.count()
                
            return status
            
        except Exception as e:
            logger.error(f"Error checking collection status: {e}")
            return {"error": str(e)}

# Initialize vector database manager
vector_db = VectorDatabaseManager(config)
vector_db.initialize_client()
vector_db.get_or_create_collections()

print("✅ Vector database manager initialized!")
print(f"Status: {vector_db.check_collection_status()}")

## 4. Semantic Search with Intelligent Caching

In [None]:
class SemanticSearchEngine:
    """Optimized semantic search with intelligent caching and performance metrics"""
    
    def __init__(self, vector_db: VectorDatabaseManager, config: RAGConfig):
        self.vector_db = vector_db
        self.config = config
        self.search_metrics = {
            "total_searches": 0,
            "cache_hits": 0,
            "cache_misses": 0,
            "average_search_time": 0
        }
    
    def search_cache(self, query: str) -> Optional[Dict]:
        """Search for similar queries in cache"""
        try:
            cache_results = self.vector_db.cache_collection.query(
                query_texts=[query],
                n_results=1
            )
            
            if (cache_results['distances'] and 
                cache_results['distances'][0] and 
                cache_results['distances'][0][0] <= self.config.similarity_threshold):
                
                self.search_metrics["cache_hits"] += 1
                logger.info(f"Cache hit for query: {query[:50]}...")
                return self._parse_cache_results(cache_results)
            else:
                self.search_metrics["cache_misses"] += 1
                return None
                
        except Exception as e:
            logger.error(f"Cache search error: {e}")
            return None
    
    def _parse_cache_results(self, cache_results: Dict) -> Dict:
        """Parse cached search results"""
        try:
            metadata = cache_results['metadatas'][0][0]
            
            # Reconstruct results from cache metadata
            results = {
                'documents': [],
                'metadatas': [],
                'distances': [],
                'ids': []
            }
            
            for key, value in metadata.items():
                if key.startswith('documents'):
                    results['documents'].append(value)
                elif key.startswith('metadatas'):
                    results['metadatas'].append(eval(value))  # Note: Use json.loads in production
                elif key.startswith('distances'):
                    results['distances'].append(float(value))
                elif key.startswith('ids'):
                    results['ids'].append(value)
            
            return results
            
        except Exception as e:
            logger.error(f"Error parsing cache results: {e}")
            return None
    
    def search_main_collection(self, query: str) -> Dict:
        """Search main document collection"""
        try:
            results = self.vector_db.main_collection.query(
                query_texts=[query],
                n_results=self.config.max_search_results
            )
            
            # Cache the results for future use
            self._cache_search_results(query, results)
            
            return results
            
        except Exception as e:
            logger.error(f"Main collection search error: {e}")
            return {}
    
    def _cache_search_results(self, query: str, results: Dict):
        """Cache search results for future queries"""
        try:
            # Prepare metadata for caching
            cache_metadata = {}
            
            for i, (doc, meta, dist, doc_id) in enumerate(zip(
                results.get('documents', [[]])[0],
                results.get('metadatas', [[]])[0], 
                results.get('distances', [[]])[0],
                results.get('ids', [[]])[0]
            )):
                cache_metadata[f'documents{i}'] = doc
                cache_metadata[f'metadatas{i}'] = str(meta)
                cache_metadata[f'distances{i}'] = str(dist)
                cache_metadata[f'ids{i}'] = doc_id
            
            # Add to cache collection
            self.vector_db.cache_collection.add(
                documents=[query],
                ids=[f"query_{hash(query)}"],
                metadatas=[cache_metadata]
            )
            
        except Exception as e:
            logger.error(f"Error caching results: {e}")
    
    def search(self, query: str) -> pd.DataFrame:
        """Perform semantic search with caching"""
        import time
        start_time = time.time()
        
        self.search_metrics["total_searches"] += 1
        
        # Check cache first
        cached_results = self.search_cache(query)
        
        if cached_results:
            results = cached_results
            logger.info("Returning cached results")
        else:
            # Search main collection
            results = self.search_main_collection(query)
            logger.info("Returning fresh search results")
        
        # Convert to DataFrame for easier handling
        if results and 'documents' in results and results['documents']:
            results_df = pd.DataFrame({
                'Documents': results['documents'][0],
                'Metadatas': results['metadatas'][0],
                'Distances': results['distances'][0],
                'IDs': results['ids'][0]
            })
        else:
            logger.warning("No search results found")
            results_df = pd.DataFrame()
        
        # Update metrics
        search_time = time.time() - start_time
        self.search_metrics["average_search_time"] = (
            (self.search_metrics["average_search_time"] * (self.search_metrics["total_searches"] - 1) + search_time) 
            / self.search_metrics["total_searches"]
        )
        
        return results_df
    
    def get_search_metrics(self) -> Dict:
        """Get search performance metrics"""
        if self.search_metrics["total_searches"] > 0:
            cache_hit_rate = (self.search_metrics["cache_hits"] / 
                            self.search_metrics["total_searches"]) * 100
        else:
            cache_hit_rate = 0
            
        return {
            **self.search_metrics,
            "cache_hit_rate": f"{cache_hit_rate:.2f}%"
        }

# Initialize semantic search engine
search_engine = SemanticSearchEngine(vector_db, config)
print("✅ Semantic search engine initialized!")

## 5. Cross-Encoder Re-ranking Implementation

In [None]:
class CrossEncoderReranker:
    """Cross-encoder re-ranking with batch processing and performance comparison"""
    
    def __init__(self, config: RAGConfig):
        self.config = config
        self.model = None
        self.load_model()
        
    def load_model(self):
        """Load cross-encoder model with error handling"""
        try:
            logger.info(f"Loading cross-encoder model: {self.config.cross_encoder_model}")
            self.model = CrossEncoder(self.config.cross_encoder_model)
            logger.info("Cross-encoder model loaded successfully")
            
        except Exception as e:
            logger.error(f"Failed to load cross-encoder model: {e}")
            raise
    
    def rerank_results(self, query: str, search_results: pd.DataFrame) -> pd.DataFrame:
        """Re-rank search results using cross-encoder"""
        if search_results.empty or self.model is None:
            return search_results
            
        try:
            # Prepare query-document pairs
            query_doc_pairs = [
                [query, doc] for doc in search_results['Documents']
            ]
            
            # Get re-ranking scores
            rerank_scores = self.model.predict(query_doc_pairs)
            
            # Add scores to dataframe
            results_with_rerank = search_results.copy()
            results_with_rerank['Rerank_Score'] = rerank_scores
            
            # Sort by re-ranking score (highest first)
            results_with_rerank = results_with_rerank.sort_values(
                'Rerank_Score', ascending=False
            ).reset_index(drop=True)
            
            return results_with_rerank
            
        except Exception as e:
            logger.error(f"Re-ranking failed: {e}")
            return search_results
    
    def compare_rankings(self, query: str, search_results: pd.DataFrame) -> Dict:
        """Compare semantic search vs re-ranked results"""
        if search_results.empty:
            return {"error": "No results to compare"}
            
        try:
            # Get top results from semantic search (by distance)
            semantic_top = search_results.nsmallest(self.config.top_k_rerank, 'Distances')
            
            # Get re-ranked results
            reranked_results = self.rerank_results(query, search_results)
            rerank_top = reranked_results.head(self.config.top_k_rerank)
            
            # Calculate overlap
            semantic_ids = set(semantic_top['IDs'])
            rerank_ids = set(rerank_top['IDs'])
            overlap = len(semantic_ids.intersection(rerank_ids))
            
            comparison = {
                "semantic_top_distances": semantic_top['Distances'].tolist(),
                "rerank_top_scores": rerank_top['Rerank_Score'].tolist() if 'Rerank_Score' in rerank_top.columns else [],
                "overlap_count": overlap,
                "overlap_percentage": (overlap / self.config.top_k_rerank) * 100,
                "semantic_ids": list(semantic_ids),
                "rerank_ids": list(rerank_ids)
            }
            
            return comparison
            
        except Exception as e:
            logger.error(f"Ranking comparison failed: {e}")
            return {"error": str(e)}
    
    def get_top_results(self, query: str, search_results: pd.DataFrame) -> pd.DataFrame:
        """Get top K results after re-ranking"""
        reranked = self.rerank_results(query, search_results)
        top_results = reranked.head(self.config.top_k_rerank)
        
        # Select relevant columns for RAG
        if not top_results.empty:
            return top_results[['Documents', 'Metadatas', 'Rerank_Score']].copy()
        else:
            return pd.DataFrame()

# Initialize cross-encoder reranker
reranker = CrossEncoderReranker(config)
print("✅ Cross-encoder reranker initialized!")

## 6. Retrieval Augmented Generation System

In [None]:
class RAGGenerator:
    """Comprehensive RAG system with OpenAI integration and optimized prompting"""
    
    def __init__(self, config: RAGConfig):
        self.config = config
        self.system_prompt = self._create_system_prompt()
        
    def _create_system_prompt(self) -> str:
        """Create optimized system prompt for insurance domain"""
        return """You are a highly knowledgeable insurance specialist assistant. Your role is to provide accurate, helpful, and comprehensive answers about insurance policies and documents.

Key Guidelines:
1. Always base your answers on the provided document context
2. Provide specific details including numbers, percentages, and policy terms when available
3. If information is incomplete, clearly state what's missing and suggest where to find it
4. Use clear, professional language that customers can understand
5. Always include proper citations with policy names and page numbers
6. If tables are present in the context, format them clearly in your response
7. Focus only on information relevant to the user's question"""

    def _create_user_prompt(self, query: str, context_documents: pd.DataFrame) -> str:
        """Create optimized user prompt with context"""
        if context_documents.empty:
            return f"Query: {query}\n\nNo relevant documents found. Please inform the user that no information is available for their query."
            
        # Format context documents
        context_text = ""
        for idx, row in context_documents.iterrows():
            metadata = row['Metadatas']
            policy_name = metadata.get('Policy_Name', 'Unknown Policy')
            page_no = metadata.get('Page_No', 'Unknown Page')
            
            context_text += f"\n--- Document {idx + 1} ---\n"
            context_text += f"Source: {policy_name}, {page_no}\n"
            context_text += f"Content: {row['Documents'][:2000]}...\n"  # Limit context length
        
        user_prompt = f"""Query: {query}

Context Documents:
{context_text}

Instructions:
1. Answer the user's query using ONLY the information from the context documents above
2. Include specific details, numbers, and policy terms when available
3. If tables are mentioned, reformat them clearly
4. Provide citations in the format: [Policy Name - Page Number]
5. If the query cannot be fully answered with the available context, state what information is missing
6. Be concise but comprehensive

Response:"""
        
        return user_prompt
    
    def generate_response(self, query: str, context_documents: pd.DataFrame) -> Dict[str, Any]:
        """Generate response using OpenAI with error handling and retries"""
        try:
            user_prompt = self._create_user_prompt(query, context_documents)
            
            messages = [
                {"role": "system", "content": self.system_prompt},
                {"role": "user", "content": user_prompt}
            ]
            
            # Call OpenAI API with retry logic
            response = self._call_openai_with_retry(messages)
            
            # Parse and format response
            return self._format_response(response, query, context_documents)
            
        except Exception as e:
            logger.error(f"Response generation failed: {e}")
            return {
                "response": f"I apologize, but I encountered an error while processing your query: {str(e)}",
                "query": query,
                "sources": [],
                "error": str(e)
            }
    
    def _call_openai_with_retry(self, messages: List[Dict], max_retries: int = 3) -> str:
        """Call OpenAI API with retry logic"""
        for attempt in range(max_retries):
            try:
                response = openai.chat.completions.create(
                    model=self.config.chat_model,
                    messages=messages,
                    temperature=0.3,
                    max_tokens=1500
                )
                return response.choices[0].message.content
                
            except Exception as e:
                if attempt == max_retries - 1:
                    raise e
                logger.warning(f"OpenAI API call failed (attempt {attempt + 1}): {e}")
                import time
                time.sleep(2 ** attempt)  # Exponential backoff
    
    def _format_response(self, response_text: str, query: str, context_docs: pd.DataFrame) -> Dict[str, Any]:
        """Format the response with metadata"""
        sources = []
        
        if not context_docs.empty:
            for _, row in context_docs.iterrows():
                metadata = row['Metadatas']
                sources.append({
                    "policy_name": metadata.get('Policy_Name', 'Unknown'),
                    "page_number": metadata.get('Page_No', 'Unknown'),
                    "relevance_score": row.get('Rerank_Score', row.get('Distances', 0))
                })
        
        return {
            "response": response_text,
            "query": query,
            "sources": sources,
            "source_count": len(sources),
            "timestamp": pd.Timestamp.now().isoformat()
        }

# Initialize RAG generator
rag_generator = RAGGenerator(config)
print("✅ RAG generator initialized!")

## 7. Complete RAG Pipeline Orchestrator

In [None]:
class EfficientRAGPipeline:
    """Complete RAG pipeline orchestrator with performance monitoring"""
    
    def __init__(self, config: RAGConfig):
        self.config = config
        self.pdf_processor = PDFProcessor(config)
        self.vector_db = VectorDatabaseManager(config)
        self.search_engine = SemanticSearchEngine(self.vector_db, config)
        self.reranker = CrossEncoderReranker(config)
        self.rag_generator = RAGGenerator(config)
        
        self.pipeline_metrics = {
            "queries_processed": 0,
            "average_response_time": 0,
            "successful_responses": 0,
            "failed_responses": 0
        }
        
        # Initialize components
        self._initialize_pipeline()
    
    def _initialize_pipeline(self):
        """Initialize all pipeline components"""
        logger.info("Initializing RAG pipeline components...")
        
        try:
            # Initialize vector database
            self.vector_db.initialize_client()
            self.vector_db.get_or_create_collections()
            
            logger.info("✅ RAG pipeline initialized successfully!")
            
        except Exception as e:
            logger.error(f"Pipeline initialization failed: {e}")
            raise
    
    def process_documents(self, pdf_directory: str = None) -> bool:
        """Process PDFs and populate vector database"""
        try:
            pdf_dir = pdf_directory or self.config.pdf_directory
            
            if not Path(pdf_dir).exists():
                logger.error(f"PDF directory not found: {pdf_dir}")
                return False
            
            # Extract text from PDFs
            logger.info("Processing PDF documents...")
            documents_df = self.pdf_processor.process_directory(pdf_dir)
            
            if documents_df.empty:
                logger.warning("No documents processed")
                return False
            
            # Add to vector database
            logger.info("Adding documents to vector database...")
            doc_count = self.vector_db.add_documents_batch(documents_df)
            
            logger.info(f"✅ Successfully processed {len(documents_df)} pages from {pdf_dir}")
            return True
            
        except Exception as e:
            logger.error(f"Document processing failed: {e}")
            return False
    
    def query(self, user_query: str, enable_reranking: bool = True) -> Dict[str, Any]:
        """Process a complete query through the RAG pipeline"""
        import time
        start_time = time.time()
        
        try:
            self.pipeline_metrics["queries_processed"] += 1
            
            # Step 1: Semantic search with caching
            logger.info(f"Processing query: {user_query[:100]}...")
            search_results = self.search_engine.search(user_query)
            
            if search_results.empty:
                return {
                    "response": "I couldn't find any relevant information for your query. Please try rephrasing your question or check if the documents are properly loaded.",
                    "query": user_query,
                    "sources": [],
                    "search_results_count": 0,
                    "processing_time": time.time() - start_time
                }
            
            # Step 2: Re-ranking (optional)
            if enable_reranking:
                logger.info("Re-ranking search results...")
                top_results = self.reranker.get_top_results(user_query, search_results)
            else:
                # Use top results from semantic search
                top_results = search_results.head(self.config.top_k_rerank)
                top_results = top_results[['Documents', 'Metadatas']].copy()
            
            # Step 3: Generate response
            logger.info("Generating response...")
            response_data = self.rag_generator.generate_response(user_query, top_results)
            
            # Add pipeline metrics
            processing_time = time.time() - start_time
            response_data.update({
                "search_results_count": len(search_results),
                "processing_time": processing_time,
                "reranking_enabled": enable_reranking
            })
            
            # Update metrics
            self._update_metrics(processing_time, success=True)
            
            return response_data
            
        except Exception as e:
            logger.error(f"Query processing failed: {e}")
            self._update_metrics(time.time() - start_time, success=False)
            
            return {
                "response": f"I encountered an error while processing your query: {str(e)}",
                "query": user_query,
                "sources": [],
                "error": str(e),
                "processing_time": time.time() - start_time
            }
    
    def _update_metrics(self, processing_time: float, success: bool):
        """Update pipeline performance metrics"""
        if success:
            self.pipeline_metrics["successful_responses"] += 1
        else:
            self.pipeline_metrics["failed_responses"] += 1
        
        # Update average response time
        total_queries = self.pipeline_metrics["queries_processed"]
        current_avg = self.pipeline_metrics["average_response_time"]
        self.pipeline_metrics["average_response_time"] = (
            (current_avg * (total_queries - 1) + processing_time) / total_queries
        )
    
    def get_system_status(self) -> Dict[str, Any]:
        """Get comprehensive system status"""
        try:
            return {
                "pipeline_metrics": self.pipeline_metrics,
                "search_metrics": self.search_engine.get_search_metrics(),
                "vector_db_status": self.vector_db.check_collection_status(),
                "config": {
                    "embedding_model": self.config.embedding_model,
                    "chat_model": self.config.chat_model,
                    "similarity_threshold": self.config.similarity_threshold,
                    "max_search_results": self.config.max_search_results,
                    "top_k_rerank": self.config.top_k_rerank
                }
            }
        except Exception as e:
            return {"error": f"Failed to get system status: {e}"}
    
    def reset_metrics(self):
        """Reset all performance metrics"""
        self.pipeline_metrics = {
            "queries_processed": 0,
            "average_response_time": 0,
            "successful_responses": 0,
            "failed_responses": 0
        }
        self.search_engine.search_metrics = {
            "total_searches": 0,
            "cache_hits": 0,
            "cache_misses": 0,
            "average_search_time": 0
        }
        logger.info("Pipeline metrics reset")

# Initialize the complete RAG pipeline
rag_pipeline = EfficientRAGPipeline(config)
print("🚀 Complete RAG pipeline initialized and ready!")

## 8. Interactive Query Interface and Usage Examples

In [None]:
# Example: Process documents (replace with your PDF directory path)
# Uncomment and modify the path below to process your PDFs

# DEMO_PDF_DIRECTORY = "./sample_pdfs"  # Replace with your PDF directory
# success = rag_pipeline.process_documents(DEMO_PDF_DIRECTORY)

# For this demo, we'll show the system status
print("=== RAG Pipeline System Status ===")
status = rag_pipeline.get_system_status()

for section, data in status.items():
    print(f"\n{section.upper().replace('_', ' ')}:")
    if isinstance(data, dict):
        for key, value in data.items():
            print(f"  {key}: {value}")
    else:
        print(f"  {data}")

print(f"\n{'='*50}")
print("📋 To process your documents, update the DEMO_PDF_DIRECTORY path above")
print("🔍 Once documents are loaded, you can query the system using the examples below")

In [None]:
def display_query_response(response_data: Dict[str, Any]):
    """Display query response in a formatted way"""
    print("="*80)
    print("🤖 RAG SYSTEM RESPONSE")
    print("="*80)
    
    print(f"📝 QUERY: {response_data['query']}")
    print(f"⏱️  PROCESSING TIME: {response_data.get('processing_time', 0):.3f} seconds")
    print(f"📊 SEARCH RESULTS: {response_data.get('search_results_count', 0)} documents found")
    print(f"🔄 RERANKING: {'Enabled' if response_data.get('reranking_enabled', False) else 'Disabled'}")
    
    print("\n" + "="*80)
    print("💬 RESPONSE:")
    print("="*80)
    print(response_data['response'])
    
    if response_data.get('sources'):
        print("\n" + "="*80)
        print("📚 SOURCES:")
        print("="*80)
        for i, source in enumerate(response_data['sources'], 1):
            score = source.get('relevance_score', 'N/A')
            print(f"{i}. {source['policy_name']} - {source['page_number']} (Score: {score})")
    
    if response_data.get('error'):
        print(f"\n⚠️  ERROR: {response_data['error']}")
    
    print("="*80)

def query_system(query: str, enable_reranking: bool = True):
    """Query the RAG system and display results"""
    response = rag_pipeline.query(query, enable_reranking=enable_reranking)
    display_query_response(response)
    return response

# Example usage:
print("🔧 Interactive Query System Ready!")
print("💡 Use: query_system('Your question here') to ask questions")
print("📖 Example queries:")
print("   - 'What are the premium rates for different age groups?'")
print("   - 'Does the policy cover pre-existing conditions?'")
print("   - 'What is the claim process for this insurance?'")
print("   - 'What are the exclusions in this policy?'")

# Uncomment to try a sample query:
# query_system("What are the premium rates for different age groups?")

In [None]:
def compare_search_methods(query: str):
    """Compare semantic search vs reranked results"""
    print(f"🔍 COMPARING SEARCH METHODS FOR: {query}")
    print("="*80)
    
    # Get search results
    search_results = rag_pipeline.search_engine.search(query)
    
    if search_results.empty:
        print("❌ No search results found")
        return
    
    # Compare rankings
    comparison = rag_pipeline.reranker.compare_rankings(query, search_results)
    
    print(f"📊 SEMANTIC SEARCH TOP {config.top_k_rerank}:")
    semantic_top = search_results.nsmallest(config.top_k_rerank, 'Distances')
    for i, (_, row) in enumerate(semantic_top.iterrows(), 1):
        print(f"  {i}. Distance: {row['Distances']:.4f} | {row['Metadatas']['Policy_Name']} - {row['Metadatas']['Page_No']}")
    
    print(f"\n🎯 RERANKED TOP {config.top_k_rerank}:")
    reranked = rag_pipeline.reranker.rerank_results(query, search_results)
    reranked_top = reranked.head(config.top_k_rerank)
    for i, (_, row) in enumerate(reranked_top.iterrows(), 1):
        score = row.get('Rerank_Score', 'N/A')
        print(f"  {i}. Score: {score:.4f} | {row['Metadatas']['Policy_Name']} - {row['Metadatas']['Page_No']}")
    
    print(f"\n📈 COMPARISON METRICS:")
    print(f"  Overlap: {comparison.get('overlap_count', 0)}/{config.top_k_rerank} ({comparison.get('overlap_percentage', 0):.1f}%)")
    
    print("="*80)

def benchmark_system(queries: List[str], iterations: int = 1):
    """Benchmark system performance with multiple queries"""
    print(f"⚡ BENCHMARKING SYSTEM WITH {len(queries)} QUERIES ({iterations} iterations each)")
    print("="*80)
    
    results = []
    
    for i in range(iterations):
        print(f"\nIteration {i+1}/{iterations}")
        for query in queries:
            result = rag_pipeline.query(query, enable_reranking=True)
            results.append({
                'query': query,
                'processing_time': result.get('processing_time', 0),
                'success': 'error' not in result,
                'search_results': result.get('search_results_count', 0)
            })
    
    # Calculate statistics
    processing_times = [r['processing_time'] for r in results]
    success_rate = (sum(r['success'] for r in results) / len(results)) * 100
    
    print(f"\n📊 BENCHMARK RESULTS:")
    print(f"  Total Queries: {len(results)}")
    print(f"  Success Rate: {success_rate:.1f}%")
    print(f"  Avg Processing Time: {np.mean(processing_times):.3f}s")
    print(f"  Min Processing Time: {np.min(processing_times):.3f}s")
    print(f"  Max Processing Time: {np.max(processing_times):.3f}s")
    print(f"  Std Processing Time: {np.std(processing_times):.3f}s")
    
    # Display system metrics
    print(f"\n🔧 CURRENT SYSTEM METRICS:")
    metrics = rag_pipeline.get_system_status()
    if 'search_metrics' in metrics:
        search_metrics = metrics['search_metrics']
        print(f"  Cache Hit Rate: {search_metrics.get('cache_hit_rate', 'N/A')}")
        print(f"  Total Searches: {search_metrics.get('total_searches', 0)}")
    
    print("="*80)

# Example benchmark queries
SAMPLE_QUERIES = [
    "What are the premium rates?",
    "What conditions are covered?",
    "How do I file a claim?",
    "What are the exclusions?",
    "What is the waiting period?"
]

print("🧪 Benchmarking Tools Ready!")
print("💡 Usage:")
print("   - compare_search_methods('Your query') - Compare semantic vs reranked results")
print("   - benchmark_system(SAMPLE_QUERIES) - Run performance benchmark")
print("   - rag_pipeline.get_system_status() - View detailed system status")

In [None]:
def maintain_system():
    """Perform system maintenance tasks"""
    print("🔧 PERFORMING SYSTEM MAINTENANCE")
    print("="*60)
    
    # Clear caches
    print("1. Clearing search cache...")
    rag_pipeline.search_engine.clear_cache()
    
    # Optimize vector database
    print("2. Optimizing vector database...")
    try:
        rag_pipeline.vector_db.collection.get()  # Force connection check
        print("   ✅ Vector database connection verified")
    except Exception as e:
        print(f"   ⚠️ Vector database issue: {e}")
    
    # Clean up temporary files
    print("3. Cleaning temporary files...")
    import tempfile
    import shutil
    temp_dir = tempfile.gettempdir()
    print(f"   Temp directory: {temp_dir}")
    
    # Check disk space
    print("4. Checking disk space...")
    import shutil
    total, used, free = shutil.disk_usage(config.vector_store_path)
    print(f"   Total: {total // (2**30)} GB")
    print(f"   Used: {used // (2**30)} GB")  
    print(f"   Free: {free // (2**30)} GB")
    
    print("✅ Maintenance completed!")
    print("="*60)

def export_system_config():
    """Export current system configuration"""
    import json
    
    config_dict = {
        'openai_model': config.openai_model,
        'embedding_model': config.embedding_model,
        'reranker_model': config.reranker_model,
        'chunk_size': config.chunk_size,
        'chunk_overlap': config.chunk_overlap,
        'top_k_search': config.top_k_search,
        'top_k_rerank': config.top_k_rerank,
        'batch_size': config.batch_size,
        'vector_store_path': config.vector_store_path,
        'enable_caching': config.enable_caching,
        'cache_size': config.cache_size,
        'log_level': config.log_level,
        'max_retries': config.max_retries,
        'retry_delay': config.retry_delay
    }
    
    config_json = json.dumps(config_dict, indent=2)
    print("📝 CURRENT SYSTEM CONFIGURATION:")
    print("="*50)
    print(config_json)
    print("="*50)
    
    return config_dict

def optimize_for_speed():
    """Optimize system for speed (may reduce accuracy)"""
    print("⚡ OPTIMIZING SYSTEM FOR SPEED")
    print("="*50)
    
    # Reduce chunk overlap
    config.chunk_overlap = 50
    print(f"✓ Reduced chunk overlap to {config.chunk_overlap}")
    
    # Reduce search results
    config.top_k_search = 3
    print(f"✓ Reduced search results to {config.top_k_search}")
    
    # Disable reranking for speed
    config.top_k_rerank = 0
    print("✓ Disabled reranking for maximum speed")
    
    # Increase cache size
    config.cache_size = 200
    print(f"✓ Increased cache size to {config.cache_size}")
    
    print("⚡ Speed optimization completed!")
    print("💡 Use optimize_for_accuracy() to restore accuracy settings")
    print("="*50)

def optimize_for_accuracy():
    """Optimize system for accuracy (may reduce speed)"""
    print("🎯 OPTIMIZING SYSTEM FOR ACCURACY")
    print("="*50)
    
    # Increase chunk overlap
    config.chunk_overlap = 200
    print(f"✓ Increased chunk overlap to {config.chunk_overlap}")
    
    # Increase search results
    config.top_k_search = 10
    print(f"✓ Increased search results to {config.top_k_search}")
    
    # Enable reranking
    config.top_k_rerank = 5
    print(f"✓ Enabled reranking with top {config.top_k_rerank} results")
    
    print("🎯 Accuracy optimization completed!")
    print("💡 Use optimize_for_speed() if you need faster responses")
    print("="*50)

print("🛠️ System Utilities Ready!")
print("💡 Available commands:")
print("   - maintain_system() - Perform maintenance tasks")
print("   - export_system_config() - View current configuration")
print("   - optimize_for_speed() - Optimize for faster responses")
print("   - optimize_for_accuracy() - Optimize for better accuracy")

## 🚀 Getting Started Guide

### 1. First Time Setup
```python
# Step 1: Update the PDF directory path
DEMO_PDF_DIRECTORY = r"C:\path\to\your\insurance\pdfs"  # Update this path!

# Step 2: Initialize the system (run this once)
rag_pipeline = EfficientRAGPipeline()
rag_pipeline.setup_system(DEMO_PDF_DIRECTORY)
```

### 2. Basic Usage
```python
# Query the system
response = query_system("What are the premium rates for health insurance?")

# Compare search methods
compare_search_methods("What conditions are covered?")

# Check system status
rag_pipeline.get_system_status()
```

### 3. Performance Optimization
```python
# For faster responses (less accuracy)
optimize_for_speed()

# For better accuracy (slower responses)
optimize_for_accuracy()

# System maintenance
maintain_system()
```

### 4. Troubleshooting
- **No results found**: Check if PDFs are processed correctly
- **Slow responses**: Try `optimize_for_speed()`
- **Poor accuracy**: Try `optimize_for_accuracy()`
- **Memory issues**: Reduce `batch_size` or `cache_size` in config
- **API errors**: Check OpenAI API key and internet connection

### 5. Key Features
- ✅ **Modular Architecture**: Easy to extend and maintain
- ✅ **Intelligent Caching**: Faster repeated queries
- ✅ **Batch Processing**: Efficient document processing
- ✅ **Error Handling**: Robust error recovery
- ✅ **Performance Metrics**: Built-in monitoring
- ✅ **Cross-encoder Reranking**: Improved relevance
- ✅ **Configurable**: Easy to tune for your needs

In [None]:
# 🔍 SYSTEM READINESS CHECK
print("🔍 CHECKING SYSTEM READINESS...")
print("="*60)

# Check if system is initialized
try:
    if 'rag_pipeline' in locals():
        print("✅ RAG Pipeline: Initialized")
        
        # Check components
        if hasattr(rag_pipeline, 'pdf_processor'):
            print("✅ PDF Processor: Ready")
        if hasattr(rag_pipeline, 'vector_db'):
            print("✅ Vector Database: Ready")
        if hasattr(rag_pipeline, 'search_engine'):
            print("✅ Search Engine: Ready")
        if hasattr(rag_pipeline, 'reranker'):
            print("✅ Reranker: Ready")
        if hasattr(rag_pipeline, 'generator'):
            print("✅ Generator: Ready")
            
        # Check if documents are loaded
        try:
            status = rag_pipeline.get_system_status()
            doc_count = status.get('vector_db', {}).get('document_count', 0)
            if doc_count > 0:
                print(f"✅ Documents Loaded: {doc_count} chunks")
            else:
                print("⚠️ No documents loaded yet")
        except:
            print("⚠️ Document status unknown")
            
    else:
        print("❌ RAG Pipeline: Not initialized")
        print("💡 Run the initialization cells above first!")
        
except Exception as e:
    print(f"❌ Error checking system: {e}")

print("\n🎯 NEXT STEPS:")
if 'rag_pipeline' not in locals():
    print("1. ⬆️ Run all cells above to initialize the system")
    print("2. 📝 Update DEMO_PDF_DIRECTORY with your PDF path")
    print("3. 🏃 Run rag_pipeline.setup_system(DEMO_PDF_DIRECTORY)")
elif 'doc_count' in locals() and doc_count == 0:
    print("1. 📝 Update DEMO_PDF_DIRECTORY with your PDF path")
    print("2. 🏃 Run rag_pipeline.setup_system(DEMO_PDF_DIRECTORY)")
else:
    print("1. 🚀 Start querying: query_system('Your question here')")
    print("2. 📊 Compare methods: compare_search_methods('Your query')")
    print("3. ⚙️ Optimize: optimize_for_speed() or optimize_for_accuracy()")

print("="*60)