## CLIP Embedding
- The Clip Embedding takes a bunch of tokens (the text prompt input) and adds context to them
- The CLIP module is composed of the **CLIPEmbedding** and **CLIPlayer** modules for the job

In [None]:
import math
import torch
from torch import nn
from torch.nn import functional as F

### The Self-Attention module is an essential component of the CLIP Embedding module

In [None]:
class SelfAttention(nn.Module):
	def __init__(self, n_heads, d_embed, in_proj_bias = True, out_proj_bias = True):
		"""
		Param n_heads: the number of heads in the attention block
		Param d_embed: the embedding dimension of the token, i.e. the length of the vector for each token
		Param in_proj_bias
		Param out_proj_bias
		"""
		super().__init__()
		self.in_proj = nn.Linear(d_embed, 3 * d_embed, bias = in_proj_bias)
		self.out_proj = nn.Linear(d_embed, d_embed, bias = out_proj_bias)
		self.n_heads = n_heads
		self.d_head = d_embed // n_heads

	def forward(self, x, causal_mask = False):
		input_shape = x.shape
		batch_size, sequence_length, d_embed = input_shape
		interim_shape = (batch_size, sequence_length, self.n_heads, self.d_head)

		q, k, v = self.in_proj(x).chunk(3, dim = -1)
		q = q.view(interim_shape).transpose(1, 2)
		k = k.view(interim_shape).transpose(1, 2)
		v = v.view(interim_shape).transpose(1, 2)

		weight = q @ k.transpose(-1, -2)
		if causal_mask:
			mask = torch.ones_like(weight, dtype = torch.bool).triu(1)
			weight.masked_fill_(mask, -torch.inf)
		weight /= math.sqrt(self.d_head)
		weight = F.softmax(weight, dim = -1)

		output = weight @ v
		output = output.transpose(1, 2)
		output = output.reshape(input_shape)
		output = self.out_proj(output)
		return output

### The CLIPEmbedding module converts the tokens into tensors with embeddings and adds position values

In [None]:
class CLIPEmbedding(nn.Module):
	def __init__(self, n_vocab: int, n_embed: int, n_tokens: int):
		super().__init__()
		self.token_embedding = nn.Embedding(n_vocab, n_embed)
		self.position_value = nn.Parameter(torch.zeros(n_tokens, n_embed))

	def forward(self, tokens):
		x = self.token_embedding(tokens)
		x += self.position_value
		return x

### CLIPlayer module uses the Self-Attention module to add context

In [None]:
class CLIPlayer(nn.Module):
    def __init__(self, n_heads: int, n_embed: int):
        super().__init__()
        self.layernorm_1 = nn.LayerNorm(n_embed)
        self.attention = SelfAttention(n_heads, n_embed)
        self.layernorm_2 = nn.LayerNorm(n_embed)
        self.linear_1 = nn.Linear(n_embed, 4 * n_embed)
        self.linear_2 = nn.Linear(4 * n_embed, n_embed)

    def forward(self, x):
        x = self.layernorm_1(x)
        x = self.attention(x, causal_mask = True)
        residue = x
        x = self.layernorm_2(x)
        x = self.linear_1(x)
        x = x * torch.sigmoid(1.702 * x)
        x = self.linear_2(x)
        x += residue
        return x

In [None]:
class CLIP(nn.Module):
	def __init__(self):
		super().__init__()
		self.embedding = CLIPEmbedding(49408, 768, 77)
		self.layers = nn.ModuleList([
			CLIPlayer(12, 768) for i in range(12)
			])
		self.layernorm = nn.LayerNorm(768)

	def forward(self, tokens: torch.LongTensor) -> torch.FloatTensor:
		tokens = tokens.type(torch.long)
		state = self.embedding(tokens)
		for layer in self.layers:
			state = layer(state)
		output = self.layernorm(state)
		return output

#### The Embeddings for the CLIP Module, i.e. the Embedding module named CLIPEmbedding

In [None]:
## Create an Embedding module to convert the tokens into embedding tensors
## The layers are as follows:
## Here the vocabulary size is 49408, i.e. the number of words in the dictionary
## The size of the embedding vector, i.e. the size by which each token is represented is 768
## The number of tokens, i.e. the number of words in the query is fixed at 77 (random number)
token_embedding = nn.Embedding(49408, 768)
position_value = nn.Parameter(torch.zeros(77, 768))
## Consider a set of input tokens
input_tokens = torch.randint(10, 100, (1, 77))
## First, obtain the embeddings for these tokens and then input the positions for the embeddings to derive context
token_embeddings = token_embedding(input_tokens)
assert token_embeddings.shape == torch.Size([1, 77, 768])
token_embeddings += position_value
assert token_embeddings.shape == torch.Size([1, 77, 768])

#### Next, the CLIP Layers to derive context

In [None]:
sample_layer = CLIPlayer(12, 768)
state = sample_layer(token_embeddings)
assert state.shape == torch.Size([1, 77, 768])

#### Finally a layer normalization to enable robust learnings

In [None]:
layer_norm = nn.LayerNorm(768)
context = layer_norm(state)
assert context.shape == torch.Size([1, 77, 768])