In [None]:
from dataclasses import dataclass
from typing import Dict, List, Optional, Tuple, Union

import fitz
import numpy as np
from langchain_ollama import ChatOllama, OllamaEmbeddings
from numpy.linalg import norm

In [None]:
@dataclass
class Document:
    """A class to represent a document with text content and metadata."""

    page_content: str
    metadata: Dict[str, str] = None

In [None]:
class PDFReader:
    """
    A class to read and process PDF documents using PyMuPDF (fitz).

    Attributes:
        chunk_size (int): Number of characters per text chunk
        chunk_overlap (int): Number of overlapping characters between chunks
    """

    def __init__(self, chunk_size: int = 1000, chunk_overlap: int = 200):
        """
        Initialize the PDF reader with chunking parameters.

        Args:
            chunk_size (int): Size of each text chunk in characters (default: 1000)
            chunk_overlap (int): Overlap between chunks in characters (default: 200)
        """
        self.chunk_size = chunk_size
        self.chunk_overlap = chunk_overlap

    def read_pdf(self, file_path: str) -> List[Document]:
        """
        Read a PDF file and return chunks of text as Documents.

        Args:
            file_path (str): Path to the PDF file

        Returns:
            List of Document objects containing text chunks and metadata

        Raises:
            FileNotFoundError: If the PDF file doesn't exist
            Exception: For other PDF reading errors
        """
        try:
            documents = []
            doc = fitz.open(file_path)

            for page_num in range(len(doc)):
                page = doc.load_page(page_num)
                text = page.get_text()

                # Split text into chunks with overlap
                chunks = self._chunk_text(text)

                for chunk in chunks:
                    metadata = {
                        "source": file_path,
                        "page": page_num + 1,
                        "chunk_size": len(chunk),
                    }
                    documents.append(Document(page_content=chunk, metadata=metadata))

            print(f"Read {len(documents)} chunks from {file_path}")
            return documents

        except FileNotFoundError:
            print(f"PDF file not found: {file_path}")
            raise
        except Exception as e:
            print(f"Error reading PDF file {file_path}: {e}")
            raise

    def _chunk_text(self, text: str) -> List[str]:
        """
        Split text into chunks of specified size with overlap.

        Args:
            text (str): Input text to chunk

        Returns:
            List of text chunks
        """
        chunks = []
        start = 0
        text_length = len(text)

        while start < text_length:
            end = min(start + self.chunk_size, text_length)
            chunks.append(text[start:end])

            if end == text_length:
                break

            start = end - self.chunk_overlap

        return chunks

In [None]:
class VectorStore:
    """
    A simple in-memory vector store implementation for demonstration purposes.

    Attributes:
        documents (List[Document]): List of stored documents
        embeddings (List[np.ndarray]): Corresponding embeddings for documents
    """

    def __init__(self):
        """Initialize an empty vector store."""
        self.documents: List[Document] = []
        self.embeddings: List[np.ndarray] = []

    def add_documents(
        self, documents: List[Document], embeddings: List[np.ndarray]
    ) -> None:
        """
        Add documents and their embeddings to the vector store.

        Args:
            documents (List[Document]): List of Document objects to add
            embeddings (List[np.ndarray]): List of corresponding embeddings as numpy arrays

        Raises:
            ValueError: If lengths of documents and embeddings don't match
        """
        if len(documents) != len(embeddings):
            raise ValueError("Number of documents must match number of embeddings")

        self.documents.extend(documents)
        self.embeddings.extend(embeddings)
        print(f"Added {len(documents)} documents to the vector store")

    def similarity_search(
        self, query_embedding: np.ndarray, k: int = 3
    ) -> List[Tuple[Document, float]]:
        """
        Perform similarity search using cosine similarity.

        Args:
            query_embedding (np.ndarray): Embedding of the query as numpy array
            k (int): Number of top results to return (default: 3)

        Returns:
            List of tuples containing (Document, similarity_score) ordered by similarity
        """
        if not self.documents:
            raise ValueError("Vector store is empty")

        # Calculate cosine similarities
        similarities = []
        for doc_embedding in self.embeddings:
            cosine_sim = np.dot(query_embedding, doc_embedding) / (
                norm(query_embedding) * norm(doc_embedding)
            )
            similarities.append(cosine_sim)

        # Get top k results
        top_indices = np.argsort(similarities)[-k:][::-1]
        results = [(self.documents[i], similarities[i]) for i in top_indices]

        print(f"Retrieved {len(results)} documents from vector store")
        return results

In [None]:
class QueryTransformer:
    """Base class for query transformation techniques."""

    def __init__(self, llm_client: ChatOllama):
        """
        Initialize the query transformer with an LLM client.

        Args:
            llm_client (ChatOllama): Initialized ChatOllama client
        """
        self.llm = llm_client

    def transform(self, query: str) -> str:
        """Transform the input query (to be implemented by subclasses)."""
        raise NotImplementedError

In [None]:
class QueryRewriter(QueryTransformer):
    """Implements query rewriting technique to make queries more specific."""

    def transform(self, query: str) -> str:
        """
        Rewrite the query to be more specific and detailed.

        Args:
            query (str): Original user query

        Returns:
            Rewritten, more specific query
        """
        prompt = f"""
        You are a helpful query rewriter. Your task is to make the following query more specific 
        and detailed while preserving the original intent. Add relevant details that would help 
        retrieve more precise information.
        
        Original query: {query}
        
        Rewritten query:
        """

        try:
            response = self.llm.invoke(prompt)
            rewritten_query = response.content.strip()
            print(f"Rewritten query: {rewritten_query}")
            return rewritten_query
        except Exception as e:
            print(f"Error in query rewriting: {e}")
            return query  # Fallback to original query

In [None]:
class StepBackPrompting(QueryTransformer):
    """Implements step-back prompting to generate broader contextual queries."""

    def transform(self, query: str) -> str:
        """
        Generate a broader query to retrieve contextual background information.

        Args:
            query (str): Original user query

        Returns:
            Broader, more conceptual query for context retrieval
        """
        prompt = f"""
        You are a helpful assistant that generates step-back questions. For the given query, 
        create a broader question that asks about the higher-level concepts or principles 
        needed to answer the original query.
        
        Original query: {query}
        
        Step-back question:
        """

        try:
            response = self.llm.invoke(prompt)
            step_back_query = response.content.strip()
            print(f"Step-back query: {step_back_query}")
            return step_back_query
        except Exception as e:
            print(f"Error in step-back prompting: {e}")
            return query  # Fallback to original query

In [None]:
class SubQueryDecomposer(QueryTransformer):
    """Implements sub-query decomposition to break complex queries into simpler parts."""

    def transform(self, query: str) -> List[str]:
        """
        Break down a complex query into simpler sub-queries.

        Args:
            query (str): Original complex user query

        Returns:
            List of simpler sub-queries that cover aspects of the original query
        """
        prompt = f"""
        You are a helpful query decomposer. Break down the following complex query into 
        2-3 simpler sub-queries that cover different aspects of the original query.
        Return each sub-query on a new line.
        
        Original query: {query}
        
        Generate 3 sub-queries, one per line, in this format:
        1. [First sub-query]
        2. [Second sub-query]
        3. [Third sub-query]
        """

        try:
            response = self.llm.invoke(prompt)
            sub_queries = [q.strip() for q in response.content.split("\n") if q.strip()]
            print(f"Generated sub-queries: {sub_queries}")
            return sub_queries
        except Exception as e:
            print(f"Error in query decomposition: {e}")
            return [query]  # Fallback to original query as single sub-query

In [None]:
class RAGSystem:
    """
    A RAG (Retrieval-Augmented Generation) system with query transformation capabilities.

    Attributes:
        llm: ChatOllama instance for generation
        embeddings: OllamaEmbeddings instance for creating embeddings
        vector_store: VectorStore for document retrieval
        query_transformers: Dictionary of available query transformation techniques
    """

    def __init__(self, llm_model: str = "llama3.2:3b", embedding_model: str = "llama3"):
        """
        Initialize the RAG system with LLM and embedding models.

        Args:
            llm_model: Name of the Ollama LLM model to use
            embedding_model: Name of the Ollama embedding model to use
        """
        try:
            self.llm = ChatOllama(model=llm_model, temperature=0.3)
            self.embeddings = OllamaEmbeddings(model=embedding_model)
            self.vector_store = VectorStore()
            self.pdf_reader = PDFReader()

            # Initialize query transformers
            self.query_transformers = {
                "rewriting": QueryRewriter(self.llm),
                "step_back": StepBackPrompting(self.llm),
                "decomposition": SubQueryDecomposer(self.llm),
            }

            print("RAG system initialized successfully")
        except Exception as e:
            print(f"Failed to initialize RAG system: {e}")
            raise

    def add_pdf_documents(self, file_paths: List[str]) -> None:
        """
        Add PDF documents to the vector store after reading and chunking.

        Args:
            file_paths (List[str]): List of paths to PDF files

        Raises:
            ValueError: If no valid PDF files are provided
        """
        if not file_paths:
            raise ValueError("No PDF files provided")

        all_documents = []
        for file_path in file_paths:
            try:
                documents = self.pdf_reader.read_pdf(file_path)
                all_documents.extend(documents)
            except Exception as e:
                print(f"Skipping {file_path} due to error: {e}")
                continue

        if not all_documents:
            raise ValueError("No valid PDF documents could be processed")

        self.add_documents(all_documents)

    def add_documents(self, documents: List[Document]) -> None:
        """
        Add documents to the vector store after generating their embeddings.

        Args:
            documents (List[Document]): List of Document objects to add
        """
        try:
            # Generate embeddings for each document
            texts = [doc.page_content for doc in documents]
            doc_embeddings = self.embeddings.embed_documents(texts)
            doc_embeddings = [np.array(embedding) for embedding in doc_embeddings]

            # Add to vector store
            self.vector_store.add_documents(documents, doc_embeddings)
        except Exception as e:
            print(f"Error adding documents: {e}")
            raise

    def retrieve(
        self, query: str, transformation: Optional[str] = None, k: int = 3
    ) -> List[Tuple[Document, float]]:
        """
        Retrieve relevant documents with optional query transformation.

        Args:
            query (str): The query to search with
            transformation (str): Optional query transformation technique to apply
                           ('rewriting', 'step_back', 'decomposition', or None)
            k (int): Number of documents to retrieve

        Returns:
            List of (Document, similarity_score) tuples
        """
        try:
            # Apply query transformation if specified
            transformed_queries = [query]

            if transformation and transformation in self.query_transformers:
                transformer = self.query_transformers[transformation]

                if transformation == "decomposition":
                    transformed_queries = transformer.transform(query)
                else:
                    transformed_query = transformer.transform(query)
                    transformed_queries = [transformed_query]

            # Retrieve documents for each transformed query
            all_results = []
            for q in transformed_queries:
                # Generate embedding for the query
                query_embedding = np.array(self.embeddings.embed_query(q))

                # Perform similarity search
                results = self.vector_store.similarity_search(query_embedding, k=k)
                all_results.extend(results)

            # Deduplicate and sort results by similarity score
            unique_results = list(
                {doc.page_content: (doc, score) for doc, score in all_results}.values()
            )
            unique_results.sort(key=lambda x: x[1], reverse=True)

            return unique_results[:k]

        except Exception as e:
            print(f"Error in retrieval: {e}")
            return []

    def generate_response(self, query: str, retrieved_docs: List[Document]) -> str:
        """
        Generate a response using the LLM based on the query and retrieved documents.

        Args:
            query (str): Original user query
            retrieved_docs (List[Document]): List of relevant documents retrieved

        Returns:
            Generated response as a string
        """
        try:
            # Format the retrieved documents as context
            context = "\n\n".join([doc.page_content for doc in retrieved_docs])

            prompt = f"""
            You are a helpful assistant that answers questions based on the provided context.
            Use the following context to answer the question at the end. If you don't know
            the answer, just say you don't know - don't make up an answer.
            
            Context:
            {context}
            
            Question: {query}
            
            Answer:
            """

            response = self.llm.invoke(prompt)
            return response.content.strip()

        except Exception as e:
            print(f"Error in response generation: {e}")
            return (
                "I encountered an error while generating a response. Please try again."
            )

In [None]:
rag = RAGSystem(llm_model="llama3.2:3b", embedding_model="llama3.2:3b")

In [None]:
pdf_paths = ["./dataset/health supplements/1. dietary supplements - for whom.pdf"]

In [None]:
rag.add_pdf_documents(pdf_paths)

In [None]:
query = "What is the main topic discussed in section 3?"

In [None]:
print("=== Original Query ===")
print(query + "\n")

In [None]:
print("=== Query Rewriting ===")
rewritten_query = rag.query_transformers["rewriting"].transform(query)
print(rewritten_query + "\n")

In [None]:
print("=== Step-back Prompting ===")
step_back_query = rag.query_transformers["step_back"].transform(query)
print(step_back_query + "\n")

In [None]:
print("=== Sub-query Decomposition ===")
sub_queries = rag.query_transformers["decomposition"].transform(query)
for i, q in enumerate(sub_queries, 1):
    print(f"{i}. {q}")

In [None]:
print("=== RAG with Query Rewriting ===")
retrieved_docs = rag.retrieve(query, transformation="step_back")
response = rag.generate_response(query, [doc for doc, _ in retrieved_docs])
print("Response:", response)