In [None]:
import json
from typing import Any, Dict, List, Optional, Tuple

import fitz
import numpy as np
from langchain_ollama import ChatOllama, OllamaEmbeddings
from numpy.linalg import norm
from pydantic import BaseModel

In [None]:
class Document(BaseModel):
    """A document with text content and metadata."""

    page_content: str
    metadata: Dict[str, Any] = {}

In [None]:
class RetrievalResult(BaseModel):
    """Result of a retrieval operation containing documents and scores."""

    documents: List[Document]
    scores: List[float]

In [None]:
class GenerationOutput(BaseModel):
    """Output of the generation step with the response and reflection."""

    response: str
    reflection: str
    relevant_documents: List[Document]

In [None]:
class SelfRAG:
    """Implementation of Self-RAG (Retrieval-Augmented Generation with Self-Reflection)."""

    def __init__(
        self,
        llm_model: str = "llama3.2:3b",
        embedding_model: str = "llama3.2:3b",
        retrieval_top_k: int = 3,
        reflection_threshold: float = 0.7,
    ):
        """
        Initialize the Self-RAG system.

        Args:
            llm_model (str): Name of the Ollama LLM model to use
            embedding_model (str): Name of the Ollama embedding model to use
            retrieval_top_k (int): Number of documents to retrieve
            reflection_threshold (float): Confidence threshold for triggering self-reflection
        """
        self.llm = ChatOllama(model=llm_model, temperature=0.3)
        self.embeddings = OllamaEmbeddings(model=embedding_model)
        self.retrieval_top_k = retrieval_top_k
        self.reflection_threshold = reflection_threshold
        self.document_store: List[Document] = []

    def add_documents(self, documents: List[Document]) -> None:
        """Add documents to the document store."""
        self.document_store.extend(documents)
        print(
            f"Added {len(documents)} documents to the store. Total documents: {len(self.document_store)}"
        )

    def load_pdf_documents(
        self, file_path: str, metadata: Optional[Dict] = None
    ) -> List[Document]:
        """
        Load and extract text from a PDF file using Fitz (PyMuPDF).

        Args:
            file_path (str): Path to the PDF file
            metadata (Optional[Dict]): Optional metadata to attach to all documents from this file

        Returns:
            List of Document objects extracted from the PDF
        """
        if metadata is None:
            metadata = {}

        documents = []
        try:
            # Open the PDF file
            with fitz.open(file_path) as pdf_file:
                # Add file-specific metadata
                file_metadata = {
                    "source": file_path,
                    "total_pages": len(pdf_file),
                    **metadata,
                }

                # Extract text from each page
                for page_num, page in enumerate(pdf_file):
                    text = page.get_text()
                    if text.strip():  # Only add pages with content
                        page_metadata = {"page_number": page_num + 1, **file_metadata}
                        documents.append(
                            Document(page_content=text, metadata=page_metadata)
                        )

            print(f"Loaded {len(documents)} pages from PDF: {file_path}")
            return documents

        except Exception as e:
            print(f"Error loading PDF file {file_path}: {e}")
            raise

    def embed_text(self, text: str) -> List[float]:
        """Generate embeddings for a given text."""
        try:
            return self.embeddings.embed_query(text)
        except Exception as e:
            print(f"Error generating embeddings: {e}")
            raise

    def cosine_similarity(self, vec_a: List[float], vec_b: List[float]) -> float:
        """Calculate cosine similarity between two vectors."""
        return np.dot(vec_a, vec_b) / (norm(vec_a) * norm(vec_b))

    def retrieve_documents(self, query: str) -> RetrievalResult:
        """
        Retrieve relevant documents based on the query.

        Args:
            query (str): The input query to retrieve documents for

        Returns:
            RetrievalResult containing documents and their similarity scores
        """
        if not self.document_store:
            print("Document store is empty. No documents to retrieve.")
            return RetrievalResult(documents=[], scores=[])

        try:
            # Embed the query
            query_embedding = self.embed_text(query)

            # Calculate similarities with all documents
            similarities = []
            for doc in self.document_store:
                doc_embedding = self.embed_text(doc.page_content)
                similarity = self.cosine_similarity(query_embedding, doc_embedding)
                similarities.append(similarity)

            # Get top-k documents
            top_indices = np.argsort(similarities)[-self.retrieval_top_k :][::-1]
            top_documents = [self.document_store[i] for i in top_indices]
            top_scores = [similarities[i] for i in top_indices]

            return RetrievalResult(documents=top_documents, scores=top_scores)

        except Exception as e:
            print(f"Error during document retrieval: {e}")
            raise

    def generate_response(
        self, query: str, context_documents: List[Document]
    ) -> Tuple[str, float]:
        """
        Generate a response based on the query and context documents.

        Args:
            query (str): The input query
            context_documents (List[Document]): Relevant documents to use as context

        Returns:
            Tuple of (generated_response, confidence_score)
        """
        try:
            # Prepare context from documents
            context = "\n\n".join([doc.page_content for doc in context_documents])

            # Create prompt with context
            prompt = f"""You are a helpful AI assistant. Use the following context to answer the question.
            If you don't know the answer, say you don't know. Be truthful and accurate.
            
            Context:
            {context}
            
            Question: {query}
            
            Answer:"""

            # Generate response
            response = self.llm.invoke(prompt)
            response_text = response.content

            # Generate confidence score (simple implementation - could be enhanced)
            confidence_prompt = f"""On a scale from 0 to 1, how confident are you that the following answer is correct?
            Answer: {response_text}
            
            Provide only the numerical confidence score between 0 and 1:"""

            confidence_response = self.llm.invoke(confidence_prompt)

            try:
                confidence = float(confidence_response.content.strip())
            except ValueError:
                print("Could not parse confidence score, defaulting to 0.5")
                confidence = 0.5

            return response_text, confidence

        except Exception as e:
            print(f"Error during response generation: {e}")
            raise

    def generate_reflection(
        self, query: str, response: str, documents: List[Document]
    ) -> str:
        """
        Generate a self-reflection on the quality of the response.

        Args:
            query (str): The original query
            response (str): The generated response
            documents (List[Document]): Documents used for the response

        Returns:
            Reflection text analyzing the response quality
        """
        try:
            context = "\n\n".join([doc.page_content for doc in documents])

            reflection_prompt = f"""Analyze the following QA interaction and provide a critical reflection:
            
            Question: {query}
            
            Context used:
            {context}
            
            Answer provided:
            {response}
            
            Reflection (consider accuracy, completeness, relevance to context, and potential improvements):"""

            reflection_response = self.llm.invoke(reflection_prompt)
            return reflection_response.content

        except Exception as e:
            print(f"Error during reflection generation: {e}")
            return "Error generating reflection."

    def __call__(self, query: str) -> GenerationOutput:
        """
        Execute the full Self-RAG pipeline: retrieve, generate, reflect.

        Args:
            query (str): The input query

        Returns:
            GenerationOutput containing response, reflection, and relevant documents
        """
        try:
            # Step 1: Retrieve relevant documents
            retrieval_result = self.retrieve_documents(query)
            print(
                f"Retrieved {len(retrieval_result.documents)} documents for query: {query}"
            )

            # Step 2: Generate initial response
            response, confidence = self.generate_response(
                query, retrieval_result.documents
            )
            print(f"Generated response with confidence: {confidence:.2f}")

            # Step 3: Generate reflection if confidence is below threshold
            reflection = ""
            if confidence < self.reflection_threshold:
                reflection = self.generate_reflection(
                    query, response, retrieval_result.documents
                )
                print("Generated self-reflection due to low confidence")

            return GenerationOutput(
                response=response,
                reflection=reflection,
                relevant_documents=retrieval_result.documents,
            )

        except Exception as e:
            print(f"Error in Self-RAG pipeline: {e}")
            raise

In [None]:
rag = SelfRAG(llm_model="llama3.2:3b", embedding_model="llama3.2:3b")

In [None]:
pdf_docs = rag.load_pdf_documents(
    file_path="./dataset/health supplements/1. dietary supplements - for whom.pdf",
    metadata={"document_type": "research_paper"},
)

In [None]:
rag.add_documents(pdf_docs)

In [None]:
query = "How to lose weight?"

In [None]:
result = rag(query)

In [None]:
print("\n=== Response ===")
print(result.response)

if result.reflection:
    print("\n=== Reflection ===")
    print(result.reflection)

print("\n=== Relevant Documents ===")
for doc in result.relevant_documents:
    print(
        f"- {doc.page_content[:100]}... (Source: {doc.metadata.get('source', 'unknown')})"
    )