In [None]:
%%capture
!pip install -q transformers >/dev/null

import os
import time
import urllib.request
import pandas as pd
import numpy as np
from multiprocessing.dummy import Pool
from transformers import CLIPProcessor, CLIPModel
from PIL import Image, ImageFile
ImageFile.LOAD_TRUNCATED_IMAGES = True

processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")

def compute_image_embeddings(list_of_images):
  return model.get_image_features(**processor(images=list_of_images, return_tensors="pt", padding=True))

def load_image(path, same_height=False):
  im = Image.open(path)
  if im.mode != 'RGB':
    im = im.convert('RGB')
  if same_height:
    ratio = 224/im.size[1]
    return im.resize((int(im.size[0]*ratio), int(im.size[1]*ratio)))    
  else:
    ratio = 224/min(im.size)
    return im.resize((int(im.size[0]*ratio), int(im.size[1]*ratio)))

def fetch_url(url_filename):
  url, filename = url_filename
  urllib.request.urlretrieve(url, filename)

In [None]:
max_n_parallel = 20
latency = 2 # idle duration to reduce the download rate for the images

for dataset in ['', '2']:
  df = pd.read_csv(f'data{dataset}.csv')
  length = len(df)

  try:
    image_embeddings = np.load(f"embeddings{dataset}.npy")
    i = image_embeddings.shape[0]
  except FileNotFoundError:
    image_embeddings, i = None, 0

  while i < length:
    for f in os.listdir():
      if '.jpeg' in f:
        os.remove(f)

    n_parallel = min(max_n_parallel, length - i)
    url_filename_list = [(df.iloc[i + j]['path'], str(i + j) + '.jpeg') for j in range(n_parallel)]
    _ = Pool(n_parallel).map(fetch_url, url_filename_list)
    batch_embeddings = compute_image_embeddings([load_image(str(i + j) + '.jpeg') for j in range(n_parallel)]).detach().numpy()

    if image_embeddings is None:
      image_embeddings = batch_embeddings
    else:
      image_embeddings = np.vstack((image_embeddings, batch_embeddings))

    i = image_embeddings.shape[0]
    time.sleep(latency)
    if i % 100 == 0:
      np.save(f"embeddings{dataset}.npy", image_embeddings)
      print(i)