### Relevant Segment Extraction (RSE) for Enhanced RAG
Relevant Segment Extraction (RSE) technique improves the context quality in our RAG system. Rather than simply retrieving a collection of isolated chunks, we identify and reconstruct continuous segments of text that provide better context to our language model.

#### Key Concept
Relevant chunks tend to be clustered together within documents. By identifying these clusters and preserving their continuity, we provide more coherent context for the LLM to work with.

#### Work flow:
1. Process data(text from pdf -> chunking -> create vector store)
2. Calculate values for each chunks
3. Implement **maximum sum subarray algorithm** to find best segments(contiguous chunks with highest score)
4. Merge best contiguous chunks from index into text
5. Generation
6. Compare with Standard Retrieval
7. Evaluation

In [1]:
!pip install -q pymupdf
!pip install -q bitsandbytes

[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m24.1/24.1 MB[0m [31m100.3 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m72.9/72.9 MB[0m [31m33.8 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m363.4/363.4 MB[0m [31m3.3 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m13.8/13.8 MB[0m [31m118.3 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m24.6/24.6 MB[0m [31m99.2 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m883.7/883.7 kB[0m [31m57.5 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m664.8/664.8 MB[0m [31m2.4 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m211.5/211.5 MB[0m [31m5.0 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━

#### Import necessary libraries

In [2]:
import fitz
import os
import numpy as np
import json
import tqdm
import re

import torch
import torch.nn.functional as F
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoModel
from sklearn.metrics.pairwise import cosine_similarity

#### Define models

Using **Llama3.2 3B** for generation and **bge-base-en-v1.5** for embedding

Total VRAM needed is around 8GB

In [3]:
# Device configuration
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")

# Load generation model and tokenizer
print("Loading generation model...")
gen_tokenizer = AutoTokenizer.from_pretrained("unsloth/Llama-3.2-3B-Instruct")
gen_model = AutoModelForCausalLM.from_pretrained(
    "unsloth/Llama-3.2-3B-Instruct",
    torch_dtype="auto",
    device_map="auto"
)

# Load embedding model and tokenizer
print("Loading embedding model...")
embed_model = AutoModel.from_pretrained("BAAI/bge-base-en-v1.5")
embed_tokenizer = AutoTokenizer.from_pretrained("BAAI/bge-base-en-v1.5")

# Move embedding model to device
embed_model = embed_model.to(device)

print("All models loaded successfully!")

Using device: cuda
Loading generation model...


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


tokenizer_config.json: 0.00B [00:00, ?B/s]

tokenizer.json:   0%|          | 0.00/17.2M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/454 [00:00<?, ?B/s]

chat_template.jinja: 0.00B [00:00, ?B/s]

config.json:   0%|          | 0.00/890 [00:00<?, ?B/s]

model.safetensors.index.json: 0.00B [00:00, ?B/s]

Fetching 2 files:   0%|          | 0/2 [00:00<?, ?it/s]

model-00002-of-00002.safetensors:   0%|          | 0.00/1.46G [00:00<?, ?B/s]

model-00001-of-00002.safetensors:   0%|          | 0.00/4.97G [00:00<?, ?B/s]

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

generation_config.json:   0%|          | 0.00/234 [00:00<?, ?B/s]

Loading embedding model...


config.json:   0%|          | 0.00/777 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/438M [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/366 [00:00<?, ?B/s]

vocab.txt: 0.00B [00:00, ?B/s]

tokenizer.json: 0.00B [00:00, ?B/s]

special_tokens_map.json:   0%|          | 0.00/125 [00:00<?, ?B/s]

All models loaded successfully!


#### Extracting and chunking texts

In [4]:
def extract_text_from_pdf(pdf_path):
  pdf = fitz.open(pdf_path)
  text = ""

  for page in pdf:
    text += page.get_text()

  return text # str

def chunk_text(text, chunk_size, overlap):
  return [text[i:i+chunk_size] for i in range(0, len(text), chunk_size - overlap)] # list(str)

#### Define SimpleVectorStore

In [61]:
class SimpleVectorStore:
    """
    A lightweight vector store implementation using NumPy.
    """
    def __init__(self, dimension=768):
        """
        Initialize the vector store.

        Args:
            dimension (int): Dimension of embeddings
        """
        self.dimension = dimension
        self.vectors = []
        self.documents = []
        self.metadata = []

    def add_documents(self, documents, vectors=None, metadata=None):
        """
        Add documents to the vector store.

        Args:
            documents (List[str]): List of document chunks
            vectors (List[List[float]], optional): List of embedding vectors
            metadata (List[Dict], optional): List of metadata dictionaries
        """
        if vectors is None:
            vectors = [None] * len(documents)

        if metadata is None:
            metadata = [{} for _ in range(len(documents))]

        for doc, vec, meta in zip(documents, vectors, metadata):
            self.documents.append(doc)
            self.vectors.append(vec)
            self.metadata.append(meta)

    def search(self, query_vector, top_k=5):
        """
        Search for most similar documents.

        Args:
            query_vector (List[float]): Query embedding vector
            top_k (int): Number of results to return

        Returns:
            List[Dict]: List of results with documents, scores, and metadata
        """
        if not self.vectors or not self.documents:
            return []

        # Convert query vector to numpy array
        query_vector = query_vector.reshape(1, -1)

        # Calculate similarities
        similarities = []
        for i, vector in enumerate(self.vectors):
            if vector is not None:
                vector = vector.reshape(1, -1)
                sim = cosine_similarity(query_vector, vector)[0][0]  # Extract scalar
                similarities.append((i, sim))

        # Sort by similarity (descending)
        similarities.sort(key=lambda x: x[1], reverse=True)

        # Get top-k results
        results = []
        for i, score in similarities[:top_k]:
            results.append({
                "document": self.documents[i],
                "score": float(score),
                "metadata": self.metadata[i]
            })

        return results


#### Define **create_embedding** function

In [None]:
def create_embeddings(text):
    """
    Create embeddings for text using the loaded embedding model.

    Args:
        text: str or list of str - input text(s) to embed

    Returns:
        list[np.ndarray]: list of normalized embeddings, each of shape (dim,)
    """
    # Handle single string input
    is_single = isinstance(text, str)
    if is_single: text = [text]

    # Tokenize with error handling
    try:
        inputs = embed_tokenizer(
            text,
            padding=True,
            truncation=True,
            max_length=512,
            return_tensors="pt"
        ).to(device)
    except Exception as e:
        print(f"Tokenization error: {e}")
        return None

    # Generate embeddings with no gradient computation
    try:
        with torch.no_grad():
            output = embed_model(**inputs)
            # Use CLS token embedding [CLS] at position 0
            cls_emb = output.last_hidden_state[:, 0, :]
            # L2 normalize for cosine similarity
            emb_normalized = F.normalize(cls_emb, p=2, dim=1)

        # Convert to list of np.ndarray, each of shape (dim,)
        embeddings = [emb.cpu().numpy() for emb in emb_normalized]

        # Return single embedding if input was single string
        return embeddings

    except Exception as e:
        print(f"Embedding generation error: {e}")
        return None

#### Define document processing

In [None]:
def process_document(pdf_path, chunk_size = 800):
  print("Extracting text from PDF...")
  text = extract_text_from_pdf(pdf_path)

  print("Chunking text...")
  text_chunks = chunk_text(text, chunk_size = chunk_size, overlap = 200)

  print("Creating embeddings for chunks...")
  embeddings = create_embeddings(text_chunks)

  vector_store = SimpleVectorStore()

  metadata = [{"chunk_index": i, "source": pdf_path} for i in range(len(text_chunks))]

  vector_store.add_documents(text_chunks, embeddings, metadata)

  doc_info = {
      "chunks": text_chunks,
      "source": pdf_path
  }

  return text_chunks, vector_store, doc_info

#### **RSE Core Algorithm: Computing Chunk Values and Finding Best Segments**


In [None]:
def calculate_chunk_values(query, text_chunks, vector_store, irrelevant_chunk_penalty=0.2):
  """
  Calculate chunk values by combining relevance and position.

  Args:
    query (str): Query text
    chunks (List[str]): List of document chunks
    vector_store (SimpleVectorStore): Vector store containing the chunks
    irrelevant_chunk_penalty (float): Penalty for irrelevant chunks

  Returns:
    List[float]: List of chunk values
  """
  query_embedding = create_embeddings(query)[0] #nda(dim,)

  num_chunks = len(text_chunks)

  results = vector_store.search(query_embedding, top_k = num_chunks)

  relevance_score = {result["metadata"]["chunk_index"]: result["score"] for result in results}

  chunk_values = []

  for i in range(num_chunks):
    score = relevance_score.get(i,0.0)
    value = score - irrelevant_chunk_penalty
    chunk_values.append(value)

  return chunk_values

In [65]:
def find_best_segments(chunk_values, max_segment_length=20, total_max_length=30, min_segment_value=0.2):
    """
    Find the best segments using a variant of the maximum sum subarray algorithm.

    Args:
        chunk_values (List[float]): Values for each chunk
        max_segment_length (int): Maximum length of a single segment
        total_max_length (int): Maximum total length across all segments
        min_segment_value (float): Minimum value for a segment to be considered

    Returns:
        List[Tuple[int, int]]: List of (start, end) indices for best segments
    """
    print("Finding optimal continuous text segments...")

    best_segments = []
    segment_scores = []
    total_included_chunks = 0

    # Keep finding segments until we hit our limits
    while total_included_chunks < total_max_length:
        best_score = min_segment_value  # Minimum threshold for a segment
        best_segment = None

        # Try each possible starting position
        for start in range(len(chunk_values)):
            # Skip if this start position is already in a selected segment
            if any(start >= s[0] and start < s[1] for s in best_segments):
                continue

            # Try each possible segment length
            for length in range(1, min(max_segment_length, len(chunk_values) - start) + 1):
                end = start + length

                # Skip if end position is already in a selected segment
                if any(end > s[0] and end <= s[1] for s in best_segments):
                    continue

                # Calculate segment value as sum of chunk values
                segment_value = sum(chunk_values[start:end])

                # Update best segment if this one is better
                if segment_value > best_score:
                    best_score = segment_value
                    best_segment = (start, end)

        # If we found a good segment, add it
        if best_segment:
            best_segments.append(best_segment)
            segment_scores.append(best_score)
            total_included_chunks += best_segment[1] - best_segment[0]
            print(f"Found segment {best_segment} with score {best_score:.4f}")
        else:
            # No more good segments to find
            break

    # Sort segments by their starting position for readability
    best_segments = sorted(best_segments, key=lambda x: x[0])

    return best_segments, segment_scores


#### Reconstructing and Using Segments for RAG

In [66]:
def reconstruct_segments(chunks, best_segments):
    """
    Reconstruct text segments based on chunk indices.

    Args:
        chunks (List[str]): List of all document chunks
        best_segments (List[Tuple[int, int]]): List of (start, end) indices for segments

    Returns:
        List[str]: List of reconstructed text segments
    """
    reconstructed_segments = []  # Initialize an empty list to store the reconstructed segments

    for start, end in best_segments:
        # Join the chunks in this segment to form the complete segment text
        segment_text = " ".join(chunks[start:end])
        # Append the segment text and its range to the reconstructed_segments list
        reconstructed_segments.append({
            "text": segment_text,
            "segment_range": (start, end),
        })

    return reconstructed_segments  # Return the list of reconstructed text segments


In [67]:
def format_segments_for_context(segments):
    """
    Format segments into a context string for the LLM.

    Args:
        segments (List[Dict]): List of segment dictionaries

    Returns:
        str: Formatted context text
    """
    context = []  # Initialize an empty list to store the formatted context

    for i, segment in enumerate(segments):
        # Create a header for each segment with its index and chunk range
        segment_header = f"SEGMENT {i+1} (Chunks {segment['segment_range'][0]}-{segment['segment_range'][1]-1}):"
        context.append(segment_header)  # Add the segment header to the context list
        context.append(segment['text'])  # Add the segment text to the context list
        context.append("-" * 80)  # Add a separator line for readability

    # Join all elements in the context list with double newlines and return the result
    return "\n\n".join(context)


In [None]:
def gen(system_prompt, user_prompt, temperature=0):
    text = gen_tokenizer.apply_chat_template(
        conversation=[
            {"role": "system", "content": system_prompt},
            {"role": "user", "content": user_prompt}
        ],
        tokenize=False,
        add_generation_prompt=True
    )

    model_inputs = gen_tokenizer([text], return_tensors="pt").to(device)

    if temperature > 0:
        generated_ids = gen_model.generate(
            **model_inputs,
            do_sample=True,
            temperature=temperature,
            max_new_tokens=1024
        )
    else:
        generated_ids = gen_model.generate(
            **model_inputs,
            do_sample=False,
            temperature= None,
            top_p=None,  # disable top_p and temperature to not receive warning
            max_new_tokens=1024
        )
    generated_ids = [
        output_ids[len(input_ids):]
        for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
    ]
    response = gen_tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
    return response

#### Generating Responses with RSE Context

In [69]:
def generate_response(query, context, model="unsloth/Llama-3.2-3B-Instruct"):
    """
    Generate a response based on the query and context.

    Args:
        query (str): User query
        context (str): Context text from relevant segments
        model (str): LLM model to use

    Returns:
        str: Generated response
    """
    print("Generating response using relevant segments as context...")

    # Define the system prompt to guide the AI's behavior
    system_prompt = """You are a helpful assistant that answers questions based on the provided context.
    The context consists of document segments that have been retrieved as relevant to the user's query.
    Use the information from these segments to provide a comprehensive and accurate answer.
    If the context doesn't contain relevant information to answer the question, say so clearly."""

    # Create the user prompt by combining the context and the query
    user_prompt = f"""
Context:
{context}

Question: {query}

Please provide a helpful answer based on the context provided.
"""

    # Generate the response using the specified model
    response = gen(system_prompt, user_prompt)

    # Return the generated response content
    return response


#### Completed pipeline for RSE

In [70]:
def rag_with_rse(pdf_path, query, chunk_size=800, irrelevant_chunk_penalty=0.2):
    """
    Complete RAG pipeline with Relevant Segment Extraction.

    Args:
        pdf_path (str): Path to the document
        query (str): User query
        chunk_size (int): Size of chunks
        irrelevant_chunk_penalty (float): Penalty for irrelevant chunks

    Returns:
        Dict: Result with query, segments, and response
    """
    print("\n=== STARTING RAG WITH RELEVANT SEGMENT EXTRACTION ===")
    print(f"Query: {query}")

    # Process the document to extract text, chunk it, and create embeddings
    chunks, vector_store, doc_info = process_document(pdf_path, chunk_size)

    # Calculate relevance scores and chunk values based on the query
    print("\nCalculating relevance scores and chunk values...")
    chunk_values = calculate_chunk_values(query, chunks, vector_store, irrelevant_chunk_penalty)

    # Find the best segments of text based on chunk values
    best_segments, scores = find_best_segments(
        chunk_values,
        max_segment_length=20,
        total_max_length=30,
        min_segment_value=0.2
    )

    # Reconstruct text segments from the best chunks
    print("\nReconstructing text segments from chunks...")
    segments = reconstruct_segments(chunks, best_segments)

    # Format the segments into a context string for the language model
    context = format_segments_for_context(segments)

    # Generate a response from the language model using the context
    response = generate_response(query, context)

    # Compile the result into a dictionary
    result = {
        "query": query,
        "segments": segments,
        "response": response
    }

    print("\n=== FINAL RESPONSE ===")
    print(response)

    return result


#### Comparing with Standard Retrieval

In [None]:
def standard_top_k_retrieval(pdf_path, query, k=10, chunk_size=800):
    """
    Standard RAG with top-k retrieval.

    Args:
        pdf_path (str): Path to the document
        query (str): User query
        k (int): Number of chunks to retrieve
        chunk_size (int): Size of chunks

    Returns:
        Dict: Result with query, chunks, and response
    """
    print("\n=== STARTING STANDARD TOP-K RETRIEVAL ===")
    print(f"Query: {query}")

    # Process the document to extract text, chunk it, and create embeddings
    chunks, vector_store, doc_info = process_document(pdf_path, chunk_size)

    # Create an embedding for the query
    print("Creating query embedding and retrieving chunks...")
    query_embedding = create_embeddings([query])[0]

    # Retrieve the top-k most relevant chunks based on the query embedding
    results = vector_store.search(query_embedding, top_k=k)
    retrieved_chunks = [result["document"] for result in results]

    # Format the retrieved chunks into a context string
    context = "\n\n".join([
        f"CHUNK {i+1}:\n{chunk}"
        for i, chunk in enumerate(retrieved_chunks)
    ])

    # Generate a response from the language model using the context
    response = generate_response(query, context)

    # Compile the result into a dictionary
    result = {
        "query": query,
        "chunks": retrieved_chunks,
        "response": response
    }

    print("\n=== FINAL RESPONSE ===")
    print(response)

    return result


#### Evaluation

In [72]:
def evaluate_methods(pdf_path, query, reference_answer=None):
    """
    Compare RSE with standard top-k retrieval.

    Args:
        pdf_path (str): Path to the document
        query (str): User query
        reference_answer (str, optional): Reference answer for evaluation
    """
    print("\n========= EVALUATION =========\n")

    # Run the RAG with Relevant Segment Extraction (RSE) method
    rse_result = rag_with_rse(pdf_path, query)

    # Run the standard top-k retrieval method
    standard_result = standard_top_k_retrieval(pdf_path, query)

    # If a reference answer is provided, evaluate the responses
    if reference_answer:
        print("\n=== COMPARING RESULTS ===")

        # Create an evaluation prompt to compare the responses against the reference answer
        evaluation_prompt = f"""
            Query: {query}

            Reference Answer:
            {reference_answer}

            Response from Standard Retrieval:
            {standard_result["response"]}

            Response from Relevant Segment Extraction:
            {rse_result["response"]}

            Compare these two responses against the reference answer. Which one is:
            1. More accurate and comprehensive
            2. Better at addressing the user's query
            3. Less likely to include irrelevant information

            Explain your reasoning for each point.
        """

        print("Evaluating responses against reference answer...")
        system_prompt = "You are an objective evaluator of RAG system responses."
        # Generate the evaluation using the specified model
        evaluation = gen(system_prompt, evaluation_prompt)

        # Print the evaluation results
        print("\n=== EVALUATION RESULTS ===")
        print(evaluation)

    # Return the results of both methods
    return {
        "rse_result": rse_result,
        "standard_result": standard_result
    }


In [None]:
 # Load the validation data from a JSON file
with open('data/val.json') as f:
    data = json.load(f)

# Extract the first query from the validation data
query = data[0]['question']

# Extract the reference answer from the validation data
reference_answer = data[0]['ideal_answer']

# pdf_path
pdf_path = "data/AI_Information.pdf"

# Run evaluation
results = evaluate_methods(pdf_path, query, reference_answer)





=== STARTING RAG WITH RELEVANT SEGMENT EXTRACTION ===
Query: What is 'Explainable AI' and why is it considered important?
Extracting text from PDF...
Chunking text...
Creating embeddings for chunks...


  return forward_call(*args, **kwargs)



Calculating relevance scores and chunk values...
Finding optimal continuous text segments...
Found segment (32, 52) with score 9.0159
Found segment (0, 20) with score 8.6664

Reconstructing text segments from chunks...
Generating response using relevant segments as context...

=== FINAL RESPONSE ===
Based on the provided context, Explainable AI (XAI) refers to techniques and methods developed to make AI systems more transparent and understandable. The goal of XAI is to provide insights into how AI models make decisions, enhancing trust and accountability in AI systems.

Explainable AI is considered important because it addresses the challenges of "black box" AI systems, which are often difficult to understand and interpret. By providing explanations for AI decisions, XAI helps users assess the reliability and fairness of AI systems, which is crucial for building trust in AI.

XAI is essential for several reasons:

1. **Transparency**: Explainable AI provides insights into how AI model

  return forward_call(*args, **kwargs)


Creating query embedding and retrieving chunks...
Generating response using relevant segments as context...

=== FINAL RESPONSE ===
Based on the provided context, Explainable AI (XAI) refers to techniques that aim to make AI systems more transparent and understandable, enabling users to assess their fairness and accuracy. The primary goal of XAI is to provide insights into how AI models make decisions, enhancing trust and accountability in AI systems.

Explainable AI is considered important for several reasons:

1. **Building trust**: By providing transparency and explainability, XAI helps users understand how AI models arrive at their decisions, which is essential for building trust in AI systems.
2. **Addressing fairness and accuracy**: XAI techniques can help identify biases and errors in AI systems, ensuring that they are fair and accurate.
3. **Accountability**: Explainable AI enables accountability by providing a clear understanding of how AI decisions are made, which is crucial 