In [None]:
# install the required libraries

!pip install transformers
!pip install tqdm

In [None]:
import json
import webcolors
import random

# preparing the ColorFoil benchmark

def prapare_data():

  img_list = [] # list of image urls
  cap_list = [] # list of captions
  foil_list = [] # list of foiled captions

  with open('/content/captions_val2017.json') as f: # read the data
      d = json.load(f)
      img_length = len(d["images"]) # total number of images in the MS COCO val set (2017)

      # create three lists of images, captions and foils
      for i in d["images"]:
        id = i["id"]
        flag = False
        for j in d["annotations"]:
          if j["image_id"]==id:
            caption = j["caption"]
            for word in caption.split(' '):
              if word in webcolors.CSS3_NAMES_TO_HEX: # using the webcolor python package
                flag = True
            if flag == True:
              foil = create_foil(caption) # call the create_foil function. it will randomly choose a foil color.
              img_list.append(i["coco_url"])
              cap_list.append(caption)
              foil_list.append(foil)
              print(i["coco_url"])
              print(caption, foil)
              break

      return img_list, cap_list, foil_list



In [None]:
# replace the original color with the foil color

def create_foil(caption):
  # most commonly used colors
  colors = ["blue", "black", "red", "pink", "yellow", "grey", "orange", "white", "green", "brown"]
  lst = caption.split(' ')
  for color in lst:
    if color in webcolors.CSS3_NAMES_TO_HEX:
      num = random.randint(0, 9)
      if colors[num] == color:
        foiling_color = colors[num-1]
      else: foiling_color = colors[num]
      caption = caption.replace(color, foiling_color)
  return caption


In [None]:
# finally we have the three lists

img_list, cap_list, foil_list = prapare_data()

# CLIP Model

In [None]:
'''zero-shot evaluation of CLIP on the ColorFoil '''

from PIL import Image
from tqdm import tqdm
import requests

from transformers import CLIPProcessor, CLIPModel
from transformers import ViltProcessor, ViltModel
import torch

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
model.to(device)

acc = 0 # number of correct caption-foil pairs

for i in tqdm(range(len(img_list))):
  url = img_list[i]

  if i == 550 or i == 1164 or i==1752 or i ==2236: # these are greyscale images. we are removing them
    continue

  image = Image.open(requests.get(url, stream=True).raw)

  inputs = processor(text=[cap_list[i], foil_list[i]], images=image, return_tensors="pt", padding=True).to(device)

  outputs = model(**inputs)
  logits_per_image = outputs.logits_per_image  # this is the image-text similarity score
  probs = logits_per_image.softmax(dim=1)  # we can take the softmax to get the label probabilities

  if probs[0][0] > probs[0][1]:
    acc += 1


print(acc)

In [None]:
# calcualte the final accuracy

print(acc)
print(len(img_list)-4)
print('The accuracy is: ', acc/(len(img_list)-4)*100) # substitute the grayscale images

# ViLT Model

In [None]:
from transformers import ViltProcessor, ViltForImageAndTextRetrieval
import requests
from tqdm import tqdm
from PIL import Image
import torch

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
processor = ViltProcessor.from_pretrained("dandelin/vilt-b32-finetuned-coco")
model = ViltForImageAndTextRetrieval.from_pretrained("dandelin/vilt-b32-finetuned-coco")
model.to(device)

acc = 0

for i in tqdm(range(len(img_list))):
  url = img_list[i]
  if i == 550 or i == 1164 or i==1752 or i==2236: # these are greyscale images. we are removing them
    continue
  image = Image.open(requests.get(url, stream=True).raw)
  texts = [cap_list[i], foil_list[i]]

  # forward pass
  scores = dict()
  for text in texts:
      # prepare inputs
      encoding = processor(image, text, return_tensors="pt").to(device)
      outputs = model(**encoding)
      scores[text] = outputs.logits[0, :].item()

  if scores[texts[0]] > scores[texts[1]]:
    acc += 1

In [None]:
# calcualte the final accuracy

print(acc)
print(len(img_list)-4)
print('The accuracy is: ', acc/(len(img_list)-4)*100) # substitute the grayscale images

# BridgeTower Model

In [None]:
from transformers import BridgeTowerProcessor, BridgeTowerForImageAndTextRetrieval
import requests
from tqdm import tqdm
from PIL import Image
import torch

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

processor = BridgeTowerProcessor.from_pretrained("BridgeTower/bridgetower-base-itm-mlm")
model = BridgeTowerForImageAndTextRetrieval.from_pretrained("BridgeTower/bridgetower-base-itm-mlm")
model.to(device)
acc = 0

for i in tqdm(range(len(img_list))):
  url = img_list[i]
  if i == 550 or i == 1164 or i==1752 or i ==2236: # these are greyscale images. we are removing them
    continue
  image = Image.open(requests.get(url, stream=True).raw)
  texts = [foil_list[i], cap_list[i]]

  # forward pass
  scores = dict()
  for text in texts:
      # prepare inputs
      encoding = processor(image, text, return_tensors="pt").to(device)
      outputs = model(**encoding)
      scores[text] = outputs.logits[0, 1].item()

  if scores[texts[0]] < scores[texts[1]]:
    acc += 1

In [None]:
# calcualte the final accuracy

print(acc)
print(len(img_list)-4)
print('The accuracy is: ', acc/(len(img_list)-4)*100) # substitute the grayscale images

In [None]:
# check if there exists any greyscale image in MS COCO val set. We need to remove them due to incompatibility issues
# if there is any grayscale image, the code will provide error message
# only run once

import numpy as np

for i in tqdm(range(2237, 2511)):
  url = img_list[i]

  img = Image.open(requests.get(url, stream=True).raw)

  img.split()
  ### splitting b,g,r channels
  r,g,b=img.split()

  ### PIL to numpy conversion:
  r = np.array(r)
  g = np.array(g)
  b = np.array(b)

  ### getting differences between (b,g), (r,g), (b,r) channel pixels
  r_g=np.count_nonzero(abs(r-g))
  r_b=np.count_nonzero(abs(r-b))
  g_b=np.count_nonzero(abs(g-b))

  ### sum of differences
  diff_sum=float(r_g+r_b+g_b)

  ### get image size:
  width, height = img.size

  ### get total pixels on image:
  totalPixels = width * height

  ### finding ratio of diff_sum with respect to size of image
  ratio = diff_sum/totalPixels
