In [1]:
import gc
import logging
from typing import List, Optional, Union, cast

import torch
from datasets import load_dataset
from PIL import Image
from torch.utils.data import DataLoader
from tqdm import tqdm
from vidore_benchmark.evaluation.vidore_evaluators import ViDoReEvaluatorQA
from vidore_benchmark.retrievers import BaseVisionRetriever
from vidore_benchmark.utils.data_utils import ListDataset

In [2]:
logger = logging.getLogger(__name__)


def get_torch_device(device: str = "auto") -> str:
    """
    Returns the device (string) to be used by PyTorch.

    `device` arg defaults to "auto" which will use:
    - "cuda:0" if available
    - else "mps" if available
    - else "cpu".
    """

    if device == "auto":
        if torch.cuda.is_available():
            device = "cuda:0"
        elif torch.backends.mps.is_available():  # for Apple Silicon
            device = "mps"
        else:
            device = "cpu"
        logger.info(f"Using device: {device}")

    return device


def tear_down_torch():
    """
    Teardown for PyTorch.
    Clears GPU cache for both CUDA and MPS.
    """
    gc.collect()
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
    if torch.backends.mps.is_available():
        torch.mps.empty_cache()

In [3]:
class GenericRetriever(BaseVisionRetriever):
    """
    Generic retriever, based on retriever used in the ViDoRe benchmark.
    """

    def __init__(
        self,
        pretrained_model_name_or_path: str = "vidore/colpali-v1.3",
        device: str = "auto",
        num_workers: int = 0,
        **kwargs,
    ):
        super().__init__(use_visual_embedding=True)

        try:
            from colpali_engine.models import ColPali, ColPaliProcessor
        except ImportError:
            raise ImportError(
                'Install the missing dependencies with `pip install "vidore-benchmark[colpali-engine]"` '
                "to use ColPaliRetriever."
            )

        self.device = get_torch_device(device)
        self.num_workers = num_workers

        # Load the model
        self.model = cast(
            ColPali,
            ColPali.from_pretrained(
                pretrained_model_name_or_path,
                torch_dtype=torch.bfloat16,
                device_map=self.device,
            ).eval(),
        )

        # Load the processor
        self.processor = cast(
            ColPaliProcessor,
            ColPaliProcessor.from_pretrained(pretrained_model_name_or_path),
        )

    def process_images(self, images: List[Image.Image], **kwargs):
        return self.processor.process_images(images=images)

    def process_queries(self, queries: List[str], **kwargs):
        return self.processor.process_queries(queries=queries)

    def forward_queries(
        self, queries: List[str], batch_size: int, **kwargs
    ) -> List[torch.Tensor]:
        dataloader = DataLoader(
            dataset=ListDataset[str](queries),
            batch_size=batch_size,
            shuffle=False,
            collate_fn=self.process_queries,
            num_workers=self.num_workers,
        )

        query_embeddings: List[torch.Tensor] = []

        with torch.no_grad():
            # for batch_query in tqdm(dataloader, desc="Forward pass queries...", leave=False):
            #     embeddings_query = self.model(**batch_query).to("cpu")
            #     query_embeddings.extend(list(torch.unbind(embeddings_query)))
            for batch_query in tqdm(
                dataloader, desc="Forward pass queries...", leave=False
            ):
                batch_query = batch_query.to(self.device)
                query_embeddings = self.model(**batch_query).to("cpu")

        return query_embeddings

    def forward_passages(
        self, passages: List[Image.Image], batch_size: int, **kwargs
    ) -> List[torch.Tensor]:
        dataloader = DataLoader(
            dataset=ListDataset[Image.Image](passages),
            batch_size=batch_size,
            shuffle=False,
            collate_fn=self.process_images,
            num_workers=self.num_workers,
        )

        passage_embeddings: List[torch.Tensor] = []

        with torch.no_grad():
            # for batch_doc in tqdm(dataloader, desc="Forward pass documents...", leave=False):
            #     embeddings_doc = self.model(**batch_doc).to("cpu")
            #     passage_embeddings.extend(list(torch.unbind(embeddings_doc)))
            for batch_doc in tqdm(
                dataloader, desc="Forward pass documents...", leave=False
            ):
                batch_doc = batch_doc.to(self.device)
                passage_embeddings = self.model(**batch_doc).to("cpu")

        return passage_embeddings

    def get_scores(
        self,
        query_embeddings: Union[torch.Tensor, List[torch.Tensor]],
        passage_embeddings: Union[torch.Tensor, List[torch.Tensor]],
        batch_size: Optional[int] = 128,
    ) -> torch.Tensor:
        if batch_size is None:
            raise ValueError(
                "`batch_size` must be provided for ColPaliRetriever's scoring"
            )
        scores = self.processor.score(
            query_embeddings,
            passage_embeddings,
            batch_size=batch_size,
            device="cpu",
        )
        return scores

In [4]:
print(GenericRetriever)

<class '__main__.GenericRetriever'>


In [None]:
# Setup retriever
genericRetriever = GenericRetriever(
    pretrained_model_name_or_path="vidore/colpali-v1.3",
    device="auto",
    num_workers=0,
)

# Evaluate on a single QA format dataset
vidore_evaluator_qa = ViDoReEvaluatorQA(genericRetriever)
ds = load_dataset("vidore/tabfquad_test_subsampled", split="test")
metrics_dataset_qa = vidore_evaluator_qa.evaluate_dataset(
    ds=ds, batch_query=4, batch_passage=4
)
print(metrics_dataset_qa)

Fetching 2 files:   0%|          | 0/2 [00:00<?, ?it/s]

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Map:   0%|          | 0/280 [00:00<?, ? examples/s]

Dataloader pre-batching for passages:   0%|          | 0/7 [00:00<?, ?it/s]       