<a href="https://colab.research.google.com/github/shpotes/tensorflowers/blob/clip/notebooks/CLIPTest.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import sys

if 'google.colab' in sys.modules:
  !pip install transformers datasets -qq

In [2]:
from itertools import chain
from collections import Counter
from datasets import load_dataset
from transformers import CLIPProcessor, CLIPModel
from PIL import Image
import torch
from tqdm import tqdm

In [3]:
model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")

ftfy or spacy is not installed using BERT BasicTokenizer instead of ftfy.


In [4]:
ds = load_dataset("shpotes/tfcol", split="validation")

Using custom data configuration default
Reusing dataset tf_col (/root/.cache/huggingface/datasets/shpotes___tf_col)/default/1.0.0/0c616218d5e0a194334e0ed0adacd86ab9b315ec6b03a8b388dece024753def2)


In [5]:
def int2str(x):
  _int2str = ds.features["labels"].feature.int2str
  if isinstance(x, int):
    return _int2str(x)
  elif hasattr(x, '__iter__'):
    return [_int2str(i) for i in x]
  raise TypeError

def batch(iterable, n=1):
  l = len(iterable)
  num = range(l)
  for ndx in range(0, l, n):
    yield num[ndx:min(ndx + n, l)], iterable[ndx:min(ndx + n, l)]

In [6]:
images = [Image.open(img).convert("RGB") for img in ds["image"]]

prompt_seeds = [
    "clothing store",
    "liquor store",
    "barber shop",
    "electronic store",
    "coffee store",
    "furniture store",
    "fast food cart", # puesto movil 
    "ERROR OSJDFADOIAJSOIDJAMS", # electrodomesticos
    "butcher shop",
    "bar",
    "pet shop",
    "store",
    "pharmacy",
    "sport store",
    "car shop",
    "shoe shop",
    "supermarket",
    "hotel"
]

prompts = [f"an image of a {seed}" for seed in prompt_seeds]

model = model.cuda()

In [7]:
score = [0 for _ in range(len(prompts))]

In [8]:
for image_idx, image_batch in tqdm(batch(images, 8)):
  input = processor(
      text=prompts, 
      images=image_batch,
      return_tensors="pt", 
      padding=True,
  )
  
  input = {k: v.cuda() for k, v in input.items()}

  outputs = model(**input)
  logits_per_image = outputs.logits_per_image
  probs = logits_per_image.softmax(dim=1)

  for img_idx, prob_idx in zip(image_idx, range(8)):
    topk = set(torch.topk(probs[prob_idx], 5).indices.cpu().tolist())
    labels = set(ds["labels"][img_idx])


    for i in (topk & labels):
      score[i] += 1

83it [00:25,  3.20it/s]


In [9]:
final_score = [local / len(ds) for local in score]

In [10]:
class_names = ds.features["labels"].feature.names

In [11]:
dict(zip(class_names, final_score))

{'animales': 0.00303951367781155,
 'bar': 0.02127659574468085,
 'belleza/barbería/peluquería': 0.02127659574468085,
 'café/restaurante': 0.041033434650455926,
 'carnicería/fruver': 0.019756838905775075,
 'deporte': 0.0,
 'electrodomésticos': 0.0,
 'electrónica/cómputo': 0.0243161094224924,
 'farmacia': 0.004559270516717325,
 'ferretería': 0.0,
 'licorera': 0.00911854103343465,
 'muebles/tapicería': 0.0060790273556231,
 'parqueadero': 0.0060790273556231,
 'puesto móvil/toldito': 0.0182370820668693,
 'ropa': 0.057750759878419454,
 'talleres carros/motos': 0.0121580547112462,
 'tienda': 0.0060790273556231,
 'zapatería': 0.0}

In [12]:
boots = [0 for _ in range(100)]

for k in range(100):
  random_baseline = [0 for _ in range(len(prompts))]

  for labels in ds["labels"]:
    topk = set(torch.topk(torch.rand(18), 5).indices.tolist())
    labels = set(labels)

    for i in (topk & labels):
        random_baseline[i] += 1

  random_baseline = [local / len(ds) for local in random_baseline]
  boots[k] = sum([x <= y for x, y in zip(final_score, random_baseline)]) / len(prompts)

In [13]:
dict(zip(class_names,boots))

{'animales': 0.6666666666666666,
 'bar': 0.6666666666666666,
 'belleza/barbería/peluquería': 0.6666666666666666,
 'café/restaurante': 0.6666666666666666,
 'carnicería/fruver': 0.5555555555555556,
 'deporte': 0.6111111111111112,
 'electrodomésticos': 0.6666666666666666,
 'electrónica/cómputo': 0.6666666666666666,
 'farmacia': 0.6111111111111112,
 'ferretería': 0.6666666666666666,
 'hotel': 0.6111111111111112,
 'licorera': 0.6111111111111112,
 'muebles/tapicería': 0.6666666666666666,
 'parqueadero': 0.6666666666666666,
 'puesto móvil/toldito': 0.5555555555555556,
 'ropa': 0.6111111111111112,
 'supermercado': 0.6111111111111112,
 'talleres carros/motos': 0.6111111111111112,
 'tienda': 0.6111111111111112,
 'zapatería': 0.6666666666666666}

In [14]:
{int2str(k): v for k, v in Counter(chain.from_iterable(ds["labels"])).items()}

{'animales': 7,
 'bar': 15,
 'belleza/barbería/peluquería': 31,
 'café/restaurante': 165,
 'carnicería/fruver': 14,
 'deporte': 3,
 'electrodomésticos': 19,
 'electrónica/cómputo': 22,
 'farmacia': 27,
 'ferretería': 24,
 'hotel': 16,
 'licorera': 6,
 'muebles/tapicería': 26,
 'parqueadero': 64,
 'puesto móvil/toldito': 55,
 'ropa': 43,
 'supermercado': 18,
 'talleres carros/motos': 60,
 'tienda': 71,
 'zapatería': 3}