<a href="https://colab.research.google.com/github/vinishgeorge/google-colab/blob/main/colpali.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import torch
print(f"CUDA available: {torch.cuda.is_available()}")
print(f"CUDA version: {torch.version.cuda if torch.cuda.is_available() else 'N/A'}")
if torch.cuda.is_available():
    print(f"CUDA device count: {torch.cuda.device_count()}")
    print(f"Current CUDA device: {torch.cuda.current_device()}")
    print(f"CUDA device name: {torch.cuda.get_device_name(0)}")

CUDA available: True
CUDA version: 12.4
CUDA device count: 1
Current CUDA device: 0
CUDA device name: Tesla T4


In [1]:
%pip install colpali_engine

Collecting colpali_engine
  Downloading colpali_engine-0.3.8-py3-none-any.whl.metadata (27 kB)
Collecting gputil (from colpali_engine)
  Downloading GPUtil-1.4.0.tar.gz (5.5 kB)
  Preparing metadata (setup.py) ... [?25l[?25hdone
Collecting transformers<4.48.0,>=4.47.0 (from colpali_engine)
  Downloading transformers-4.47.1-py3-none-any.whl.metadata (44 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m44.1/44.1 kB[0m [31m2.4 MB/s[0m eta [36m0:00:00[0m
Collecting nvidia-cuda-nvrtc-cu12==12.4.127 (from torch>=2.2.0->colpali_engine)
  Downloading nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-runtime-cu12==12.4.127 (from torch>=2.2.0->colpali_engine)
  Downloading nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-cupti-cu12==12.4.127 (from torch>=2.2.0->colpali_engine)
  Downloading nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl

In [2]:
%pip install requests Pillow



In [3]:
import requests
from PIL import Image
from io import BytesIO

# Array of image URLs
urls = [
    "https://vinishgeorgesandboxdiag.blob.core.windows.net/images/copali_test1.png",
    "https://vinishgeorgesandboxdiag.blob.core.windows.net/images/copali_test2.png"
]

# List to store all the fetched images
images = []

# Process each URL in the array
for i, url in enumerate(urls):
    try:
        # Fetch the image from the URL
        response = requests.get(url)

        # Check if the request was successful
        if response.status_code == 200:
            # Convert the response content to an image
            img = Image.open(BytesIO(response.content))

            # Resize it to 128x128 if needed
            img = img.resize((128, 128))

            # Add to our images list
            images.append(img)

            # Save the image with a unique name
            img.save(f"downloaded_image_{i}.jpg")

            print(f"Successfully downloaded image from {url}")
        else:
            print(f"Failed to fetch image from {url}: HTTP {response.status_code}")

    except Exception as e:
        print(f"Error processing {url}: {str(e)}")

# Now 'images' contains all the successfully downloaded PIL Image objects
print(f"Total images downloaded: {len(images)}")

Error processing https://vinishgeorgesandboxdiag.blob.core.windows.net/images/copali_test1.png: cannot write mode RGBA as JPEG
Error processing https://vinishgeorgesandboxdiag.blob.core.windows.net/images/copali_test2.png: cannot write mode RGBA as JPEG
Total images downloaded: 2


In [None]:
import torch
from PIL import Image
from transformers.utils.import_utils import is_flash_attn_2_available

from colpali_engine.models import ColQwen2, ColQwen2Processor

model_name = "vidore/colqwen2-v1.0"
# device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
model = ColQwen2.from_pretrained(
    model_name,
    torch_dtype=torch.bfloat16,
    device_map="cpu",
    attn_implementation="flash_attention_2" if is_flash_attn_2_available() else None,
).eval()

processor = ColQwen2Processor.from_pretrained(model_name)

# Your inputs

queries = [
    "how can i get all the assessment results of the user?",
    "what are the features associated with assessments?",
]

# Process the inputs
batch_images = processor.process_images(images).to(model.device)
batch_queries = processor.process_queries(queries).to(model.device)

# Forward pass
with torch.no_grad():
    image_embeddings = model(**batch_images)
    query_embeddings = model(**batch_queries)

scores = processor.score_multi_vector(query_embeddings, image_embeddings)

In [None]:
# Add this code after computing the scores

# Convert scores to a more readable format
def interpret_scores(scores, queries, image_urls):
    # Convert to numpy for easier handling
    scores_np = scores.cpu().numpy()

    # Print human-readable results
    print("\n===== QUERY-IMAGE MATCH RESULTS =====\n")

    for i, query in enumerate(queries):
        print(f"Query {i+1}: \"{query}\"")
        print("-" * 50)

        # Get scores for this query against all images
        query_scores = scores_np[i]

        # Sort images by relevance for this query
        sorted_indices = query_scores.argsort()[::-1]  # Descending order

        for rank, idx in enumerate(sorted_indices):
            score = query_scores[idx]
            url = image_urls[idx]
            filename = url.split('/')[-1]

            # Convert score to percentage for better readability
            score_percent = score * 100 if score <= 1.0 else score

            print(f"Rank {rank+1}: Image '{filename}'")
            print(f"   Score: {score:.4f} ({score_percent:.2f}%)")
            print(f"   URL: {url}")
            print()

        # Best match for this query
        best_idx = scores_np[i].argmax()
        best_score = scores_np[i][best_idx]
        best_image = image_urls[best_idx].split('/')[-1]

        print(f"Best match for query: '{best_image}' with score {best_score:.4f}\n")
        print("=" * 50 + "\n")

    # Print overall top matches
    print("\n===== OVERALL TOP MATCHES =====\n")

    # Flatten scores to find top matches overall
    flat_scores = scores_np.flatten()
    flat_indices = flat_scores.argsort()[::-1][:5]  # Top 5 matches

    for rank, flat_idx in enumerate(flat_indices):
        # Convert flat index to query and image indices
        query_idx = flat_idx // len(image_urls)
        image_idx = flat_idx % len(image_urls)

        score = flat_scores[flat_idx]
        query = queries[query_idx]
        image = image_urls[image_idx].split('/')[-1]

        print(f"Overall Rank {rank+1}:")
        print(f"   Query: \"{query}\"")
        print(f"   Image: {image}")
        print(f"   Score: {score:.4f}")
        print()

# Call the function
interpret_scores(scores, queries, urls)

# You can also visualize the results with a heatmap
import matplotlib.pyplot as plt
import numpy as np

def plot_similarity_heatmap(scores, queries, image_urls):
    # Convert to numpy
    scores_np = scores.cpu().numpy()

    # Create labels for the plot
    image_labels = [url.split('/')[-1] for url in image_urls]
    query_labels = [f"Q{i+1}: {q[:20]}..." if len(q) > 20 else f"Q{i+1}: {q}" for i, q in enumerate(queries)]

    # Create the figure
    plt.figure(figsize=(10, 6))
    heatmap = plt.imshow(scores_np, cmap='viridis')
    plt.colorbar(heatmap, label='Similarity Score')

    # Add labels
    plt.xticks(np.arange(len(image_labels)), image_labels, rotation=45, ha='right')
    plt.yticks(np.arange(len(query_labels)), query_labels)

    # Add values on the heatmap
    for i in range(len(query_labels)):
        for j in range(len(image_labels)):
            plt.text(j, i, f"{scores_np[i, j]:.2f}",
                     ha="center", va="center",
                     color="white" if scores_np[i, j] > 0.5 else "black")

    plt.title('Query-Image Similarity Scores')
    plt.tight_layout()
    plt.savefig('similarity_heatmap.png')
    plt.close()

    print("Heatmap saved as 'similarity_heatmap.png'")

# Generate the heatmap
plot_similarity_heatmap(scores, queries, urls)