<a href="https://colab.research.google.com/github/tivon-x/all-rag-techniques/blob/main/09_rse.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# 面向增强型 RAG 的相关片段提取（RSE）

在本笔记本中，我们实现了一种相关片段提取（Relevant Segment Extraction, RSE）技术，以提升我们 RAG 系统中的上下文质量。我们不再简单地检索一组孤立的片段，而是识别并重构连续的文本段落，为我们的语言模型提供更好的上下文。

## 核心概念

相关的片段往往会在文档中聚集在一起。通过识别这些聚集区域并保留它们的连续性，我们为大语言模型提供了更连贯的上下文。

## 概述

相关段落提取（RSE）是一种从检索到的文本块中重建连续多块文本段的方法。该步骤发生在向量搜索（以及可选的重排序）之后，但在将检索到的上下文呈现给大语言模型（LLM）之前。这种方法确保相邻的文本块按照它们在原始文档中的顺序呈现给LLM。同时，它还会加入那些未被标记为相关但夹在高相关块之间的文本块，从而进一步改善提供给LLM的上下文。如本笔记本末尾所示的评估结果，这种方法能显著提升检索增强生成（RAG）的性能。

## 动机

在为RAG对文档进行分块时，选择合适的分块大小需要在多种权衡之间进行考量。大块文本能为LLM提供比小块更好的上下文，但也使得精确检索特定信息变得更加困难。有些查询（如简单的知识型问题）最适合用小块处理，而另一些查询（如高层次的问题）则需要非常大的块。有些查询可以通过文档中的一句话来回答，而另一些查询则需要整个章节才能妥善解答。大多数现实世界的RAG应用场景都会面临这些类型查询的组合。

我们真正需要的是一个更动态的系统：当只需要短块时能检索短块，而在需要时也能检索非常大的块。那么，我们该如何实现这一点呢？

我们的解决方案源于一个简单的洞察：**相关的文本块往往在其原始文档中聚集在一起**。

## 关键组件

#### 文本块键值存储
RSE需要能够通过doc_id和chunk_index作为键快速从数据库中检索文本块内容。这是因为并非所有需要包含在给定段落中的文本块都会在初始搜索结果中被返回。因此，除了向量数据库外，可能还需要使用某种键值存储。

## 方法细节

#### 文档分块
可以使用标准的文档分块方法。这里唯一的特殊要求是文档分块时不能有重叠。这样我们就可以通过连接文本块来重建文档的各个部分（即段落）。

#### RSE优化
在完成标准的文本块检索过程（理想情况下包括重排序步骤）后，就可以开始RSE过程。第一步是将相似度得分和相关性排名结合起来得到文本块值（chunk value）。这比单独使用相似度得分或单独使用排名提供了更稳健的起点。

然后我们从每个文本块的值中减去一个常数阈值（假设为0.2），使得不相关的文本块具有负值（最低可达-0.2），而相关的文本块具有正值（最高可达0.8）。通过这种方式计算文本块值，我们可以将段落值定义为文本块值的总和。

例如，假设文档中0-4号文本块的值如下：[-0.2, -0.2, 0.4, 0.8, -0.1]。那么仅包含2-3号文本块的段落值就是0.4+0.8=1.2。

寻找最佳段落就变成了一个带约束的最大子数组和问题。我们使用带有少量启发式规则的暴力搜索来提高效率。这通常耗时约5-10毫秒。

![RSE](https://github.com/NirDiamant/RAG_Techniques/blob/main/images/relevant-segment-extraction.svg?raw=1)

## 环境配置

In [1]:
# fitz库需要从pymudf那里安装
%pip install --quiet --force-reinstall pymupdf

[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m24.1/24.1 MB[0m [31m87.1 MB/s[0m eta [36m0:00:00[0m
[?25h

In [2]:
import fitz
import os
import numpy as np
import json
from openai import OpenAI
import re

## 从 PDF 文件提取文本

In [3]:
def extract_text_from_pdf(pdf_path):
    """
    Extracts text from a PDF file and prints the first `num_chars` characters.

    Args:
    pdf_path (str): Path to the PDF file.

    Returns:
    str: Extracted text from the PDF.
    """
    # Open the PDF file
    mypdf = fitz.open(pdf_path)
    all_text = ""  # Initialize an empty string to store the extracted text

    # Iterate through each page in the PDF
    for page_num in range(mypdf.page_count):
        page = mypdf[page_num]  # Get the page
        text = page.get_text("text")  # Extract text from the page
        all_text += text  # Append the extracted text to the all_text string

    return all_text  # Return the extracted text

## 分块
一旦我们提取了文本，我们会将其分成更小的、有重叠部分的块，以提高检索的准确性。

In [4]:
def chunk_text(text, chunk_size=800, overlap=0):
    """
    Split text into non-overlapping chunks.
    For RSE, we typically want non-overlapping chunks so we can reconstruct segments properly.

    Args:
        text (str): Input text to chunk
        chunk_size (int): Size of each chunk in characters
        overlap (int): Overlap between chunks in characters

    Returns:
        List[str]: List of text chunks
    """
    chunks = []

    # Simple character-based chunking
    for i in range(0, len(text), chunk_size - overlap):
        chunk = text[i:i + chunk_size]
        if chunk:  # Ensure we don't add empty chunks
            chunks.append(chunk)

    return chunks

## OpenAI API Client

In [5]:
# colab环境
from google.colab import userdata
# 使用火山引擎
api_key = userdata.get("ARK_API_KEY")
base_url = userdata.get("ARK_BASE_URL")

In [6]:
client = OpenAI(
    base_url=base_url,
    api_key=api_key
)

## 构建向量数据库

In [7]:
class SimpleVectorStore:
    """
    A lightweight vector store implementation using NumPy.
    """
    def __init__(self, dimension=1536):
        """
        Initialize the vector store.

        Args:
            dimension (int): Dimension of embeddings
        """
        self.dimension = dimension
        self.vectors = []
        self.documents = []
        self.metadata = []

    def add_documents(self, documents, vectors=None, metadata=None):
        """
        Add documents to the vector store.

        Args:
            documents (List[str]): List of document chunks
            vectors (List[List[float]], optional): List of embedding vectors
            metadata (List[Dict], optional): List of metadata dictionaries
        """
        if vectors is None:
            vectors = [None] * len(documents)

        if metadata is None:
            metadata = [{} for _ in range(len(documents))]

        for doc, vec, meta in zip(documents, vectors, metadata):
            self.documents.append(doc)
            self.vectors.append(vec)
            self.metadata.append(meta)

    def search(self, query_vector, top_k=5):
        """
        Search for most similar documents.

        Args:
            query_vector (List[float]): Query embedding vector
            top_k (int): Number of results to return

        Returns:
            List[Dict]: List of results with documents, scores, and metadata
        """
        if not self.vectors or not self.documents:
            return []

        # Convert query vector to numpy array
        query_array = np.array(query_vector)

        # Calculate similarities
        similarities = []
        for i, vector in enumerate(self.vectors):
            if vector is not None:
                # Compute cosine similarity
                similarity = np.dot(query_array, vector) / (
                    np.linalg.norm(query_array) * np.linalg.norm(vector)
                )
                similarities.append((i, similarity))

        # Sort by similarity (descending)
        similarities.sort(key=lambda x: x[1], reverse=True)

        # Get top-k results
        results = []
        for i, score in similarities[:top_k]:
            results.append({
                "document": self.documents[i],
                "score": float(score),
                "metadata": self.metadata[i]
            })

        return results

## 创建嵌入向量

In [8]:
def create_embeddings(texts, model="doubao-embedding-text-240715", batch_size=100):
    """
    Generate embeddings for texts.

    Args:
        texts (List[str]): List of texts to embed
        model (str): Embedding model to use
        batch_size (int): size of a batch, for processing. Default is 100

    Returns:
        List[List[float]]: List of embedding vectors
    """
    if not texts:
        return []  # Return an empty list if no texts are provided

    all_embeddings = []  # Initialize a list to store all embeddings

    for i in range(0, len(texts), batch_size):
        batch = texts[i:i + batch_size]  # Get the current batch of texts

        # Create embeddings for the current batch using the specified model
        response = client.embeddings.create(
            input=batch,
            model=model
        )

        # Extract embeddings from the response
        batch_embeddings = [item.embedding for item in response.data]
        all_embeddings.extend(batch_embeddings)  # Add the batch embeddings to the list

    return all_embeddings  # Return the list of all embeddings

## 处理文档

In [9]:
def process_document(pdf_path, chunk_size=800):
    """
    Process a document for use with RSE.

    Args:
        pdf_path (str): Path to the PDF document
        chunk_size (int): Size of each chunk in characters

    Returns:
        Tuple[List[str], SimpleVectorStore, Dict]: Chunks, vector store, and document info
    """
    print("Extracting text from document...")
    # Extract text from the PDF file
    text = extract_text_from_pdf(pdf_path)

    print("Chunking text into non-overlapping segments...")
    # 将提取的文本分割成不重叠的段落
    chunks = chunk_text(text, chunk_size=chunk_size, overlap=0)
    print(f"Created {len(chunks)} chunks")

    print("Generating embeddings for chunks...")
    # Generate embeddings for the text chunks
    chunk_embeddings = create_embeddings(chunks)

    # Create an instance of the SimpleVectorStore
    vector_store = SimpleVectorStore()

    # Add documents with metadata (including chunk index for later reconstruction)
    metadata = [{"chunk_index": i, "source": pdf_path} for i in range(len(chunks))]
    vector_store.add_documents(chunks, chunk_embeddings, metadata)

    # Track original document structure for segment reconstruction
    doc_info = {
        "chunks": chunks,
        "source": pdf_path,
    }

    return chunks, vector_store, doc_info

## RSE 核心算法：计算文本块值并找到最佳段落
现在我们已经具备了处理文档和为其片段生成嵌入向量的必要函数，可以实现 RSE 的核心算法。

In [11]:
def calculate_chunk_values(query, chunks, vector_store, irrelevant_chunk_penalty=0.2):
    """
    结合相关性和排名计算文本块值.

    Args:
        query (str): Query text
        chunks (List[str]): List of document chunks
        vector_store (SimpleVectorStore): Vector store containing the chunks
        irrelevant_chunk_penalty (float): Penalty for irrelevant chunks

    Returns:
        List[float]: List of chunk values
    """
    # Create query embedding
    query_embedding = create_embeddings([query])[0]

    # Get all chunks with similarity scores
    num_chunks = len(chunks)
    results = vector_store.search(query_embedding, top_k=num_chunks)

    # Create a mapping of chunk_index to relevance score
    relevance_scores = {
        result["metadata"]["chunk_index"]: result["score"] for result in results
    }

    # 对排名进行衰减，避免过大
    decay_rate = 30

    # Calculate chunk values (relevance score minus penalty)
    chunk_values = []
    for i in range(num_chunks):
        # Get relevance score or default to 0 if not in results
        score = relevance_scores.get(i, 0.0)
        # 融合排名和相似度得分
        fusion_value = np.exp(-i / decay_rate) * score
        # 应用惩罚以将值转换为使不相关片段具有负值
        value = fusion_value - irrelevant_chunk_penalty
        chunk_values.append(value)

    return chunk_values

In [12]:
def find_best_segments(chunk_values, max_segment_length=20, total_max_length=30, min_segment_value=0.2):
    """
    使用最大子数组和算法的一个变体来找到最佳段落。

    Args:
        chunk_values (List[float]): Values for each chunk
        max_segment_length (int): Maximum length of a single segment.
          The lenght of a segment is the number of chunks in this segment.
        total_max_length (int): Maximum total length across all segments
        min_segment_value (float): Minimum value for a segment to be considered

    Returns:
        Tuple[List[Tuple[int, int]], List[float]]: List of (start, end) indices for best segments and List of score of each segment
    """
    print("Finding optimal continuous text segments...")

    # a segment: (start, end)
    best_segments = []
    segment_scores = []
    total_included_chunks = 0

    # Keep finding segments until we hit our limits
    while total_included_chunks < total_max_length:
        best_score = min_segment_value  # Minimum threshold for a segment
        best_segment = None

        # Try each possible starting position
        for start in range(len(chunk_values)):
            # Skip if this start position is already in a selected segment
            if any(start >= s[0] and start < s[1] for s in best_segments):
                continue

            # Try each possible segment length
            for length in range(1, min(max_segment_length, len(chunk_values) - start) + 1):
                end = start + length

                # Skip if end position is already in a selected segment
                if any(end > s[0] and end <= s[1] for s in best_segments):
                    continue

                # Calculate segment value as sum of chunk values
                segment_value = sum(chunk_values[start:end])

                # Update best segment if this one is better
                if segment_value > best_score:
                    best_score = segment_value
                    best_segment = (start, end)

        # If we found a good segment, add it
        if best_segment:
            best_segments.append(best_segment)
            segment_scores.append(best_score)
            total_included_chunks += best_segment[1] - best_segment[0] # length of this segment
            print(f"Found segment {best_segment} with score {best_score:.4f}")
        else:
            # No more good segments to find
            break

    # Sort segments by their starting position for readability
    best_segments = sorted(best_segments, key=lambda x: x[0])

    return best_segments, segment_scores

## 重建段落并将其应用于RAG

In [13]:
def reconstruct_segments(chunks, best_segments):
    """
    Reconstruct text segments based on chunk indices.

    Args:
        chunks (List[str]): List of all document chunks
        best_segments (List[Tuple[int, int]]): List of (start, end) indices for segments

    Returns:
        List[str]: List of reconstructed text segments
    """
    reconstructed_segments = []  # Initialize an empty list to store the reconstructed segments

    for start, end in best_segments:
        # Join the chunks in this segment to form the complete segment text
        segment_text = " ".join(chunks[start:end])
        # Append the segment text and its range to the reconstructed_segments list
        reconstructed_segments.append({
            "text": segment_text,
            "segment_range": (start, end),
        })

    return reconstructed_segments  # Return the list of reconstructed text segments

In [14]:
def format_segments_for_context(segments):
    """
    Format segments into a context string for the LLM.

    Args:
        segments (List[Dict]): List of segment dictionaries

    Returns:
        str: Formatted context text
    """
    context = []  # Initialize an empty list to store the formatted context

    for i, segment in enumerate(segments):
        # Create a header for each segment with its index and chunk range
        segment_header = f"SEGMENT {i+1} (Chunks {segment['segment_range'][0]}-{segment['segment_range'][1]-1}):"
        context.append(segment_header)  # Add the segment header to the context list
        context.append(segment['text'])  # Add the segment text to the context list
        context.append("-" * 80)  # Add a separator line for readability

    # Join all elements in the context list with double newlines and return the result
    return "\n\n".join(context)

## 生成响应

In [15]:
def generate_response(query, context, model="doubao-lite-128k-240828"):
    """
    Generate a response based on the query and context.

    Args:
        query (str): User query
        context (str): Context text from relevant segments
        model (str): LLM model to use

    Returns:
        str: Generated response
    """
    print("Generating response using relevant segments as context...")

    # Define the system prompt to guide the AI's behavior
    system_prompt = """You are a helpful assistant that answers questions based on the provided context.
    The context consists of document segments that have been retrieved as relevant to the user's query.
    Use the information from these segments to provide a comprehensive and accurate answer.
    If the context doesn't contain relevant information to answer the question, say so clearly."""

    # Create the user prompt by combining the context and the query
    user_prompt = f"""
Context:
{context}

Question: {query}

Please provide a helpful answer based on the context provided.
"""

    # Generate the response using the specified model
    response = client.chat.completions.create(
        model=model,
        messages=[
            {"role": "system", "content": system_prompt},
            {"role": "user", "content": user_prompt}
        ],
        temperature=0
    )

    # Return the generated response content
    return response.choices[0].message.content

## 完成 RSE pipline

In [16]:
def rag_with_rse(pdf_path, query, chunk_size=800, irrelevant_chunk_penalty=0.2):
    """
    Complete RAG pipeline with Relevant Segment Extraction.

    Args:
        pdf_path (str): Path to the document
        query (str): User query
        chunk_size (int): Size of chunks
        irrelevant_chunk_penalty (float): Penalty for irrelevant chunks

    Returns:
        Dict: Result with query, segments, and response
    """
    print("\n=== STARTING RAG WITH RELEVANT SEGMENT EXTRACTION ===")
    print(f"Query: {query}")

    # Process the document to extract text, chunk it, and create embeddings
    chunks, vector_store, doc_info = process_document(pdf_path, chunk_size)

    # Calculate relevance scores and chunk values based on the query
    print("\nCalculating relevance scores and chunk values...")
    chunk_values = calculate_chunk_values(query, chunks, vector_store, irrelevant_chunk_penalty)

    # Find the best segments of text based on chunk values
    best_segments, scores = find_best_segments(
        chunk_values,
        max_segment_length=20,
        total_max_length=30,
        min_segment_value=0.2
    )

    # Reconstruct text segments from the best chunks
    print("\nReconstructing text segments from chunks...")
    segments = reconstruct_segments(chunks, best_segments)

    # Format the segments into a context string for the language model
    context = format_segments_for_context(segments)

    # Generate a response from the language model using the context
    response = generate_response(query, context)

    # Compile the result into a dictionary
    result = {
        "query": query,
        "segments": segments,
        "response": response
    }

    print("\n=== FINAL RESPONSE ===")
    print(response)

    return result

## 与标准检索比较


In [17]:
def standard_top_k_retrieval(pdf_path, query, k=10, chunk_size=800):
    """
    Standard RAG with top-k retrieval.

    Args:
        pdf_path (str): Path to the document
        query (str): User query
        k (int): Number of chunks to retrieve
        chunk_size (int): Size of chunks

    Returns:
        Dict: Result with query, chunks, and response
    """
    print("\n=== STARTING STANDARD TOP-K RETRIEVAL ===")
    print(f"Query: {query}")

    # Process the document to extract text, chunk it, and create embeddings
    chunks, vector_store, doc_info = process_document(pdf_path, chunk_size)

    # Create an embedding for the query
    print("Creating query embedding and retrieving chunks...")
    query_embedding = create_embeddings([query])[0]

    # Retrieve the top-k most relevant chunks based on the query embedding
    results = vector_store.search(query_embedding, top_k=k)
    retrieved_chunks = [result["document"] for result in results]

    # Format the retrieved chunks into a context string
    context = "\n\n".join([
        f"CHUNK {i+1}:\n{chunk}"
        for i, chunk in enumerate(retrieved_chunks)
    ])

    # Generate a response from the language model using the context
    response = generate_response(query, context)

    # Compile the result into a dictionary
    result = {
        "query": query,
        "chunks": retrieved_chunks,
        "response": response
    }

    print("\n=== FINAL RESPONSE ===")
    print(response)

    return result

## 评估 RES

In [22]:
def evaluate_methods(pdf_path, query, reference_answer=None):
    """
    Compare RSE with standard top-k retrieval.

    Args:
        pdf_path (str): Path to the document
        query (str): User query
        reference_answer (str, optional): Reference answer for evaluation
    """
    print("\n========= EVALUATION =========\n")

    # Run the RAG with Relevant Segment Extraction (RSE) method
    rse_result = rag_with_rse(pdf_path, query)

    # Run the standard top-k retrieval method
    standard_result = standard_top_k_retrieval(pdf_path, query)

    # If a reference answer is provided, evaluate the responses
    if reference_answer:
        print("\n=== COMPARING RESULTS ===")

        # Create an evaluation prompt to compare the responses against the reference answer
        evaluation_prompt = f"""
            Query: {query}

            Reference Answer:
            {reference_answer}

            Response from Standard Retrieval:
            {standard_result["response"]}

            Response from Relevant Segment Extraction:
            {rse_result["response"]}

            Compare these two responses against the reference answer. Which one is:
            1. More accurate and comprehensive
            2. Better at addressing the user's query
            3. Less likely to include irrelevant information

            Explain your reasoning for each point.
        """

        print("Evaluating responses against reference answer...")

        # Generate the evaluation using the specified model
        evaluation = client.chat.completions.create(
            model="doubao-lite-128k-240828",
            messages=[
                {"role": "system", "content": "You are an objective evaluator of RAG system responses."},
                {"role": "user", "content": evaluation_prompt}
            ]
        )

        # Print the evaluation results
        print("\n=== EVALUATION RESULTS ===")
        print(evaluation.choices[0].message.content)

    # Return the results of both methods
    return {
        "rse_result": rse_result,
        "standard_result": standard_result
    }

In [18]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [23]:
# Load the validation data from a JSON file
with open('./drive/MyDrive/colab_data/val.json') as f:
    data = json.load(f)

# Extract the first query from the validation data
query = data[0]['question']

# Extract the reference answer from the validation data
reference_answer = data[0]['ideal_answer']

# pdf_path
pdf_path = "./drive/MyDrive/colab_data/AI_Information.pdf"

# Run evaluation
results = evaluate_methods(pdf_path, query, reference_answer)




=== STARTING RAG WITH RELEVANT SEGMENT EXTRACTION ===
Query: What is 'Explainable AI' and why is it considered important?
Extracting text from document...
Chunking text into non-overlapping segments...
Created 42 chunks
Generating embeddings for chunks...

Calculating relevance scores and chunk values...
Finding optimal continuous text segments...
Found segment (0, 20) with score 7.0652
Found segment (20, 40) with score 1.7484

Reconstructing text segments from chunks...
Generating response using relevant segments as context...

=== FINAL RESPONSE ===
Explainable AI (XAI) aims to make AI systems more transparent and understandable. It focuses on developing methods for explaining AI decisions, enhancing trust, and improving accountability. It is considered important because many AI systems, particularly deep learning models, are "black boxes," making it difficult to understand how they arrive at their decisions. By making AI systems explainable, users can assess their fairness and ac