In [None]:
from pathlib import Path
from typing import Any, Dict, List, Optional

from langchain.retrievers import ContextualCompressionRetriever
from langchain.retrievers.document_compressors import LLMChainExtractor
from langchain_community.document_loaders import PyPDFLoader
from langchain_community.vectorstores import FAISS
from langchain_core.documents import Document
from langchain_core.embeddings import Embeddings
from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.retrievers import BaseRetriever
from langchain_core.runnables import RunnablePassthrough
from langchain_experimental.text_splitter import SemanticChunker
from langchain_ollama import ChatOllama, OllamaEmbeddings
from langchain_text_splitters import RecursiveCharacterTextSplitter

In [None]:
class RAG:
    """
    A class to handle the RAG pipeline with contextual compression.

    Attributes:
        model_name (str): Name of the Ollama model to use
        embeddings_model (Embeddings): Embeddings model instance
        llm (ChatOllama): LLM instance for generation
        text_splitter (RecursiveCharacterTextSplitter): Text splitter for chunking
        vectorstore (Optional[FAISS]): Vector store for document embeddings
        retriever (Optional[BaseRetriever]): Document retriever
        compression_retriever (Optional[ContextualCompressionRetriever]): Compressed retriever
    """

    def __init__(self, model_name: str = "llama3.2:3b"):
        """
        Initialize the RAG pipeline.

        Args:
            model_name (str): Name of the Ollama model to use (default: "llama3.2:3b")
        """
        self.model_name = model_name
        self.embeddings_model: Optional[Embeddings] = None
        self.llm: Optional[ChatOllama] = None
        self.text_splitter: Optional[RecursiveCharacterTextSplitter] = None
        self.vectorstore: Optional[FAISS] = None
        self.retriever: Optional[BaseRetriever] = None
        self.compression_retriever: Optional[ContextualCompressionRetriever] = None
        self._setup_environment()

    def _setup_environment(self) -> None:
        """Set up the required models and components."""
        try:
            # Initialize Ollama embeddings
            self.embeddings_model = OllamaEmbeddings(model=self.model_name)

            # Initialize Ollama LLM
            self.llm = ChatOllama(model=self.model_name)

            # Initialize text splitter
            self.text_splitter = RecursiveCharacterTextSplitter(
                chunk_size=1000,
                chunk_overlap=200,
                length_function=len,
                is_separator_regex=False,
            )

            print("Environment setup completed successfully.")

        except Exception as e:
            print(f"Error setting up environment: {e}")
            raise

    def extract_text_from_pdf(self, file_path: str) -> List[Document]:
        """
        Extract text from a PDF file.

        Args:
            file_path (str): Path to the PDF file

        Returns:
            List of extracted documents
        """
        try:
            if not Path(file_path).exists():
                raise FileNotFoundError(f"File not found: {file_path}")

            loader = PyPDFLoader(file_path)
            documents = loader.load()
            print(f"Extracted {len(documents)} pages from PDF.")
            return documents

        except Exception as e:
            print(f"Error extracting text from PDF: {e}")
            raise

    def chunk_documents(self, documents: List[Document]) -> List[Document]:
        """
        Split document into chunks.

        Args:
            documents (List[Document]): List of documents to chunk

        Returns:
            List of chunked documents
        """
        if not self.text_splitter:
            raise ValueError("Text spliiter not initialized.")

        try:
            chunks = self.text_splitter.split_documents(documents)
            print(f"Split documents into {len(chunks)} chunks.")
            return chunks

        except Exception as e:
            print(f"Error chunking documents: {e}")
            raise

    def create_embeddings(self, chunks: List[Document]) -> None:
        """
        Create embeddings for document chunks and store them in a vector store.

        Args:
            chunks (List[Document]): List of document chunks to embed
        """
        if not self.embeddings_model:
            raise ValueError("Embeddings model not initialized.")

        try:
            self.vectorstore = FAISS.from_documents(chunks, self.embeddings_model)
            self.retriever = self.vectorstore.as_retriever(search_kwargs={"k": 4})
            print("Created embeddings and initialized retriever.")

        except Exception as e:
            print(f"Error creating embeddings: {e}")
            raise

    def setup_contextual_compression(self) -> None:
        """
        Set up contextual compression for the retriever.
        """
        if not self.llm or not self.retriever:
            raise ValueError("LLM or retriever not initialized.")

        try:
            # Create a document compressor
            compressor = LLMChainExtractor.from_llm(self.llm)

            # Create the compression retriever
            self.compression_retriever = ContextualCompressionRetriever(
                base_compressor=compressor, base_retriever=self.retriever
            )
            print("Contextual compression retriever setup completed.")

        except Excetion as e:
            print(f"Error setting up contextual compression: {e}")
            raise

    def retrieve_documents(self, query: str, compressed: bool = True) -> List[Document]:
        """
        Retrieve relevant documents for a query.

        Args:
            query (str): The query to search for
            compressed (bool): Whether to use compressed retrieval (default: True)

        Returns:
            List of relevant documents
        """
        try:
            if compressed:
                if not self.compression_retriever:
                    raise ValueError("Compression retriever not initialized.")
                return self.compression_retriever.invoke(query)
            else:
                if not self.retriever:
                    raise ValueError("Retriever not initialized.")
                return self.retriever.invoke(query)
        except Exception as e:
            print(f"Error retrieving documents: {e}")
            raise

    def generate_response(self, query: str, compressed: bool = True) -> str:
        """
        Generate a response to a query using RAG.

        Args:
            query (str): The query to respond to
            compressed (bool): Whether to use compressed retrieval (default: True)

        Returns:
            The generated response
        """
        if not self.llm:
            raise ValueError("LLM not initialized.")

        try:
            # Retrieve relevant documents
            retrieved_docs = self.retrieve_documents(query, compressed)

            # Format the documents as context
            context = "\n\n".join([doc.page_content for doc in retrieved_docs])

            # Create a prompt template
            prompt = ChatPromptTemplate.from_template(
                """Answer the following question based on the provided context.
                Be concise and accurate. If you don't know the answer, say you don't know.
                
                Context: {context}
                
                Question: {question}
                
                Answer:"""
            )

            # Create the RAG chain
            rag_chain = (
                {"context": lambda x: context, "question": RunnablePassthrough()}
                | prompt
                | self.llm
                | StrOutputParser()
            )

            # Generate the response
            response = rag_chain.invoke(query)
            return response
        except Exception as e:
            print(f"Error generating response: {e}")
            raise

In [None]:
pdf_path = "./dataset/health supplements/1. dietary supplements - for whom.pdf"

In [None]:
rag = RAG(model_name="llama3.2:3b")

In [None]:
documents = rag.extract_text_from_pdf(pdf_path)

In [None]:
chunks = rag.chunk_documents(documents)

In [None]:
rag.create_embeddings(chunks)

In [None]:
rag.setup_contextual_compression()

In [None]:
query = "What is the main topic discussed in section 3?"

In [None]:
regular_docs = rag.retrieve_documents(query, compressed=False)

In [None]:
compressed_docs = rag.retrieve_documents(query, compressed=True)

In [None]:
print(f"Regular retrieval returned {len(regular_docs)} documents")
print(f"Compressed retrieval returned {len(compressed_docs)} documents\n")

In [None]:
response = rag.generate_response(query)

In [None]:
print(f"Response to query '{query}':\n{response}")