[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/weaviate/recipes/blob/main/weaviate-features/multi-vector/ColPali-POC.ipynb)

# ColPali

**Please note: The multi-vector feature was added to Weaviate `1.29`. Test out the feature in [this notebook](/weaviate-features/multi-vector/multi-vector-colipali-rag.ipynb).**

Note: This was run on an A100 with Google Colab, you might need such a GPU to avoid OOM errors.

In [1]:
!pip install pdf2image==1.17.0 > /dev/null
!pip install peft==0.12.0 > /dev/null

In [2]:
!sudo apt-get update > /dev/null
!sudo apt-get install poppler-utils > /dev/null

W: Skipping acquire of configured file 'main/source/Sources' as repository 'https://r2u.stat.illinois.edu/ubuntu jammy InRelease' does not seem to provide it (sources.list entry misspelt?)
debconf: unable to initialize frontend: Dialog
debconf: (No usable dialog-like program is installed, so the dialog based frontend cannot be used. at /usr/share/perl5/Debconf/FrontEnd/Dialog.pm line 78, <> line 4.)
debconf: falling back to frontend: Readline
debconf: unable to initialize frontend: Readline
debconf: (This frontend requires a controlling tty.)
debconf: falling back to frontend: Teletype
dpkg-preconfigure: unable to re-open stdin: 


In [5]:
!huggingface-cli login


    _|    _|  _|    _|    _|_|_|    _|_|_|  _|_|_|  _|      _|    _|_|_|      _|_|_|_|    _|_|      _|_|_|  _|_|_|_|
    _|    _|  _|    _|  _|        _|          _|    _|_|    _|  _|            _|        _|    _|  _|        _|
    _|_|_|_|  _|    _|  _|  _|_|  _|  _|_|    _|    _|  _|  _|  _|  _|_|      _|_|_|    _|_|_|_|  _|        _|_|_|
    _|    _|  _|    _|  _|    _|  _|    _|    _|    _|    _|_|  _|    _|      _|        _|    _|  _|        _|
    _|    _|    _|_|      _|_|_|    _|_|_|  _|_|_|  _|      _|    _|_|_|      _|        _|    _|    _|_|_|  _|_|_|_|

    To login, `huggingface_hub` requires a token generated from https://huggingface.co/settings/tokens .
Enter your token (input will not be visible): 
Add token as git credential? (Y/n) n
Token is valid (permission: read).
Your token has been saved to /root/.cache/huggingface/token
Login successful


In [8]:
import os
import numpy as np
from pathlib import Path
from typing import List, cast
from pdf2image import convert_from_path
from PIL import Image
import torch
from torch import nn
from transformers import LlamaTokenizerFast, PaliGemmaProcessor
from transformers.models.paligemma.configuration_paligemma import PaliGemmaConfig
from transformers.models.paligemma.modeling_paligemma import PaliGemmaForConditionalGeneration, PaliGemmaPreTrainedModel

# Define ColPali model
class ColPali(PaliGemmaPreTrainedModel):
    def __init__(self, config: PaliGemmaConfig):
        super(ColPali, self).__init__(config=config)
        self.model = PaliGemmaForConditionalGeneration(config)
        self.dim = 128
        self.custom_text_proj = nn.Linear(self.model.config.text_config.hidden_size, self.dim)
        self.main_input_name = "doc_input_ids"

    def forward(self, *args, **kwargs) -> torch.Tensor:
        outputs = self.model(*args, output_hidden_states=True, **kwargs)
        last_hidden_states = outputs.hidden_states[-1]
        proj = self.custom_text_proj(last_hidden_states)
        proj = proj / proj.norm(dim=-1, keepdim=True)
        proj = proj * kwargs["attention_mask"].unsqueeze(-1)
        return proj

# Define input classes
class ColPaliTextInput:
    def __init__(self, input_ids, attention_mask):
        self.input_ids = input_ids
        self.attention_mask = attention_mask

    def to(self, device):
        return ColPaliTextInput(
            input_ids=self.input_ids.to(device),
            attention_mask=self.attention_mask.to(device),
        )

class ColPaliImageInput:
    def __init__(self, input_ids, pixel_values, attention_mask):
        self.input_ids = input_ids
        self.pixel_values = pixel_values
        self.attention_mask = attention_mask

    def to(self, device):
        return ColPaliImageInput(
            input_ids=self.input_ids.to(device),
            pixel_values=self.pixel_values.to(device),
            attention_mask=self.attention_mask.to(device),
        )

# Define ColPaliProcessor
class ColPaliProcessor:
    def __init__(self, processor: PaliGemmaProcessor):
        self.processor = processor
        self.tokenizer = cast(LlamaTokenizerFast, self.processor.tokenizer)

    @staticmethod
    def from_pretrained(model_name: str) -> 'ColPaliProcessor':
        return ColPaliProcessor(processor=PaliGemmaProcessor.from_pretrained(model_name))

    def process_text(self, text: str | List[str], padding: str = "longest", return_tensors: str = "pt", add_special_tokens: bool = True) -> ColPaliTextInput:
        if add_special_tokens:
            if isinstance(text, str):
                text = self.tokenizer.bos_token + text + "\n"
            elif isinstance(text, list):
                text = [self.tokenizer.bos_token + t + "\n" for t in text]
            else:
                raise ValueError("text must be a string or a list of strings.")

        batch_output = self.tokenizer(text, padding=padding, return_tensors=return_tensors, add_special_tokens=add_special_tokens)

        return ColPaliTextInput(
            input_ids=batch_output["input_ids"],
            attention_mask=batch_output["attention_mask"],
        )

    def process_image(self, image: Image.Image | List[Image.Image], padding: str = "longest", do_convert_rgb: bool = True, return_tensors: str = "pt", add_special_prompt: bool = True) -> ColPaliImageInput:
        special_prompt = "Describe the image." if add_special_prompt else None
        if isinstance(image, Image.Image):
            text_input = [special_prompt]
        elif isinstance(image, list):
            text_input = [special_prompt] * len(image)
        else:
            raise ValueError("image must be a PIL Image or a list of PIL Images.")

        batch_output = self.processor(
            text=text_input,
            images=image,
            padding=padding,
            do_convert_rgb=do_convert_rgb,
            return_tensors=return_tensors,
        )

        return ColPaliImageInput(
            input_ids=batch_output["input_ids"],
            pixel_values=batch_output["pixel_values"],
            attention_mask=batch_output["attention_mask"],
        )

    def decode(self, *args, **kwargs):
        return self.tokenizer.decode(*args, **kwargs)

    def batch_decode(self, *args, **kwargs):
        return self.tokenizer.batch_decode(*args, **kwargs)

# Helper functions
def convert_pdf_to_images(pdf_file: str, save_folder: str) -> List[Image.Image]:
    images = convert_from_path(pdf_file)
    os.makedirs(save_folder, exist_ok=True)
    saved_images = []
    for i, image in enumerate(images):
        image_path = os.path.join(save_folder, f"page_{i+1}.jpg")
        image.save(image_path, "JPEG")
        saved_images.append(Image.open(image_path))
    return saved_images

def process_pdfs_with_colpali(pdf_files, output_dir, model, processor):
    all_embeddings = []
    all_page_info = []

    for pdf_file in pdf_files:
        pdf_images = convert_pdf_to_images(pdf_file, os.path.join(output_dir, "pdf_images"))

        for page_num, image in enumerate(pdf_images):
            image_input = processor.process_image(image).to(model.device)
            with torch.no_grad():
                page_embedding = model(**vars(image_input))

            # Average over sequence dimension if necessary
            if len(page_embedding.shape) == 3:
                page_embedding = page_embedding.mean(dim=1)

            all_embeddings.append(page_embedding.cpu().numpy().squeeze())
            all_page_info.append({"pdf": pdf_file, "page": page_num})

    embeddings_array = np.array(all_embeddings)

    np.save(Path(output_dir) / "embeddings.npy", embeddings_array)
    np.save(Path(output_dir) / "page_info.npy", all_page_info)

    return embeddings_array, all_page_info

def answer_query_with_colpali(query, embeddings_array, page_info, model, processor):
    query_input = processor.process_text(query).to(model.device)
    with torch.no_grad():
        query_embedding = model(**vars(query_input))

    # Reshape embeddings if necessary
    if len(embeddings_array.shape) == 3:
        embeddings_array = embeddings_array.mean(axis=1)  # Average over sequence dimension
    if len(query_embedding.shape) == 3:
        query_embedding = query_embedding.mean(axis=1)  # Average over sequence dimension

    # Ensure both embeddings are 2D
    embeddings_array = embeddings_array.squeeze()
    query_embedding = query_embedding.cpu().numpy().squeeze()

    # Compute similarity scores
    similarity_scores = np.dot(embeddings_array, query_embedding.T)

    K = 5
    top_k_indices = np.argsort(similarity_scores.flatten())[-K:][::-1]

    top_results = [
        {"score": similarity_scores.flatten()[i], "info": page_info[i]}
        for i in top_k_indices
    ]

    return top_results

In [9]:
device = "cuda:0" if torch.cuda.is_available() else "cpu"

model_path = "google/paligemma-3b-mix-448"
lora_path = "vidore/colpali"

model = ColPali.from_pretrained(model_path)
model.load_adapter(lora_path, adapter_name="colpali")
model.to(device)

processor = ColPaliProcessor.from_pretrained(model_path)

pdf_files = ["ALTO.pdf", "MIPRO.pdf", "STORM.pdf"]
output_dir = "colpali_output"

# Process PDFs and save embeddings
embeddings, page_info = process_pdfs_with_colpali(pdf_files, output_dir, model, processor)

# Answer a query
query = "How does MIPRO compare to Bayesian Bootstrap?" # The answer should be contained im MIPRO.pdf
results = answer_query_with_colpali(query, embeddings, page_info, model, processor)

# Print results
for result in results:
  print(f"Score: {result['score']}, PDF: {result['info']['pdf']}, Page: {result['info']['page']}")

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

Some weights of ColPali were not initialized from the model checkpoint at google/paligemma-3b-mix-448 and are newly initialized: ['custom_text_proj.bias', 'custom_text_proj.weight', 'language_model.lm_head.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Score: 0.407981812953949, PDF: MIPRO.pdf, Page: 4
Score: 0.38309675455093384, PDF: ALTO.pdf, Page: 5
Score: 0.3767203688621521, PDF: MIPRO.pdf, Page: 5
Score: 0.3727259337902069, PDF: MIPRO.pdf, Page: 7
Score: 0.3650705814361572, PDF: MIPRO.pdf, Page: 0


In [10]:
query = "How is streaming used in Compound AI Systems?" # The answer should be contained in ALTO.pdf
results = answer_query_with_colpali(query, embeddings, page_info, model, processor)

# Print results
for result in results:
  print(f"Score: {result['score']}, PDF: {result['info']['pdf']}, Page: {result['info']['page']}")

Score: 0.3940265476703644, PDF: ALTO.pdf, Page: 5
Score: 0.36726292967796326, PDF: MIPRO.pdf, Page: 0
Score: 0.36110228300094604, PDF: ALTO.pdf, Page: 1
Score: 0.3595043420791626, PDF: ALTO.pdf, Page: 0
Score: 0.35756224393844604, PDF: STORM.pdf, Page: 8
