In [None]:

# Install required packages
!pip install -q langchain langchain-community pymupdf faiss-cpu rank_bm25 transformers \
    sentence-transformers accelerate tqdm nltk rouge-score summarizers scikit-learn \
    pandas seaborn matplotlib
# Write your code to app.py

import faiss
import json
import torch
# Memory optimization imports
import gc
import torch
from tqdm import tqdm

#%%
### 1. Imports & Configuration
import os
import time
import numpy as np
import pandas as pd
import faiss
import json
import nltk
from typing import List, Dict, Tuple, Union, Optional, Any
from rank_bm25 import BM25Okapi
from sentence_transformers import SentenceTransformer, CrossEncoder
from langchain.schema import Document
from langchain_community.document_loaders import PyMuPDFLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter, TokenTextSplitter
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoModelForSeq2SeqLM, pipeline as hf_pipeline
from transformers import T5Tokenizer, T5ForConditionalGeneration
from rouge_score import rouge_scorer
import textwrap
import matplotlib.pyplot as plt
import seaborn as sns
from collections import Counter
from nltk.tokenize import sent_tokenize
import random
import re

nltk.download('punkt')
nltk.download('punkt_tab')
nltk.download('stopwords')  # Often useful for text processing
nltk.download('wordnet')   # For lemmatization if needed
# Check GPU availability and set device

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Set random seed for reproducibility
np.random.seed(42)
torch.manual_seed(42)
if device.type == "cuda":
    torch.cuda.manual_seed_all(42)

# Define global constants
MAX_SUMMARY_LENGTH = 512
CACHE_DIR = "./model_cache"
os.makedirs(CACHE_DIR, exist_ok=True)

#%%
### 2. Data Preprocessing & Preparation
class DataProcessor:
    """Handle data loading, preprocessing, and chunking with different strategies"""

    @staticmethod
    def load_and_chunk(file_paths: Union[str, List[str]],
                       chunking_strategy: str = "recursive",
                       chunk_size: int = 512,
                       chunk_overlap: int = 64,
                       force_reload: bool = False) -> List[Document]:
        """
        Process one or more PDFs into chunks with metadata using specified strategy

        Args:
            file_paths: Path to the PDF file or list of PDF file paths
            chunking_strategy: Strategy for chunking ('recursive', 'token', 'sentence', 'paragraph')
            chunk_size: Size of chunks
            chunk_overlap: Overlap between chunks
            force_reload: Whether to force reload from PDF instead of using cache

        Returns:
            List of Document objects
        """
        # Handle single file path as a list for consistency
        if isinstance(file_paths, str):
            file_paths = [file_paths]

        # Generate a unique cache key based on all file paths
        cache_key = "_".join([os.path.basename(fp) for fp in file_paths])
        cache_path = f"./cache_{cache_key}_{chunking_strategy}_{chunk_size}_{chunk_overlap}.json"

        # Try to load from cache if available
        if os.path.exists(cache_path) and not force_reload:
            print(f"Loading chunks from cache: {cache_path}")
            try:
                with open(cache_path, 'r', encoding='utf-8') as f:
                    cached_data = json.load(f)

                chunks = []
                for item in cached_data:
                    chunks.append(Document(
                        page_content=item['page_content'],
                        metadata=item['metadata']
                    ))
                print(f"Loaded {len(chunks)} chunks from cache")
                return chunks
            except (json.JSONDecodeError, KeyError) as e:
                print(f"Error loading cache file {cache_path}: {str(e)}. Regenerating chunks...")

        # Process each PDF and combine chunks
        all_chunks = []
        for file_path in file_paths:
            print(f"\nProcessing file {file_path} with {chunking_strategy} strategy, chunk_size={chunk_size}, overlap={chunk_overlap}")
            loader = PyMuPDFLoader(file_path)
            pages = loader.load()
            print(f"Loaded {len(pages)} pages from {file_path}")

            # Clean text (remove excessive whitespace, etc.)
            for page in pages:
                page.page_content = DataProcessor._clean_text(page.page_content)
                # Add source file metadata
                page.metadata['source_file'] = os.path.basename(file_path)

            # Apply chunking strategy
            if chunking_strategy == "recursive":
                splitter = RecursiveCharacterTextSplitter(
                    chunk_size=chunk_size,
                    chunk_overlap=chunk_overlap,
                    length_function=len,
                    add_start_index=True
                )
                chunks = splitter.split_documents(pages)

            elif chunking_strategy == "token":
                splitter = TokenTextSplitter(
                    chunk_size=chunk_size,
                    chunk_overlap=chunk_overlap
                )
                chunks = splitter.split_documents(pages)

            elif chunking_strategy == "sentence":
                chunks = DataProcessor._sentence_based_chunking(pages, chunk_size, chunk_overlap)

            elif chunking_strategy == "paragraph":
                chunks = DataProcessor._paragraph_based_chunking(pages, chunk_size, chunk_overlap)

            else:
                raise ValueError(f"Unknown chunking strategy: {chunking_strategy}")

            print(f"Created {len(chunks)} chunks from {file_path}")
            all_chunks.extend(chunks)

        # Cache the combined results
        cache_data = []
        for chunk in all_chunks:
            cache_data.append({
                'page_content': chunk.page_content,
                'metadata': chunk.metadata
            })

        try:
            with open(cache_path, 'w', encoding='utf-8') as f:
                json.dump(cache_data, f, ensure_ascii=False, indent=2)
        except Exception as e:
            print(f"Warning: Could not save cache file {cache_path}: {str(e)}")

        return all_chunks
    @staticmethod
    def _clean_text(text: str) -> str:
        """Clean text by removing excessive whitespace and normalizing"""
        # Replace multiple newlines with single newline
        text = re.sub(r'\n+', '\n', text)
        # Replace multiple spaces with single space
        text = re.sub(r' +', ' ', text)
        # Strip whitespace
        text = text.strip()
        return text

    @staticmethod
    def _sentence_based_chunking(pages: List[Document], chunk_size: int, chunk_overlap: int) -> List[Document]:
        """Create chunks based on sentence boundaries"""
        chunks = []
        for page in pages:
            sentences = sent_tokenize(page.page_content)
            current_chunk = []
            current_size = 0

            for sentence in sentences:
                sentence_size = len(sentence)

                if current_size + sentence_size > chunk_size and current_chunk:
                    # Create chunk from current sentences
                    chunk_text = " ".join(current_chunk)
                    chunks.append(Document(
                        page_content=chunk_text,
                        metadata={**page.metadata, 'chunk_strategy': 'sentence'}
                    ))

                    # Handle overlap by keeping some sentences
                    overlap_size = 0
                    overlap_chunk = []
                    for s in reversed(current_chunk):
                        if overlap_size + len(s) <= chunk_overlap:
                            overlap_chunk.insert(0, s)
                            overlap_size += len(s)
                        else:
                            break

                    current_chunk = overlap_chunk
                    current_size = overlap_size

                current_chunk.append(sentence)
                current_size += sentence_size

            # Add remaining text as final chunk
            if current_chunk:
                chunk_text = " ".join(current_chunk)
                chunks.append(Document(
                    page_content=chunk_text,
                    metadata={**page.metadata, 'chunk_strategy': 'sentence'}
                ))

        return chunks

    @staticmethod
    def _paragraph_based_chunking(pages: List[Document], chunk_size: int, chunk_overlap: int) -> List[Document]:
        """Create chunks based on paragraph boundaries"""
        chunks = []
        for page in pages:
            # Split by double newline to identify paragraphs
            paragraphs = re.split(r'\n\s*\n', page.page_content)
            paragraphs = [p for p in paragraphs if p.strip()]  # Remove empty paragraphs

            current_chunk = []
            current_size = 0

            for paragraph in paragraphs:
                paragraph_size = len(paragraph)

                if current_size + paragraph_size > chunk_size and current_chunk:
                    # Create chunk from current paragraphs
                    chunk_text = "\n\n".join(current_chunk)
                    chunks.append(Document(
                        page_content=chunk_text,
                        metadata={**page.metadata, 'chunk_strategy': 'paragraph'}
                    ))

                    # Handle overlap by keeping some paragraphs
                    overlap_size = 0
                    overlap_chunk = []
                    for p in reversed(current_chunk):
                        if overlap_size + len(p) <= chunk_overlap:
                            overlap_chunk.insert(0, p)
                            overlap_size += len(p)
                        else:
                            break

                    current_chunk = overlap_chunk
                    current_size = overlap_size

                current_chunk.append(paragraph)
                current_size += paragraph_size

            # Add remaining text as final chunk
            if current_chunk:
                chunk_text = "\n\n".join(current_chunk)
                chunks.append(Document(
                    page_content=chunk_text,
                    metadata={**page.metadata, 'chunk_strategy': 'paragraph'}
                ))

        return chunks

    @staticmethod
    def preprocess_for_query(text: str) -> str:
        """Preprocess query text"""
        return text.strip()

#%%
### 3. Advanced Retrieval System
class AdvancedRetriever:
    """
    Enhanced retrieval system with multiple retrieval methods:
    - Dense retrieval (embedding-based)
    - Sparse retrieval (BM25)
    - Hybrid retrieval (combining dense and sparse)
    - Re-ranking with cross-encoders
    - Contextual embedding
    """

    def __init__(self, chunks: List[Document], embedding_model: str = "BAAI/bge-small-en"):
        self.chunks = chunks
        print(f"Loading dense embedding model: {embedding_model}...")
        self.dense_model = SentenceTransformer(embedding_model, cache_folder=CACHE_DIR, device=str(device))

        # Load cross-encoder for re-ranking
        print("Loading cross-encoder for re-ranking...")
        self.cross_encoder = CrossEncoder(
            'cross-encoder/ms-marco-MiniLM-L-6-v2',
            cache_dir=CACHE_DIR,
            device=str(device)
        )

        self._create_indices()
        print(f"Initialized retriever with {len(self.chunks)} documents")

    def _create_indices(self):
        """Create all necessary indices for retrieval"""
        # Extract texts and prepare for indexing
        texts = [chunk.page_content for chunk in self.chunks]

        # Dense Index
        print("Creating dense embeddings...")
        self.embeddings = self.dense_model.encode(texts, show_progress_bar=True, device=str(device))

        # Use GPU-accelerated FAISS if available
        res = faiss.StandardGpuResources() if faiss.get_num_gpus() > 0 else None

        # Create indices
        print("Creating FAISS index...")
        dimension = self.embeddings.shape[1]

        # Create both L2 (Euclidean) and Inner Product (cosine) indices
        self.index_ip = faiss.IndexFlatIP(dimension)  # Inner product for cosine similarity
        self.index_l2 = faiss.IndexFlatL2(dimension)  # L2 distance

        # Move to GPU if available
        if res is not None:
            print("Using GPU acceleration for FAISS")
            self.index_ip = faiss.index_cpu_to_gpu(res, 0, self.index_ip)
            self.index_l2 = faiss.index_cpu_to_gpu(res, 0, self.index_l2)

        # Add embeddings to indices
        self.index_ip.add(self.embeddings.astype('float32'))
        self.index_l2.add(self.embeddings.astype('float32'))

        # Create HNSW indices for faster retrieval (approximate but faster)
        print("Creating HNSW index for faster retrieval...")
        self.index_hnsw = faiss.IndexHNSWFlat(dimension, 32)  # 32 neighbors per node
        self.index_hnsw.add(self.embeddings.astype('float32'))

        # Sparse Index (CPU only)
        print("Creating sparse BM25 index...")
        tokenized_corpus = [doc.split() for doc in texts]
        self.bm25_index = BM25Okapi(tokenized_corpus)

        # TF-IDF index for additional sparse retrieval
        print("Creating TF-IDF index...")
        from sklearn.feature_extraction.text import TfidfVectorizer
        self.tfidf_vectorizer = TfidfVectorizer(lowercase=True, stop_words='english')
        self.tfidf_matrix = self.tfidf_vectorizer.fit_transform(texts)

    def search(self,
               query: str,
               k: int = 5,
               method: str = "hybrid",
               rerank: bool = False,
               distance_metric: str = "cosine") -> Tuple[List[Document], float]:
        """
        Search with timing and detailed logging

        Args:
            query: Query string
            k: Number of results to return
            method: Retrieval method (dense, sparse, hybrid, hnsw, tfidf)
            rerank: Whether to apply cross-encoder reranking
            distance_metric: For dense retrieval (cosine, l2)

        Returns:
            Retrieved documents and elapsed time
        """
        start_time = time.time()
        query = DataProcessor.preprocess_for_query(query)

        if method == "dense":
            results = self._dense_search(query, k, distance_metric)
        elif method == "sparse":
            results = self._sparse_search(query, k)
        elif method == "hnsw":
            results = self._hnsw_search(query, k)
        elif method == "tfidf":
            results = self._tfidf_search(query, k)
        else:  # hybrid
            results = self._hybrid_search(query, k)

        # Apply reranking if requested
        if rerank and len(results) > 1:
            results = self._rerank_results(query, results, k)

        elapsed = time.time() - start_time

        # Print retrieval results
        print(f"\nRetrieved {len(results)} documents in {elapsed:.3f}s using {method} search" +
              f" with{'' if rerank else 'out'} reranking:")

        for i, doc in enumerate(results[:3], 1):  # Show top 3 for brevity
            page_number = doc.metadata.get('page', 'Unknown')
            print(f"[Doc {i}] Page {page_number}\n{textwrap.shorten(doc.page_content, width=200)}")
            print("-" * 80)

        return results, elapsed

    def _dense_search(self, query: str, k: int, distance_metric: str = "cosine") -> List[Document]:
        """Dense embedding-based search"""
        embedding = self.dense_model.encode([query], device=str(device))[0]
        embedding = embedding.astype('float32').reshape(1, -1)

        if distance_metric == "cosine":
            # Inner product search (for normalized vectors, equivalent to cosine similarity)
            _, indices = self.index_ip.search(embedding, k)
        else:  # l2
            # Euclidean distance search
            _, indices = self.index_l2.search(embedding, k)

        return [self.chunks[i] for i in indices[0] if i != -1]

    def _sparse_search(self, query: str, k: int) -> List[Document]:
        """BM25-based sparse search"""
        tokenized = query.split()
        scores = self.bm25_index.get_scores(tokenized)
        indices = np.argsort(scores)[::-1][:k]
        return [self.chunks[i] for i in indices]

    def _hnsw_search(self, query: str, k: int) -> List[Document]:
        """HNSW approximate nearest neighbor search (faster but may be less accurate)"""
        embedding = self.dense_model.encode([query], device=str(device))[0]
        _, indices = self.index_hnsw.search(np.array([embedding]).astype('float32'), k)
        return [self.chunks[i] for i in indices[0] if i != -1]

    def _tfidf_search(self, query: str, k: int) -> List[Document]:
        """TF-IDF based search"""
        query_vec = self.tfidf_vectorizer.transform([query])
        # Calculate cosine similarity between query and documents
        from sklearn.metrics.pairwise import cosine_similarity
        scores = cosine_similarity(query_vec, self.tfidf_matrix)[0]
        indices = np.argsort(scores)[::-1][:k]
        return [self.chunks[i] for i in indices]

    def _hybrid_search(self, query: str, k: int) -> List[Document]:
        """Hybrid search combining dense and sparse retrievals with RRF fusion"""
        # Get more results than needed for fusion
        dense_results = self._dense_search(query, k*2)
        sparse_results = self._sparse_search(query, k*2)

        # Reciprocal Rank Fusion
        rrf_scores = {}

        # Process dense results
        for rank, doc in enumerate(dense_results):
            idx = self.chunks.index(doc)
            rrf_scores[idx] = rrf_scores.get(idx, 0) + 1/(60 + rank)

        # Process sparse results
        for rank, doc in enumerate(sparse_results):
            idx = self.chunks.index(doc)
            rrf_scores[idx] = rrf_scores.get(idx, 0) + 1/(60 + rank)

        # Sort by RRF score and get top k
        sorted_indices = sorted(rrf_scores.items(), key=lambda x: x[1], reverse=True)[:k]
        return [self.chunks[idx] for idx, _ in sorted_indices]

    def _rerank_results(self, query: str, results: List[Document], k: int) -> List[Document]:
        """Rerank results using cross-encoder"""
        if not results:
            return results

        # Create query-document pairs for reranking
        doc_texts = [doc.page_content for doc in results]
        pairs = [[query, doc] for doc in doc_texts]

        # Get cross-encoder scores
        scores = self.cross_encoder.predict(pairs)

        # Sort by score
        doc_score_pairs = list(zip(results, scores))
        reranked_results = [doc for doc, _ in sorted(doc_score_pairs, key=lambda x: x[1], reverse=True)]

        return reranked_results[:k]  # Return top k after reranking

#%%
### 4. Document Context Enhancement
class ContextEnhancer:
    """
    Enhance retrieved contexts before sending to LLM:
    - Summarization
    - Context reordering
    - Highlighting key information
    - Removing redundancy
    """

    def __init__(self):
        # Load summarization model
        print("Loading summarization model...")
        self.summarizer = hf_pipeline(
            "summarization",
            model="facebook/bart-large-cnn",
            device_map="auto"
        )

    def enhance_context(self,
                       documents: List[Document],
                       query: str,
                       methods: List[str] = None) -> Tuple[str, Dict]:
        """
        Apply enhancement methods to retrieved documents

        Args:
            documents: List of retrieved documents
            query: Original query
            methods: List of enhancement methods to apply

        Returns:
            Enhanced context string and metadata about applied enhancements
        """
        if not methods:
            methods = []

        print(f"\nEnhancing context using methods: {', '.join(methods) if methods else 'None'}")

        # Start with original context
        context_texts = [doc.page_content for doc in documents]
        context = "\n\n".join(context_texts)

        metadata = {
            "original_length": len(context),
            "methods_applied": methods
        }

        # Apply enhancements in sequence
        if "reorder" in methods:
            context, reorder_meta = self._reorder_by_relevance(context_texts, query)
            metadata["reordering"] = reorder_meta

        if "summarize" in methods:
            context, summary_meta = self._summarize_context(context, query)
            metadata["summarization"] = summary_meta

        if "highlight" in methods:
            context, highlight_meta = self._highlight_key_info(context, query)
            metadata["highlighting"] = highlight_meta

        if "deduplicate" in methods:
            context, dedup_meta = self._remove_redundancy(context)
            metadata["deduplication"] = dedup_meta

        metadata["final_length"] = len(context)

        print(f"Context enhancement complete: {metadata['original_length']} chars → " +
              f"{metadata['final_length']} chars ({metadata['final_length']/metadata['original_length']:.1%})")

        return context, metadata

    def _summarize_context(self, context: str, query: str) -> Tuple[str, Dict]:
        """Generate concise summaries to reduce context length"""
        start_time = time.time()

        # For long contexts, chunk and summarize separately
        if len(context) > 1024:
            chunks = textwrap.wrap(context, 1024, break_long_words=False, break_on_hyphens=False)
            summaries = []

            for chunk in chunks:
                result = self.summarizer(
                    chunk,
                    max_length=min(len(chunk)//4, 256),
                    min_length=min(len(chunk)//8, 100),
                    do_sample=False
                )
                summaries.append(result[0]['summary_text'])

            summary = "\n\n".join(summaries)
        else:
            result = self.summarizer(
                context,
                max_length=min(len(context)//4, 256),
                min_length=min(len(context)//8, 100),
                do_sample=False
            )
            summary = result[0]['summary_text']

        elapsed = time.time() - start_time

        metadata = {
            "original_length": len(context),
            "summary_length": len(summary),
            "reduction_pct": 1 - (len(summary) / len(context)),
            "time_taken": elapsed
        }

        # Add prefix to make it clear this is a summary
        prefixed_summary = f"SUMMARY OF RETRIEVED INFORMATION:\n{summary}\n\nFULL CONTEXT:\n{context}"

        return prefixed_summary, metadata

    def _reorder_by_relevance(self, contexts: List[str], query: str) -> Tuple[str, Dict]:
        """Reorder context chunks by relevance to query"""
        start_time = time.time()

        # Use NLTK to break into sentences for more fine-grained reordering
        all_sentences = []
        sentence_to_chunk = {}

        for chunk_idx, chunk in enumerate(contexts):
            sentences = sent_tokenize(chunk)
            for sent in sentences:
                if len(sent.strip()) > 0:
                    all_sentences.append(sent)
                    sentence_to_chunk[sent] = chunk_idx

        # Create a relevance scorer using query-based heuristics
        def relevance_score(sentence, query):
            # Very simple relevance measure - count query terms in sentence
            score = 0
            query_terms = set(query.lower().split())
            for term in query_terms:
                if term in sentence.lower():
                    score += 1
            return score

        # Score and sort sentences
        sentence_scores = [(sent, relevance_score(sent, query)) for sent in all_sentences]
        sorted_sentences = sorted(sentence_scores, key=lambda x: x[1], reverse=True)

        # Reconstruct chunks in order of relevance
        chunk_relevance = {}
        for sent, score in sentence_scores:
            chunk_idx = sentence_to_chunk[sent]
            chunk_relevance[chunk_idx] = chunk_relevance.get(chunk_idx, 0) + score

        sorted_chunks = sorted([(idx, contexts[idx], rel) for idx, rel in chunk_relevance.items()],
                                key=lambda x: x[2], reverse=True)

        reordered_context = "\n\n".join([chunk for _, chunk, _ in sorted_chunks])

        elapsed = time.time() - start_time

        metadata = {
            "chunk_scores": {idx: score for idx, _, score in sorted_chunks},
            "time_taken": elapsed
        }

        return reordered_context, metadata

    def _highlight_key_info(self, context: str, query: str) -> Tuple[str, Dict]:
        """Highlight key information related to the query"""
        start_time = time.time()

        # Simple approach: highlight sentences containing query terms
        query_terms = set(query.lower().split())
        sentences = sent_tokenize(context)

        highlighted_sentences = []
        for sentence in sentences:
            should_highlight = any(term in sentence.lower() for term in query_terms)
            if should_highlight:
                highlighted_sentences.append(f"*IMPORTANT:* {sentence}")
            else:
                highlighted_sentences.append(sentence)

        highlighted_context = " ".join(highlighted_sentences)

        elapsed = time.time() - start_time

        metadata = {
            "highlights_added": sum(1 for s in highlighted_sentences if s.startswith("*IMPORTANT:*")),
            "time_taken": elapsed
        }

        return highlighted_context, metadata

    def _remove_redundancy(self, context: str) -> Tuple[str, Dict]:
        """Remove redundant information from context"""
        start_time = time.time()

        # Split into sentences
        sentences = sent_tokenize(context)

        # Simple deduplication: remove exact or near-duplicate sentences
        unique_sentences = []
        seen = set()

        for sentence in sentences:
            # Create a simplified fingerprint of the sentence
            fingerprint = re.sub(r'[^\w]', '', sentence.lower())

            # Keep only if we haven't seen a very similar fingerprint
            if fingerprint not in seen:
                unique_sentences.append(sentence)
                seen.add(fingerprint)

        deduplicated_context = " ".join(unique_sentences)

        elapsed = time.time() - start_time

        metadata = {
            "original_sentences": len(sentences),
            "deduplicated_sentences": len(unique_sentences),
            "reduction_pct": 1 - (len(unique_sentences) / len(sentences)),
            "time_taken": elapsed
        }

        return deduplicated_context, metadata

#%%
### 5. Generation Models
class AdvancedAnswerGenerator:
    """
    Enhanced answer generation with multiple model options and prompt strategies
    - Multiple model choices
    - Different prompt techniques
    - Chain-of-thought prompting
    - Few-shot prompting
    """

    def __init__(self):
        self._loaded_models = {}  # Cache for dynamically loaded models

    def _get_model(self, model_type: str) -> Any:
        """Load model on demand and cache it"""
        if model_type not in self._loaded_models:
            print(f"Loading {model_type} model...")

            if model_type == "qwen":
                self._loaded_models[model_type] = hf_pipeline(
                    "text-generation",
                    model="Qwen/Qwen1.5-0.5B-Chat",
                    device_map="auto",
                    torch_dtype="auto",
                    model_kwargs={"cache_dir": CACHE_DIR}
                )
            elif model_type == "flan-t5":
                self._loaded_models[model_type] = hf_pipeline(
                    "text2text-generation",
                    model="google/flan-t5-base",
                    device_map="auto",
                    model_kwargs={"cache_dir": CACHE_DIR}
                )
            elif model_type == "tiny-llama":
                self._loaded_models[model_type] = hf_pipeline(
                    "text-generation",
                    model="TinyLlama/TinyLlama-1.1B-Chat-v1.0",
                    device_map="auto",
                    torch_dtype="auto",
                    model_kwargs={"cache_dir": CACHE_DIR}
                )
            elif model_type == "phi-2":
                self._loaded_models[model_type] = hf_pipeline(
                    "text-generation",
                    model="microsoft/phi-2",
                    device_map="auto",
                    torch_dtype="auto",
                    model_kwargs={"cache_dir": CACHE_DIR}
                )
            else:
                raise ValueError(f"Unknown model type: {model_type}")

        return self._loaded_models[model_type]

    def generate(self,
                query: str,
                context: str,
                model_type: str = "qwen",
                prompt_strategy: str = "standard",
                max_new_tokens: int = 500,
                temperature: float = 0.7) -> Tuple[str, float]:
        """
        Generate answer with timing and different prompt strategies

        Args:
            query: User query
            context: Retrieved context
            model_type: Model to use for generation
            prompt_strategy: Prompt technique (standard, cot, few_shot)
            max_new_tokens: Maximum number of tokens to generate
            temperature: Sampling temperature

        Returns:
            Generated answer and elapsed time
        """
        model = self._get_model(model_type)
        prompt = self._format_prompt(query, context, model_type, prompt_strategy)

        start_time = time.time()
        try:
            if model_type == "flan-t5":
                result = model(prompt, max_length=max_new_tokens)[0]['generated_text']
            else:
                result = model(
                    prompt,
                    max_new_tokens=max_new_tokens,
                    do_sample=True,
                    temperature=temperature
                )[0]['generated_text']

                # Post-process to extract the actual answer
                if model_type in ["qwen", "tiny-llama"]:
                    # Extract content between assistant tags
                    match = re.search(r'<\|im_start\|>assistant\n(.*?)(?:<\|im_end\|>|$)', result, re.DOTALL)
                    if match:
                        result = match.group(1).strip()
                elif model_type == "phi-2":
                    # Extract content after the prompt
                    result = result.split("Answer:")[-1].strip()
        except Exception as e:
            print(f"Error generating answer: {str(e)}")
            result = f"Error generating answer with {model_type}: {str(e)}"

        elapsed = time.time() - start_time
        print(f"Generated answer in {elapsed:.3f}s using {model_type} with {prompt_strategy} prompting")

        # Log a preview of the answer
        print(f"Answer preview: {result[:100]}..." if len(result) > 100 else f"Answer: {result}")

        return result, elapsed

    def _format_prompt(self, query: str, context: str, model_type: str, strategy: str) -> str:
        """
        Format prompt according to model requirements and prompt strategy

        Strategies:
        - standard: Basic prompt with context and question
        - cot: Chain-of-thought prompting
        - few_shot: Few-shot examples
        """
        # Limit context length based on model type
        max_ctx_len = 4000  # Default
        if model_type == "flan-t5":
            max_ctx_len = 2000
        elif model_type in ["phi-2", "qwen", "tiny-llama"]:
            max_ctx_len = 3000

        if len(context) > max_ctx_len:
            context = context[:max_ctx_len] + "...(context truncated due to length)"

        # Model-specific formatting
        if model_type in ["qwen", "tiny-llama"]:
            if strategy == "cot":
                return f"""<|im_start|>system
You are an AI assistant that answers questions based on the provided context. Think step by step before providing your final answer.
Context: {context}<|im_end|>
<|im_start|>user
{query}<|im_end|>
<|im_start|>assistant
Let me think through this step by step:
1. First, I'll understand what the question is asking.
2. Then, I'll search the provided context for relevant information.
3. Finally, I'll formulate a comprehensive answer based on the context.

"""
            elif strategy == "few_shot":
                return f"""<|im_start|>system
You are an AI assistant that answers questions based on the provided context.
Context: {context}<|im_end|>
<|im_start|>user
What services are available to students?<|im_end|>
<|im_start|>assistant
Based on the provided context, the services available to students include academic advising, counseling services, career services, library resources, and health services.<|im_end|>
<|im_start|>user
{query}<|im_end|>
<|im_start|>assistant
"""
            else:  # standard
                return f"""<|im_start|>system
You are an AI assistant that answers questions based on the provided context.
Context: {context}<|im_end|>
<|im_start|>user
{query}<|im_end|>
<|im_start|>assistant
"""
        elif model_type == "phi-2":
            if strategy == "cot":
                return f"""Context: {context}

Question: {query}

Let me think through this step by step to find the answer in the context.

Step 1: Understand what the question is asking.
Step 2: Look through the context for relevant information.
Step 3: Formulate my answer based on the context.

Answer:"""
            elif strategy == "few_shot":
                return f"""Context: {context}

Question: What services are available to students?
Answer: Based on the provided context, the services available to students include academic advising, counseling services, career services, library resources, and health services.

Question: {query}
Answer:"""
            else:  # standard
                return f"""Context: {context}

Question: {query}
Answer:"""
        else:  # flan-t5
            if strategy == "cot":
                return f"context: {context}\nquestion: {query}\nThink step by step to answer the question:"
            elif strategy == "few_shot":
                return f"context: {context}\nquestion: What services are available to students?\nanswer: Based on the provided context, the services available to students include academic advising, counseling services, career services, library resources, and health services.\nquestion: {query}\nanswer:"
            else:  # standard
                return f"context: {context}\nquestion: {query}\nanswer:"

#%%
### 6. Enhanced Evaluation System
class AdvancedRagEvaluator:
    """
    Enhanced evaluation system with multiple metrics:
    - Relevance
    - Faithfulness
    - Conciseness
    - Answer quality
    - Consistency
    - Latency
    """

    def __init__(self):
        print("Initializing advanced evaluator...")
        self._evaluator = None
        self.rouge_scorer = rouge_scorer.RougeScorer(['rouge1', 'rouge2', 'rougeL'], use_stemmer=True)

    @property
    def evaluator(self):
        """Lazy-load evaluation model"""
        if self._evaluator is None:
            print("Loading evaluation model...")
            self._evaluator = hf_pipeline(
                "text-generation",
                model="Qwen/Qwen1.5-0.5B-Chat",
                device_map="auto",
                torch_dtype="auto",
                model_kwargs={"cache_dir": CACHE_DIR}
            )
        return self._evaluator

    def evaluate(self, answer: str, context: str, query: str, full_metrics: bool = True) -> Dict[str, float]:
        """
        Enhanced evaluation with multiple scoring metrics

        Args:
            answer: Generated answer
            context: Context used for generation
            query: Original query
            full_metrics: Whether to calculate all metrics (slower) or basic ones

        Returns:
            Dictionary of evaluation metrics
        """
        print("\nEvaluating answer quality...")
        metrics = {}

        # Always calculate basic metrics
        relevance_score = self._get_llm_score(
            f"Rate answer relevance to question (1-5, 5=most relevant):\nQuestion: {query}\nAnswer: {answer}\nScore:"
        )
        faithfulness_score = self._get_llm_score(
            f"Rate answer faithfulness to context (1-5, 5=most faithful):\nContext: {context[:2000]}...\nAnswer: {answer}\nScore:"
        )

        metrics["relevance"] = relevance_score
        metrics["faithfulness"] = faithfulness_score
        metrics["composite_score"] = (relevance_score + faithfulness_score) / 2

        # Calculate additional metrics if requested
        if full_metrics:
            # Conciseness
            conciseness_score = self._get_llm_score(
                f"Rate answer conciseness (1-5, 5=most concise):\nAnswer: {answer}\nScore:"
            )
            metrics["conciseness"] = conciseness_score

            # Completeness
            completeness_score = self._get_llm_score(
                f"Rate how completely the answer addresses the question (1-5, 5=most complete):\nQuestion: {query}\nAnswer: {answer}\nScore:"
            )
            metrics["completeness"] = completeness_score

            # ROUGE scores (content overlap with context)
            if context:
                rouge_scores = self.rouge_scorer.score(answer, context[:2000])
                metrics["rouge1"] = rouge_scores['rouge1'].fmeasure
                metrics["rouge2"] = rouge_scores['rouge2'].fmeasure
                metrics["rougeL"] = rouge_scores['rougeL'].fmeasure

            # Calculate overall quality
            metrics["overall_quality"] = (
                metrics["relevance"] +
                metrics["faithfulness"] +
                metrics["conciseness"] +
                metrics["completeness"]
            ) / 4

        # Log metrics
        print(f"Evaluation Scores:")
        for metric, value in metrics.items():
            if isinstance(value, float):
                print(f"  - {metric}: {value:.2f}")
            else:
                print(f"  - {metric}: {value}")

        return metrics



    def _get_llm_score(self, prompt: str) -> float:
        """Helper method to extract score from LLM evaluation"""
        try:
            response = self.evaluator(
                f"<|im_start|>user\n{prompt}<|im_end|>\n<|im_start|>assistant\n",
                max_new_tokens=2,
                do_sample=False
            )[0]['generated_text'].strip()

            # Extract first digit from response
            for char in response:
                if char.isdigit():
                    return float(char)

            # If no digit found, use regex to find numbers like "4.5" or "3"
            matches = re.findall(r'\d+(\.\d+)?', response)
            if matches:
                score = float(matches[0])
                return min(5.0, max(1.0, score))  # Clamp between 1 and 5

            return 3.0  # Default score if no digit found

        except Exception as e:
            print(f"Error in evaluation: {str(e)}")
            return 3.0  # Default score on error

#%%
### 7. Complete RAG Pipeline with Comprehensive Grid Search
class AdvancedRagPipeline:
    """
    Memory-optimized RAG Pipeline with comprehensive grid search for optimal parameters:
    - Multiple chunking strategies
    - Multiple retrieval methods
    - Multiple prompt strategies
    - Multiple context enhancement methods
    - Multiple LLM models
    """

    def __init__(self, file_paths: List[str]):
        """Initialize the RAG pipeline with multiple PDF files"""
        self.file_paths = file_paths  # Store list of PDF paths
        self.retriever = None
        self._loaded_models = {}  # Cache for dynamically loaded models
        self._context_enhancer = None
        self._evaluator = None

    @property
    def context_enhancer(self):
        """Lazy-load context enhancer"""
        if self._context_enhancer is None:
            self._context_enhancer = ContextEnhancer()
            torch.cuda.empty_cache()
        return self._context_enhancer

    @property
    def evaluator(self):
        """Lazy-load evaluator"""
        if self._evaluator is None:
            self._evaluator = AdvancedRagEvaluator()
            torch.cuda.empty_cache()
        return self._evaluator

    def prepare_for_demo(self):
        """Prepare the pipeline for interactive demo"""
        if self.retriever is None:
            chunks = DataProcessor.load_and_chunk(
                self.file_paths,  # Use list of file paths
                chunking_strategy="recursive",
                chunk_size=512,
                chunk_overlap=64
            )
            self.retriever = AdvancedRetriever(chunks)

    def _get_generator_model(self, model_type):
        """Load model on demand"""
        key = f"generator_{model_type}"
        if key not in self._loaded_models:
            print(f"Loading {model_type} model...")

            if model_type == "qwen":
                self._loaded_models[key] = hf_pipeline(
                    "text-generation",
                    model="Qwen/Qwen1.5-0.5B-Chat",
                    device_map="auto",
                    torch_dtype="auto",
                    model_kwargs={"cache_dir": CACHE_DIR}
                )
            elif model_type == "flan-t5":
                self._loaded_models[key] = hf_pipeline(
                    "text2text-generation",
                    model="google/flan-t5-base",
                    device_map="auto",
                    model_kwargs={"cache_dir": CACHE_DIR}
                )
            elif model_type == "tiny-llama":
                self._loaded_models[key] = hf_pipeline(
                    "text-generation",
                    model="TinyLlama/TinyLlama-1.1B-Chat-v1.0",
                    device_map="auto",
                    torch_dtype="auto",
                    model_kwargs={"cache_dir": CACHE_DIR}
                )
            elif model_type == "phi-2":
                self._loaded_models[key] = hf_pipeline(
                    "text-generation",
                    model="microsoft/phi-2",
                    device_map="auto",
                    torch_dtype="auto",
                    model_kwargs={"cache_dir": CACHE_DIR}
                )
            else:
                raise ValueError(f"Unknown model type: {model_type}")

            torch.cuda.empty_cache()

        return self._loaded_models[key]

    def experiment(self, config: Dict) -> pd.DataFrame:
        """Run grid search experiment in memory-efficient batches"""
        test_queries = config.get("test_queries", [])

        # Calculate total configurations
        total_configs = self._calculate_total_configs(config)
        print(f"\nStarting grid search with {total_configs} configurations...")

        # Store results
        results = []
        experiment_start = time.time()

        # Process in batches by chunking strategy
        for chunking_strategy in config["chunking_strategies"]:
            for chunk_size in config["chunk_sizes"]:
                for chunk_overlap in config["chunk_overlaps"]:
                    # Process documents for this chunk configuration
                    chunks = DataProcessor.load_and_chunk(
                        self.file_paths,  # Use list of file paths
                        chunking_strategy=chunking_strategy,
                        chunk_size=chunk_size,
                        chunk_overlap=chunk_overlap
                    )
                    self.retriever = AdvancedRetriever(chunks)

                    # Run configuration batches with this retriever
                    batch_results = self._run_chunking_batch(
                        config=config,
                        chunk_size=chunk_size,
                        chunk_overlap=chunk_overlap,
                        chunking_strategy=chunking_strategy
                    )
                    results.extend(batch_results)

                    # Clear memory
                    self.retriever = None
                    torch.cuda.empty_cache()
                    gc.collect()

                    # Save batch results to disk to free memory
                    batch_df = pd.DataFrame(batch_results)
                    batch_counter = len(results) // len(batch_results)
                    batch_df.to_csv(f"rag_results_batch_{batch_counter}.csv", index=False)

        total_experiment_time = time.time() - experiment_start
        print(f"\nExperiment completed in {total_experiment_time:.2f} seconds")

        # Combine all batch results
        results_df = self._combine_batch_results()
        self._combine_batch_results()

        return results_df
    def _run_chunking_batch(self, config, chunk_size, chunk_overlap, chunking_strategy):
        """Run all configurations for a specific chunking strategy"""
        batch_results = []

        # Calculate total configs for this batch
        total_batch_configs = (
            len(config["retrieval_methods"]) *
            len(config["reranking_options"]) *
            len(config["context_enhancers"]) *
            len(config["prompt_strategies"]) *
            len(config["models"]) *
            len(config["top_k_values"]) *
            len(config["test_queries"])
        )

        with tqdm(total=total_batch_configs, desc=f"Running {chunking_strategy} chunks") as pbar:
            # Loop through parameter combinations for this chunking strategy
            for method in config["retrieval_methods"]:
                for rerank in config["reranking_options"]:
                    for context_enhancers in config["context_enhancers"]:
                        for prompt_strategy in config["prompt_strategies"]:
                            # Group by model to avoid loading/unloading too frequently
                            for model_type in config["models"]:
                                # Load model once for this batch
                                model = self._get_generator_model(model_type)

                                for top_k in config["top_k_values"]:
                                    for query in config["test_queries"]:
                                        result = self._run_configuration(
                                            query=query,
                                            chunk_size=chunk_size,
                                            chunk_overlap=chunk_overlap,
                                            chunking_strategy=chunking_strategy,
                                            method=method,
                                            rerank=rerank,
                                            context_enhancers=context_enhancers,
                                            prompt_strategy=prompt_strategy,
                                            model_type=model_type,
                                            top_k=top_k
                                        )
                                        batch_results.append(result)
                                        pbar.update(1)

                                # Explicitly clean up after all configs for this model
                                torch.cuda.empty_cache()
                                gc.collect()

        return batch_results

    def _calculate_total_configs(self, config):
        """Calculate total number of configurations"""
        return (
            len(config["chunk_sizes"]) *
            len(config["chunk_overlaps"]) *
            len(config["chunking_strategies"]) *
            len(config["retrieval_methods"]) *
            len(config["reranking_options"]) *
            len(config["context_enhancers"]) *
            len(config["prompt_strategies"]) *
            len(config["models"]) *
            len(config["top_k_values"]) *
            len(config["test_queries"])
        )

    def _run_configuration(self,
                          query: str,
                          chunk_size: int,
                          chunk_overlap: int,
                          chunking_strategy: str,
                          method: str,
                          rerank: bool,
                          context_enhancers: List[str],
                          prompt_strategy: str,
                          model_type: str,
                          top_k: int) -> Dict:
        """
        Execute single configuration run with full metrics

        Args:
            query: User query
            chunk_size: Size of chunks
            chunk_overlap: Overlap between chunks
            chunking_strategy: Strategy for chunking
            method: Retrieval method
            rerank: Whether to apply reranking
            context_enhancers: List of context enhancement methods
            prompt_strategy: Prompt strategy
            model_type: Model for generation
            top_k: Number of documents to retrieve

        Returns:
            Dictionary with results
        """
        print(f"\n{'-'*80}")
        print(f"Running configuration:")
        print(f"- Query: {query}")
        print(f"- Chunking: {chunking_strategy} (size={chunk_size}, overlap={chunk_overlap})")
        print(f"- Retrieval: {method} (rerank={rerank}, top_k={top_k})")
        print(f"- Enhancement: {context_enhancers}")
        print(f"- Generation: {model_type} with {prompt_strategy} prompting")
        print(f"{'-'*80}")

        # Track timing
        phase_times = {}

        # 1. Retrieval phase
        start_time = time.time()
        contexts, retrieval_time = self.retriever.search(
            query=query,
            k=top_k,
            method=method,
            rerank=rerank
        )
        phase_times["retrieval"] = retrieval_time

        # 2. Context enhancement phase
        start_time = time.time()
        if context_enhancers:
            context_str, enhancement_meta = self.context_enhancer.enhance_context(
                documents=contexts,
                query=query,
                methods=context_enhancers
            )
        else:
            context_str = "\n\n".join([c.page_content for c in contexts])
            enhancement_meta = {"original_length": len(context_str), "final_length": len(context_str)}

        phase_times["enhancement"] = time.time() - start_time

        # 3. Generation phase
        answer, generation_time = self._generate_answer(
            query=query,
            context=context_str,
            model_type=model_type,
            prompt_strategy=prompt_strategy
        )
        phase_times["generation"] = generation_time

        # 4. Evaluation phase
        start_time = time.time()
        metrics = self.evaluator.evaluate(
            answer=answer,
            context=context_str,
            query=query
        )
        phase_times["evaluation"] = time.time() - start_time

        # 5. Compile results
        result = {
            "query": query,
            "chunk_size": chunk_size,
            "chunk_overlap": chunk_overlap,
            "chunking_strategy": chunking_strategy,
            "retrieval_method": method,
            "reranking": "Yes" if rerank else "No",
            "context_enhancement": "+".join(context_enhancers) if context_enhancers else "None",
            "prompt_strategy": prompt_strategy,
            "model": model_type,
            "top_k": top_k,
            "context_length": enhancement_meta["final_length"],
            "answer_length": len(answer),
            "relevance": metrics["relevance"],
            "faithfulness": metrics["faithfulness"],
            "composite_score": metrics["composite_score"],
            "retrieval_time": phase_times["retrieval"],
            "enhancement_time": phase_times["enhancement"],
            "generation_time": phase_times["generation"],
            "evaluation_time": phase_times["evaluation"],
            "total_time": sum(phase_times.values()),
            "answer": answer[:1000] + "..." if len(answer) > 1000 else answer  # Store truncated answer
        }

        # Add additional metrics if available
        for key, value in metrics.items():
            if key not in result:
                result[key] = value

        # Calculate efficiency scores
        result["efficiency_score"] = result["composite_score"] / result["total_time"]

        print(f"Configuration completed in {result['total_time']:.3f}s with composite score: {result['composite_score']:.2f}")

        # Clean up after each configuration
        torch.cuda.empty_cache()
        gc.collect()

        return result

    def _generate_answer(self, query: str, context: str, model_type: str, prompt_strategy: str) -> Tuple[str, float]:
        """Generate answer with memory management"""
        model = self._get_generator_model(model_type)
        prompt = self._format_prompt(query, context, model_type, prompt_strategy)

        start_time = time.time()
        try:
            if model_type == "flan-t5":
                result = model(prompt, max_length=500)[0]['generated_text']
            else:
                result = model(
                    prompt,
                    max_new_tokens=500,
                    do_sample=True,
                    temperature=0.7
                )[0]['generated_text']

                # Post-process to extract the actual answer
                if model_type in ["qwen", "tiny-llama"]:
                    # Extract content between assistant tags
                    match = re.search(r'<\|im_start\|>assistant\n(.*?)(?:<\|im_end\|>|$)', result, re.DOTALL)
                    if match:
                        result = match.group(1).strip()
                elif model_type == "phi-2":
                    # Extract content after the prompt
                    result = result.split("Answer:")[-1].strip()
        except Exception as e:
            print(f"Error generating answer: {str(e)}")
            result = f"Error generating answer with {model_type}: {str(e)}"

        elapsed = time.time() - start_time
        return result, elapsed

    def _format_prompt(self, query: str, context: str, model_type: str, strategy: str) -> str:
        """Format prompt according to model requirements"""
        # Limit context length based on model type
        max_ctx_len = 4000  # Default
        if model_type == "flan-t5":
            max_ctx_len = 2000
        elif model_type in ["phi-2", "qwen", "tiny-llama"]:
            max_ctx_len = 3000

        if len(context) > max_ctx_len:
            context = context[:max_ctx_len] + "...(context truncated due to length)"

        # Model-specific formatting
        if model_type in ["qwen", "tiny-llama"]:
            if strategy == "cot":
                return f"""<|im_start|>system
You are an AI assistant that answers questions based on the provided context. Think step by step before providing your final answer.
Context: {context}<|im_end|>
<|im_start|>user
{query}<|im_end|>
<|im_start|>assistant
Let me think through this step by step:
1. First, I'll understand what the question is asking.
2. Then, I'll search the provided context for relevant information.
3. Finally, I'll formulate a comprehensive answer based on the context.

"""
            elif strategy == "few_shot":
                return f"""<|im_start|>system
You are an AI assistant that answers questions based on the provided context.
Context: {context}<|im_end|>
<|im_start|>user
What services are available to students?<|im_end|>
<|im_start|>assistant
Based on the provided context, the services available to students include academic advising, counseling services, career services, library resources, and health services.<|im_end|>
<|im_start|>user
{query}<|im_end|>
<|im_start|>assistant
"""
            else:  # standard
                return f"""<|im_start|>system
You are an AI assistant that answers questions based on the provided context.
Context: {context}<|im_end|>
<|im_start|>user
{query}<|im_end|>
<|im_start|>assistant
"""
        elif model_type == "phi-2":
            if strategy == "cot":
                return f"""Context: {context}

Question: {query}

Let me think through this step by step to find the answer in the context.

Step 1: Understand what the question is asking.
Step 2: Look through the context for relevant information.
Step 3: Formulate my answer based on the context.

Answer:"""
            elif strategy == "few_shot":
                return f"""Context: {context}

Question: What services are available to students?
Answer: Based on the provided context, the services available to students include academic advising, counseling services, career services, library resources, and health services.

Question: {query}
Answer:"""
            else:  # standard
                return f"""Context: {context}

Question: {query}
Answer:"""
        else:  # flan-t5
            if strategy == "cot":
                return f"context: {context}\nquestion: {query}\nThink step by step to answer the question:"
            elif strategy == "few_shot":
                return f"context: {context}\nquestion: What services are available to students?\nanswer: Based on the provided context, the services available to students include academic advising, counseling services, career services, library resources, and health services.\nquestion: {query}\nanswer:"
            else:  # standard
                return f"context: {context}\nquestion: {query}\nanswer:"

    def _combine_batch_results(self) -> pd.DataFrame:
        """Combine all batch results from disk"""
        batch_files = [f for f in os.listdir() if f.startswith("rag_results_batch_")]
        if not batch_files:
            return pd.DataFrame()

        dfs = []
        for batch_file in batch_files:
            try:
                df = pd.read_csv(batch_file)
                dfs.append(df)
                os.remove(batch_file)  # Clean up after loading
            except Exception as e:
                print(f"Error loading batch file {batch_file}: {str(e)}")

        if dfs:
            return pd.concat(dfs, ignore_index=True)
        return pd.DataFrame()

    def _analyze_results(self, results_df: pd.DataFrame):
        """
        Analyze grid search results and generate visualizations

        Args:
            results_df: DataFrame with experiment results
        """
        print("\n" + "="*80)
        print("EXPERIMENT RESULTS ANALYSIS")
        print("="*80)

        # Calculate summary statistics
        print("\n--- Best Configurations by Composite Score ---")
        top_configs = results_df.sort_values('composite_score', ascending=False).head(5)
        for i, (_, row) in enumerate(top_configs.iterrows(), 1):
            print(f"\n{i}. Composite Score: {row['composite_score']:.2f}")
            print(f"   Query: {row['query']}")
            print(f"   Chunking: {row['chunking_strategy']} (size={row['chunk_size']}, overlap={row['chunk_overlap']})")
            print(f"   Retrieval: {row['retrieval_method']} (rerank={row['reranking']}, top_k={row['top_k']})")
            print(f"   Enhancement: {row['context_enhancement']}")
            print(f"   Generation: {row['model']} with {row['prompt_strategy']} prompting")
            print(f"   Times: Retrieval={row['retrieval_time']:.3f}s, Generation={row['generation_time']:.3f}s, Total={row['total_time']:.3f}s")

        print("\n--- Best Configurations by Efficiency Score ---")
        top_efficient = results_df.sort_values('efficiency_score', ascending=False).head(5)
        for i, (_, row) in enumerate(top_efficient.iterrows(), 1):
            print(f"\n{i}. Efficiency Score: {row['efficiency_score']:.2f}")
            print(f"   Query: {row['query']}")
            print(f"   Chunking: {row['chunking_strategy']} (size={row['chunk_size']}, overlap={row['chunk_overlap']})")
            print(f"   Retrieval: {row['retrieval_method']} (rerank={row['reranking']}, top_k={row['top_k']})")
            print(f"   Enhancement: {row['context_enhancement']}")
            print(f"   Generation: {row['model']} with {row['prompt_strategy']} prompting")
            print(f"   Times: Retrieval={row['retrieval_time']:.3f}s, Generation={row['generation_time']:.3f}s, Total={row['total_time']:.3f}s")

        # Analyze performance by component
        print("\n--- Performance by Model ---")
        model_perf = results_df.groupby('model')[['relevance', 'faithfulness', 'composite_score', 'total_time']].mean()
        print(model_perf)

        print("\n--- Performance by Retrieval Method ---")
        retrieval_perf = results_df.groupby('retrieval_method')[['relevance', 'faithfulness', 'composite_score', 'total_time']].mean()
        print(retrieval_perf)

        print("\n--- Performance by Chunking Strategy ---")
        chunking_perf = results_df.groupby('chunking_strategy')[['relevance', 'faithfulness', 'composite_score', 'total_time']].mean()
        print(chunking_perf)

        print("\n--- Performance by Context Enhancement ---")
        enhancement_perf = results_df.groupby('context_enhancement')[['relevance', 'faithfulness', 'composite_score', 'total_time']].mean()
        print(enhancement_perf)

        print("\n--- Performance by Prompt Strategy ---")
        prompt_perf = results_df.groupby('prompt_strategy')[['relevance', 'faithfulness', 'composite_score', 'total_time']].mean()
        print(prompt_perf)

        # Save results
        timestamp = time.strftime("%Y%m%d-%H%M%S")
        results_path = f"rag_experiment_results_{timestamp}.csv"
        results_df.to_csv(results_path, index=False)
        print(f"\nResults saved to {results_path}")

        # Generate visualizations if matplotlib is available
        try:
            # Correlation heatmap
            numeric_df = results_df.select_dtypes(include=[np.number])
            plt.figure(figsize=(12, 10))
            sns.heatmap(numeric_df.corr(), annot=True, cmap='coolwarm', fmt=".2f", linewidths=0.5)
            plt.title('Correlation Between Metrics')
            plt.tight_layout()
            plt.savefig(f"rag_correlation_heatmap_{timestamp}.png")

            # Performance by model and retrieval method
            plt.figure(figsize=(14, 8))
            performance_pivot = results_df.pivot_table(
                index='model',
                columns='retrieval_method',
                values='composite_score',
                aggfunc='mean'
            )
            sns.heatmap(performance_pivot, annot=True, cmap='viridis', fmt=".2f", linewidths=0.5)
            plt.title('Average Composite Score by Model and Retrieval Method')
            plt.tight_layout()
            plt.savefig(f"rag_model_retrieval_performance_{timestamp}.png")

            print(f"Visualizations saved as PNG files")
        except Exception as e:
            print(f"Could not generate visualizations: {str(e)}")

#%%
### 8. Interactive demo mode
def run_interactive_demo(pipeline, retrieval_method="hybrid", model_type="qwen", top_k=5):
    """
    Run an interactive demo of the RAG system

    Args:
        pipeline: Configured RAG pipeline
        retrieval_method: Method for retrieval
        model_type: Model for generation
        top_k: Number of documents to retrieve
    """
    print("\n" + "="*80)
    print("Interactive RAG System Demo")
    print("="*80)
    print("Type 'exit' to quit")

    while True:
        query = input("\nEnter your question: ")
        if query.lower() in ['exit', 'quit']:
            break

        # Retrieval phase
        contexts, retrieval_time = pipeline.retriever.search(query, k=top_k, method=retrieval_method)
        context_str = "\n\n".join([c.page_content for c in contexts])

        # Generation phase
        answer, generation_time = pipeline._generate_answer(query, context_str, model_type, "standard")

        # Print results
        print(f"\nAnswer (generated in {generation_time:.2f}s):")
        print("-" * 80)
        print(answer)
        print("-" * 80)
        print(f"Total response time: {retrieval_time + generation_time:.2f}s")

#%%
### 9. Execution & Analysis
if __name__ == "__main__":
    # PDF file path - replace with your own PDF path
    PDF_FILES = [
            "/content/13. Atlas of Diabetes Mellitus (3rd Edition).pdf",
            "/content/diabetes.pdf"  # Replace with your second PDF path
        ]
    # Full experimental configuration with comprehensive parameter search
    COMPREHENSIVE_CONFIG = {
        # Test queries to evaluate system performance
        "test_queries": [
           # "What is the policy for academic probation?",
            #"What are the requirements for graduation?",
            #"How can students access mental health services?",
            #"What are the rules regarding academic dishonesty?",
            #"How can I apply for a leave of absence?"
           " what are the parking rules and regulation such that I dont get a fine ?"
        ],

        # Chunking parameters
        "chunk_sizes": [256, 512, 1024],
        "chunk_overlaps": [32, 64, 128],
        "chunking_strategies": ["recursive", "sentence", "paragraph"],

        # Retrieval parameters
        "retrieval_methods": ["dense", "sparse", "hybrid","tfidf"],
        #"retrieval_methods": ["dense", "sparse", "hybrid", "hnsw", "tfidf"],
        "reranking_options": [True, False],
        "top_k_values": [3, 5, 10],

        # Context enhancement parameters
        "context_enhancers": [
            [],  # No enhancement
            ["summarize"],
            ["reorder"],
            ["highlight"]
            #["deduplicate"],
            #["reorder", "deduplicate"],
            #["summarize", "highlight"]
        ],

        # Generation parameters
        #"prompt_strategies": ["standard", "cot", "few_shot"],
        "prompt_strategies": ["standard"],
       # "models": ["qwen", "flan-t5", "tiny-llama", "phi-2"]
        "models": ["qwen", "flan-t5"]
    }

    # Simplified configuration for faster testing
    FAST_CONFIG = {
        "test_queries": [
            "What is the policy for academic probation?",
            "What are the requirements for graduation?"
        ],
        "chunk_sizes": [512],
        "chunk_overlaps": [64],
        "chunking_strategies": ["recursive", "sentence"],
        "retrieval_methods": ["dense", "hybrid"],
        "reranking_options": [False],
        "top_k_values": [5],
        "context_enhancers": [[], ["summarize"]],
        "prompt_strategies": ["standard"],
        "models": ["qwen", "flan-t5"]
    }

    # Choose which configuration to use
    # EXPERIMENT_CONFIG = COMPREHENSIVE_CONFIG  # Uncomment for full experiment
    EXPERIMENT_CONFIG = FAST_CONFIG           # Use for fast testing


    # Initialize pipeline with multiple PDFs
    print(f"\nInitializing RAG pipeline with PDFs: {PDF_FILES}")
    pipeline = AdvancedRagPipeline(PDF_FILES)

    # Run full experiment
    print("\nStarting RAG experiment with configuration:")
    for key, value in EXPERIMENT_CONFIG.items():
        print(f"- {key}: {value if not isinstance(value, list) else len(value)} options")

    results_df = pipeline.experiment(EXPERIMENT_CONFIG)

    # Run interactive demo with best configurations
    best_config = results_df.loc[results_df['composite_score'].idxmax()]
    print(f"\nStarting interactive demo with best configuration:")
    print(f"- Retrieval: {best_config['retrieval_method']}")
    print(f"- Model: {best_config['model']}")
    print(f"- Top K: {best_config['top_k']}")

    pipeline.prepare_for_demo()
    run_interactive_demo(
        pipeline,
        retrieval_method=best_config['retrieval_method'],
        model_type=best_config['model'],
        top_k=int(best_config['top_k'])
    )


Overwriting app.py


In [None]:
import json
with open("best_rag_config.json", "w") as f:
    json.dump(best_config.to_dict(), f, indent=2)

AttributeError: 'dict' object has no attribute 'to_dict'

In [None]:

faiss.write_index(pipeline.retriever.index_ip, "rag_faiss_index.bin")

In [None]:
# Load the best configuration
import json
with open("best_rag_config.json", "r") as f:
    best_config = json.load(f)

# Initialize pipeline with two PDF files
PDF_FILES = [
    "/content/student-handbook2022-23.pdf",
    "/content/parking-regulations.pdf"  # Replace with your second PDF path
]
pipeline = AdvancedRagPipeline(PDF_FILES)  # Pass list of PDF paths
data_processor = DataProcessor()

# Load and chunk both PDFs
chunks = data_processor.load_and_chunk(
    pipeline.file_paths,
    chunking_strategy=best_config['chunking_strategy'],
    chunk_size=best_config['chunk_size'],
    chunk_overlap=best_config['chunk_overlap']
)

# Create retriever with combined chunks
pipeline.retriever = AdvancedRetriever(chunks)

# Load the saved FAISS index (ensure it was created with both PDFs)
try:
    pipeline.retriever.index_ip = faiss.read_index("rag_faiss_index.bin")
except FileNotFoundError:
    st.error("FAISS index not found. Please run the experiment first to generate the index.")
    st.stop()

# Streamlit UI
st.title("RAG System Demo")
query = st.text_input("Ask a question about the student handbook or parking regulations:")
if st.button("Submit"):
    contexts, _ = pipeline.retriever.search(
        query=query,
        k=int(best_config['top_k']),
        method=best_config['retrieval_method'],
        rerank=best_config['reranking'] == "Yes"
    )
    context_str = "\n\n".join([c.page_content for c in contexts])
    if best_config['context_enhancement'] != "None":
        context_str, _ = pipeline.context_enhancer.enhance_context(
            contexts, query, best_config['context_enhancement'].split("+")
        )
    answer, _ = pipeline._generate_answer(
        query=query,
        context=context_str,
        model_type=best_config['model'],
        prompt_strategy=best_config['prompt_strategy']
    )
    st.write("**Answer:**")
    st.write(answer)

JSONDecodeError: Expecting value: line 1 column 1 (char 0)

In [None]:
!wget https://bin.equinox.io/c/bNyj1mQVY4c/ngrok-v3-stable-linux-amd64.tgz
!tar -xvf ngrok-v3-stable-linux-amd64.tgz
!mv ngrok /usr/local/bin/

--2025-04-11 22:34:09--  https://bin.equinox.io/c/bNyj1mQVY4c/ngrok-v3-stable-linux-amd64.tgz
Resolving bin.equinox.io (bin.equinox.io)... 99.83.220.108, 75.2.60.68, 35.71.179.82, ...
Connecting to bin.equinox.io (bin.equinox.io)|99.83.220.108|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 9395172 (9.0M) [application/octet-stream]
Saving to: ‘ngrok-v3-stable-linux-amd64.tgz’


2025-04-11 22:34:13 (13.3 MB/s) - ‘ngrok-v3-stable-linux-amd64.tgz’ saved [9395172/9395172]

ngrok


In [None]:
# Install ngrok and pyngrok
!wget https://bin.equinox.io/c/bNyj1mQVY4c/ngrok-v3-stable-linux-amd64.tgz -q
!tar -xvf ngrok-v3-stable-linux-amd64.tgz -q
!mv ngrok /usr/local/bin/
!pip install pyngrok -q

# Authenticate ngrok
!ngrok authtoken 2vbPj6xVgq76zpNz5SyudZB8Ymg_5zpSqYBXLt1Ekj83wZSGx

# Run Streamlit and expose it
from pyngrok import ngrok
import subprocess

ngrok.kill()  # Clear previous tunnels
process = subprocess.Popen(['streamlit', 'run', '/content/app.py', '--server.port', '8501'])
public_url = ngrok.connect(8501)
print(f"Access your Streamlit app at: {public_url}")

import time
time.sleep(3600)  # Keep running for 1 hour

tar: invalid option -- 'q'
Try 'tar --help' or 'tar --usage' for more information.
mv: cannot stat 'ngrok': No such file or directory
Authtoken saved to configuration file: /root/.config/ngrok/ngrok.yml
Access your Streamlit app at: NgrokTunnel: "https://05ee-34-83-151-40.ngrok-free.app" -> "http://localhost:8501"


In [None]:
!ngrok authtoken 2vbPj6xVgq76zpNz5SyudZB8Ymg_5zpSqYBXLt1Ekj83wZSGx

Authtoken saved to configuration file: /root/.config/ngrok/ngrok.yml


In [None]:
import subprocess
!pip install pyngrok
from pyngrok import ngrok

# Start Streamlit in the background
process = subprocess.Popen(['streamlit', 'run', 'app.py', '--server.port', '8501'])

# Create a public URL with ngrok
public_url = ngrok.connect(8501)
print(f"Access your Streamlit app at: {public_url}")

Access your Streamlit app at: NgrokTunnel: "https://3de3-34-83-151-40.ngrok-free.app" -> "http://localhost:8501"
