# 基于图像的文档索引与搜索（使用 ColPali 和 Qdrant）

我们可以检索包含图像的文档，例如用户指南或旧扫描文档。我们将使用支持图像的嵌入模型来处理文档和查询。我们还将调整向量数据库以高效存储和搜索这些嵌入向量。

以下是步骤：
* [创建图像集合索引](#creating-image-collection-index)
* [搜索图像索引](#searching-the-image-index)
* [基于检索到的图像生成回复](#generate-response-with-the-retrieved-images)

## 可视化改进

In [None]:
from rich.console import Console
from rich_theme_manager import Theme, ThemeManager
import pathlib

theme_dir = pathlib.Path("themes")
theme_manager = ThemeManager(theme_dir=theme_dir)
dark = theme_manager.get("dark")

# Create a console with the dark theme
console = Console(theme=dark)

In [None]:
import warnings

# Suppress warnings
warnings.filterwarnings('ignore')

## 创建图像集合索引 <a id='creating-image-collection-index'></a>

### 将 PDF 文件转换为图像

我们不希望依赖从 PDF 文件中提取文本，而是专注于页面的视觉内容。

In [None]:
import os
from pdf2image import convert_from_path


def convert_pdfs_to_images(pdf_folder):
    pdf_files = [f for f in os.listdir(pdf_folder) if f.endswith(".pdf")]
    all_images = {}

    for doc_id, pdf_file in enumerate(pdf_files):
        pdf_path = os.path.join(pdf_folder, pdf_file)
        images = convert_from_path(pdf_path, poppler_path=r'/opt/homebrew/Cellar/poppler/24.04.0_1/bin')
        all_images[pdf_file] = images

    return all_images

In [None]:
# all_images = convert_pdfs_to_images("data/ikea/")
all_images = convert_pdfs_to_images("data/shokz/")

In [None]:
console.print(all_images)

In [None]:
import matplotlib.pyplot as plt

fig, axes = plt.subplots(1, 8, figsize=(15, 10))

first_pdf_key = next(iter(all_images))
for i, ax in enumerate(axes.flat):
    img = all_images[first_pdf_key][i]
    ax.imshow(img)
    ax.axis("off")

plt.tight_layout()
plt.show()

In [None]:
from colpali_engine.models import ColPali, ColPaliProcessor
import torch


# Initialize ColPali model and processor
model_name = (
    "vidore/colpali-v1.2"  # Use the latest version available
)
colpali_model = ColPali.from_pretrained(
    model_name,
    torch_dtype=torch.bfloat16,
    device_map="mps",  # Use "cuda:0" for GPU, "cpu" for CPU, or "mps" for Apple Silicon
)
colpali_processor = ColPaliProcessor.from_pretrained(
    "vidore/colpaligemma-3b-pt-448-base"
)

In [None]:
console.print(colpali_model)

In [None]:
sample_image = all_images[first_pdf_key][0]
with torch.no_grad():
    sample_batch = colpali_processor.process_images([sample_image]).to(
        colpali_model.device
    )
    sample_embedding = colpali_model(**sample_batch)

In [None]:
console.print(sample_embedding)

In [None]:
from rich.table import Table

table = Table(title="Document Embedding")
table.add_column("Documents", style="cyan", no_wrap=True)
table.add_column("Tokens", style="bright_yellow")
table.add_column("Vector Size", style="green")

table.add_row(
    str(sample_embedding.shape[0]), 
    str(sample_embedding.shape[1]), 
    str(sample_embedding.shape[2])
)

console.print(table)

In [None]:
from qdrant_client import QdrantClient

qdrant_client = QdrantClient(
    ":memory:"
)  # Use ":memory:" for in-memory database or "path/to/db" for persistent storage

In [None]:
vector_size = sample_embedding.shape[2]

In [None]:
from qdrant_client.http import models

multi_vector_params = models.VectorParams(
    size=vector_size,
    distance=models.Distance.COSINE,
    multivector_config=models.MultiVectorConfig(
        comparator=models.MultiVectorComparator.MAX_SIM
    ),
)

### 使用量化减少向量内存占用

我们可以定义一个 `ScalarQuantizationConfig` 并在创建集合时传递它。在服务器端，Qdrant 会将向量转换为 8 位整数，从而减少内存占用并加快搜索过程。您还可以切换 `always_ram` 参数，将向量保留在 RAM 中。这将提高性能，但会增加内存使用量。

In [None]:
scalar_quant = models.ScalarQuantizationConfig(
    type=models.ScalarType.INT8,
    quantile=0.99,
    always_ram=False,
)

In [None]:
collection_name="user-guides"

qdrant_client.recreate_collection(
    collection_name=collection_name,  # the name of the collection
    on_disk_payload=True,  # store the payload on disk
    optimizers_config=models.OptimizersConfigDiff(
        indexing_threshold=100
    ),  # it can be useful to swith this off when doing a bulk upload and then manually trigger the indexing once the upload is done
    vectors_config=models.VectorParams(
        size=vector_size,
        distance=models.Distance.COSINE,
        multivector_config=models.MultiVectorConfig(
            comparator=models.MultiVectorComparator.MAX_SIM
        ),
        quantization_config=models.ScalarQuantization(
            scalar=scalar_quant,
        ),
    ),
)

### 将编码后的图像插入向量数据库

我们定义一个辅助函数，通过客户端将点上传到 Qdrant。我们使用 stamina 库来在网络问题的情况下启用重试。

In [None]:
import stamina

@stamina.retry(on=Exception, attempts=3)
def upsert_to_qdrant(batch):
    try:
        qdrant_client.upsert(
            collection_name=collection_name,
            points=points,
            wait=False,
        )
    except Exception as e:
        print(f"Error during upsert: {e}")
        return False
    return True

我们现在将向量上传到 Qdrant。我们通过创建数据批次，将其传递给 ColPali 模型，然后将嵌入添加到 Qdrant 的 `PointStruct` 中。

In [None]:
import uuid
from tqdm import tqdm

batch_size = 2  # Adjust based on your GPU memory constraints

total_images = sum(len(images) for images in all_images.values())

# Use tqdm to create a progress bar
with tqdm(total=total_images, desc="Indexing Progress") as pbar:
    for doc_id, pdf_file in enumerate(all_images.keys()):
        for i in range(0, len(all_images[pdf_file]), batch_size):
            images = all_images[pdf_file][i : i + batch_size]

            # Process and encode images
            with torch.no_grad():
                batch_images = colpali_processor.process_images(images).to(
                    colpali_model.device
                )
                image_embeddings = colpali_model(**batch_images)

            # Prepare points for Qdrant
            points = []
            for j, embedding in enumerate(image_embeddings):
                unique_id = str(uuid.uuid5(uuid.NAMESPACE_DNS, f"{doc_id}.{i + j}"))
                # Convert the embedding to a list of vectors
                multivector = embedding.cpu().float().numpy().tolist()
                points.append(
                    models.PointStruct(
                        id=unique_id,  
                        vector=multivector,  # This is now a list of vectors
                        payload={
                            "doc": pdf_file, 
                            "page": i+j+1
                        },  # can also add other metadata/data
                    )
                )
            # Upload points to Qdrant
            try:
                upsert_to_qdrant(points)
            # clown level error handling here 🤡
            except Exception as e:
                print(f"Error during upsert: {e}")
                continue

            # Update the progress bar
            pbar.update(batch_size)

print("Indexing complete!")

如果在上传期间关闭了索引，您可以通过设置较低的索引阈值来触发索引。

In [None]:
qdrant_client.update_collection(
    collection_name=collection_name,
    optimizer_config=models.OptimizersConfigDiff(indexing_threshold=10),
)

In [None]:
console.print( 
    qdrant_client
    .get_collection(collection_name)
)

In [None]:
console.print(
    qdrant_client
    .scroll(
        collection_name=collection_name, 
        limit=20
    )
)

## 搜索图像索引 <a id='searching-the-image-index'></a>

一旦我们将编码后的图像上传到向量数据库，就可以对其进行查询。

In [None]:
# query_text = "How do I answer a call?"
query_text = "Why the led is flashing red and blue?"
with torch.no_grad():
    batch_query = colpali_processor.process_queries([query_text]).to(
        colpali_model.device
    )
    query_embedding = colpali_model(**batch_query)


In [None]:
console.print(query_embedding.shape)

In [None]:
# Convert the query embedding to a list of vectors
multivector_query = query_embedding[0].cpu().float().numpy().tolist()

In [None]:
search_result = qdrant_client.query_points(
    collection_name=collection_name, 
    query=multivector_query, 
    limit=3, 
    timeout=60,
)

In [None]:
console.print(search_result)

### 显示搜索结果中的图像

我们可以显示通过向量搜索检索到的图像。

In [None]:
import matplotlib.pyplot as plt

# Extract the top 3 images from the search result for display
top_images = search_result.points[:6]

# Create a figure with subplots for each image
fig, axs = plt.subplots(1, 3, figsize=(15, 10))

# Iterate over the top images and plot each one
for i, point in enumerate(top_images):
    pdf_file = point.payload.get('doc')
    page_num = int(point.payload.get('page')) - 1
    img = all_images[pdf_file][page_num]
    axs[i].imshow(img)
    axs[i].set_title(f"Score: {point.score}, \n Doc: {pdf_file}")
    axs[i].axis('off')  # Do not display axes for better visualization

plt.tight_layout()
plt.show()

## 基于检索到的图像生成回复  <a id='generate-response-with-the-retrieved-images'></a>

在 **A**ugmentation（增强）步骤中，我们使用 base64 对检索到的图像进行编码，并将其作为提示的一部分与用户的查询一起发送给生成模型。

In [None]:
import base64
from io import BytesIO

top_image = search_result.points[0]
pdf_file = top_image.payload.get('doc')
page_num = int(top_image.payload.get('page')) - 1
image = all_images[pdf_file][page_num]
display(image)

buffered = BytesIO()
image.save(buffered, format="PNG")  # You may choose another format if needed
img_bytes = buffered.getvalue()

image1_media_type = "image/png"

image1_data = base64.standard_b64encode(img_bytes).decode("utf-8")

In [None]:
from dotenv import load_dotenv

load_dotenv()

In [None]:
import anthropic

client = anthropic.Anthropic()
message = client.messages.create(
    model="claude-3-5-sonnet-20241022",
    max_tokens=1024,
    messages=[
        {
            "role": "user",
            "content": [
                {
                    "type": "image",
                    "source": {
                        "type": "base64",
                        "media_type": image1_media_type,
                        "data": image1_data,
                    },
                },
                {
                    "type": "text",
                    "text": query_text
                }
            ],
        }
    ],
)
console.print(message)
