In [None]:
%pip install faiss-cpu
from PIL import Image
import requests
import faiss
import json
import time
import os
from google.colab import files



In [None]:
from transformers import AutoProcessor, AutoModel
import torch
import torch.nn as nn

device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

class mySigLipModel(nn.Module):
    def __init__(self):
        super(mySigLipModel, self).__init__()
        self.model = AutoModel.from_pretrained("google/siglip-base-patch16-224")
        self.processor = AutoProcessor.from_pretrained("google/siglip-base-patch16-224")
        self.shape = 768

    def get_image_embeddings(self, images, type):
      if type == "gpu":
        self.model.to(device)
        inputs = self.processor(images=images, return_tensors="pt").to(device)
        outputs = self.model.get_image_features(**inputs).cpu()
        self.model.to("cpu")
      else:
        inputs = self.processor(images=images, return_tensors="pt")
        outputs = self.model.get_image_features(**inputs)
      return outputs.detach().numpy()

In [None]:
class ImageIndex():
    def __init__(self):
        self.extractor = mySigLipModel()
        self.index = faiss.IndexFlatL2(self.extractor.shape)
        self.train_images_path = "SBU_captioned_photo_dataset_urls.txt"
        self.train_images = open(self.train_images_path, "r").readlines()

    # def indexing(self, index_path, image_path, start_idx, end_idx):
    #   image_urls = []
    #   images = []

    #   print("start image loading ...")
    #   start = time.time()

    #   for i, img_url in enumerate(self.train_images[start_idx:end_idx]):
    #       img_url = img_url.strip()
    #       try:
    #           frame = Image.open(requests.get(img_url, stream=True).raw).convert("RGB")
    #       except:
    #           print(f"Failed to load image {i}: {img_url}")
    #           continue

    #       image_urls.append(img_url)
    #       images.append(frame)

    #   end = time.time()
    #   print("Finish image loading in " + str(end - start) + ' seconds')

    #   print("start indexing ...")
    #   start = time.time()

    #   embeddings = self.extractor.get_image_embeddings(images, "gpu")
    #   self.index.add(embeddings)

    #   end = time.time()
    #   print('Finish indexing in ' + str(end - start) + ' seconds')
    #   faiss.write_index(self.index, index_path)

    #   with open(image_path, "w") as f:
    #       json.dump(image_urls, f, indent=2)

    def indexing(self, index_path, image_path, start_idx, end_idx):
      image_urls = []

      print("start image indexing ...")
      start = time.time()

      for i, img_url in enumerate(self.train_images[start_idx:end_idx]):
          img_url = img_url.strip()
          try:
              frame = Image.open(requests.get(img_url, stream=True).raw).convert("RGB")
          except:
              print(f"Failed to load image {i}: {img_url}")
              continue

          image_urls.append(img_url)
          embedding = self.extractor.get_image_embeddings(frame, "gpu")
          self.index.add(embedding)

      end = time.time()
      print('Finish indexing in ' + str(end - start) + ' seconds')
      faiss.write_index(self.index, index_path)

      with open(image_path, "w") as f:
          json.dump(image_urls, f, indent=2)

In [None]:
indexer = ImageIndex()

start_idx = 10000 + 3000
end_idx = start_idx+5000

index_path = "siglip-image-index-{}k-{}k.bin".format(start_idx//1000, end_idx//1000)
image_path = "siglip_image_urls-{}k-{}k.json".format(start_idx//1000, end_idx//1000)

print("Start Indexing from {}k to {}k".format(start_idx//1000, end_idx//1000))

indexer.indexing(index_path, image_path, start_idx, end_idx)

files.download(index_path)
files.download(image_path)

print("Finish Indexing from {}k to {}k\n".format(start_idx//1000, end_idx//1000))

Start Indexing from 13k to 18k
start image indexing ...
Failed to load image 3: http://static.flickr.com/8/6935632_b2df4ef70b.jpg
Failed to load image 5: http://static.flickr.com/2200/2250706054_c4cd6681ae.jpg
Failed to load image 8: http://static.flickr.com/4133/5032259430_c74edf3476.jpg
Failed to load image 14: http://static.flickr.com/5001/5211082788_cb4de4425e.jpg
Failed to load image 29: http://static.flickr.com/13/102208672_0f94c69ed5.jpg
Failed to load image 40: http://static.flickr.com/2236/2275681314_90a214ff73.jpg
Failed to load image 43: http://static.flickr.com/3342/4601510597_1711fec2da.jpg
Failed to load image 55: http://static.flickr.com/151/404441441_591c44e20b.jpg
Failed to load image 57: http://static.flickr.com/2746/4427160477_61deb80b1b.jpg
Failed to load image 64: http://static.flickr.com/4118/4891647198_469a947f2c.jpg
Failed to load image 69: http://static.flickr.com/3456/3936712975_04fb6c7e93.jpg
Failed to load image 75: http://static.flickr.com/4032/4262892094_c

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

Finish Indexing from 13k to 18k



In [None]:
# from google.colab import files
# files.download("siglip-image-index-90k.bin")
# files.download("siglip_image_urls-90k.json")

In [None]:
# from google.colab import files
# path = "README.txt"
# files.download(path)
# files.download("download.m")

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>