In [1]:
import faiss
from extract_features import mySigLipModel
from display_image import ImageDisplay
import time
import numpy as np
import requests
from PIL import Image
import json
import re

  from .autonotebook import tqdm as notebook_tqdm


In [8]:
image_path = './indexed_100k-200k/siglip_image_urls-120k-125k.json'
index_path = './indexed_100k-200k/siglip-image-index-120k-125k.bin'
dataset_caption_path = './dataset/SBU_captioned_photo_dataset_captions.txt'
dataset_url_path = './dataset/SBU_captioned_photo_dataset_urls.txt'

In [9]:
#load in the embedding extractor
extractor = mySigLipModel()
#load in image displayer
display = ImageDisplay()
#load in indexer
index = faiss.read_index(index_path)

In [10]:
testing_size = 50
k_list = [3, 5, 10, 20, 50]

In [11]:
start_index = 120000
end_index = 125000
with open(dataset_caption_path, 'r') as f:
    captions = f.readlines()[start_index:end_index]
with open(dataset_url_path, 'r') as f:
    urls = f.readlines()[start_index:end_index]
with open(image_path, 'r') as f:
    train_image_urls = json.load(f)

In [12]:
# select (testing_size * 1.2) images from the dataset in random
selecting_size = int(testing_size * 1.5)  # 20% more for the case that some images are not valid
np.random.seed(0)
selected_indices = np.random.choice(len(captions), selecting_size, replace=False)
selected_captions = [captions[i] for i in selected_indices]
selected_urls = [urls[i] for i in selected_indices]

In [13]:
# compute the recall@k
for k in k_list:
    recall = 0
    count = 0
    for i in range(selecting_size):
        url = selected_urls[i]
        # use regex to extract the image id in the pattern of 'http://static.flickr.com/[image_id]/XXX.jpg'
        image_id = int(re.search(r'http://static.flickr.com/(\d+)/', url).group(1))
        # print('image_id: {}'.format(image_id))
        caption = selected_captions[i]

        # check if the image is valid
        try:
            frame = Image.open(requests.get(url, stream=True).raw)
            count += 1
        except:
            continue

        # get the embedding of the caption
        query_embedding = extractor.get_text_embedding(caption)

        # search the k nearest neighbors
        D, I = index.search(query_embedding, k)
        # check if the caption of the image is in the k nearest neighbors
        result_urls = [train_image_urls[j].strip() for j in I[0]]
        result_ids = [int(re.search(r'http://static.flickr.com/(\d+)/', result_url).group(1)) for result_url in result_urls]
        
        # print('Query: {}'.format(caption))
        # print('Query url: {}'.format(url))
        # print('Results:')
        # for j in range(k):
        #     print('id: {}, url: {}'.format(result_ids[j], result_urls[j]))

        if image_id in result_ids:
            recall += 1

        if count >= testing_size:
            break

    print('Recall@{}: {}'.format(k, recall / testing_size))

Recall@3: 0.14
Recall@5: 0.22
Recall@10: 0.34
Recall@20: 0.46
Recall@50: 0.5
