In [1]:
import requests
import torch
from PIL import Image
from io import BytesIO

from transformers import ColPaliForRetrieval, ColPaliProcessor

In [5]:
model_name = "vidore/colpali-v1.3-hf"

model = ColPaliForRetrieval.from_pretrained(
    model_name,
    torch_dtype=torch.bfloat16,
    device_map="auto"
)

processor = ColPaliProcessor.from_pretrained(model_name)

`torch_dtype` is deprecated! Use `dtype` instead!


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Some parameters are on the meta device because they were offloaded to the disk and cpu.


In [3]:
headers = {"User-Agent": "Mozilla/5.0"}

def fetch_image(url):
    r = requests.get(url, headers=headers, stream=True)
    r.raise_for_status()
    return Image.open(BytesIO(r.content))

url1 = "https://upload.wikimedia.org/wikipedia/commons/8/89/US-original-Declaration-1776.jpg"
url2 = "https://upload.wikimedia.org/wikipedia/commons/thumb/4/4c/Romeoandjuliet1597.jpg/500px-Romeoandjuliet1597.jpg"

images = [fetch_image(url1), fetch_image(url2)]

In [4]:
images

[<PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=1014x516>,
 <PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=500x829>]

In [6]:
inputs_images = processor(images=images).to(model.device)

In [7]:
with torch.no_grad():
    image_embeddings = model(**inputs_images).embeddings

In [8]:
image_embeddings

tensor([[[ 0.0884,  0.1338, -0.0273,  ...,  0.1367, -0.0625,  0.1377],
         [-0.0162,  0.0649,  0.0444,  ..., -0.0620,  0.0342, -0.0640],
         [ 0.0020,  0.0114, -0.0625,  ...,  0.0192, -0.0264,  0.1309],
         ...,
         [-0.0061,  0.0786,  0.1025,  ..., -0.0845, -0.0294, -0.0432],
         [ 0.0282,  0.1211,  0.1289,  ..., -0.0425, -0.1128, -0.0444],
         [ 0.0255,  0.1611,  0.1475,  ..., -0.0219, -0.1709, -0.1011]],

        [[ 0.0269,  0.0048,  0.1016,  ...,  0.0713, -0.0162,  0.2002],
         [-0.0019,  0.1045,  0.0938,  ...,  0.1270, -0.1914,  0.0830],
         [ 0.0256,  0.1748,  0.1484,  ...,  0.0112, -0.2041, -0.1123],
         ...,
         [-0.0036,  0.0879,  0.0986,  ..., -0.0679, -0.0586, -0.0339],
         [ 0.0408,  0.1084,  0.1387,  ..., -0.0237, -0.1348, -0.0439],
         [ 0.0383,  0.1572,  0.1631,  ..., -0.0043, -0.1846, -0.0962]]],
       dtype=torch.bfloat16)

In [9]:
len(image_embeddings)

2

In [10]:
image_embeddings[0]

tensor([[ 0.0884,  0.1338, -0.0273,  ...,  0.1367, -0.0625,  0.1377],
        [-0.0162,  0.0649,  0.0444,  ..., -0.0620,  0.0342, -0.0640],
        [ 0.0020,  0.0114, -0.0625,  ...,  0.0192, -0.0264,  0.1309],
        ...,
        [-0.0061,  0.0786,  0.1025,  ..., -0.0845, -0.0294, -0.0432],
        [ 0.0282,  0.1211,  0.1289,  ..., -0.0425, -0.1128, -0.0444],
        [ 0.0255,  0.1611,  0.1475,  ..., -0.0219, -0.1709, -0.1011]],
       dtype=torch.bfloat16)

In [11]:
image_embeddings[1]

tensor([[ 0.0269,  0.0048,  0.1016,  ...,  0.0713, -0.0162,  0.2002],
        [-0.0019,  0.1045,  0.0938,  ...,  0.1270, -0.1914,  0.0830],
        [ 0.0256,  0.1748,  0.1484,  ...,  0.0112, -0.2041, -0.1123],
        ...,
        [-0.0036,  0.0879,  0.0986,  ..., -0.0679, -0.0586, -0.0339],
        [ 0.0408,  0.1084,  0.1387,  ..., -0.0237, -0.1348, -0.0439],
        [ 0.0383,  0.1572,  0.1631,  ..., -0.0043, -0.1846, -0.0962]],
       dtype=torch.bfloat16)

In [12]:
image_embeddings[0].shape

torch.Size([1030, 128])

In [13]:
image_embeddings[0][0]

tensor([ 8.8379e-02,  1.3379e-01, -2.7344e-02,  6.5613e-03,  1.1621e-01,
         1.1169e-02,  1.0889e-01, -7.8125e-02, -2.0752e-02,  2.2583e-03,
         1.2302e-04,  4.6631e-02, -1.8921e-02,  8.1055e-02,  2.2583e-03,
         4.6143e-02, -1.3477e-01,  6.6406e-02,  1.3281e-01, -5.4932e-02,
         2.6245e-02,  2.8076e-02, -1.3574e-01,  7.9590e-02,  1.1279e-01,
        -4.8340e-02,  5.1025e-02,  2.3804e-02, -9.0332e-02, -1.3672e-01,
        -9.1309e-02,  8.9355e-02, -2.6489e-02,  1.0840e-01, -5.2734e-02,
        -6.6406e-02,  4.1992e-02,  9.6191e-02,  1.0559e-02, -1.5723e-01,
         1.6895e-01,  1.9897e-02,  1.6113e-02, -5.6885e-02,  1.0889e-01,
        -1.5918e-01,  7.1777e-02, -8.8501e-03, -1.2390e-02, -9.7656e-02,
        -5.5908e-02,  5.9082e-02, -2.2827e-02, -5.6885e-02, -3.0640e-02,
         1.0437e-02,  1.2695e-01,  1.0645e-01,  8.0078e-02, -7.1289e-02,
         2.5146e-02,  8.0566e-02, -1.6309e-01, -6.6406e-02,  1.9434e-01,
         9.9609e-02, -6.2988e-02, -5.4688e-02,  1.5