In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
cd ..

/home/xavi/projects/image-search


  self.shell.db['dhist'] = compress_dhist(dhist)[-100:]


In [3]:
from typing import Tuple
import numpy as np
import pandas as pd
from pathlib import Path
import faiss
import torch
from torch.utils.data import DataLoader
from transformers import AutoModel, AutoProcessor
from tqdm import tqdm

from image_search.data import ImagesDataset, ConversionsDataset
from image_search.train import temporal_train_test_split
from image_search.model import QueryModel, ImageModel, LightningImageSearchSigLIP

from image_search.metrics import hit_rate, mean_average_precision_at_k

In [4]:
# CONSTANTS
IMAGES_FOLDER = Path("./data/unsplash-research-dataset-lite-latest/photos/")
BASE_MODEL = "google/siglip-base-patch16-224"
BATCH_SIZE = 2048
NUM_WORKERS = 4
SEED = 42

CHECKPOINT_PATHS = {
    "VANILLA": None,
    "FINE-TUNED": "./mlruns/168016125050379525/b4956e72f44f4df48b395cf338545014/checkpoints/epoch=0-step=60598.ckpt"
}

MODEL_NAME = "FINE-TUNED"
CHECKPOINT_PATH = CHECKPOINT_PATHS[MODEL_NAME]
OUTPUT_PATH = Path("./results/") / f"{MODEL_NAME}.csv"

In [5]:
processor = AutoProcessor.from_pretrained(BASE_MODEL)

In [6]:
model = AutoModel.from_pretrained(BASE_MODEL).to("cuda")

In [7]:
if CHECKPOINT_PATH is not None:
    # Note: Some of these keyword arguments should not be necessary after the latest changes in the code (but this would require retraining the model)
    lightning_model = LightningImageSearchSigLIP.load_from_checkpoint(CHECKPOINT_PATH, model=model, lr=1e-4)
else:
    # In case we want to load the vanilla model directly 
    lightning_model = LightningImageSearchSigLIP(model=model, lr=1e-4)

In [8]:
lightning_model = lightning_model.to(device="cuda", dtype=torch.bfloat16)

In [9]:
image_model = lightning_model.image_model
query_model = lightning_model.query_model

In [10]:
# Enable eval mode
image_model = image_model.eval()
query_model = query_model.eval()

In [11]:
torch.set_grad_enabled(False)

<torch.autograd.grad_mode.set_grad_enabled at 0x7f0e4caba8c0>

# Images

In [12]:
images_dataset = ImagesDataset(image_folder=IMAGES_FOLDER, processor=processor)
images_dataloader = torch.utils.data.DataLoader(images_dataset, batch_size=BATCH_SIZE, num_workers=NUM_WORKERS)

# Generate Index
Following the [Pinecone guide](https://www.pinecone.io/learn/series/faiss/faiss-tutorial/)

In [13]:
embedding_dim = model.vision_model.config.hidden_size
# Note: Since this dataset contains a pretty small number of images (25k) we can just use brute-force and use a Flat Faiss index.
# This way we also avoid having to "train" the retriever.
# In a real scenario, we would try to optimize for faster retrieval with potentially millions of items
index = faiss.IndexFlatL2(embedding_dim)

In [14]:
for batch in tqdm(images_dataloader):
    pixel_values = batch["pixel_values"].to(device=lightning_model.device, dtype=lightning_model.dtype)
    
    image_embeddings = image_model(pixel_values=pixel_values)
    image_embeddings = image_embeddings.to(device="cpu", dtype=torch.float32).detach().numpy()

    ids = batch["id"].to("cpu").detach().numpy()
    index.add(image_embeddings)

100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 13/13 [01:37<00:00,  7.51s/it]


In [15]:
# To avoid re-computing all indices, we will save the current index
faiss.write_index(index, f"indices/{MODEL_NAME}.index")

In [16]:
index = faiss.read_index(f"indices/{MODEL_NAME}.index")

# Dataset

In [17]:
# conversions = load_and_preprocess_data()
conversions = pd.read_parquet("./data/clean/conversions.parquet")
_, conversions_val = temporal_train_test_split(conversions)

In [18]:
# Dataset
image_dataset = ImagesDataset(image_folder=IMAGES_FOLDER, processor=processor)
dataset_val = ConversionsDataset(data=conversions_val, image_dataset=image_dataset, processor=processor)
dataloader_val = DataLoader(dataset_val, batch_size=BATCH_SIZE, num_workers=NUM_WORKERS)

In [19]:
true_ids = []
predicted_ids = []

for batch in tqdm(dataloader_val):
    input_ids = batch["input_ids"].to(lightning_model.device)
    pixel_values = batch["pixel_values"].to(device=lightning_model.device, dtype=lightning_model.dtype)

    ids = batch["ids"].unsqueeze(dim=0).numpy()

    query_embedding = query_model(input_ids)
    image_embedding = image_model(pixel_values)
    
    query_embedding = query_embedding.to(device="cpu", dtype=torch.float32).detach().numpy()
    # image_embedding = image_embedding.to("cpu").detach().numpy()

    distances, indices = index.search(query_embedding, k=25)

    true_ids.append(ids)
    predicted_ids.append(indices)

100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 947/947 [1:02:23<00:00,  3.95s/it]


In [20]:
true_ids = np.concatenate(true_ids, axis=1)
predicted_ids = np.concatenate(predicted_ids, axis=0)

# METRICS
- (N)DCG: https://arize.com/blog-course/ndcg/
- (Mean) Average Precion

### Hit Rate

In [22]:
k = [1, 5, 10, 25]

hit_rate_score = hit_rate(true_ids, predicted_ids, k=k)
map_score = mean_average_precision_at_k(true_ids=true_ids, predicted_ids=predicted_ids, k=k)

scores = pd.DataFrame({
    "hit_rate": hit_rate_score,
    "mAP": map_score,
})
scores.index.name = "k"
scores

Unnamed: 0_level_0,hit_rate,mAP
k,Unnamed: 1_level_1,Unnamed: 2_level_1
1,0.144425,0.144425
5,0.348244,0.217511
10,0.463021,0.232716
25,0.601751,0.241641


In [23]:
scores.to_csv(OUTPUT_PATH)