In [None]:
import os
import re
from pathlib import Path
from typing import Dict, List, Optional, Tuple

import fitz
import numpy as np
from langchain_ollama import ChatOllama, OllamaEmbeddings
from numpy.linalg import norm

In [None]:
class RAG:
    """
    A class to implement Document Augmentation in RAG (Retrieval-Augmented Generation) pipeline.
    """

    def __init__(self, model_name: str = "llama3.2:3b"):
        """
        Initialize the RAG pipeline with Ollama models.

        Args:
            model_name (str): Name of the Ollama model to use (default: "llama3.2:3b")
        """
        self.model_name = model_name
        self.embedding_model = OllamaEmbeddings(model=model_name)
        self.llm = ChatOllama(model=model_name)
        self.chunks: List[str] = []
        self.augmented_chunks: List[str] = []
        self.embeddings: List[List[float]] = []
        self.augmented_embeddings: List[List[float]] = []

    def extract_text_from_pdf(self, pdf_path: str) -> str:
        """
        Extract text content from a PDF file.

        Args:
            pdf_path (str): Path to the PDF file

        Returns:
            Extracted text as a single string

        Raises:
            FileNotFoundError: If PDF file doesn't exist
            ValueError: If PDF extraction fails
        """
        try:
            if not os.path.exists(pdf_path):
                raise FileNotFoundError(f"PDF file not found at {pdf_path}")

            doc = fitz.open(pdf_path)
            text = "\n".join(page.get_text() for page in doc)

            if not text.strip():
                raise ValueError("No text could be extracted from the PDF")

            return text

        except Exception as e:
            raise ValueError(f"Error extracting text from PDF: {str(e)}")

    def chunk_text(
        self, text: str, chunk_size: int = 1000, overlap: int = 200
    ) -> List[str]:
        """
        Split text into chunks with optional overlap.

        Args:
            text (str): Input text to chunk
            chunk_size (int): Maximum size of each chunk in characters (default: 1000)
            overlap (int): Number of overlapping characters between chunks (default: 200)

        Returns:
            List of text chunks
        """
        # Clean the text by removing excessive whitespace
        text = re.sub(r"\s+", " ", text).strip()

        chunks = []
        start = 0

        while start < len(text):
            end = min(start + chunk_size, len(text))
            chunk = text[start:end]
            chunks.append(chunk)

            # Stop if we've reached the end of the text
            if end == len(text):
                break

            # Move the start position, accounting for overlap
            start = end - overlap

        return chunks

    def augment_chunk(self, chunk: str) -> str:
        """
        Augment a text chunk by generating questions and answers about its content.

        Args:
            chunk (str): Text chunk to augment

        Returns:
            Augmented chunk with Q&A section
        """
        # Prompt to generate questions about the chunk
        prompt = f"""Based on the following text, generate 2-3 relevant questions and their answers.
        Format the output exactly as:
        [Original Text]
        {{original text}}
        
        [Questions and Answers]
        1. Q: {{question}}
           A: {{answer}}
        2. Q: {{question}}
           A: {{answer}}
        
        Text: {chunk}"""

        try:
            response = self.llm.invoke(prompt)
            augmented_text = response.content

            # Combine original and augmented content
            return f"{chunk}\n\n---AUGMENTED---\n{augmented_text}"

        except Exception as e:
            print(f"Error augmenting chunk: {str(e)}")
            return chunk  # Return original if augmentation fails

    def generate_embedding(self, text: str) -> List[float]:
        """
        Generate embedding for a given text using Ollama.

        Args:
            text (str): Input text to embed

        Returns:
            Embedding vector as list of floats
        """
        try:
            return self.embedding_model.embed_query(text)
        except Exception as e:
            print(f"Error generating embedding: {str(e)}")
            return []

    def cosine_similarity(self, vec_a: List[float], vec_b: List[float]) -> float:
        """
        Calculate cosine similarity between two vectors.

        Args:
            vec_a (List[float]): First vector
            vec_b (List[float]): Second vector

        Returns:
            Cosine similarity score between -1 and 1
        """
        if not vec_a or not vec_b or len(vec_a) != len(vec_b):
            return 0.0

        return np.dot(vec_a, vec_b) / (norm(vec_a) * norm(vec_b))

    def process_document(self, pdf_path: str):
        """
        Process a PDF document: extract, chunk, augment, and generate embeddings.

        Args:
            pdf_path (str): Path to the PDF file
        """
        print("Processing document...")

        # Step 1: Extract text from PDF
        text = self.extract_text_from_pdf(pdf_path)
        print(f"Extracted {len(text)} characters from PDF")

        # Step 2: Chunk the text
        self.chunks = self.chunk_text(text)
        print(f"Created {len(self.chunks)} text chunks")

        # Step 3: Augment chunks
        self.augmented_chunks = [self.augment_chunk(chunk) for chunk in self.chunks]
        print(f"Augmented {len(self.augmented_chunks)} chunks")

        # Step 4: Generate embeddings for original chunks
        self.embeddings = [self.generate_embedding(chunk) for chunk in self.chunks]

        # Step 5: Generate embeddings for augmented chunks
        self.augmented_embeddings = [
            self.generate_embedding(chunk) for chunk in self.augmented_chunks
        ]

        print(f"Generated {len(self.embeddings)} original embeddings")
        print(f"Generated {len(self.augmented_embeddings)} augmented embeddings")

    def retrieve_relevant_chunks(
        self, query: str, use_augmented: bool = False, top_k: int = 3
    ) -> List[Tuple[str, float]]:
        """
        Retrieve relevant chunks based on query similarity.

        Args:
            query (str): Search query
            use_augmented (bool): Whether to use augmented chunks (default: False)
            top_k (int): Number of top chunks to return

        Returns:
            List of tuples containing (chunk_text, similarity_score)
        """
        # Generate embedding for the query
        query_embedding = self.generate_embedding(query)
        if not query_embedding:
            return []

        # Choose which embeddings and chunks to use
        embeddings = self.augmented_embeddings if use_augmented else self.embeddings
        chunks = self.augmented_chunks if use_augmented else self.chunks

        # Calculate similarity scores
        similarities = []
        for i, emb in enumerate(embeddings):
            if emb:  # Only process if we have a valid embedding
                sim = self.cosine_similarity(query_embedding, emb)
                similarities.append((chunks[i], sim))

        # Sort by similarity (descending) and return top_k
        similarities.sort(key=lambda x: x[1], reverse=True)
        return similarities[:top_k]

    def generate_response(self, query: str, use_augmented: bool = False) -> str:
        """
        Generate a response to a query using RAG pipeline.

        Args:
            query (str): User query
            use_augmented (bool): Whether to use augmented chunks (default: False)

        Returns:
            Generated response
        """
        # Retrieve relevant chunks
        relevant_chunks = self.retrieve_relevant_chunks(query, use_augmented)

        if not relevant_chunks:
            return "I couldn't find any relevant information to answer your question."

        # Prepare context for the LLM
        context = "\n\n".join([chunk for chunk, _ in relevant_chunks])

        # Generate prompt with context
        prompt = f"""Use the following context to answer the question at the end.
        If you don't know the answer, just say you don't know, don't try to make up an answer.
        
        Context:
        {context}
        
        Question: {query}
        
        Answer:"""

        try:
            response = self.llm.invoke(prompt)
            return response.content
        except Exception as e:
            return f"Error generating response: {str(e)}"

In [None]:
pdf_path = "./dataset/health supplements/1. dietary supplements - for whom.pdf"

In [None]:
rag = RAG(model_name="llama3.2:3b")

In [None]:
rag.process_document(pdf_path)

In [None]:
print("Vector Database Statistics:")
print(f"- Total original chunks: {len(rag.chunks)}")
print(f"- Total augmented chunks: {len(rag.augmented_chunks)}")
print(f"- Total original embeddings: {len(rag.embeddings)}")
print(f"- Total augmented embeddings: {len(rag.augmented_embeddings)}")

In [None]:
query = (
    "Provide information about the functionality and safety of dieatary supplements."
)

In [None]:
print(f"\nRunning query: '{query}'")

In [None]:
print("\nResults with ORIGINAL data:")
original_chunks = rag.retrieve_relevant_chunks(query, use_augmented=False)
print("\nMost relevant original chunks:")
for i, (chunk, score) in enumerate(original_chunks, 1):
    print(f"\nChunk {i} (Similarity: {score:.2f}):")
    print(chunk[:200] + "..." if len(chunk) > 200 else chunk)

In [None]:
original_response = rag.generate_response(query, use_augmented=False)
print("\nGenerated response (original):")
print(original_response)

In [None]:
print("\nResults with AUGMENTED data:")
augmented_chunks = rag.retrieve_relevant_chunks(query, use_augmented=True)
print("\nMost relevant augmented chunks:")
for i, (chunk, score) in enumerate(augmented_chunks, 1):
    print(f"\nChunk {i} (Similarity: {score:.2f}):")
    print(chunk[:200] + "..." if len(chunk) > 200 else chunk)

In [None]:
augmented_response = rag.generate_response(query, use_augmented=True)
print("\nGenerated response (augmented):")
print(augmented_response)