# 通过 Reverse HYDE 提升语义相似性

我们想要检索的文档通常比用户的查询更长且格式不同。为了提高**基于用户查询的文档检索** (**R**etrieval) 的准确性，我们将从每个文档生成假设的潜在查询，并将它们用作文档的向量嵌入 - AKA Reverse Hyde。

请注意，原始的 [Hyde 技术](https://arxiv.org/abs/2212.10496) 处理用户的输入查询，并从这些查询生成假设文档，然后使用这些假设文档来检索真实文档。而在 Reverse HYDE 中，处理是在索引文档时完成的，而不是在检索时。因此，查询的延迟不会受到影响。

* [Reverse HYDE 实现](#reverse-hyde-implementation)
* [用 Reverse HYDE 输出丰富向量数据库](#enriching-vector-database-with-reverse-hyde-output)
* [查询增强后的索引](#query-the-enriched-index)

### 可视化改进

我们将使用 [rich 库](https://github.com/Textualize/rich) 来使输出更具可读性，并抑制警告信息。

In [None]:
from rich.console import Console
from rich_theme_manager import Theme, ThemeManager
import pathlib

theme_dir = pathlib.Path("themes")
theme_manager = ThemeManager(theme_dir=theme_dir)
dark = theme_manager.get("dark")

# Create a console with the dark theme
console = Console(theme=dark)

In [None]:
import warnings

# Suppress warnings
warnings.filterwarnings('ignore')

## Reverse HYDE 实现

我们将创建一个类，用于生成假设问题并通过计算语义相似性匹配来检索文档。在实际应用中，我们可以使用向量数据库来存储、索引和检索嵌入向量。

In [None]:
import openai
from typing import List, Dict
import numpy as np
from sklearn.metrics.pairwise import cosine_similarity

In [None]:
class ReverseHyde:
    def __init__(self, api_key: str):
        openai.api_key = api_key
        self.model = "text-embedding-ada-002"

    def get_embedding(self, text: str) -> List[float]:
        client = openai.OpenAI()
        response = client.embeddings.create(input=text, model=self.model)
        return response.data[0].embedding

    def generate_reverse_hyde(self, chunk: str, n: int = 3) -> List[str]:
        prompt = f"""
        
Given the following text chunk, generate {n} different questions that this chunk would be a good answer to:

Chunk: {chunk}

Questions (enumarate the questions with 1. 2., etc.):
"""

        client = openai.OpenAI()
        response = client.chat.completions.create(
            model="gpt-3.5-turbo",
            messages=[
                {"role": "system", "content": "You are a helpful assistant."},
                {"role": "user", "content": prompt}
            ],
            max_tokens=100,
            n=1,
            stop=None,
            temperature=0.7,
        )

        
        questions = response.choices[0].message.content.strip().split('\n')
        return [q.split('. ', 1)[1] for q in questions if '. ' in q]

    def process_chunks(self, chunks: List[str], n: int = 3) -> Dict[str, List[str]]:
        processed_chunks = {}
        for chunk in chunks:
            processed_chunks[chunk] = self.generate_reverse_hyde(chunk, n)
        return processed_chunks


从环境变量加载 API 密钥

In [None]:
from dotenv import load_dotenv

load_dotenv()

## 用 Reverse HYDE 输出丰富向量数据库

我们将在一组文档上应用 Reverse HYDE 方法，并用 LLM 生成的假设问题丰富向量数据库索引。

In [None]:
import os
# Usage example
api_key = os.getenv("OPENAI_API_KEY")
reverse_hyde = ReverseHyde(api_key)

chunks = [
    "A mitochondrion (pl. mitochondria) is an organelle found in the cells of most eukaryotes, such as animals, plants and fungi. Mitochondria have a double membrane structure and use aerobic respiration to generate adenosine triphosphate (ATP), which is used throughout the cell as a source of chemical energy. They were discovered by Albert von Kölliker in 1857 in the voluntary muscles of insects. Meaning a thread-like granule, the term mitochondrion was coined by Carl Benda in 1898. The mitochondrion is popularly nicknamed the \"powerhouse of the cell\", a phrase popularized by Philip Siekevitz in a 1957 Scientific American article of the same name.",
    "Python is a high-level, general-purpose programming language. Its design philosophy emphasizes code readability with the use of significant indentation. Python is dynamically typed and garbage-collected. It supports multiple programming paradigms, including structured (particularly procedural), object-oriented and functional programming. It is often described as a \"batteries included\" language due to its comprehensive standard library.",
    "The American Civil War (from April 12, 1861 to May 26, 1865) was a civil war in the United States between the Union (\"the North\") and the Confederacy (\"the South\"), which was formed in 1861 by states that had seceded from the Union. The central conflict leading to war was a dispute over whether slavery should be permitted to expand into the western territories, leading to more slave states, or be prohibited from doing so, which many believed would place slavery on a course of ultimate extinction."
]

processed_chunks = reverse_hyde.process_chunks(chunks, n=5)

In [None]:
console.print(processed_chunks)

## 查询增强后的索引

一旦我们拥有一个包含多个文档假设问题的索引，就可以用它来基于真实用户的查询检索文档。

In [None]:
query = "What generates energy in a cell?"

In [None]:
from qdrant_client import models, QdrantClient
from sentence_transformers import SentenceTransformer

# create the vector database client
qdrant = QdrantClient(":memory:") # Create in-memory Qdrant instance

# Create the embedding encoder
encoder = SentenceTransformer('all-MiniLM-L6-v2') # Model to create embeddings

In [None]:
# Create collection to store the wine rating data
hyde_collection_name="reverse_hyde"

qdrant.recreate_collection(
    collection_name=hyde_collection_name,
    vectors_config=models.VectorParams(
        size=encoder.get_sentence_embedding_dimension(), # Vector size is defined by used model
        distance=models.Distance.COSINE
    )
)

In [None]:
import uuid
# vectorize!
qdrant.upload_points(
    collection_name=hyde_collection_name,
    points=[
        models.PointStruct(
            id=uuid.uuid5(uuid.NAMESPACE_URL, f"{d_idx}-{q_idx}").hex,
            vector=encoder.encode(question).tolist(),
            payload={ 
                "document": document , 
                "doc_id": d_idx
            }
        ) for d_idx, (document, questions) 
            in enumerate(processed_chunks.items()) 
                for q_idx, question in enumerate(questions)
    ]
)

In [None]:
console.print(
    qdrant
    .get_collection(
        collection_name=hyde_collection_name
    )
)

### 在集合中搜索最佳匹配

In [None]:
from rich.panel import Panel
from rich.table import Table

def search_collection(collection_name: str, query: str, limit: int = 1):
    """
    This function searches the specified collection for the best match to the given query.
    It then creates a table and a panel to display the query and the best match.
    
    :param collection_name: The name of the collection to search.
    :param query: The query to search for.
    :param limit: The maximum number of results to return. Default is 1.
    """
    hits = qdrant.search(
        collection_name=collection_name,
        query_vector=encoder.encode(query).tolist(),
        limit=limit
    )
    # Create a table for both query and best match
    table = Table(show_header=True, header_style="bold yellow")
    table.add_column("Query", style="bright_cyan", width=30)
    table.add_column("Best Matching Chunk", style="bright_yellow", width=50)
    table.add_column("Score", style="bright_green")
    for hit in hits:
        table.add_row(query, f"{hit.payload['document'][:80]}...", "{:.4f}".format(hit.score))

    # Create a panel for the table
    panel = Panel(
        table,
        title=f"[bold]Query and Best Match in {collection_name}",
        border_style="white",
        expand=False
    )

    # Print the panel
    console.print(panel)

In [None]:
search_collection(hyde_collection_name, query)

### 与仅文档索引（无 HYDE）进行比较

我们将对相同的文档进行索引，但不添加 Reverse HYDE 问题，并比较相似性得分。

In [None]:
# Create collection to store the wine rating data
docs_collection_name="documents_only"

qdrant.recreate_collection(
    collection_name=docs_collection_name,
    vectors_config=models.VectorParams(
        size=encoder.get_sentence_embedding_dimension(), # Vector size is defined by used model
        distance=models.Distance.COSINE
    )
)

In [None]:
# vectorize!
qdrant.upload_points(
    collection_name=docs_collection_name,
    points=[
        models.PointStruct(
            id=idx,
            vector=encoder.encode(document).tolist(),
            payload={ "document": document}
        ) for idx, (document, questions) in enumerate(processed_chunks.items())
    ]
)

In [None]:
search_collection(docs_collection_name, query)