# Using CLIP for image retrieval

This notebook will show us how to use CLIP for image retrieval, using both text and images. 

First we will need to install the package `open_clip_torch`:

In [None]:
!pip install open_clip_torch

Now lets do some imports:

In [None]:
import os
import torch
import open_clip
import numpy as np
import matplotlib.pyplot as plt

from PIL import Image
from open_clip import tokenizer

%matplotlib inline
%config InlineBackend.figure_format = 'retina'

### Loading the model

Lets load in our CLIP model that we will be using to encode text and images into a shared embedding space. We will be using the `convnext` models trained on the `laion2b` dataset. 

To see all available CLIP models you can print `clip.available_models()`.

In [None]:
model, _, preprocess = open_clip.create_model_and_transforms('convnext_base_w', pretrained='laion2b_s13b_b82k_augreg')

Lets inspect our model. This will show us the number of parameters our CLIP model has, the number of tokens our model can process in its context window and the number of tokens in the CLIP models vocabulary. 

In [None]:
model.eval()
context_length = model.context_length
vocab_size = model.vocab_size

print("Model parameters:", f"{np.sum([int(np.prod(p.shape)) for p in model.parameters()]):,}")
print("Context length:", context_length)
print("Vocab size:", vocab_size)

### Image Preprocessing

We resize the input images and center-crop them to conform with the image resolution that the model expects. Before doing so, we will normalize the pixel intensity using the dataset mean and standard deviation.

The second return value from `clip.load()` contains a torchvision `Transform` that performs this preprocessing.



In [None]:
preprocess

### Text Preprocessing

We use a case-insensitive tokenizer, which can be invoked using `tokenizer.tokenize()`. By default, the outputs are padded to become 77 tokens long, which is what the CLIP models expects.

In [None]:
tokenizer.tokenize("Hello World!")

### Setting up input images and texts

We are going to feed 8 example images and their textual descriptions to the model, and compare the similarity between the corresponding features.

In [None]:
# images in skimage to use and their textual descriptions
descriptions = {
    "page": "a page of text about segmentation",
    "chelsea": "a facial photo of a tabby cat",
    "astronaut": "a portrait of an astronaut with the American flag",
    "rocket": "a rocket standing on a launchpad",
    "motorcycle_right": "a red motorcycle standing in a garage",
    "camera": "a person looking at a camera on a tripod",
    "horse": "a black-and-white silhouette of a horse",
    "coffee": "a cup of coffee on a saucer"
}

Lets look at the loaded images with their corresponding descriptions:

In [None]:
original_images = []
images = []
texts = []
filenames = []
plt.figure(figsize=(16, 5))

im_folder_path = '../media/demo_images/'

for filename in [filename for filename in os.listdir(im_folder_path) if filename.endswith(".png") or filename.endswith(".jpg")]:
    name = os.path.splitext(filename)[0]
    if name not in descriptions:
        continue

    image = Image.open(os.path.join(im_folder_path, filename)).convert("RGB")

    plt.subplot(2, 4, len(images) + 1)
    plt.imshow(image)
    plt.title(f"{filename}\n{descriptions[name]}")
    plt.xticks([])
    plt.yticks([])

    original_images.append(image)
    images.append(preprocess(image))
    texts.append(descriptions[name])
    filenames.append(name)

plt.tight_layout()


### Building features

We normalize the images, tokenize each text input, and run the forward pass of the model to get the image and text features.

In [None]:
image_input = torch.tensor(np.stack(images))
print(image_input.shape)
text_tokens = tokenizer.tokenize(["This is " + desc for desc in texts])

In [None]:
with torch.no_grad():
    image_features = model.encode_image(image_input).float()
    text_features = model.encode_text(text_tokens).float()

### Calculating cosine similarity

Here we are normalising the features of both the embedding tensors calculated for the images and text. We then perform a matrix multiplication between the two matricies (with [the @ operator](https://www.logilax.com/numpy-at-operator/)). Performing a matrix multiplication between these two tensors is the same as taking the dot product of all the vectors in each row+column of the two matricies (text embeddings and image embeddings). This mathematical operation is the same as [taking the cosine similarity](https://en.wikipedia.org/wiki/Cosine_similarity).

In [None]:
image_features /= image_features.norm(dim=-1, keepdim=True)
text_features /= text_features.norm(dim=-1, keepdim=True)
similarity = text_features.cpu().numpy() @ image_features.cpu().numpy().T

### Plot confusion matrix

Here we will plot the confusion matrix to see how similiar our text and image embeddings are to each other:

In [None]:
count = len(descriptions)

plt.figure(figsize=(20, 14))
plt.imshow(similarity, vmin=0.1, vmax=0.3)
plt.yticks(range(count), texts, fontsize=18)
plt.xticks([])
for i, image in enumerate(original_images):
    plt.imshow(image, extent=(i - 0.5, i + 0.5, -1.6, -0.6), origin="lower")
for x in range(similarity.shape[1]):
    for y in range(similarity.shape[0]):
        plt.text(x, y, f"{similarity[y, x]:.2f}", ha="center", va="center", size=12)

for side in ["left", "top", "right", "bottom"]:
  plt.gca().spines[side].set_visible(False)

plt.xlim([-0.5, count - 0.5])
plt.ylim([count + 0.5, -2])

plt.title("Cosine similarity between text and image features", size=20)

### Make new query with text

Here we are going to use a new individual text string and compare it to our array of image features. We will then find use this information to find the image that is the closest match to our text prompt. 

Try changing the text prompt to how the match differs:

In [None]:
query_text = "This is a picture of a bicycle"
query_tokens = tokenizer.tokenize([query_text])

with torch.no_grad():
    query_features = model.encode_text(query_tokens).float()

query_features /= query_features.norm(dim=-1, keepdim=True)
q_similarity = query_features.cpu().numpy() @ image_features.cpu().numpy().T

max_index = np.argmax(q_similarity)
print(f'The closest match for \"{query_text}\" is the image: {filenames[max_index]}')
original_images[max_index]

### Make new query with an image

As well as performing information retrieval with text, we can also do it with images. Here let's load in a sketch image and see what we can retrieve. This is exactly the method used in the [Sketchy Collections](https://ualshowcase.arts.ac.uk/project/316244/cover) project by DSAI alumni Polo Sologub.

In [None]:
sketch_image = Image.open('../media/bike-sketch.jpg').convert("RGB")
sketch_im_np = preprocess(sketch_image).unsqueeze(0)
sketch_image

In [None]:
with torch.no_grad():
    sketch_features = model.encode_image(sketch_im_np).float()

sketch_features /= sketch_features.norm(dim=-1, keepdim=True)
s_similarity = sketch_features.cpu().numpy() @ image_features.cpu().numpy().T

max_index = np.argmax(q_similarity)
print(f'The closest match for the sketch query is the image: {filenames[max_index]}')
original_images[max_index]

## Tasks
**Task 1:** Run through this all of the code cells in this notebook and spend time reading and understanding the code.

**Task 2:** Try changing using different [text queries](#make-new-query-with-text) and [image queries](#make-new-query-with-an-image) to see what results you get. 

**Task 3:** Based on the code you have seen here, can you build a simple information retrieval system with a larger dataset of images, and use either text or images to query it? You will need to:
 - Load in a selection of images 
 - Get embeddings for all the images with CLIP (you may need to do this in batches depending on the size of the dataset)
 - Get the embedding for the query text or image
 - Calculate the cosine similiarity between the query and the dataset
 - Display the result based on the query

You can use one of the datasets you already have, such as the dataset you made from week 3, or you can download one from [kaggle](https://www.kaggle.com/) or download this [sample of the Metropolitan museum collection](http://ptak.felk.cvut.cz/met/dataset/test_met.tar.gz).

### Bonus exercises

Here are a few other ways you could extend and adapt CLIP for your projects:

**A:** Instead of searching for images with text can you search for documents of text with images? You could use CLIP to process either [the limerick](https://git.arts.ac.uk/tbroad/limerick-dataset) or [haiku](https://git.arts.ac.uk/tbroad/haiku-dataset) datasets. 

**B:** Can you use CLIP to train a CPPN (from week 2) or guide the generation of a GAN (from week 6) and use it to search the latent space for a specific text prompt?

**C:** Can you build a interactive application in Dorothy that uses CLIP to do image retrieval in real-time based on the web-cam image. You could hold up drawings to the webcam to make your own prototype of the [sketchy collections project](https://ualshowcase.arts.ac.uk/project/316244/cover).