In [None]:

from transformers import ColQwen2ForRetrieval, ColQwen2Processor
from colpali_engine.utils.torch_utils import ListDataset, get_torch_device
from torch.utils.data import DataLoader
from peft import PeftModel
import torch
from typing import List, cast

device = get_torch_device("auto")

# Load tokenizer and base model
base_model_path = "/home/linux/yyj/colpali/finetune/colqwen2-v1.0-hf"
base_model = ColQwen2ForRetrieval.from_pretrained(
    base_model_path,
    torch_dtype=torch.bfloat16,
    device_map="cuda:0",
    attn_implementation="flash_attention_2",
).eval()
# Load fine-tuned adapter (LoRA weights)
adapter_path = "/home/linux/yyj/colpali/finetune/wiky_city_zh_0702_lr2e4_colqwen"
model = PeftModel.from_pretrained(base_model, adapter_path)
max_pixels = 1100*28*28
processor = ColQwen2Processor.from_pretrained(adapter_path, max_pixels=max_pixels)

queries = [
    "阿布达比5月气温怎么样",
    "阿布达比12月降水量怎么样",
]

dataloader = DataLoader(
    dataset=ListDataset[str](queries),
    batch_size=1,
    shuffle=False,
    collate_fn=lambda x: processor.process_queries(x),
)

qs: List[torch.Tensor] = []
for batch_query in dataloader:
    with torch.no_grad():
        batch_query = {k: v.to(model.device) for k, v in batch_query.items()}
        query_outputs = model(**batch_query)
        qr_embs = query_outputs.embeddings
    qs.extend(list(torch.unbind(qr_embs.to("cuda:0"))))


  from .autonotebook import tqdm as notebook_tqdm
You are attempting to use Flash Attention 2.0 without specifying a torch dtype. This might lead to unexpected behaviour


In [2]:
print(len(qs))
qs[0].shape

2


torch.Size([22, 128])

In [3]:
from tqdm import tqdm
from PIL import Image
import os

test_pages_dir = "/home/linux/yyj/colpali/finetune/mmlm-rag/test_pages"
images = [Image.open(os.path.join(test_pages_dir, name)) for name in os.listdir(test_pages_dir)]

dataloader = DataLoader(
    dataset=ListDataset[str](images),
    batch_size=1,
    shuffle=False,
    collate_fn=lambda x: processor.process_images(x),
)
ds: List[torch.Tensor] = []
for batch_doc in tqdm(dataloader):
    with torch.no_grad():
        batch_doc = {k: v.to(model.device) for k, v in batch_doc.items()}
        image_outputs = model(**batch_doc)
        img_embs = image_outputs.embeddings
    ds.extend(list(torch.unbind(img_embs.to("cuda:0"))))

  0%|          | 0/3 [00:00<?, ?it/s]

100%|██████████| 3/3 [00:02<00:00,  1.03it/s]


In [4]:
print(len(ds))
print(ds[1].shape)

3
torch.Size([1064, 128])


In [None]:
from pymilvus import MilvusClient, DataType
import numpy as np
import concurrent.futures

class MilvusColbertRetriever:
    def __init__(self, milvus_client, collection_name, dim=128):
        # Initialize the retriever with a Milvus client, collection name, and dimensionality of the vector embeddings.
        # If the collection exists, load it.
        self.collection_name = collection_name
        self.client = milvus_client
        if self.client.has_collection(collection_name=self.collection_name):
            self.client.load_collection(collection_name)
        self.dim = dim

    def create_collection(self):
        # Create a new collection in Milvus for storing embeddings.
        # Drop the existing collection if it already exists and define the schema for the collection.
        if self.client.has_collection(collection_name=self.collection_name):
            self.client.drop_collection(collection_name=self.collection_name)
        schema = self.client.create_schema(
            auto_id=True,
            enable_dynamic_fields=True,
        )
        schema.add_field(field_name="pk", datatype=DataType.INT64, is_primary=True)
        schema.add_field(
            field_name="vector", datatype=DataType.FLOAT_VECTOR, dim=self.dim
        )
        schema.add_field(field_name="seq_id", datatype=DataType.INT16)
        schema.add_field(field_name="doc_id", datatype=DataType.INT64)
        schema.add_field(field_name="doc", datatype=DataType.VARCHAR, max_length=65535)

        self.client.create_collection(
            collection_name=self.collection_name, schema=schema
        )

    def create_index(self):
        # Create an index on the vector field to enable fast similarity search.
        # Releases and drops any existing index before creating a new one with specified parameters.
        self.client.release_collection(collection_name=self.collection_name)
        self.client.drop_index(
            collection_name=self.collection_name, index_name="vector"
        )
        index_params = self.client.prepare_index_params()
        index_params.add_index(
            field_name="vector",
            index_name="vector_index",
            index_type="HNSW",  # or any other index type you want
            metric_type="IP",  # or the appropriate metric type
            params={
                "M": 16,
                "efConstruction": 500,
            },  # adjust these parameters as needed
        )

        self.client.create_index(
            collection_name=self.collection_name, index_params=index_params, sync=True
        )

    def create_scalar_index(self):
        # Create a scalar index for the "doc_id" field to enable fast lookups by document ID.
        self.client.release_collection(collection_name=self.collection_name)

        index_params = self.client.prepare_index_params()
        index_params.add_index(
            field_name="doc_id",
            index_name="int32_index",
            index_type="INVERTED",  # or any other index type you want
        )

        self.client.create_index(
            collection_name=self.collection_name, index_params=index_params, sync=True
        )

    def search(self, data, topk):
        # Perform a vector search on the collection to find the top-k most similar documents.
        # data=query
        # "params": {} means no extra parameters are passed to the HNSW index algorithm
        search_params = {"metric_type": "IP", "params": {}}
        # len(results)=22(query.length), each query vector(1*128) has top50 similar doc vector (1*128)
        results = self.client.search(
            self.collection_name,
            data,
            limit=int(50),
            output_fields=["vector", "seq_id", "doc_id"],
            search_params=search_params,
        )
        doc_ids = set()
        # For each row vector of a query (22 in total), deduplicate the doc_id results of its top 50
        for r_id in range(len(results)):
            for r in range(len(results[r_id])):
                doc_ids.add(results[r_id][r]["entity"]["doc_id"])

        scores = []

        def rerank_single_doc(doc_id, data, client, collection_name):
            # Rerank a single document by retrieving its embeddings and calculating the similarity with the query.
            doc_colbert_vecs = client.query(
                collection_name=collection_name,
                filter=f"doc_id in [{doc_id}]",
                output_fields=["seq_id", "vector", "doc"],
                limit=1000,  return the first 1000 row vectors
            )
            # stack these vectors
            doc_vecs = np.vstack(
                [doc_colbert_vecs[i]["vector"] for i in range(len(doc_colbert_vecs))]
            )
            # perform ColBERT late interaction for single query-doc_id
            score = np.dot(data, doc_vecs.T).max(1).sum()
            return (score, doc_id)

        with concurrent.futures.ThreadPoolExecutor(max_workers=300) as executor:
            futures = {
                executor.submit(
                    rerank_single_doc, doc_id, data, client, self.collection_name
                ): doc_id
                for doc_id in doc_ids
            }
            for future in concurrent.futures.as_completed(futures):
                score, doc_id = future.result()
                scores.append((score, doc_id))
        # sort by the score in each tuple in scores
        scores.sort(key=lambda x: x[0], reverse=True)
        # if the total exceeds topk, take the topk; otherwise, return all
        if len(scores) >= topk:
            return scores[:topk]
        else:
            return scores

    def insert(self, data):
        # Insert ColBERT embeddings and metadata for a document into the collection.
        # input is the embedding of an image/pdf page, 1064*128
        # data["colbert_vecs"] is a list, 1064*128, each row in the list is a 1*128 torch.tensor
        colbert_vecs = [vec for vec in data["colbert_vecs"]]
        # seq_length = 1064
        seq_length = len(colbert_vecs)
        # generate the same doc_id and docs file path for 1064 row vectors
        # repeat data["doc_id"] for seq_length times, [doc_id, ..., doc_id];
        doc_ids = [data["doc_id"] for i in range(seq_length)]
        # seq_ids = [0,1,2,...,1063]
        seq_ids = list(range(seq_length))
        # repeat data["filepath"] seq_length times
        docs = [data["filepath"] for i in range(seq_length)]

        # Insert the data as multiple vectors (one for each sequence) along with the corresponding metadata.
        self.client.insert(
            self.collection_name,
            [
                {
                    "vector": colbert_vecs[i],
                    "seq_id": seq_ids[i],
                    "doc_id": doc_ids[i],
                    "doc": docs[i],
                }
                for i in range(seq_length)
            ],
        )

client = MilvusClient(uri="http://localhost:19530")

In [11]:
retriever = MilvusColbertRetriever(collection_name="colqwen_test", milvus_client=client)
# retriever.create_collection()
# retriever.create_index()


In [None]:
from colpali_engine.compression.token_pooling import HierarchicalTokenPooler

# Define the pooler with the desired level of compression
pooler = HierarchicalTokenPooler()

# Pool the embeddings, returun_dict default as False, only return pooled_embeddings
outputs = pooler.pool_embeddings(ds, pool_factor=3)

In [23]:
filepaths = [os.path.join(test_pages_dir, name) for name in os.listdir(test_pages_dir)]
for i in range(len(filepaths)):
    data = {
        "colbert_vecs": outputs[i].float().cpu().numpy(),
        "doc_id": i,
        "filepath": filepaths[i],
    }
    retriever.insert(data)


In [24]:
for query in qs:
    # query: 22*128
    query = query.float().cpu().numpy()
    result = retriever.search(query, topk=1)
    print(filepaths[result[0][1]])

/home/linux/yyj/colpali/finetune/mmlm-rag/test_pages/test_page_001.png
/home/linux/yyj/colpali/finetune/mmlm-rag/test_pages/test_page_001.png
