In [None]:
import re
import fitz
import math
import numpy as np
from sklearn.metrics.pairwise import cosine_similarity
from langchain_ollama import ChatOllama, OllamaEmbeddings
from statistics import mean, stdev
from typing import List, Dict, Tuple, Optional, Callable

In [None]:
class SemanticChunking:
    """
    A class to implement semantic chunking.
    """
    
    def __init__(self, ollama_model: str = "llama3.2:3b"):
        """
        Initialize the SemanticChunking with an Ollama model.
        
        Args:
            ollama_model (str): The name of the Ollama model to use for chat and embeddings.
        """
        self.ollama_model = ollama_model
        self.embedding_model = OllamaEmbeddings(model=ollama_model)
        self.chat_model = ChatOllama(model=ollama_model)
        self.sentences: List[str] = []
        self.sentence_embeddings: np.ndarray = None
        self.similarity_diffs: List[float] = []
        self.chunks: List[str] = []
        self.chunk_embeddings: np.ndarray = None
        
    def extract_text_from_pdf(self, pdf_path: str) -> str:
        """
        Extract text from a PDF file using PyMuPDF.
        
        Args:
            pdf_path (str): Path to the PDF file.
            
        Returns:
            str: Extracted text from the PDF.
            
        Raises:
            FileNotFoundError: If the PDF file doesn't exist.
            Exception: For other extraction errors.
        """
        try:
            text = ""
            with fitz.open(pdf_path) as doc:
                for page in doc:
                    text += page.get_text()
            return text
        except FileNotFoundError:
            raise FileNotFoundError(f"The PDF file at {pdf_path} was not found.")
        except Exception as e:
            raise Exception(f"Error extracting text from PDF: {str(e)}")
    
    def split_into_sentences(self, text: str) -> List[str]:
        """
        Split text into sentences using a simple regex pattern.
        
        Args:
            text (str): The input text to split.
            
        Returns:
            List[str]: List of sentences.
        """
        # This is a simple sentence splitter - you might want to use a more sophisticated one
        sentences = re.split(r'(?<!\w\.\w.)(?<![A-Z][a-z]\.)(?<=\.|\?|\!)\s', text)
        # Remove empty strings and strip whitespace
        sentences = [s.strip() for s in sentences if s.strip()]
        return sentences
    
    def generate_sentence_embeddings(self, sentences: List[str]) -> np.ndarray:
        """
        Generate embeddings for each sentence using Ollama embeddings.
        
        Args:
            sentences (List[str]): List of sentences to embed.
            
        Returns:
            np.ndarray: Array of sentence embeddings (n_sentences x embedding_dim).
        """
        embeddings = self.embedding_model.embed_documents(sentences)
        return np.array(embeddings)
    
    def calculate_similarity_differences(self) -> List[float]:
        """
        Calculate cosine similarity differences between consecutive sentences.
        
        Returns:
            List[float]: List of similarity differences between consecutive sentences.
            
        Raises:
            ValueError: If sentence embeddings haven't been generated yet.
        """
        if self.sentence_embeddings is None:
            raise ValueError("Sentence embeddings must be generated first.")
            
        similarity_diffs = []
        for i in range(len(self.sentence_embeddings) - 1):
            # Calculate cosine similarity between current and next sentence
            sim = cosine_similarity(
                self.sentence_embeddings[i].reshape(1, -1),
                self.sentence_embeddings[i+1].reshape(1, -1)
            )[0][0]
            similarity_diffs.append(1 - sim)  # Using difference for breakpoint detection
            
        return similarity_diffs
    
    def find_breakpoints(self, method: str = "percentile", **kwargs) -> List[int]:
        """
        Find breakpoints between chunks using different methods.
        
        Args:
            method (str): Method to use for finding breakpoints. Options:
                - "percentile": Use a percentile threshold
                - "standard_deviation": Use mean + n*std_dev
                - "interquartile": Use IQR method
            **kwargs: Additional method-specific parameters.
            
        Returns:
            List[int]: Indices of breakpoints between sentences.
            
        Raises:
            ValueError: If an invalid method is specified or similarity diffs not calculated.
        """
        if not self.similarity_diffs:
            raise ValueError("Similarity differences must be calculated first.")
            
        diffs = self.similarity_diffs
        
        if method == "percentile":
            percentile = kwargs.get('percentile', 95)
            threshold = np.percentile(diffs, percentile)
            breakpoints = [i for i, diff in enumerate(diffs) if diff > threshold]
            
        elif method == "standard_deviation":
            n_std = kwargs.get('n_std', 1.5)
            mean_diff = mean(diffs)
            std_diff = stdev(diffs) if len(diffs) > 1 else 0
            threshold = mean_diff + n_std * std_diff
            breakpoints = [i for i, diff in enumerate(diffs) if diff > threshold]
            
        elif method == "interquartile":
            q1 = np.percentile(diffs, 25)
            q3 = np.percentile(diffs, 75)
            iqr = q3 - q1
            threshold = q3 + 1.5 * iqr
            breakpoints = [i for i, diff in enumerate(diffs) if diff > threshold]
            
        else:
            raise ValueError(f"Invalid method: {method}. Choose from 'percentile', 'standard_deviation', or 'interquartile'.")
            
        return breakpoints
    
    def split_into_semantic_chunks(self, breakpoints: List[int]) -> List[str]:
        """
        Split sentences into chunks based on breakpoints.
        
        Args:
            breakpoints (List[int]): Indices where chunks should be split.
            
        Returns:
            List[str]: List of text chunks.
            
        Raises:
            ValueError: If sentences haven't been extracted yet.
        """
        if not self.sentences:
            raise ValueError("Sentences must be extracted first.")
            
        # Add start and end points
        breakpoints = sorted(breakpoints)
        breakpoints = [0] + [bp + 1 for bp in breakpoints] + [len(self.sentences)]
        
        chunks = []
        for i in range(len(breakpoints) - 1):
            start = breakpoints[i]
            end = breakpoints[i+1]
            chunk = ' '.join(self.sentences[start:end])
            chunks.append(chunk)
            
        return chunks
    
    def generate_chunk_embeddings(self, chunks: List[str]) -> np.ndarray:
        """
        Generate embeddings for each chunk.
        
        Args:
            chunks (List[str]): List of text chunks to embed.
            
        Returns:
            np.ndarray: Array of chunk embeddings.
        """
        embeddings = self.embedding_model.embed_documents(chunks)
        return np.array(embeddings)
    
    def semantic_search(self, query: str, top_k: int = 3) -> List[Tuple[int, float, str]]:
        """
        Perform semantic search to find the most relevant chunks to a query.
        
        Args:
            query (str): The search query.
            top_k (int): Number of top results to return.
            
        Returns:
            List[Tuple[int, float, str]]: List of (index, similarity_score, chunk) tuples.
            
        Raises:
            ValueError: If chunks or chunk embeddings haven't been generated.
        """
        if not self.chunks or self.chunk_embeddings is None:
            raise ValueError("Chunks and chunk embeddings must be generated first.")
            
        # Embed the query
        query_embedding = np.array(self.embedding_model.embed_query(query))
        
        # Calculate cosine similarities
        similarities = cosine_similarity(
            query_embedding.reshape(1, -1),
            self.chunk_embeddings
        )[0]
        
        # Get top_k results
        top_indices = np.argsort(similarities)[-top_k:][::-1]
        results = [(i, similarities[i], self.chunks[i]) for i in top_indices]
        
        return results
    
    def generate_response(self, query: str, context_chunks: List[str]) -> str:
        """
        Generate a response to a query using the provided context chunks.
        
        Args:
            query (str): The user's query.
            context_chunks (List[str]): Relevant context chunks to use for generation.
            
        Returns:
            str: The generated response.
        """
        # Combine context chunks into a single context
        context = "\n\n".join(context_chunks)
        
        # Create the prompt
        prompt = f"""Use the following context to answer the question at the end. 
        If you don't know the answer, just say that you don't know, don't try to make up an answer.

        Context:
        {context}

        Question: {query}

        Answer:"""
        
        # Generate the response
        response = self.chat_model.invoke(prompt)
        return response.content if hasattr(response, 'content') else str(response)
    
    def process_document(self, pdf_path: str, chunk_method: str = "percentile", **chunk_kwargs):
        """
        Process a document through the entire pipeline.
        
        Args:
            pdf_path (str): Path to the PDF file.
            chunk_method (str): Method to use for semantic chunking.
            **chunk_kwargs: Additional arguments for the chunking method.
        """
        # Step 1: Extract text from PDF
        text = self.extract_text_from_pdf(pdf_path)
        
        # Step 2: Split into sentences
        self.sentences = self.split_into_sentences(text)
        
        # Step 3: Generate sentence embeddings
        self.sentence_embeddings = self.generate_sentence_embeddings(self.sentences)
        
        # Step 4: Calculate similarity differences
        self.similarity_diffs = self.calculate_similarity_differences()
        
        # Step 5: Find breakpoints
        breakpoints = self.find_breakpoints(method=chunk_method, **chunk_kwargs)
        
        # Step 6: Split into semantic chunks
        self.chunks = self.split_into_semantic_chunks(breakpoints)
        
        # Step 7: Generate chunk embeddings
        self.chunk_embeddings = self.generate_chunk_embeddings(self.chunks)
    
    def query(self, question: str, top_k: int = 3) -> str:
        """
        Query the system with a question and get a response.
        
        Args:
            question (str): The question to ask.
            top_k (int): Number of chunks to retrieve for context.
            
        Returns:
            str: The generated response.
        """
        # Step 1: Semantic search to find relevant chunks
        relevant_chunks = self.semantic_search(question, top_k=top_k)
        context_chunks = [chunk for _, _, chunk in relevant_chunks]
        
        # Step 2: Generate response using the context
        response = self.generate_response(question, context_chunks)
        
        return response

In [None]:
pdf_path = "./dataset/health supplements/1. dietary supplements - for whom.pdf"

In [None]:
rag = SemanticChunking(ollama_model="llama3.2:3b")

In [None]:
rag.process_document(pdf_path, chunk_method="percentile", percentile=95)

In [None]:
query = "What are the main findings of this document?"

In [None]:
response = rag.query(query)

In [None]:
print(f"Response to '{query}':\n{response}")