In this notebook, I try to implement the key idea of CLIP，including:
- image embedding
- word embedding
- linear projection of embeddings
- cosine similarity calculation between projected word embedding and projected image embedding
- using prompt to find the image with high cosine similarity

In [61]:
# import necessary library
import torch
import torchvision.models as models
import torchvision.transforms as transforms
from PIL import Image
from transformers import BertTokenizer, BertModel
import warnings
import torch.nn as nn
warnings.filterwarnings('ignore')

### 1. Image embedding

In [36]:
def get_image_embedding(img_path):
    # load pretrained resnet50
    resnet = models.resnet50(pretrained=True)

    # delete the final layer
    modules = list(resnet.children())[:-1]
    resnet = torch.nn.Sequential(*modules)

    # load image and preprocess
    image = Image.open(img_path).convert('RGB')
    transform = transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])
    image = transform(image)
    # model forward and get image embedding 
    with torch.no_grad():
        encoding = resnet(image.unsqueeze(0))
    img_embedding = encoding[0,:,0,0]
    return img_embedding

In [71]:
img_path1 = '1.png'
img_embed1 = get_image_embedding(img_path1)
print('dimention of image embedding:',img_embed1.shape)

dimention of image embedding: torch.Size([2048])


### 2. Text embedding

In [40]:
def get_text_embedding(text_content):
    # load bert and tokenizer
    tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
    model = BertModel.from_pretrained('bert-base-uncased', output_hidden_states = True)

    input_sentence = text_content # input sentence
    input_ids = torch.tensor([tokenizer.encode(input_sentence, add_special_tokens=True)])

    # model forward and get text embedding 
    with torch.no_grad():
        outputs = model(input_ids)
        last_hidden_states = outputs[0]
        mean_last_hidden_states = torch.mean(last_hidden_states, dim=1)
    text_embedding = mean_last_hidden_states[0,:]
    return text_embedding

In [68]:
text_content = 'a photo of digit 9'
text_embed = get_text_embedding(text_content)

print('dimention of text embedding:',text_embed.shape)

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.predictions.transform.LayerNorm.weight', 'cls.seq_relationship.weight', 'cls.predictions.decoder.weight', 'cls.predictions.transform.dense.bias', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


dimention of text embedding: torch.Size([768])


## 3.Joint multimodal embedding using linear projection

In [62]:
def linear_projection(img_embed1,text_embed,dim_align,dim_text,dim_img):
    linear_projection_text = nn.Linear(dim_text, dim_align)  
    linear_projection_img = nn.Linear(dim_img, dim_align) 
    # linear projection
    text_embed_aligned = linear_projection_text(text_embed)  
    img_embed_aligned_1 = linear_projection_img(img_embed1) 
    # normalize
    text_embed_aligned = text_embed_aligned / text_embed_aligned.norm(dim=-1, keepdim=True)
    img_embed_aligned_1 = img_embed_aligned_1 / img_embed_aligned_1.norm(dim=-1, keepdim=True)
    return img_embed_aligned_1,text_embed_aligned

In [65]:
dim_align = 512
dim_text = 768
dim_img = 2048

img_embed_aligned_1,text_embed_aligned = linear_projection(img_embed1,text_embed,dim_align,dim_text,dim_img)

print('dimention of text embedding:',text_embed_aligned.shape)
print('dimention of image embedding:',img_embed_aligned_1.shape)

dimention of text embedding: torch.Size([512])
dimention of image embedding: torch.Size([512])


## 4. Cosine similarity calculation

In [66]:
# calculate similarity between image and test
def align_embeddings(text_embedding, image_embedding):
    similarity = torch.nn.functional.cosine_similarity(text_embedding, image_embedding, dim=0)
    return similarity.item()

In [67]:
align_embeddings(text_embed_aligned, img_embed_aligned_1)

0.031788431107997894

## 5. Use prompt to select image

In [76]:
text_content = 'a photo of digit 2'
text_embed = get_text_embedding(text_content)

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.predictions.transform.LayerNorm.weight', 'cls.seq_relationship.weight', 'cls.predictions.decoder.weight', 'cls.predictions.transform.dense.bias', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [80]:
img_path1 = '1.png'
img_embed1 = get_image_embedding(img_path1)
img_embed_aligned_1,text_embed_aligned = linear_projection(img_embed1,text_embed,dim_align,dim_text,dim_img)
align_embeddings(text_embed_aligned, img_embed_aligned_1)

-0.051333289593458176

In [82]:
img_path2 = '2.png'
img_embed2 = get_image_embedding(img_path2)
img_embed_aligned_2,text_embed_aligned = linear_projection(img_embed2,text_embed,dim_align,dim_text,dim_img)
align_embeddings(text_embed_aligned, img_embed_aligned_2) # similarity

0.005957046058028936

Since '2.png' has higher similarity with the text 'a photo of digit 2' than '2.png' , we can match '2.png' to 'a photo of digit 2'.

*Note that in this project, the weights of text embedding model and image embedding model are not trained to be aligned in sematic, thus this is just a  simple implementation of key idea of CLIP and  the result is not reliable.*