In [1]:
import requests
from tqdm.auto import tqdm

import faiss
from PIL import Image
from transformers import CLIPProcessor, CLIPModel

In [2]:
# %load_ext autoreload
# %autoreload 2
import embeddings

## Reference Lists

In [30]:

req = requests.get("https://raw.githubusercontent.com/CSAILVision/places365/master/categories_places365.txt")
places_list = [x.split(' ') for x in req.content.decode('utf-8').split('\n')]
places = ["_".join(x[0].split('/')[2:]).replace("_", " ") for x in places_list]
with open("places_365.txt", 'w') as f:
    f.write("\n".join(places))

In [113]:
oi = []
with open("2017_11/class-descriptions.csv") as f:
    for line in f:
        oi.append(line.strip())
        
with open('objects_open_images.txt', 'w') as f:
    f.write("\n".join([x.split(',')[-1] for x in oi]))

## Embeddings

In [3]:
model = CLIPModel.from_pretrained("openai/clip-vit-large-patch14")
processor = CLIPProcessor.from_pretrained("openai/clip-vit-large-patch14")

In [5]:
image = Image.open('n02086240_7268.JPEG')
image_inputs = processor(images=image, return_tensors="pt", padding=True)
out = model.get_image_features(**image_inputs)

### Places

In [6]:
with open("places_365.txt") as f:
    places = f.read().split('\n')

In [7]:
inputs = processor(text=places, return_tensors="pt", padding=True)
places_embeddings = model.get_text_features(**inputs)

In [23]:
places_index = embeddings.FaissIndex(embedding_size=768, faiss_index_location='faiss_indices/places.index', indexer=faiss.IndexFlatIP)
places_index.reset()

In [24]:
places_index.add(places_embeddings.detach().numpy(), places)

In [25]:
places_index.search(out.detach().numpy())

(array([[0.17297854, 0.15687369, 0.15423924, 0.15403895, 0.15325068,
         0.14620401, 0.1395231 , 0.13754925, 0.13722014, 0.13552673]],
       dtype=float32),
 array([[ 50, 346, 201, 173, 261, 303, 208, 155,  63, 221]]),
 ['beauty salon',
  'veterinarians office',
  'kennel outdoor',
  'hayfield',
  'pet shop',
  'shower',
  'laundromat',
  'galley',
  'bow window indoor',
  'manufactured home'])

### Objects

In [6]:
with open("objects_open_images.txt") as f:
    objects = f.read().split('\n')

In [7]:
objects_index = embeddings.FaissIndex(embedding_size=768, faiss_index_location='faiss_indices/objects.index', indexer=faiss.IndexFlatIP)

In [8]:
objects_index.reset()

In [9]:
batches = list(range(0, len(objects), 300)) + [len(objects)]

In [11]:
batched_objects = []
for idx in range(0,len(batches)-1):
    batched_objects.append(objects[batches[idx]:batches[idx+1]])

In [12]:

for batch in tqdm(batched_objects):
    inputs = processor(text=batch, return_tensors="pt", padding=True)
    objects_embeddings = model.get_text_features(**inputs)
    objects_index.add(objects_embeddings.detach().numpy(), batch)

  0%|          | 0/67 [00:00<?, ?it/s]

In [13]:
objects_index.search(out.detach().numpy())

(array([[0.23854859, 0.21685843, 0.21458338, 0.21301872, 0.21031086,
         0.20976993, 0.20975675, 0.20324187, 0.19942844, 0.19938366]],
       dtype=float32),
 array([[ 1128,  6936, 11958,  1844, 18549,  8918, 18977,  7881, 10819,
         18120]]),
 ['Chinese crested dog',
  'Dandie dinmont terrier',
  'Yorkipoo',
  'Affenpinscher',
  'Plummer terrier',
  'Biewer terrier',
  'Grooming trimmer',
  'Cesky terrier',
  'Glen of imaal terrier',
  'Shaving and grooming'])