In [1]:
import faiss
from extract_features import mySigLipModel
import time
import numpy as np
import requests
from PIL import Image
import json
import re

In [3]:
image_path = './indexed_100k-200k/siglip_image_urls-120k-125k.json'
index_path = './indexed_100k-200k/siglip-image-index-120k-125k.bin'
dataset_caption_path = './dataset/SBU_captioned_photo_dataset_captions.txt'
dataset_url_path = './dataset/SBU_captioned_photo_dataset_urls.txt'

In [4]:
#load in the embedding extractor
extractor = mySigLipModel()
#load in indexer
index = faiss.read_index(index_path)

In [5]:
testing_size = 50
k_list = [3, 5, 10, 20, 50]

In [11]:
start_index = 120000
end_index = 125000
with open(dataset_caption_path, 'r') as f:
    captions = f.readlines()[start_index:end_index]
with open(dataset_url_path, 'r') as f:
    urls = f.readlines()[start_index:end_index]
with open(image_path, 'r') as f:
    train_image_urls = json.load(f)

In [12]:
# select (testing_size * 1.5) images from the dataset in random
selecting_size = int(testing_size * 1.5)  # 20% more for the case that some images are not valid
np.random.seed(0)
selected_indices = np.random.choice(len(captions), selecting_size, replace=False)
selected_captions = [captions[i] for i in selected_indices]
selected_urls = [urls[i] for i in selected_indices]

## Compute the Recall of the model

In [13]:
# compute the recall@k
for k in k_list:
    recall = 0
    count = 0
    for i in range(selecting_size):
        url = selected_urls[i]
        # use regex to extract the image id in the pattern of 'http://static.flickr.com/[image_id]/XXX.jpg'
        image_id = int(re.search(r'http://static.flickr.com/(\d+)/', url).group(1))
        # print('image_id: {}'.format(image_id))
        caption = selected_captions[i]

        # check if the image is valid
        try:
            frame = Image.open(requests.get(url, stream=True).raw)
            count += 1
        except:
            continue

        # get the embedding of the caption
        query_embedding = extractor.get_text_embedding(caption)

        # search the k nearest neighbors
        D, I = index.search(query_embedding, k)
        # check if the caption of the image is in the k nearest neighbors
        result_urls = [train_image_urls[j].strip() for j in I[0]]
        result_ids = [int(re.search(r'http://static.flickr.com/(\d+)/', result_url).group(1)) for result_url in result_urls]
        
        # print('Query: {}'.format(caption))
        # print('Query url: {}'.format(url))
        # print('Results:')
        # for j in range(k):
        #     print('id: {}, url: {}'.format(result_ids[j], result_urls[j]))

        if image_id in result_ids:
            recall += 1

        if count >= testing_size:
            break

    print('Recall@{}: {}'.format(k, recall / testing_size))

Recall@3: 0.14
Recall@5: 0.22
Recall@10: 0.34
Recall@20: 0.46
Recall@50: 0.5


## Compare the search time of L2 and HNSW

In [11]:
k = 50
l2_indexing_time = 0
hnsw_indexing_time = 0

l2_indexing_path = './l2_index/0-100k.bin'
hnsw_indexing_path = './hnsw_index/0-100k.bin'
dataset_caption_path = './dataset/SBU_captioned_photo_dataset_captions.txt'

In [12]:
start_index = 0
end_index = 100000

l2_index = faiss.read_index(l2_indexing_path)
hnsw_index = faiss.read_index(hnsw_indexing_path)
with open(dataset_caption_path, 'r') as f:
    captions = f.readlines()[start_index:end_index]
with open(dataset_url_path, 'r') as f:
    urls = f.readlines()[start_index:end_index]

In [13]:
testing_size = 100
np.random.seed(0)
selected_indices = np.random.choice(len(captions), testing_size, replace=False)
selected_captions = [captions[i] for i in selected_indices]

In [15]:
# l2
start = time.time()
for caption in selected_captions:
    query_embedding = extractor.get_text_embedding(caption)
    D, I = l2_index.search(query_embedding, k)
end = time.time()
l2_indexing_time = end - start
print('Flat-l2 indexing time: {}'.format(l2_indexing_time))

# hnsw
start = time.time()
for caption in selected_captions:
    query_embedding = extractor.get_text_embedding(caption)
    D, I = hnsw_index.search(query_embedding, k)
end = time.time()
hnsw_indexing_time = end - start
print('HNSW indexing time: {}'.format(hnsw_indexing_time))

Flat-l2 indexing time: 7.276071071624756
HNSW indexing time: 6.214844465255737


## Compare the recall@50 of L2 and HNSW

In [18]:
# list of image link to filter out
filter_out = [
      "http://static.flickr.com/3203/3093853175_d2d73da89d.jpg",
   "http://static.flickr.com/4117/4769421254_193aa864c9.jpg",
   "http://static.flickr.com/3192/2943737977_4d397cf7ff.jpg",
   "http://static.flickr.com/55/116895922_9d43425b97.jpg",
   " http://static.flickr.com/2757/4434645492_bec30211bd.jpg",
   "http://static.flickr.com/3084/2902107560_ec29686f68.jpg",
    "https://static.flickr.com/1118/1194875137_25885364be.jpg",
   "https://static.flickr.com/2636/3935060630_7dcf980757.jpg",
    "https://static.flickr.com/216/523409934_f0d9aa5a96.jpg",
   "https://static.flickr.com/3147/3056994559_e9f6d21555.jpg",
    "https://static.flickr.com/25/65770723_9f1921f2bb.jpg",
    "https://static.flickr.com/2440/3634933857_a7c572d56f.jpg",
    "https://static.flickr.com/5178/5489673855_81d8916bef.jpg",
    "https://static.flickr.com/4020/4671844388_2b294f83d4.jpg",
   "https://static.flickr.com/2490/3816574357_3d657db043.jpg",
   "https://static.flickr.com/3596/3405303362_4b1b12d135.jpg",
   "https://static.flickr.com/3320/3297671530_dfce477ca7.jpg",
   "https://static.flickr.com/1249/1446311820_59accba36c.jpg",
   "https://static.flickr.com/5178/5489673855_81d8916bef.jpg",
   "https://static.flickr.com/114/279430835_08bcf39b98.jpg",
   "https://static.flickr.com/5047/5216626543_5a340330b7.jpg",
   "https://static.flickr.com/2684/4520382314_eb82d277cc.jpg",
   "https://static.flickr.com/1023/593085881_d937b91335.jpg",
   "https://static.flickr.com/4020/4671844388_2b294f83d4.jpg",
   "https://static.flickr.com/2490/3816574357_3d657db043.jpg",
   "https://static.flickr.com/1249/1446311820_59accba36c.jpg",
   "https://static.flickr.com/4020/5078251694_7b3a9c03c4.jpg",
   "https://static.flickr.com/2684/4520382314_eb82d277cc.jpg",
   "https://static.flickr.com/4113/5041050280_b445ff4505.jpg",
   "http://static.flickr.com/1225/1347132376_85bee547a0.jpg",
   "https://static.flickr.com/4058/4284577483_3cbe6f67d3.jpg",
   "https://static.flickr.com/1201/5104630077_3262c5f50a.jpg",
   "http://static.flickr.com/2713/4428452344_bb68a99fe9.jpg",
   " http://static.flickr.com/57/211777140_bbaea2510a.jpg",
   "http://static.flickr.com/2038/2329084061_1652eeee47.jpg",
   "http://static.flickr.com/4046/4419902221_ccdf0f6df9.jpg",
   "http://static.flickr.com/3546/3524977802_68086bf196.jpg",
   "http://static.flickr.com/107/293508228_633d88bfa1.jpg",
   "http://static.flickr.com/2663/3763807913_337a12d607.jpg"
    "http://static.flickr.com/2008/2102327757_d5efe046a1.jpg",
    "http://static.flickr.com/66/208914447_b9fd9ccee2.jpg",
    "http://static.flickr.com/172/370603691_cfbbcfbdf0.jpg",
    "http://static.flickr.com/1371/1161577755_334a3be670.jpg",
    "http://static.flickr.com/1379/1446311792_f9150907a8.jpg",



]

filter_out1 = [
    "https://static.flickr.com/4051/4227721866_c52e04a94c.jpg",
    "https://static.flickr.com/3198/3054356126_50e05efecb.jpg",
    "https://static.flickr.com/216/523409934_f0d9aa5a96.jpg",
    "https://static.flickr.com/3069/2632441642_b38b5fbf72.jpg",
    "https://static.flickr.com/2677/4332578964_68faa48446.jpg",
    "https://static.flickr.com/4064/4317883247_33dabacf5b.jpg",
]
filtered_urls = [url for url in filter_out + filter_out1]


# Function to extract image ID from URL
def extract_image_id(url):
    try:
        # Modified regex to match both http and https
        image_id = int(re.search(r'https?://static.flickr.com/(\d+)/', url).group(1))
        return image_id
    except AttributeError:
        # This will handle the case where the regex search finds no match
        print(f"No image ID found in URL: {url}")
        return None

# List to hold the image IDs
filter_ids = []

# Iterate over each URL and extract the image ID
for url in filtered_urls:
    id = extract_image_id(url)
    if id is not None:
        filter_ids.append(id)

In [54]:
k = 50
l2_indexing_time = 0
hnsw_indexing_time = 0

l2_indexing_path = './indexed_0-100k/siglip-image-index-0-100k.bin'
#hnsw_indexing_path = './hnsw_index/0-100k.bin'
dataset_caption_path = './dataset/SBU_captioned_photo_dataset_captions.txt'

l2_image_path = './indexed_0-100k/siglip-image-index-0-100k.json'
#hnsw_image_path = './hnsw_index/0-100k.json'

In [55]:
start_index = 0
end_index = 100000

l2_index = faiss.read_index(l2_indexing_path)
#hnsw_index = faiss.read_index(hnsw_indexing_path)
with open(dataset_caption_path, 'r') as f:
    captions = f.readlines()[start_index:end_index]
with open(dataset_url_path, 'r') as f:
    urls = f.readlines()[start_index:end_index]
with open(l2_image_path, 'r') as f:
    l2_train_image_urls = json.load(f)
#with open(hnsw_image_path, 'r') as f:
   # hnsw_train_image_urls = json.load(f)

In [56]:
testing_size = 100
selecting_size = int(testing_size * 1.5)  # 20% more for the case that some images are not valid
np.random.seed(int(time.time()))
selected_indices = np.random.choice(len(captions), selecting_size, replace=False)
selected_captions = [captions[i] for i in selected_indices]
selected_urls = [urls[i] for i in selected_indices]

In [57]:
import re
import requests
from PIL import Image

# Assuming 'extractor' and 'index' are properly defined and initialized elsewhere in your code
# along with k_list, selecting_size, selected_urls, selected_captions, train_image_urls, and testing_size
index = l2_index  # Use the L2 index for this example
train_image_urls = l2_train_image_urls  # Use the L2 image URLs for this example

# Combine the lists and convert to a set for faster lookup
filter_set = set(filter_out + filter_out1)

# Compute the recall@k
recall = 0
count = 0
for i in range(selecting_size):
    url = selected_urls[i]
    if url in filter_set:
        continue  # Skip this URL if it's in the filter set

    # Use regex to extract the image id
    image_id = int(re.search(r'https?://static.flickr.com/(\d+)/', url).group(1))
    caption = selected_captions[i]

    # Check if the image is valid
    try:
        response = requests.get(url, stream=True)
        frame = Image.open(response.raw)
        count += 1
    except:
        continue

    # Get the embedding of the caption
    query_embedding = extractor.get_text_embedding(caption)

    # Initially request a larger number of neighbors
    D, I = index.search(query_embedding, 2*k)  # Request at least 20 neighbors
    initial_result_urls = [train_image_urls[j].strip() for j in I[0]]
    filtered_result_ids = []
    filtered_result_urls = []

    # print('Query:', caption)
    for result_url in initial_result_urls:
        result_id = int(re.search(r'https?://static.flickr.com/(\d+)/', result_url).group(1))
        if result_id not in filter_ids:
            filtered_result_ids.append(result_id)
            filtered_result_urls.append(result_url)
            print( result_url)
            if len(filtered_result_urls) >= k:
                break

    # print('Filtered Results:')
    # for idx, result_url in enumerate(filtered_result_urls):
    #     print('id:', filtered_result_ids[idx], 'url:', result_url)

    if image_id in filtered_result_ids[:k]:  # Check only within the first k filtered results
        recall += 1

    if count >= testing_size:
        break

print(f'Recall@{k} using L2 index: {recall / testing_size}')

http://static.flickr.com/2661/4182402802_636d7dd23c.jpg
http://static.flickr.com/5011/5452050626_29a189976d.jpg
http://static.flickr.com/2769/4491980971_2e8e39109c.jpg
http://static.flickr.com/4098/4782588106_cef8734213.jpg
http://static.flickr.com/3042/2551474788_8d0bb7d22e.jpg
http://static.flickr.com/3536/3308656532_049143f3be.jpg
http://static.flickr.com/3051/2354656714_21fddf491b.jpg
http://static.flickr.com/2438/3632217591_26cc992d51.jpg
http://static.flickr.com/26/54348894_8637615d5c.jpg
http://static.flickr.com/2337/2232817460_8de16822cf.jpg
http://static.flickr.com/1064/1402709719_192f4d4b4e.jpg
http://static.flickr.com/4001/4415516043_4050593718.jpg
http://static.flickr.com/2386/2337482989_240d5ff280.jpg
http://static.flickr.com/3088/2584792397_4265d7df1d.jpg
http://static.flickr.com/4098/4943584612_10ce2b0c77.jpg
http://static.flickr.com/3354/3172333872_c1e306b046.jpg
http://static.flickr.com/3561/3395679010_46d6c40633.jpg
http://static.flickr.com/2776/4225239291_6d21b8ebbf.

In [25]:
index = _indexhnsw  # Use the HNSW index for this example
train_image_urls = hnsw_train_image_urls  # Use the HNSW image URLs for this example

# Compute the recall@k
recall = 0
count = 0

for i in range(selecting_size):
    url = selected_urls[i]
    if url in filter_set:
        continue  # Skip this URL if it's in the filter set

    # Use regex to extract the image id
    image_id = int(re.search(r'https?://static.flickr.com/(\d+)/', url).group(1))
    caption = selected_captions[i]

    # Check if the image is valid
    try:
        response = requests.get(url, stream=True)
        frame = Image.open(response.raw)
        count += 1
    except:
        continue

    # Get the embedding of the caption
    query_embedding = extractor.get_text_embedding(caption)

    # Initially request a larger number of neighbors
    D, I = index.search(query_embedding, 2*k)  # Request at least 20 neighbors
    initial_result_urls = [train_image_urls[j].strip() for j in I[0]]
    filtered_result_ids = []
    filtered_result_urls = []

    # print('Query:', caption)
    for result_url in initial_result_urls:
        result_id = int(re.search(r'https?://static.flickr.com/(\d+)/', result_url).group(1))
        if result_id not in filter_ids:
            filtered_result_ids.append(result_id)
            filtered_result_urls.append(result_url)
            if len(filtered_result_urls) >= k:
                break

    # print('Filtered Results:')
    # for idx, result_url in enumerate(filtered_result_urls):
    #     print('id:', filtered_result_ids[idx], 'url:', result_url)

    if image_id in filtered_result_ids[:k]:  # Check only within the first k filtered results
        recall += 1

    if count >= testing_size:
        break

print(f'Recall@{k} using HNSW index: {recall / testing_size}')

Recall@50 using HNSW index: 0.13
