# ColPali

In [2]:
import torch
from PIL import Image
from transformers.utils.import_utils import is_flash_attn_2_available

from colpali_engine.models import ColQwen2, ColQwen2Processor, ColPali, ColPaliProcessor
import os
from colpali_engine.interpretability import get_similarity_maps_from_embeddings
import matplotlib.pyplot as plt
import numpy as np

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
# 禁用所有网络请求，强制使用本地文件（保持不变）
os.environ["TRANSFORMERS_OFFLINE"] = "1"
os.environ["HF_DATASETS_OFFLINE"] = "1"
os.environ["HF_HUB_OFFLINE"] = "1"  # 新增，彻底禁用huggingface hub网络请求

# 本地模型路径（替换为colpali的本地目录，无需基础模型路径）
local_model_path = "./models/colpali-v1.3-merged"  # 确保该目录下有从vidore/colpali下载的所有文件


## 模型加载

In [4]:
# 加载模型（替换为ColPali，移除base_model_path参数）
model = ColPali.from_pretrained(
    local_model_path,
    torch_dtype=torch.bfloat16,
    device_map="cuda:0",  # 无GPU可改为"cpu"
    attn_implementation="flash_attention_2" if is_flash_attn_2_available() else None,
    local_files_only=True,  # 强制使用本地文件
    trust_remote_code=True  # 必需，colpali含自定义模型代码
).eval()

# 加载处理器（替换为ColPaliProcessor，无需base_model_path）
processor = ColPaliProcessor.from_pretrained(
    local_model_path,
    local_files_only=True  # 强制使用本地文件
)

Loading checkpoint shards: 100%|██████████| 2/2 [00:00<00:00, 14.25it/s]
Some weights of the model checkpoint at ./models/colpali-v1.3-merged were not used when initializing ColPali: ['model.language_model.model.embed_tokens.weight', 'model.language_model.model.layers.0.input_layernorm.weight', 'model.language_model.model.layers.0.mlp.down_proj.weight', 'model.language_model.model.layers.0.mlp.gate_proj.weight', 'model.language_model.model.layers.0.mlp.up_proj.weight', 'model.language_model.model.layers.0.post_attention_layernorm.weight', 'model.language_model.model.layers.0.self_attn.k_proj.weight', 'model.language_model.model.layers.0.self_attn.o_proj.weight', 'model.language_model.model.layers.0.self_attn.q_proj.weight', 'model.language_model.model.layers.0.self_attn.v_proj.weight', 'model.language_model.model.layers.1.input_layernorm.weight', 'model.language_model.model.layers.1.mlp.down_proj.weight', 'model.language_model.model.layers.1.mlp.gate_proj.weight', 'model.language_model

In [5]:

# 查看模型支持的方法
print(dir(model))  # 输出模型所有可用方法
# 检查processor是否提供高级功能
print(dir(processor))


['__abstractmethods__', '__annotations__', '__call__', '__class__', '__delattr__', '__dict__', '__dir__', '__doc__', '__eq__', '__format__', '__ge__', '__getattribute__', '__getstate__', '__gt__', '__hash__', '__init__', '__init_subclass__', '__le__', '__lt__', '__module__', '__ne__', '__new__', '__reduce__', '__reduce_ex__', '__repr__', '__setattr__', '__sizeof__', '__slots__', '__str__', '__subclasshook__', '__weakref__', '_abc_impl', '_auto_class', '_check_special_mm_tokens', '_create_repo', '_get_arguments_from_pretrained', '_get_files_timestamps', '_get_num_multimodal_tokens', '_merge_kwargs', '_upload_modified_files', 'apply_chat_template', 'attributes', 'audio_tokenizer', 'batch_decode', 'chat_template', 'check_argument_for_proper_class', 'create_plaid_index', 'decode', 'feature_extractor_class', 'from_args_and_dict', 'from_pretrained', 'get_image_mask', 'get_n_patches', 'get_possibly_dynamic_module', 'get_processor_dict', 'get_topk_plaid', 'image_processor', 'image_processor_cl

In [6]:

# Your inputs
images = [
    Image.new("RGB", (128, 128), color="white"),
    Image.new("RGB", (64, 32), color="black"),
]
queries = [
    "What is the organizational structure for our R&D department?",
    "Can you provide a breakdown of last year’s financial performance?",
]


# Process the inputs
batch_images = processor.process_images(images).to(model.device)
batch_queries = processor.process_queries(queries).to(model.device)


### 快速开始

In [7]:

# Forward pass
with torch.no_grad():
    image_embeddings = model(**batch_images)
    query_embeddings = model(**batch_queries)

scores = processor.score_multi_vector(query_embeddings, image_embeddings)
print(scores)

tensor([[4.0000, 3.1094],
        [4.5938, 3.3906]])


### 热力图

In [30]:
import torch

torch.cuda.empty_cache()  # 释放未被引用的显存
torch.cuda.ipc_collect()  # 清理跨进程缓存（可选）

In [28]:
print(processor.image_processor)
print(processor.image_processor.size)  # 目标resize尺寸
print(processor.image_processor.patch_size)  # patch大小
print(batch_images.shape)
print(image_embeddings.shape)


SiglipImageProcessor {
  "do_convert_rgb": null,
  "do_normalize": true,
  "do_rescale": true,
  "do_resize": true,
  "image_mean": [
    0.5,
    0.5,
    0.5
  ],
  "image_processor_type": "SiglipImageProcessor",
  "image_seq_length": 1024,
  "image_std": [
    0.5,
    0.5,
    0.5
  ],
  "processor_class": "ColPaliProcessor",
  "resample": 3,
  "rescale_factor": 0.00392156862745098,
  "size": {
    "height": 448,
    "width": 448
  }
}

{'height': 448, 'width': 448}


AttributeError: 'SiglipImageProcessor' object has no attribute 'patch_size'

In [29]:
n_patches = (16, 16)

# 获取图像的 patch 掩码（用于过滤无关区域）
image_mask = processor.get_image_mask(batch_images)

# 生成相似度热图（返回每个图像的热图）
batched_similarity_maps = get_similarity_maps_from_embeddings(
    image_embeddings=image_embeddings,
    query_embeddings=query_embeddings,
    n_patches=n_patches,  # 图像被分割的 patch 数量（如 16×16=256）
    image_mask=image_mask,
)


# 可视化单个图像的热图
def visualize_heatmap(image, heatmap, alpha=0.6):
    plt.figure(figsize=(10, 10))
    plt.imshow(image)  # 显示原始图像
    plt.imshow(heatmap, cmap='jet', alpha=alpha)  # 叠加热图
    plt.colorbar()  # 显示颜色刻度
    plt.axis('off')
    plt.show()


# 示例：可视化第一张图像的热图
visualize_heatmap(
    image=images[0],  # PIL 图像
    heatmap=batched_similarity_maps[0].reshape(n_patches, n_patches),
    alpha=0.6  # 热图透明度
)

ValueError: The number of patches (16 x 16 = 256) does not match the number of non-padded image tokens (1024).