## Introduction

It was in January of 2021 that **OpenAI** announced two new models: **DALL-E** and **CLIP**, both **multimodal** models connecting **texts and images**. In this tutorial we are going to implement CLIP model from scratch in **PyTorch**.

### What does CLIP do? Why is it fun?

In [Learning Transferable Visual Models From Natural Language Supervision paper](https://arxiv.org/abs/2103.00020), OpenAI introduces their new model which is called **CLIP**, for **Contrastive Language-Image Pre-training**. In a nutshell, this model learns the relationship between a whole sentence and the image it describes; in a sense that when the model is trained, given an input sentence it will be able to retrieve the most related images corresponding to that sentence (and vice versa). The important thing here is that it is trained on full sentences instead of single classes like car, dog, etc. The intuition is that when trained on whole sentences, the model can learn a lot more things and finds some pattern between images and texts.
They also show that when this model is trained on a huge dataset of images and their corresponding texts, it can also act as a classifier too. I encourage you to study the paper to learn more about this exciting model and their astonishing results on benchmarking datasets. To mention just one, CLIP model trained with this strategy classifies ImageNet better than those SOTA models trained on the ImageNet itself optimized for the only task of classification!

As a **teaser** (!), let's see what the final model that we will build in this article from scratch is capable of: given a query (raw text) like "a boy jumping with skateboard" or "a girl jumping from swing", the model will retrieve the most relevant images:

![](https://i.ibb.co/9gdYqNP/teaser-cropped.png)

In [None]:
# !pip install timm
# !pip install transformers

In [2]:
import os
import json
from pathlib import Path
import random
import json

import cv2
from PIL import Image
import numpy as np
import pandas as pd
import itertools
from tqdm.autonotebook import tqdm
import matplotlib.pyplot as plt

import torch
from torch import nn
import torch.nn.functional as F
import torchvision.transforms as transforms
import timm
from transformers import DistilBertModel, DistilBertConfig, DistilBertTokenizer

  from tqdm.autonotebook import tqdm


In [None]:
# !wget https://github.com/jbrownlee/Datasets/releases/download/Flickr8k/Flickr8k_Dataset.zip
# !unzip -qq Flickr8k_Dataset.zip
# !wget https://github.com/Delphboy/karpathy-splits/raw/refs/heads/main/dataset_flickr8k.json


In [5]:
import os
import requests
import zipfile

data_dir = "D:/data"
os.makedirs(data_dir, exist_ok=True)

dataset_url = "https://github.com/jbrownlee/Datasets/releases/download/Flickr8k/Flickr8k_Dataset.zip"
dataset_zip = os.path.join(data_dir, "Flickr8k_Dataset.zip")

json_url = "https://github.com/Delphboy/karpathy-splits/raw/refs/heads/main/dataset_flickr8k.json"
json_path = os.path.join(data_dir, "dataset_flickr8k.json")

def download_file(url, save_path):
    response = requests.get(url, stream=True)
    if response.status_code == 200:
        with open(save_path, "wb") as file:
            for chunk in response.iter_content(chunk_size=8192):
                file.write(chunk)

download_file(dataset_url, dataset_zip)

with zipfile.ZipFile(dataset_zip, "r") as zip_ref:
    zip_ref.extractall(data_dir)

download_file(json_url, json_path)

print("완료.")


KeyboardInterrupt: 

## Some pre-preocessing

In [None]:
image_path = "/content/Flicker8k_Dataset/"
captions_path = "/content/dataset_flickr8k.json"

## Config

*A note on config and CFG: I wrote the codes with python scripts and then converted it into a Jupyter Notebook. So, in case of python scripts, config is a normal python file where I put all the hyperparameters and in the case of Jupyter Notebook, its a class defined in the beginning of the notebook to keep all the hyperparameters.*

In [None]:
class CFG:
    image_path = image_path
    captions_path = captions_path
    batch_size = 32
    num_workers = 2

    head_lr = 1e-3
    image_encoder_lr = 1e-4
    text_encoder_lr = 1e-5
    weight_decay = 1e-3
    epochs = 4
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    model_name = 'resnet50'
    image_embedding = 2048
    text_encoder_model = "distilbert-base-uncased"
    text_embedding = 768
    text_tokenizer = "distilbert-base-uncased"
    max_length = 200

    pretrained = True # for both image encoder and text encoder
    trainable = True # for both image encoder and text encoder
    temperature = 0.07

    # image size
    size = 224

    # for projection head; used for both image and text encoders
    projection_dim = 256

## Utils

In [None]:
class AvgMeter:
    def __init__(self, name="Metric"):
        self.name = name
        self.reset()

    def reset(self):
        self.avg, self.sum, self.count = [0] * 3

    def update(self, val, count=1):
        self.count += count
        self.sum += val * count
        self.avg = self.sum / self.count

    def __repr__(self):
        text = f"{self.name}: {self.avg:.4f}"
        return text



## Dataset

As you can see in the tittle image of this article, we need to encode both images and their describing texts. So, the dataset needs to **return both images and texts**.
The datset we will be using is Flickr8k dataset, which contains (as name suggested) 8,000 image-text pairs.

I did not use additional data augmentations but you can add them if you want to improve the model's performance.

In [None]:
with open(captions_path) as f:
    flickr_json = json.load(f)

In [None]:
class Flickr8kCaptions(torch.utils.data.Dataset):
    """
    Flickr8K captions dataset.

    Karpathy split JSON can be downloaded from this webpage:
    https://cs.stanford.edu/people/karpathy/deepimagesent/
    """

    def __init__(self, image_path, captions_path, split, transform=None,
                 return_single_text=False):
        self.image_path = image_path
        self.captions_path = captions_path
        self.split = split
        self.transform = transform
        self.return_single_text = return_single_text

        # Read annotations and keep only those belonging to specified split.
        with open(self.captions_path) as f:
            flickr_json = json.load(f)

        # Convert the filtered list of tuples formatted as:
        # `(image_id, image_path, list[caption_ids], list[caption])`.
        # Only keep images that belong to required split.
        self.samples = [
            (
                ann["filename"][:-4],
                os.path.join(self.image_path, ann["filename"]),
                ann["sentids"],
                [entry["raw"] for entry in ann["sentences"]],
            )
            for ann in flickr_json["images"]
            if ann["split"] == split
        ]

    def __len__(self) -> int:
        return len(self.samples)

    def __getitem__(self, idx: int) -> dict:
        image_id, image_path, caption_ids, captions = self.samples[idx]
        if self.return_single_text:
            caption_idx = random.randrange(0, len(captions))
            captions = captions[caption_idx]
            caption_ids = caption_ids[caption_idx]
        image = Image.open(image_path).convert("RGB")
        if self.transform is not None:
            image = self.transform(image)

        return {
            "image_id": image_id,
            "caption_ids": caption_ids,
            "image": image,
            "captions": captions,
        }


data_config = timm.data.resolve_model_data_config(CFG.model_name)
transform = timm.data.create_transform(**data_config, is_training=False)

In [None]:
transform

In [None]:
timm.data.create_transform(**data_config, is_training=True)

## Image Encoder

The image encoder code is straight forward. I'm using PyTorch Image Models library (timm) here which makes a lot of different image models available from ResNets to EfficientNets and many more. Here we will use a ResNet50 as our image encoder. You can easily use torchvision library to use ResNets if you don't want to install a new library.

The code encodes each image to a fixed size vector with the size of the model's output channels (in case of ResNet50 the vector size will be **2048**). This is the output after the nn.AdaptiveAvgPool2d() layer.

In [None]:
class ImageEncoder(nn.Module):
    """
    Encode images to a fixed size vector
    """

    def __init__(
        self, model_name=CFG.model_name, pretrained=CFG.pretrained, trainable=CFG.trainable
    ):
        super().__init__()
        self.model = timm.create_model(
            model_name, pretrained, num_classes=0, global_pool="avg"
        )
        for p in self.model.parameters():
            p.requires_grad = trainable

    def forward(self, x):
        return self.model(x)

In [None]:
image_model = timm.create_model(CFG.model_name, pretrained=True, num_classes=0, global_pool='avg')
image_model

## Text Encoder

As I mentioned before, I'll use DistilBERT as the text encoder. Like its bigger brother BERT, two special tokens will be added to the actual input tokens: **CLS** and **SEP** which mark the start and end of a sentence. To grab the whole representation of a sentence (as the related BERT and DistilBERT papers point out) we use the final representations of the CLS token and we hope that this representation captures the overall meaning of the sentence (caption). Thinking it in this way, it is similar to what we did to images and converted them into a fixed size vector.

In the case of DistilBERT (and also BERT) the output hidden representation for each token is a vector with size **768**. So, the whole caption will be encoded in the CLS token representation whose size is 768.

In [None]:
class TextEncoder(nn.Module):
    def __init__(self, model_name=CFG.text_encoder_model, pretrained=CFG.pretrained, trainable=CFG.trainable):
        super().__init__()
        if pretrained:
            self.model = DistilBertModel.from_pretrained(model_name)
        else:
            self.model = DistilBertModel(config=DistilBertConfig())

        for p in self.model.parameters():
            p.requires_grad = trainable

        # we are using the CLS token hidden representation as the sentence's embedding
        self.target_token_idx = 0

    def forward(self, input_ids, attention_mask):
        output = self.model(input_ids=input_ids, attention_mask=attention_mask)
        last_hidden_state = output.last_hidden_state
        return last_hidden_state[:, self.target_token_idx, :]

In [None]:
text_model = DistilBertModel.from_pretrained(CFG.text_encoder_model)

## Projection Head


Now that we have encoded both our images and texts into fixed size vectors (2048 for image and 768 for text) we need to bring (project) them into a **same space** for both images and texts in order to be able to compare them and push apart the non-relevant image and texts and pull together those that match. So, the following code will bring the 2048 and 768 dimensional vectors into a 256 (projection_dim) dimensional world, where we can **compare** them.

"embedding_dim" is the size of the input vector (2048 for images and 768 for texts) and "projection_dim" is the the size of the output vector which will be 256 for our case. For understanding the details of this part you can refer to the CLIP paper.

In [None]:
class ProjectionHead(nn.Module):
    def __init__(
        self,
        embedding_dim,
        projection_dim=CFG.projection_dim,
    ):
        super().__init__()
        self.projection = nn.Linear(embedding_dim, projection_dim, bias=False)

    def forward(self, x):
        x = self.projection(x)
        return x

## CLIP

This part is where all the fun happens! I'll also talk about the loss function here. I translated some of the code from Keras code examples into PyTorch for writing this part. Take a look at the code and then read the explanation below this code block.

Here we will use the previous modules that we built to implement the main model. The \_\_init\_\_ function is self-explanatory. In the forward function, we first encode the images and texts separately into fixed size vectors (with different dimensionalities). After that, using separate projection modules we project them to that shared world (space) that I talked about previously. Here the encodings will become of similar shape (256 in our case). After that we will compute the loss. Again I recommend reading CLIP paper to get it better but I'll try my best to explain this part.

In **Linear Algebra**, one common way to measure if two vectors are of similar characteristics (they are like each other) is to calculate their **cosine similarity**; if the cosine similarity is big, they are alike and if it is small they are not (relatively speaking)!

Okay! What I just said is the most important thing to have in mind to understand this loss function. Let's continue. We talked about two vectors, but, what do we have here? We have image_embeddings, a matrix with shape (batch_size, 256) and text_embeddings with shape (batch_size, 256). Easy enough! it means we have two groups of vectors instead of two single vectors. How do we measure how similar two groups of vectors (two matrices) are to each other? Again, with dot product (@ operator in PyTorch does the dot product or matrix multiplication in this case). To be able to multiply these two matrices together, we transpose the second one. Okay, we get a matrix with shape (batch_size, batch_size) which we will call logits.

Let's consider what we hope that this model learns: **we want it to learn "similar representations (vectors)" for a given image and the caption describing it. Meaning that either we give it an image or the text describing it, we want it to produce same 256 sized vectors for both.**

CLIP training objective (loss):

![clip loss](https://miro.medium.com/v2/resize:fit:1400/1*KbxO4qPaq3z8dJ8vw5h1Qg.png)

In [None]:
class CLIPModel(nn.Module):
    def __init__(
        self,
        temperature=CFG.temperature,
        image_embedding=CFG.image_embedding,
        text_embedding=CFG.text_embedding,
    ):
        super().__init__()
        self.image_encoder = ImageEncoder()
        self.text_encoder = TextEncoder()
        self.image_projection = ProjectionHead(embedding_dim=image_embedding)
        self.text_projection = ProjectionHead(embedding_dim=text_embedding)

        # temperature in original CLIP is also a learnable parameter, but we fix this for easier implementation
        self.temperature = temperature

    def encode_image(self, image):
        # TODO: compute normalized image embeddings

        return image_embeddings

    def encode_text(self, input_ids, attention_mask):
        # TODO: compute normalized text embeddings

        return text_embeddings

    def forward(self, batch):
        image_embeddings = self.encode_image(batch['image'])
        text_embeddings = self.encode_text(batch['input_ids'], batch['attention_mask'])

        # TODO: calculate the loss
        # Hint: divide cosine similarity by self.temperature
        # Hint: use cross entropy loss (F.cross_entropy) using label targets [0, 1, ...., batch_size]

        return loss



## Train

Here are some funtions to help us load train and valid dataloaders, our model and then train and evaluate our model on those. There's not much going on here; just simple training loop and utility functions

In [None]:

def build_loaders(split='train'):
    dataset = Flickr8kCaptions(
        image_path,
        captions_path,
        split=split,
        transform=transform,
        return_single_text=True
    )
    dataloader = torch.utils.data.DataLoader(
        dataset,
        batch_size=CFG.batch_size,
        num_workers=CFG.num_workers,
        shuffle=True if split == 'train' else False,
    )
    return dataloader

Here's a handy function to train our model. There's not much happening here; just loading the batches, feeding them to the model and stepping the optimizer and lr_scheduler.

In [None]:
def train_epoch(model, train_loader, optimizer, tokenizer):
    loss_meter = AvgMeter()
    tqdm_object = tqdm(train_loader, total=len(train_loader))
    for batch in tqdm_object:
        batch['image'] = batch['image'].to(CFG.device)
        encoded_captions = tokenizer(
            batch['captions'], padding=True, truncation=True, max_length=CFG.max_length,
            return_tensors='pt'
        )
        batch['input_ids'] = encoded_captions['input_ids'].to(CFG.device)
        batch['attention_mask'] = encoded_captions['attention_mask'].to(CFG.device)

        loss = model(batch)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        count = batch["image"].size(0)
        loss_meter.update(loss.item(), count)

        tqdm_object.set_postfix(train_loss=loss_meter.avg)
    return loss_meter


def train():
    tokenizer = DistilBertTokenizer.from_pretrained(CFG.text_tokenizer)
    train_loader = build_loaders()

    model = CLIPModel().to(CFG.device)
    params = [
        {"params": model.image_encoder.parameters(), "lr": CFG.image_encoder_lr},
        {"params": model.text_encoder.parameters(), "lr": CFG.text_encoder_lr},
        {"params": list(model.image_projection.parameters()) + list(model.text_projection.parameters()),
         "lr": CFG.head_lr, "weight_decay": CFG.weight_decay}
    ]
    optimizer = torch.optim.AdamW(params, weight_decay=0.)

    for epoch in range(CFG.epochs):
        print(f"Epoch: {epoch + 1}")
        model.train()
        train_loss = train_epoch(model, train_loader, optimizer, tokenizer)

    return model

Running the next cell start training the model.

In [None]:
model = train()

## Inference

Okay! We are done with training the model. Now, we need to do inference which in our case will be giving the model a piece of text and want it to retrieve the most relevant images from an unseen validation (or test) set.

### Getting Image Embeddings

In this function, we feed our trained model validation set images and return the image_embeddings with shape (valid_set_size, 256).

In [None]:
def get_image_embeddings(model):
    tokenizer = DistilBertTokenizer.from_pretrained(CFG.text_tokenizer)
    valid_loader = build_loaders(split="val")

    model.eval()

    valid_image_embeddings = []
    with torch.no_grad():
        for batch in tqdm(valid_loader):
            image_embeddings = model.encode_image(batch["image"].to(CFG.device))
            valid_image_embeddings.append(image_embeddings)
    return torch.cat(valid_image_embeddings, dim=0)

In [None]:
image_embeddings = get_image_embeddings(model)

### Finding Matches

This function does the final task that we wished our model would be capable of: it gets the model, image_embeddings, and a text query. It will display the most relevant images from the validation set! Isn't it amazing? Let's see how it performs after all!

In [None]:
def find_matches(model, image_embeddings, query, image_filenames, n=9):
    tokenizer = DistilBertTokenizer.from_pretrained(CFG.text_tokenizer)
    encoded_query = tokenizer([query], return_tensors='pt')
    batch = {
        key: values.to(CFG.device)
        for key, values in encoded_query.items()
    }
    with torch.no_grad():
        text_embeddings = model.encode_text(batch["input_ids"], batch["attention_mask"])

    similarity = text_embeddings @ image_embeddings.T

    values, indices = torch.topk(similarity.squeeze(0), n)
    matches = [image_filenames[idx] for idx in indices]

    _, axes = plt.subplots(3, 3, figsize=(10, 10))
    for match, ax in zip(matches, axes.flatten()):
        image = cv2.imread(os.path.join(f"{CFG.image_path}", f"{match}"))
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        ax.imshow(image)
        ax.axis("off")

    plt.show()

This is how we use this function. Aaaannnndddd the results:

In [None]:
with open(captions_path) as f:
    flickr_json = json.load(f)

image_filenames = [os.path.join(image_path, ann["filename"]) for ann in flickr_json["images"] if ann["split"] == 'val']

In [None]:

find_matches(model,
             image_embeddings,
             query="dogs on the grass",
             image_filenames=image_filenames,
             n=9)

In [None]:
# try your own queries!
find_matches(model,
             image_embeddings,
             query="a boy skateboarding",
             image_filenames=image_filenames,
             n=9)