# CLIP Loss

Author: xiaodongguaAIGC

our target is to make Pytorch-implemention about CLIP, we could ref follow code, we simplfied image_encoder & text_encoder, take more attention about CLIP loss

CLIP paper : [Learning Transferable Visual Models From Natural Language Supervision](https://arxiv.org/pdf/2103.00020)

![clip](./images/clip.png)

peseudocode Numpy-Like Clip loss implemention

```python
# image_encoder - ResNet or Vision Transformer
# text_encoder - CBOW or Text Transformer
# I[n, h, w, c] - minibatch of aligned images
# T[n, l] - minibatch of aligned texts
# W_i[d_i, d_e] - learned proj of image to embed
# W_t[d_t, d_e] - learned proj of text to embed
# t - learned temperature parameter
# extract feature representations of each modality
I_f = image_encoder(I) #[n, d_i]
T_f = text_encoder(T) #[n, d_t]
# joint multimodal embedding [n, d_e]
I_e = l2_normalize(np.dot(I_f, W_i), axis=1)
T_e = l2_normalize(np.dot(T_f, W_t), axis=1)
# scaled pairwise cosine similarities [n, n]
logits = np.dot(I_e, T_e.T) * np.exp(t)
# symmetric loss function
labels = np.arange(n)
loss_i = cross_entropy_loss(logits, labels, axis=0)
loss_t = cross_entropyls, axis=1)
loss   = (loss_i + loss_t)/2

```loss_i + loss_t)/2

# config 

In [48]:
import torch
import torch.nn as nn
import torch.nn.functional as F

# image
batch_size = 8
height = 2
width = 2
chanel = 3
d_i = 4
d_e = 5

# text
seq_len = 2 
d_t = 3
vocab_size = 100
dim = 512

# Image encoder

In [49]:
class ImageEncoder(nn.Module):
    def __init__(self, height, width, chanel, d_i):
        super().__init__()
        self.input_dim = height*width*chanel
        self.output_dim = d_i
        self.encoder = nn.Linear(self.input_dim, self.output_dim, bias=False) 
    def forward(self, x):
        x_flat = x.flatten(1) # batchsize, c, h, w -> batchsize, c*h*w
        y = self.encoder(x_flat)
        return y

I = torch.randn(batch_size, chanel, height, width)
print(I.flatten(1).shape)

image_encoder = ImageEncoder(height, width, chanel, d_i)
print(image_encoder)
I_f = image_encoder(I)
print(I.shape)
print(I_f.shape)

# Text encoder

In [50]:
# Text encoder is CASUAL-LANGUAGE-Modeling， 
# Attention mask is tril be like
# 1 0 0
# 1 1 0
# 1 1 1
# and we have 3 token “hello world <EOS>” 
# <EOS> output logits as text encoder feature
class TextEncoder(nn.Module):
    def __init__(self, vocab_size, dim, d_t):
        super().__init__()
        self.dim = dim
        self.d_t = d_t
        self.embedding = nn.Embedding(vocab_size, self.dim)
        self.encoder = nn.Linear(self.dim, self.dim, bias=False)
        self.output = nn.Linear(self.dim, self.d_t, bias=False) 
    def forward(self, x):
        # x_flat = x.flatten(1) # batchsize, c, h, w -> batchsize, c*h*w
        x_embd = self.embedding(x)
        y = self.encoder(x_embd)
        # 文本里取 "[cls] token1 token2 token3" -> token3 对应的特征向量
        out = self.output(y)[:,-1,:] # <EOS> output logits as text encoder feature
        return out

I = torch.randn(batch_size, chanel, height, width)
# print(I.flatten(1).shape)
T = torch.randint(low=0, high=vocab_size, 
                  size=(batch_size, seq_len),
                  dtype=torch.int)

text_encoder = TextEncoder(vocab_size, dim, d_t)
print(text_encoder)
T_f = text_encoder(T) # if cls token 
print(T.shape)
print(T_f.shape)

# Clip loss

In [65]:
import torch.nn.functional as F
class CLIP(nn.Module):
    def __init__(self, d_i, d_t, d_e):
        super().__init__()
        self.W_i_e = nn.Linear(d_i, d_e, bias=False)
        self.W_t_e = nn.Linear(d_t, d_e, bias=False)
        self.temparture = nn.Parameter(torch.ones(1))
        # self.softmax_i = nn.Softmax(dim=0)
        # self.softmax_t = nn.Softmax(dim=1)
        self.loss_fn = nn.CrossEntropyLoss()
    
    def forward(self, I_f, T_f, loss_type='basic'):
        n, _ = I_f.size()
        
        I_e = self.W_i_e(I_f) # image_embedding
        T_e = self.W_t_e(T_f) # text_embedding

        # I_e = F.normalize(I_e, p=2, dim=1)
        # T_e = F.normalize(T_e, p=2, dim=1)
        I_e = I_e / I_e.norm(p=2, dim=-1, keepdim=True)
        T_e = T_e / T_e.norm(p=2, dim=-1, keepdim=True)

        logits = I_e @ T_e.transpose(1,0)
        logits = logits * torch.exp(self.temparture)

        labels = torch.arange(n)

        loss_i = self.loss_fn(logits, labels)
        loss_t = self.loss_fn(logits.transpose(1,0), labels)
        
        loss = loss_i + loss_t
        
        return {
            'image_embedding': I_e,
            'text_embedding': T_e,
            'logits': logits,
            'loss' : loss,
            'loss_i' : loss_i,
            'loss_t' : loss_t,
        }

print(I_f.shape)
print(T_f.shape)

clip = CLIP(d_i, d_t, d_e)
print(clip)
output = clip(I_f, T_f)

print('output feature loss: ', output['loss'])
print('output feature loss_i: ', output['loss_i'])
print('output feature loss_t: ', output['loss_t'])

print('output feature image_embedding: ', output['image_embedding'].shape)
print('output feature text_embedding: ', output['text_embedding'].shape)
print('output feature logits: ', output['logits'].shape)

# CLIP Loss Pipeline

In [64]:
# step 1: create image & text data
I = torch.randn(batch_size, chanel, height, width)
T = torch.randint(low=0, high=vocab_size, 
                  size=(batch_size, seq_len),
                  dtype=torch.int)

# step 2: create image encoder, text encoder, clip modeling
image_encoder = ImageEncoder(height, width, chanel, d_i)
text_encoder = TextEncoder(vocab_size, dim, d_t)
clip = CLIP(d_i, d_t, d_e)

# step 3: compute loss
I_f = image_encoder(I) # this is image represention
T_f = text_encoder(T)  # this is text represention
output = clip(I_f, T_f)

# step 4: update clip-> parameters of "image encoder or text encoder" 
output['loss'] 
# output['loss'].backward()

tensor(8.1138, grad_fn=<AddBackward0>)