<a href="https://colab.research.google.com/github/tivon-x/all-rag-techniques/blob/main/08_reranker.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# 用于增强 RAG 系统的重排序

本笔记本实现了重排序技术，以提高 RAG 系统中的检索质量。重排序是初始检索之后的第二个过滤步骤，确保用于响应生成的内容是最相关的。

## 重排序的关键概念

1. **初始检索**：使用基本相似性搜索进行第一轮检索（速度较快但准确性较低）
2. **文档评分**：评估每个检索到的文档与查询的相关性
3. **重新排序**：根据相关性分数对文档进行排序
4. **选择**：仅使用最相关的文档进行响应生成

## 环境配置

In [1]:
# fitz库需要从pymudf那里安装
%pip install --quiet --force-reinstall pymupdf

[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m24.1/24.1 MB[0m [31m55.4 MB/s[0m eta [36m0:00:00[0m
[?25h

In [2]:
import fitz
import os
import numpy as np
import json
from openai import OpenAI
import re
from tqdm import tqdm

## 抽取文本

In [3]:
def extract_text_from_pdf(pdf_path):
    """
    Extracts text from a PDF file.

    Args:
    pdf_path (str): Path to the PDF file.

    Returns:
    str: Extracted text from the PDF.
    """
    # Open the PDF file
    mypdf = fitz.open(pdf_path)
    all_text = ""  # Initialize an empty string to store the extracted text

    # Iterate through each page in the PDF
    for page_num in range(mypdf.page_count):
        page = mypdf[page_num]  # Get the page
        text = page.get_text("text")  # Extract text from the page
        all_text += text  # Append the extracted text to the all_text string

    return all_text  # Return the extracted text

## 分块
一旦我们提取了文本，我们会将其分成更小的、有重叠部分的块，以提高检索的准确性。

In [4]:
def chunk_text(text, n, overlap):
    """
    Chunks the given text into segments of n characters with overlap.

    Args:
    text (str): The text to be chunked.
    n (int): The number of characters in each chunk.
    overlap (int): The number of overlapping characters between chunks.

    Returns:
    List[str]: A list of text chunks.
    """
    chunks = []  # Initialize an empty list to store the chunks

    # Loop through the text with a step size of (n - overlap)
    for i in range(0, len(text), n - overlap):
        # Append a chunk of text from index i to i + n to the chunks list
        chunks.append(text[i:i + n])

    return chunks  # Return the list of text chunks

## OpenAI API Client

In [5]:
# colab环境
from google.colab import userdata
# 使用火山引擎
api_key = userdata.get("ARK_API_KEY")
base_url = userdata.get("ARK_BASE_URL")

In [6]:
client = OpenAI(
    base_url=base_url,
    api_key=api_key
)

## 构建向量数据库

In [7]:
class SimpleVectorStore:
    """
    A simple vector store implementation using NumPy.
    """
    def __init__(self):
        """
        Initialize the vector store.
        """
        self.vectors = []  # List to store embedding vectors
        self.texts = []  # List to store original texts
        self.metadata = []  # List to store metadata for each text

    def add_item(self, text, embedding, metadata=None):
        """
        Add an item to the vector store.

        Args:
        text (str): The original text.
        embedding (List[float]): The embedding vector.
        metadata (dict, optional): Additional metadata.
        """
        self.vectors.append(np.array(embedding))  # Convert embedding to numpy array and add to vectors list
        self.texts.append(text)  # Add the original text to texts list
        self.metadata.append(metadata or {})  # Add metadata to metadata list, use empty dict if None

    def similarity_search(self, query_embedding, k=5):
        """
        Find the most similar items to a query embedding.

        Args:
        query_embedding (List[float]): Query embedding vector.
        k (int): Number of results to return.

        Returns:
        List[Dict]: Top k most similar items with their texts and metadata.
        """
        if not self.vectors:
            return []  # Return empty list if no vectors are stored

        # Convert query embedding to numpy array
        query_vector = np.array(query_embedding)

        # Calculate similarities using cosine similarity
        similarities = []
        for i, vector in enumerate(self.vectors):
            # Compute cosine similarity between query vector and stored vector
            similarity = np.dot(query_vector, vector) / (np.linalg.norm(query_vector) * np.linalg.norm(vector))
            similarities.append((i, similarity))  # Append index and similarity score

        # Sort by similarity (descending)
        similarities.sort(key=lambda x: x[1], reverse=True)

        # Return top k results
        results = []
        for i in range(min(k, len(similarities))):
            idx, score = similarities[i]
            results.append({
                "text": self.texts[idx],  # Add the corresponding text
                "metadata": self.metadata[idx],  # Add the corresponding metadata
                "similarity": score  # Add the similarity score
            })

        return results  # Return the list of top k similar items

## 创建嵌入向量

In [21]:
def create_embeddings(input, model="doubao-embedding-text-240715"):
    """
    Creates embeddings for the given input using the specified OpenAI model.

    Args:
    input (str | list[str]): The input for which embeddings are to be created.
    model (str): The model to be used for creating embeddings.

    Returns:
    List[float | list[float]]: The embedding vector.
    """
    def batch_read(lst, batch_size=10):
      for i in range(0, len(lst), batch_size):
        yield lst[i : i + batch_size]

    # Handle both string and list inputs by converting string input to a list
    input_texts = input if isinstance(input, list) else [input]

    # returned embeddings
    embeddings = []

    for batch in batch_read(input_texts):
      # Create embeddings for the input text using the specified model
      response = client.embeddings.create(
          model=model,
          input=batch
      )
      embeddings.extend(item.embedding for item in response.data)

    return embeddings

## 文档处理流程

In [10]:
def process_document(pdf_path, chunk_size=1000, chunk_overlap=200):
    """
    Process a document for RAG, and then return a vectorstore which has stored the document.

    Args:
    pdf_path (str): Path to the PDF file.
    chunk_size (int): Size of each chunk in characters.
    chunk_overlap (int): Overlap between chunks in characters.

    Returns:
    SimpleVectorStore: A vector store containing document chunks and their embeddings.
    """
    # Extract text from the PDF file
    print("Extracting text from PDF...")
    extracted_text = extract_text_from_pdf(pdf_path)

    # Chunk the extracted text
    print("Chunking text...")
    chunks = chunk_text(extracted_text, chunk_size, chunk_overlap)
    print(f"Created {len(chunks)} text chunks")

    # Create embeddings for the text chunks
    print("Creating embeddings for chunks...")
    chunk_embeddings = create_embeddings(chunks)

    # Initialize a simple vector store
    store = SimpleVectorStore()

    # Add each chunk and its embedding to the vector store
    for i, (chunk, embedding) in enumerate(zip(chunks, chunk_embeddings)):
        store.add_item(
            text=chunk,
            embedding=embedding,
            metadata={"index": i, "source": pdf_path}
        )

    print(f"Added {len(chunks)} chunks to the vector store")
    return store

## 实现基于大语言模型的重排序

In [11]:
def rerank_with_llm(query, results, top_n=3, model="doubao-lite-128k-240828"):
    """
    Reranks search results using LLM relevance scoring.

    Args:
        query (str): User query
        results (List[Dict]): Initial search results
        top_n (int): Number of results to return after reranking
        model (str): Model to use for scoring

    Returns:
        List[Dict]: Reranked results
    """
    print(f"Reranking {len(results)} documents...")  # Print the number of documents to be reranked

    scored_results = []  # Initialize an empty list to store scored results

    # Define the system prompt for the LLM
    system_prompt = """You are an expert at evaluating document relevance for search queries.
Your task is to rate documents on a scale from 0 to 10 based on how well they answer the given query.

Guidelines:
- Score 0-2: Document is completely irrelevant
- Score 3-5: Document has some relevant information but doesn't directly answer the query
- Score 6-8: Document is relevant and partially answers the query
- Score 9-10: Document is highly relevant and directly answers the query

You MUST respond with ONLY a single integer score between 0 and 10. Do not include ANY other text."""

    # Iterate through each result
    for i, result in enumerate(results):
        # Show progress every 5 documents
        if i % 5 == 0:
            print(f"Scoring document {i+1}/{len(results)}...")

        # Define the user prompt for the LLM
        user_prompt = f"""
Query: {query}

Document:
{result['text']}

Rate this document's relevance to the query on a scale from 0 to 10:"""

        # Get the LLM response
        response = client.chat.completions.create(
            model=model,
            temperature=0,
            messages=[
                {"role": "system", "content": system_prompt},
                {"role": "user", "content": user_prompt}
            ]
        )

        # Extract the score from the LLM response
        score_text = response.choices[0].message.content.strip()

        # Use regex to extract the numerical score
        score_match = re.search(r'\b(10|[0-9])\b', score_text)
        if score_match:
            score = float(score_match.group(1))
        else:
            # If score extraction fails, use similarity score as fallback
            print(f"Warning: Could not extract score from response: '{score_text}', using similarity score instead")
            score = result["similarity"] * 10

        # Append the scored result to the list
        scored_results.append({
            "text": result["text"],
            "metadata": result["metadata"],
            "similarity": result["similarity"],
            "relevance_score": score
        })

    # Sort results by relevance score in descending order
    reranked_results = sorted(scored_results, key=lambda x: x["relevance_score"], reverse=True)

    # Return the top_n results
    return reranked_results[:top_n]

## 简单的基于关键词的重排序

In [12]:
def rerank_with_keywords(query, results, top_n=3):
    """
    A simple alternative reranking method based on keyword matching and position.

    Args:
        query (str): User query
        results (List[Dict]): Initial search results
        top_n (int): Number of results to return after reranking

    Returns:
        List[Dict]: Reranked results
    """
    # Extract important keywords from the query
    keywords = [word.lower() for word in query.split() if len(word) > 3]

    scored_results = []  # Initialize a list to store scored results

    for result in results:
        document_text = result["text"].lower()  # Convert document text to lowercase

        # Base score starts with vector similarity
        base_score = result["similarity"] * 0.5

        # Initialize keyword score
        keyword_score = 0
        for keyword in keywords:
            if keyword in document_text:
                # Add points for each keyword found
                keyword_score += 0.1

                # Add more points if keyword appears near the beginning
                first_position = document_text.find(keyword)
                if first_position < len(document_text) / 4:  # In the first quarter of the text
                    keyword_score += 0.1

                # Add points for keyword frequency
                frequency = document_text.count(keyword)
                keyword_score += min(0.05 * frequency, 0.2)  # Cap at 0.2

        # Calculate the final score by combining base score and keyword score
        final_score = base_score + keyword_score

        # Append the scored result to the list
        scored_results.append({
            "text": result["text"],
            "metadata": result["metadata"],
            "similarity": result["similarity"],
            "relevance_score": final_score
        })

    # Sort results by final relevance score in descending order
    reranked_results = sorted(scored_results, key=lambda x: x["relevance_score"], reverse=True)

    # Return the top_n results
    return reranked_results[:top_n]

## 响应生成

In [13]:
def generate_response(query, context, model="doubao-lite-128k-240828"):
    """
    Generates a response based on the query and context.

    Args:
        query (str): User query
        context (str): Retrieved context
        model (str): Model to use for response generation

    Returns:
        str: Generated response
    """
    # Define the system prompt to guide the AI's behavior
    system_prompt = "You are a helpful AI assistant. Answer the user's question based only on the provided context. If you cannot find the answer in the context, state that you don't have enough information."

    # Create the user prompt by combining the context and query
    user_prompt = f"""
        Context:
        {context}

        Question: {query}

        Please provide a comprehensive answer based only on the context above.
    """

    # Generate the response using the specified model
    response = client.chat.completions.create(
        model=model,
        temperature=0,
        messages=[
            {"role": "system", "content": system_prompt},
            {"role": "user", "content": user_prompt}
        ]
    )

    # Return the generated response content
    return response.choices[0].message.content

## 带有重排序的完整 RAG 流程
到目前为止，我们已经实现了 RAG 流程的核心组件，包括文档处理、问答和重排序。现在，我们将把这些组件结合起来，创建一个完整的 RAG 流程。

In [23]:
def rag_with_reranking(query, vector_store, reranking_method="llm", top_n=3, model="doubao-lite-128k-240828"):
    """
    Complete RAG pipeline incorporating reranking.

    Args:
        query (str): User query
        vector_store (SimpleVectorStore): Vector store
        reranking_method (str): Method for reranking ('llm' or 'keywords')
        top_n (int): Number of results to return after reranking
        model (str): Model for response generation

    Returns:
        Dict: Results including query, context, and response
    """
    # Create query embedding
    query_embedding = create_embeddings(query)

    # Initial retrieval (get more than we need for reranking)
    initial_results = vector_store.similarity_search(query_embedding, k=10)

    # Apply reranking
    if reranking_method == "llm":
        reranked_results = rerank_with_llm(query, initial_results, top_n=top_n)
    elif reranking_method == "keywords":
        reranked_results = rerank_with_keywords(query, initial_results, top_n=top_n)
    else:
        # No reranking, just use top results from initial retrieval
        reranked_results = initial_results[:top_n]

    # Combine context from reranked results
    context = "\n\n===\n\n".join([result["text"] for result in reranked_results])

    # Generate response based on context
    response = generate_response(query, context, model)

    return {
        "query": query,
        "reranking_method": reranking_method,
        "initial_results": initial_results[:top_n],
        "reranked_results": reranked_results,
        "context": context,
        "response": response
    }

## 评估重排质量

In [17]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [18]:
# Load the validation data from a JSON file
with open('./drive/MyDrive/colab_data/val.json') as f:
    data = json.load(f)

# Extract the first query from the validation data
query = data[0]['question']

# Extract the reference answer from the validation data
reference_answer = data[0]['ideal_answer']

# pdf_path
pdf_path = "./drive/MyDrive/colab_data/AI_Information.pdf"

In [24]:
# Process document
vector_store = process_document(pdf_path)

# Example query
query = "Does AI have the potential to transform the way we live and work?"

# Compare different methods
print("Comparing retrieval methods...")

# 1. Standard retrieval (no reranking)
print("\n=== STANDARD RETRIEVAL ===")
standard_results = rag_with_reranking(query, vector_store, reranking_method="none")
print(f"\nQuery: {query}")
print(f"\nResponse:\n{standard_results['response']}")

# 2. LLM-based reranking
print("\n=== LLM-BASED RERANKING ===")
llm_results = rag_with_reranking(query, vector_store, reranking_method="llm")
print(f"\nQuery: {query}")
print(f"\nResponse:\n{llm_results['response']}")

# 3. Keyword-based reranking
print("\n=== KEYWORD-BASED RERANKING ===")
keyword_results = rag_with_reranking(query, vector_store, reranking_method="keywords")
print(f"\nQuery: {query}")
print(f"\nResponse:\n{keyword_results['response']}")

Extracting text from PDF...
Chunking text...
Created 42 text chunks
Creating embeddings for chunks...
Added 42 chunks to the vector store
Comparing retrieval methods...

=== STANDARD RETRIEVAL ===

Query: Does AI have the potential to transform the way we live and work?

Response:
The context shows that AI is having a significant impact in various fields such as management, algorithmic trading, customer service, and many others. In industries with repetitive or routine tasks, there are concerns about job displacement but also new opportunities and the need for reskilling. AI is used in areas like supply chain operations, human resources, marketing and sales, financial services, healthcare, transportation, and retail. It can analyze large datasets, predict market movements, automate processes, and enhance decision-making. All these examples suggest that AI has the potential to transform the way we live and work by improving efficiency, providing new insights, and creating new job roles 

In [27]:
def evaluate_reranking(query, standard_results, reranked_results, reference_answer=None):
    """
    Evaluates the quality of reranked results compared to standard results.

    Args:
        query (str): User query
        standard_results (Dict): Results from standard retrieval
        reranked_results (Dict): Results from reranked retrieval
        reference_answer (str, optional): Reference answer for comparison

    Returns:
        str: Evaluation output
    """
    # Define the system prompt for the AI evaluator
    system_prompt = """You are an expert evaluator of RAG systems.
    Compare the retrieved contexts and responses from two different retrieval methods.
    Assess which one provides better context and a more accurate, comprehensive answer."""

    # Prepare the comparison text with truncated contexts and responses
    comparison_text = f"""Query: {query}

Standard Retrieval Context:
{standard_results['context'][:1000]}... [truncated]

Standard Retrieval Answer:
{standard_results['response']}

Reranked Retrieval Context:
{reranked_results['context'][:1000]}... [truncated]

Reranked Retrieval Answer:
{reranked_results['response']}"""

    # If a reference answer is provided, include it in the comparison text
    if reference_answer:
        comparison_text += f"""

Reference Answer:
{reference_answer}"""

    # Create the user prompt for the AI evaluator
    user_prompt = f"""
{comparison_text}

Please evaluate which retrieval method provided:
1. More relevant context
2. More accurate answer
3. More comprehensive answer
4. Better overall performance

Provide a detailed analysis with specific examples.
"""

    # Generate the evaluation response using the specified model
    response = client.chat.completions.create(
        model="doubao-lite-128k-240828",
        temperature=0,
        messages=[
            {"role": "system", "content": system_prompt},
            {"role": "user", "content": user_prompt}
        ]
    )

    # Return the evaluation output
    return response.choices[0].message.content

In [28]:
# Evaluate the quality of reranked results compared to standard results
evaluation = evaluate_reranking(
    query=query,  # The user query
    standard_results=standard_results,  # Results from standard retrieval
    reranked_results=llm_results,  # Results from LLM-based reranking
    reference_answer=reference_answer  # Reference answer for comparison
)

# Print the evaluation results
print("\n=== EVALUATION RESULTS ===")
print(evaluation)


=== EVALUATION RESULTS ===
1. **More relevant context**: Both retrieval methods provide relevant context as they both discuss various aspects of how AI can transform different fields such as management, trading, customer service, supply chain, HR, marketing, etc. However, the reranked retrieval context seems to have a slightly more focused set of examples within each field. For example, in the supply chain section, it specifically mentions predicting demand, managing inventory, and streamlining logistics, which gives a more detailed picture of AI's impact in that area compared to the standard retrieval context which just mentions it in a more general way.
2. **More accurate answer**: Both answers are accurate as they convey the main points about how AI has the potential to transform various aspects of life and work, including the concerns about job displacement and the need for reskilling. There isn't a significant difference in accuracy between the two.
3. **More comprehensive answer