## Imports

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import os
from pathlib import Path

import matplotlib.pyplot as plt
import numpy as np
from dotenv import load_dotenv
from PIL import Image
from vidore_benchmark.compression.token_pooling import HierarchicalEmbeddingPooler
from vidore_benchmark.retrievers.colpali_retriever import ColPaliRetriever
from vidore_benchmark.utils.constants import OUTPUT_DIR
from vidore_benchmark.utils.image_utils import scale_image

RESULTS_DIR = OUTPUT_DIR / "failure_analysis"
RESULTS_DIR.mkdir(exist_ok=True, parents=True)

load_dotenv(override=True)

while "experiments" not in os.listdir():
    os.chdir("..")

## Load model and dataset

In [None]:
colpali_retriever = ColPaliRetriever()

resolution = 448
patch_size = 14
num_patches = resolution // patch_size

In [None]:
query = "Which hour of the day had the highest overall electricity generation in 2019?"
image_filepath = Path("data/interpretability_examples/energy_electricity_generation.jpeg")
assert image_filepath.is_file(), f"File `{image_filepath}` not found"

img = Image.open(image_filepath)
scale_image(img, 256)

In [None]:
emb_documents = colpali_retriever.forward_documents(
    documents=[img],
    batch_size=1,
)

In [None]:
embedding_pooler = HierarchicalEmbeddingPooler(pool_factor=100)

emb_documents_pooled = []
list_cluster_id_to_indices = []
for emb_document in emb_documents:
    emb_document_pooled, cluster_id_to_indices = embedding_pooler.pool_embeddings(emb_document)
    emb_documents_pooled.append(emb_document_pooled)
    list_cluster_id_to_indices.append(cluster_id_to_indices)

print(emb_documents[0].shape)
print(emb_documents_pooled[0].shape)

In [None]:
list_cluster_id_to_indices[0]

In [None]:
lengths = [len(indices) for indices in list_cluster_id_to_indices[0].values()]
lengths

In [None]:
max(lengths)

In [None]:
lengths.index(max(lengths))

In [None]:
cluster_of_interest = list_cluster_id_to_indices[0][lengths.index(max(lengths)) + 1].cpu().tolist()
cluster_of_interest

In [None]:
patch_indices = [(idx // num_patches, idx % num_patches) for idx in cluster_of_interest]
patch_indices

In [None]:
emb_queries = colpali_retriever.forward_queries(
    queries=[query],
    batch_size=1,
)

In [None]:
style = {}
figsize = (8, 8)

# Get the image as a numpy array
input_image_square = img.resize((resolution, resolution))
img_array = np.array(input_image_square.convert("RGBA"))  # (H, W, C) where the last channel is the alpha channel

with plt.style.context(style):
    fig, axis = plt.subplots(num_patches, num_patches, figsize=figsize)

    # Plot the patches
    for i in range(num_patches):
        for j in range(num_patches):
            patch = img_array[i * patch_size : (i + 1) * patch_size, j * patch_size : (j + 1) * patch_size, :]
            if (i, j) in patch_indices:
                axis[i, j].imshow(patch)
            axis[i, j].axis("off")

    fig.subplots_adjust(wspace=0.1, hspace=0.1)

fig.tight_layout()

fig