In [None]:
import os
import torch
import skimage
import requests
import numpy as np
import pandas as pd
from PIL import Image
from io import BytesIO
import IPython.display
import matplotlib.pyplot as plt
from datasets import load_dataset
from collections import OrderedDict
from transformers import CLIPProcessor, CLIPModel, CLIPTokenizer
from tqdm import tqdm

<!-- We will select a sample of 300 images from this large number of images (3 318 333) -->

In [None]:
path_names = []
for i,j,k in os.walk('nnc_images_6/'):
    for file in k:
        path_names.append(i+file)

In [None]:
classes = [i.split('/')[-1][:-4] for i in path_names]

In [None]:
nc = [i.split('/')[-1][:-4] for i in path_names if "_n" not in i]

In [None]:
image_data_df = pd.DataFrame(zip(path_names, classes), columns=['image_url','caption'])

In [None]:
"""
Not all the URLs are valid. This function returns True if the URL is valid. False otherwise. 
"""
# def check_valid_URLs(image_URL):

#     try:
#       response = requests.get(image_URL)
#       Image.open(BytesIO(response.content))
#       return True
#     except:
#       return False

def get_image(image_URL):

    # response = requests.get(image_URL)
    image = image = Image.open(image_URL).convert("RGB")

    return image

def get_image_caption(image_ID):

    return image_data[image_ID]["caption"]

In [None]:
image_data_df["image"] = image_data_df["image_url"].apply(get_image)

In [None]:
import matplotlib.pyplot as plt

In [None]:
from transformers import CLIPConfig, CLIPTextConfig, CLIPVisionConfig

In [None]:
def get_model_info(model_ID, device):
  config_text = CLIPTextConfig(max_position_embeddings=128)
  config_vision = CLIPVisionConfig()
  config = CLIPConfig.from_text_vision_configs(config_text, config_vision)
  # Save the model to device
  # print(config)
  model = CLIPModel(config).to(device)
  model = CLIPModel.from_pretrained(model_ID).to(device)

  # Get the processor
  processor = CLIPProcessor.from_pretrained(model_ID)

  # Get the tokenizer
  tokenizer = CLIPTokenizer.from_pretrained(model_ID)

  # Return model, processor & tokenizer
  return model, processor, tokenizer

In [None]:
# # Set the device
device = "cuda" if torch.cuda.is_available() else "cpu"

model_ID = "openai/clip-vit-large-patch14"

model, processor, tokenizer = get_model_info(model_ID, device)

# Create Embeddings: Text and Image Embeddings

## Text Embeddings

In [None]:
def get_single_text_embedding(text):

  inputs = tokenizer(text, return_tensors = "pt", max_length=128).to(device)
  # print(inputs)
  text_embeddings = model.get_text_features(**inputs)

  # convert the embeddings to numpy array 
  embedding_as_np = text_embeddings.cpu().detach().numpy()
  # print(embedding_as_np.shape)
  return embedding_as_np

In [None]:
def get_all_text_embeddings(df, text_col):

   df["text_embeddings"] = df[str(text_col)].apply(get_single_text_embedding)

   return df 

In [None]:
image_data_df = get_all_text_embeddings(image_data_df, "caption")

In [None]:
image_data_df.head()

## Image Embeddings

In [None]:
def get_single_image_embedding(my_image):

  image = processor(
      text = None,
      images = my_image, 
      return_tensors="pt"
  )["pixel_values"].to(device)

  embedding = model.get_image_features(image)

  # convert the embeddings to numpy array
  embedding_as_np = embedding.cpu().detach().numpy()
  return embedding_as_np

### Get the embedding of all the images

In [None]:
def get_all_images_embedding(df, img_column):

  df["img_embeddings"] = df[str(img_column)].apply(get_single_image_embedding)

  return df

In [None]:
image_data_df = get_all_images_embedding(image_data_df, "image")

In [None]:
image_data_df.head()

In [None]:
import matplotlib.pyplot as plt
import numpy as np

def plot_images(images):

  for image in images:
    plt.imshow(image)
    plt.show()

def plot_images_by_side(top_images):

  index_values = list(top_images.index.values)
  list_images = [top_images.iloc[idx].image for idx in index_values] 
  list_captions = [top_images.iloc[idx].caption for idx in index_values] 
  similarity_score = [top_images.iloc[idx].cos_sim for idx in index_values] 

  n_row = n_col = 2

  _, axs = plt.subplots(n_row, n_col, figsize=(12, 12))
  axs = axs.flatten()
  for img, ax, caption, sim_score in zip(list_images, axs, list_captions, similarity_score):
      ax.imshow(img)
      sim_score = 100*float("{:.2f}".format(sim_score))
      ax.title.set_text(f"Caption: {caption}\nSimilarity: {sim_score}%")
  plt.show()

# Perform Similarity Search: Cosine 

## 1. Cosine Similarity Search

In [None]:
from sklearn.metrics.pairwise import cosine_similarity

In [None]:
def get_top_N_images(query, data, top_K=4, search_criterion="text", description_encodings=None):

    """
    Retrieve top_K (5 is default value) articles similar to the query
    """
    # Text to image Search
    if(search_criterion.lower() == "text"):
        if description_encodings is not None:
            query_vect = description_encodings
        else:
          query_vect = get_single_text_embedding(query)
        
    # Image to image Search
    else: 
      query_vect = get_single_image_embedding(query)

    # Relevant columns
    revevant_cols = ["caption", "image", "cos_sim"]
    
    # Run similarity Search
    # data["cos_sim"] = data["img_embeddings"].apply(lambda x: x @ query_vect.T)
    # print(data['cos_sim'])
    data["cos_sim"] = data["img_embeddings"].apply(lambda x: cosine_similarity(query_vect, x))
    print(data['cos_sim'].iloc[0])
    data["cos_sim"] = data["cos_sim"].apply(lambda x: x[0][0])
    
    """
    Sort Cosine Similarity Column in Descending Order 
    Here we start at 1 to remove similarity with itself because it is always 1
    """
    most_similar_articles = data.sort_values(by='cos_sim', ascending=False)[1:top_K+1]
    
    return most_similar_articles[revevant_cols].reset_index()

In [None]:
image_data_df["img_embeddings"].iloc[0].shape

In [None]:
image_data_df.columns

### a. Text to image search

In [None]:
# nc

In [None]:
pd.options.mode.chained_assignment = None

In [None]:
descriptions = open('classify_by_description_release/noun_compounds_all.json','r')
descriptions = descriptions.readlines()

In [None]:
descriptions = eval(descriptions[0])

In [None]:
captions = open('classify_by_description_release/noun_compounds_all_captions.json','r')
captions = captions.readlines()

In [None]:
captions = eval(captions[0])

In [None]:
len(captions.keys())

In [None]:
import clip
from torch.nn import functional as F
device = "cuda" if torch.cuda.is_available() else "cpu"

In [None]:
hparams = {'model_size':"ViT-L/14"}

In [None]:
clip_model, clip_preprocess = clip.load(hparams['model_size'], device=device, jit=False)
clip_model.eval()
clip_model.requires_grad_(False)

In [None]:
def compute_description_encodings(model,gpt_descriptions=None):
    description_encodings = OrderedDict()
    for k, v in gpt_descriptions.items():
        tokens = clip.tokenize(v,truncate=True).to(device)
        description_encodings[k] = F.normalize(model.encode_text(tokens))
    return description_encodings

In [None]:
gpt_descriptions = {}

for cls in nc:
    examples = []
    features = ", ".join(descriptions[cls])
    for i in captions[cls]:
        # examples.append(f"image of a {features}. An example of {cls} in an image is: {i}.") 
        examples.append(f"An example of {cls} in an image is: {i}. Some of its features are: {features}.")
        # examples.append(f"An example of {cls} in an image is: {i}.")
    gpt_descriptions[cls] = examples


In [None]:
description_encodings = compute_description_encodings(clip_model,gpt_descriptions=gpt_descriptions)

In [None]:
clip_desc = {}

for cls in nc:
    clip_desc[cls] = [f"a photo of a {cls}"]

In [None]:
for idx,i in enumerate(rev):
    if type(i) == list:
        rev[idx] = " ".join(i)

In [None]:
clip_rev_desc = {}

for cls,r in zip(nc,rev):
    clip_rev_desc[cls] = [f"a photo of a {r}"]

In [None]:
clip_description_encodings = compute_description_encodings(clip_model,gpt_descriptions=clip_desc)

In [None]:
features_description = {}

for cls in nc:
    examples = []
    features = ", ".join(descriptions[cls])
    features_description[cls] = [f"the image of a {features}."]

In [None]:
# gpt_descriptions

In [None]:
descriptor_encodings = compute_description_encodings(clip_model, gpt_descriptions=gpt_descriptions)

In [None]:
features_description_encodings = compute_description_encodings(clip_model,gpt_descriptions=features_description)

In [None]:
from torchvision import transforms
transform = transforms.Compose([transforms.Resize((224,224)), transforms.ToTensor()])

In [None]:
count = 0
f = 0
total_tested = 0
wrong = []
scores = []
for cls in tqdm(both):
    try:
        orig = image_data_df[image_data_df['caption']==cls].index.values[0]
        n1 = image_data_df[image_data_df['caption']==cls+'_n1'].index.values[0]
        n2 = image_data_df[image_data_df['caption']==cls+'_n2'].index.values[0]
        tmp_df = image_data_df.iloc[[orig,n1,n2]]
        n1_image = transform(tmp_df['image'].iloc[0]).to(device)
        n2_image = transform(tmp_df['image'].iloc[1]).to(device)
        n3_image = transform(tmp_df['image'].iloc[2]).to(device)
        # print(n1_image.shape, n2_image.shape, n3_image.shape)
        batched_images = torch.stack([n1_image, n2_image, n3_image], dim=0)
        batched_images = clip_model.encode_image(batched_images)
        batched_images = F.normalize(batched_images)
        # print(batched_images.shape, description_encodings[cls].T.shape)
        # sim = batched_images @ descriptor_encodings[cls].T 
        
        sim = batched_images @ clip_description_encodings[cls].T               # Run this for CLIP approach
        
        # sim = batched_images @ description_encodings[cls].T                     # Run this for our approach
        
        sim = [torch.mean(tens).cpu().detach().numpy() for tens in sim]
        max = np.argmax(sim)
        if cls != tmp_df['caption'].iloc[max]:
            wrong.append(tmp_df['caption'].iloc[max])
            print(cls, "||", tmp_df['caption'].iloc[max])
            # plot_images(tmp_df['image'])
            # plot_images([tmp_df['image'][orig]])
            # plot_images_by_side(top_images)
            count += 1
        else:
            scores.append(sim[max])
        total_tested += 1
        # break
    except Exception as e:
        f += 1
        # print(e)
        # break
    #     continue

# print(total_tested)
accuracy = (total_tested-count)/total_tested

print(accuracy)