# CLIP Pytorh

reference: [Transfromers](https://github.com/huggingface/transformers)

Notebook Author: [xiaodongguaAIGC](https://github.com/dhcode-cpp)

Author: xiaodongguaAIGC

github: dhcode-cpp

gmail: dhcode95@gmail.com

![CLIP](./images/clip.png)

In [325]:
import collections.abc
from collections import OrderedDict
import math
from typing import Dict, List, Optional, Set, Tuple, Union

import torch
import torch.utils.checkpoint
from torch import nn
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss

# config 

In [326]:
class CLIPTextConfig():
    model_type = "clip_text_model"
    def __init__(
        self,
        vocab_size=49408,
        hidden_size=512,
        intermediate_size=2048,
        projection_dim=512,
        num_hidden_layers=12,
        num_attention_heads=8,
        max_position_embeddings=77,
        # hidden_act="quick_gelu",
        hidden_act="gelu",
        layer_norm_eps=1e-5,
        attention_dropout=0.0,
        initializer_range=0.02,
        initializer_factor=1.0,
        # This differs from `CLIPTokenizer`'s default and from openai/clip
        # See https://github.com/huggingface/transformers/pull/24773#issuecomment-1632287538
        pad_token_id=1,
        bos_token_id=49406,
        eos_token_id=49407,
        **kwargs,
    ):
        # super().__init__(pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs)

        self.vocab_size = vocab_size
        self.hidden_size = hidden_size
        self.intermediate_size = intermediate_size
        self.projection_dim = projection_dim
        self.num_hidden_layers = num_hidden_layers
        self.num_attention_heads = num_attention_heads
        self.max_position_embeddings = max_position_embeddings
        self.layer_norm_eps = layer_norm_eps
        self.hidden_act = hidden_act
        self.initializer_range = initializer_range
        self.initializer_factor = initializer_factor
        self.attention_dropout = attention_dropout
        self.pad_token_id = pad_token_id
        self.bos_token_id = bos_token_id
        self.eos_token_id = eos_token_id


class CLIPVisionConfig():
    model_type = "clip_vision_model"

    def __init__(
        self,
        hidden_size=768,
        intermediate_size=3072,
        projection_dim=512,
        num_hidden_layers=12,
        num_attention_heads=12,
        num_channels=3,
        image_size=224,
        patch_size=32,
        # hidden_act="quick_gelu",
        hidden_act="gelu",
        layer_norm_eps=1e-5,
        attention_dropout=0.0,
        initializer_range=0.02,
        initializer_factor=1.0,
        **kwargs,
    ):
        super().__init__(**kwargs)

        self.hidden_size = hidden_size
        self.intermediate_size = intermediate_size
        self.projection_dim = projection_dim
        self.num_hidden_layers = num_hidden_layers
        self.num_attention_heads = num_attention_heads
        self.num_channels = num_channels
        self.patch_size = patch_size
        self.image_size = image_size
        self.initializer_range = initializer_range
        self.initializer_factor = initializer_factor
        self.attention_dropout = attention_dropout
        self.layer_norm_eps = layer_norm_eps
        self.hidden_act = hidden_act

class CLIPConfig():
    model_type = "clip"

    def __init__(
        self, text_config=None, vision_config=None, projection_dim=512, logit_scale_init_value=2.6592, **kwargs
    ):
        super().__init__()
        self.text_config = text_config
        self.vision_config = vision_config

        self.projection_dim = projection_dim
        self.logit_scale_init_value = logit_scale_init_value
        self.initializer_factor = 1.0

        self.gradient_checkpointing = True

In [327]:
hidden_size = 128
intermediate_size = hidden_size*4
projection_dim = hidden_size
num_hidden_layers = 2
num_attention_heads = 4
num_channels = 3
image_size = 224
patch_size = 32

vocab_size = 100
pad_token_id = 0
bos_token_id = 1
eos_token_id = 2

batch_size = 2

clip_projection_dim = 512

text_config = CLIPTextConfig(hidden_size = 128,
                intermediate_size = hidden_size*4,
                projection_dim = hidden_size,
                num_hidden_layers = 2,
                num_attention_heads = 4,
                num_channels = 3,
                pad_token_id = 0,
                bos_token_id = 1,
                eos_token_id = 2,
                )

vision_config = CLIPVisionConfig(hidden_size = 128,
                intermediate_size = hidden_size*4,
                projection_dim = hidden_size,
                num_hidden_layers = 2,
                num_attention_heads = 4,
                num_channels = 3,
                image_size = 224,
                patch_size = 32)

config = CLIPConfig(text_config = text_config,
                    vision_config = vision_config,
                    projection_dim = clip_projection_dim)
print(config.__dict__)
print(config.text_config.__dict__)
print(config.vision_config.__dict__)

# Clip Model

对于两个encoder，通用transformer 编码器结构，可以复用

## CLIP: Transformer

In [328]:
import math
class CLIPAttention(nn.Module):
    """Multi-headed attention from 'Attention Is All You Need' paper"""

    def __init__(self, config):
        super().__init__()
        self.config = config
        self.embed_dim = config.hidden_size
        self.num_heads = config.num_attention_heads
        self.head_dim = self.embed_dim // self.num_heads
        self.dropout = config.attention_dropout

        self.k_proj = nn.Linear(self.embed_dim, self.embed_dim)
        self.v_proj = nn.Linear(self.embed_dim, self.embed_dim)
        self.q_proj = nn.Linear(self.embed_dim, self.embed_dim)
        self.out_proj = nn.Linear(self.embed_dim, self.embed_dim)

    def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
        return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()

    def forward(
        self,
        hidden_states: torch.Tensor,
        attention_mask: Optional[torch.Tensor] = None,
        # causal_attention_mask: Optional[torch.Tensor] = None,
        # output_attentions: Optional[bool] = False,
    ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
        """Input shape: Batch x Time x Channel"""

        bsz, tgt_len, embed_dim = hidden_states.size()
        self.scale = 1.0 / math.sqrt(embed_dim)

        # get query proj
        query_states = self.q_proj(hidden_states) * self.scale
        key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
        value_states = self._shape(self.v_proj(hidden_states), -1, bsz)

        # split multi-head
        proj_shape = (bsz * self.num_heads, -1, self.head_dim)
        query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape)
        key_states = key_states.view(*proj_shape)
        value_states = value_states.view(*proj_shape)

        # scaled dot product attention
        src_len = key_states.size(1)
        attn_weights = torch.bmm(query_states, key_states.transpose(1, 2))
        if attention_mask is not None:
            attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask
            attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)

        attn_weights = nn.functional.softmax(attn_weights, dim=-1)

        attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
        attn_output = torch.bmm(attn_probs, value_states)

        # output proj
        attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim)
        attn_output = attn_output.transpose(1, 2)
        attn_output = attn_output.reshape(bsz, tgt_len, embed_dim)
        attn_output = self.out_proj(attn_output)

        attn_weights_reshaped = None
        return attn_output, attn_weights_reshaped

In [329]:
batch_size = 2
seq_len = 10

clip_attention = CLIPAttention(text_config)
print(clip_attention)
x_embd = torch.randn(batch_size, seq_len, hidden_size)
x_attn = clip_attention(x_embd)
print(x_attn[0].shape)

In [330]:
class GELUActivation(nn.Module):
    """
    Original Implementation of the GELU activation function in Google BERT repo when initially created. For
    information: OpenAI GPT's GELU is slightly different (and gives slightly different results): 0.5 * x * (1 +
    torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3)))) This is now written in C in nn.functional
    Also see the Gaussian Error Linear Units paper: https://arxiv.org/abs/1606.08415
    """

    def __init__(self, use_gelu_python: bool = False):
        super().__init__()
        if use_gelu_python:
            self.act = self._gelu_python
        else:
            self.act = nn.functional.gelu

    def _gelu_python(self, input: torch.Tensor) -> torch.Tensor:
        return input * 0.5 * (1.0 + torch.erf(input / math.sqrt(2.0)))

    def forward(self, input: torch.Tensor) -> torch.Tensor:
        return self.act(input)
gelu = GELUActivation()
dummy_tensor = torch.randn(1,2,3)
print(dummy_tensor)
activate_tensor = gelu(dummy_tensor)
print(activate_tensor)

class ClassInstantier(OrderedDict):
    def __getitem__(self, key):
        content = super().__getitem__(key)
        cls, kwargs = content if isinstance(content, tuple) else (content, {})
        return cls(**kwargs)
ACT2CLS = {
    "gelu": GELUActivation,
    "relu": nn.ReLU,
    "sigmoid": nn.Sigmoid,
    "silu": nn.SiLU,
    "swish": nn.SiLU,
}
ACT2FN = ClassInstantier(ACT2CLS)

class CLIPMLP(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config
        self.activation_fn = ACT2FN[config.hidden_act]
        self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size)
        self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size)

    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
        hidden_states = self.fc1(hidden_states)
        hidden_states = self.activation_fn(hidden_states)
        hidden_states = self.fc2(hidden_states)
        return hidden_states

In [331]:
clip_mlp = CLIPMLP(text_config)
print(clip_mlp)
config.hidden_act = 'gelu'
x_mlp = clip_mlp(x_attn[0])
print(x_mlp.shape)

In [332]:
class CLIPEncoderLayer(nn.Module):
    def __init__(self, config: CLIPConfig):
        super().__init__()
        self.embed_dim = config.hidden_size
        self.self_attn = CLIPAttention(config)
        self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
        self.mlp = CLIPMLP(config)
        self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)

    def forward(
        self,
        x: torch.Tensor,
        attention_mask: torch.Tensor,
    ) -> torch.FloatTensor:
        _x = x

        x = self.layer_norm1(x) # prenorm
        x, _ = self.self_attn(
            hidden_states=x,
            attention_mask=attention_mask,
            # causal_attention_mask=causal_attention_mask,
        )
        x = _x + x

        _x = x
        x = self.layer_norm2(x)
        x = self.mlp(x)
        x = _x + x

        return x

In [333]:
clip_encoder = CLIPEncoderLayer(text_config)
print(clip_encoder)
config.hidden_act = 'gelu'
x_encoder = clip_mlp(x_embd)
print(x_encoder.shape)

In [334]:
class CLIPEncoder(nn.Module):
    def __init__(self, config: CLIPConfig):
        super().__init__()
        self.config = config
        self.layers = nn.ModuleList([CLIPEncoderLayer(config) for _ in range(config.num_hidden_layers)])
        self.gradient_checkpointing = False

    def forward(
        self,
        x,
        attention_mask: Optional[torch.Tensor] = None,
    ) -> torch.Tensor:
        x = x
        for idx, encoder_layer in enumerate(self.layers):
                x = encoder_layer(
                    x,
                    attention_mask,
                )
        x = x
        return x

In [335]:
clip_encoders = CLIPEncoder(text_config)
print(clip_encoders.layers[0])
print(len(clip_encoders.layers))
print(config.text_config.num_hidden_layers)

In [336]:
encoder = CLIPEncoder(config.vision_config)
print(encoder)
encoder = CLIPEncoder(config.text_config)
print(encoder)

# CLIP: Text Encoder Model

In [337]:
class CLIPTextEmbeddings(nn.Module):
    def __init__(self, config: CLIPTextConfig):
        super().__init__()
        embed_dim = config.hidden_size

        self.token_embedding = nn.Embedding(config.vocab_size, embed_dim)
        self.position_embedding = nn.Embedding(config.max_position_embeddings, embed_dim)
        self.register_buffer(
            "position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False
        )

    def forward(
        self,
        input_ids: Optional[torch.LongTensor] = None,
    ) -> torch.FloatTensor:
        seq_length = input_ids.shape[-1]
        position_ids = self.position_ids[:, :seq_length]
        inputs_embeds = self.token_embedding(input_ids)
        position_embeddings = self.position_embedding(position_ids)
        embeddings = inputs_embeds + position_embeddings

        return embeddings

clip_embedding = CLIPTextEmbeddings(config.text_config)
text_src = torch.randint(low=0, high=text_config.vocab_size-1, 
                          size=(batch_size, seq_len),
                          dtype=torch.int)
text_embedding = clip_embedding(text_src)
print(text_embedding.shape)

In [351]:
# Clip text encoder 用的是因果建模mask
# 所以取eos token的向量，作为text encoder
# 对于BERT类模型，会有[CLS]token 来作为 text encoder 向量

class CLIPTextTransformer(nn.Module):
    def __init__(self, config: CLIPTextConfig):
        super().__init__()
        self.config = config
        embed_dim = config.hidden_size
        self.embeddings = CLIPTextEmbeddings(config)
        self.encoder = CLIPEncoder(config)
        self.final_layer_norm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
        self.eos_token_id = config.eos_token_id

    def forward(
        self,
        input_ids: Optional[torch.Tensor] = None,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.Tensor] = None,
    ) :

        input_shape = input_ids.size()
        input_ids = input_ids.view(-1, input_shape[-1])

        x = self.embeddings(input_ids=input_ids,)

        # attention_mask = _prepare_4d_attention_mask(attention_mask, hidden_states.dtype)
        causal_attention_mask = torch.triu(torch.ones(seq_len, seq_len), diagonal=1) * -torch.inf
        causal_attention_mask = causal_attention_mask.unsqueeze(0).unsqueeze(0)

        x = self.encoder(
            x,
            attention_mask=attention_mask,
        )

        last_hidden_state = x
        # print(last_hidden_state.shape)
        pooler_output = self.final_layer_norm(last_hidden_state[:,-1,:]) # poolerout

        # if self.eos_token_id == 2:
        #     # The `eos_token_id` was incorrect before PR #24773: Let's keep what have been done here.
        #     # A CLIP model with such `eos_token_id` in the config can't work correctly with extra new tokens added
        #     # ------------------------------------------------------------
        #     # text_embeds.shape = [batch_size, sequence_length, transformer.width]
        #     # left padding的情况下，最右边的token是eos-token
        #     # take features from the eot embedding (eot_token is the highest number in each sequence) 
        #     # casting to torch.int for onnx compatibility: argmax doesn't support int64 inputs with opset 14
        #     pooled_output = last_hidden_state[
        #         torch.arange(last_hidden_state.shape[0], device=last_hidden_state.device),
        #         input_ids.to(dtype=torch.int, device=last_hidden_state.device).argmax(dim=-1),
        #     ]
        # else:
        #     # The config gets updated `eos_token_id` from PR #24773 (so the use of exta new tokens is possible)
        #     pooled_output = last_hidden_state[
        #         torch.arange(last_hidden_state.shape[0], device=last_hidden_state.device),
        #         # We need to get the first position of `eos_token_id` value (`pad_token_ids` might equal to `eos_token_id`)
        #         # Note: we assume each sequence (along batch dim.) contains an  `eos_token_id` (e.g. prepared by the tokenizer)
        #         (input_ids.to(dtype=torch.int, device=last_hidden_state.device) == self.eos_token_id)
        #         .int()
        #         .argmax(dim=-1),
        #     ]

        return {
            'last_hidden_state': last_hidden_state,
            'pooler_output': pooler_output,
        }
        

In [352]:
print(config.text_config.eos_token_id)

In [353]:
print(config.text_config.eos_token_id)
clip_transformers = CLIPTextTransformer(config.text_config)
text_src = torch.randint(low=0, high=text_config.vocab_size-1, 
                          size=(batch_size, seq_len),
                          dtype=torch.int)

result = clip_transformers(text_src)
print(result['last_hidden_state'].shape)
print(result['pooler_output'].shape)

In [355]:
class CLIPTextModel(nn.Module):
    def __init__(self, config: CLIPTextConfig):
        super().__init__()
        self.text_model = CLIPTextTransformer(config)
        # self.post_init()

    def forward(
        self,
        input_ids: Optional[torch.Tensor] = None,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.Tensor] = None,
    ) -> torch.FloatTensor:
        
        return self.text_model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            position_ids=position_ids,
        )['pooler_output']
        
# print(config.text_config.eos_token_id)
clip_text_model = CLIPTextModel(config.text_config)
text_src = torch.randint(low=0, high=text_config.vocab_size-1, 
                          size=(batch_size, seq_len),
                          dtype=torch.int)

text_hidden = clip_text_model(text_src)
print(text_hidden.shape)

## CLIP Vision Encoder

In [357]:
class CLIPVisionEmbeddings(nn.Module):
    def __init__(self, config: CLIPVisionConfig):
        super().__init__()
        self.config = config
        self.embed_dim = config.hidden_size
        self.image_size = config.image_size
        self.patch_size = config.patch_size

        self.class_embedding = nn.Parameter(torch.randn(self.embed_dim))

        self.patch_embedding = nn.Conv2d(
            in_channels=config.num_channels,
            out_channels=self.embed_dim,
            kernel_size=self.patch_size,
            stride=self.patch_size,
            bias=False,
        )

        self.num_patches = (self.image_size // self.patch_size) ** 2
        self.num_positions = self.num_patches + 1
        self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim)
        self.register_buffer("position_ids", torch.arange(self.num_positions).expand((1, -1)), persistent=False)

    def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor:
        batch_size = pixel_values.shape[0]
        target_dtype = self.patch_embedding.weight.dtype
        patch_embeds = self.patch_embedding(pixel_values.to(dtype=target_dtype))  # shape = [*, width, grid, grid]
        patch_embeds = patch_embeds.flatten(2).transpose(1, 2)

        class_embeds = self.class_embedding.expand(batch_size, 1, -1)
        embeddings = torch.cat([class_embeds, patch_embeds], dim=1)
        embeddings = embeddings + self.position_embedding(self.position_ids)
        return embeddings


class CLIPVisionTransformer(nn.Module):
    def __init__(self, config: CLIPVisionConfig):
        super().__init__()
        self.config = config
        embed_dim = config.hidden_size

        self.embeddings = CLIPVisionEmbeddings(config)
        self.pre_layrnorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
        self.encoder = CLIPEncoder(config)
        self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
        
    def forward(
        self,
        pixel_values: Optional[torch.FloatTensor] = None,
    ) :
        hidden_states = self.embeddings(pixel_values)
        hidden_states = self.pre_layrnorm(hidden_states)
        encoder_outputs = self.encoder(
            hidden_states,
        )
        last_hidden_state = encoder_outputs
        pooled_output = last_hidden_state[:, 0, :] # 第一个token为clstoken
        pooled_output = self.post_layernorm(pooled_output)

        return { 
            'last_hidden_state':last_hidden_state,
            'pooler_output':pooled_output,
        }

class CLIPVisionModel(nn.Module):
    config_class = CLIPVisionConfig
    main_input_name = "pixel_values"
    _no_split_modules = ["CLIPEncoderLayer"]

    def __init__(self, config: CLIPVisionConfig):
        super().__init__()
        self.vision_model = CLIPVisionTransformer(config)

    def forward(
        self,
        pixel_values: Optional[torch.FloatTensor] = None,
    ) -> torch.FloatTensor :

        return self.vision_model(
            pixel_values=pixel_values,
        )['pooler_output']


In [358]:
x_img = torch.randn(batch_size, num_channels, image_size, image_size)
clip_vision_model = CLIPVisionModel(vision_config)
x_vision_output = clip_vision_model(x_img)
print(x_vision_output.shape)

## CLIP Model

In [359]:
def contrastive_loss(logits: torch.Tensor) -> torch.Tensor:
    return nn.functional.cross_entropy(logits, torch.arange(len(logits), device=logits.device))


def clip_loss(similarity: torch.Tensor) -> torch.Tensor:
    caption_loss = contrastive_loss(similarity)
    image_loss = contrastive_loss(similarity.t())
    return (caption_loss + image_loss) / 2.0

In [378]:
class CLIPModel(nn.Module):
    config_class = CLIPConfig
    # _no_split_modules = ["CLIPTextEmbeddings", "CLIPEncoderLayer", "CLIPVisionEmbeddings"]

    def __init__(self, config: CLIPConfig):
        super().__init__()
        self.config = config

        text_config = config.text_config
        vision_config = config.vision_config

        self.projection_dim = config.projection_dim
        self.text_embed_dim = text_config.hidden_size
        self.vision_embed_dim = vision_config.hidden_size

        text_model = CLIPTextModel(text_config)
        self.text_model = text_model.text_model

        vision_model = CLIPVisionModel(vision_config)
        self.vision_model = vision_model.vision_model

        self.visual_projection = nn.Linear(self.vision_embed_dim, self.projection_dim, bias=False)
        self.text_projection = nn.Linear(self.text_embed_dim, self.projection_dim, bias=False)
        self.logit_scale = nn.Parameter(torch.tensor(self.config.logit_scale_init_value))

        # Initialize weights and apply final processing
        # self.post_init()

    # @add_start_docstrings_to_model_forward(CLIP_TEXT_INPUTS_DOCSTRING)
    def get_text_features(
        self,
        input_ids: Optional[torch.Tensor] = None,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.Tensor] = None,
    ) -> torch.FloatTensor:
        text_outputs = self.text_model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            position_ids=position_ids,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )

        pooled_output = text_outputs[1]
        text_features = self.text_projection(pooled_output)

        return text_features

    
    def get_image_features(
        self,
        pixel_values: Optional[torch.FloatTensor] = None,
    ) -> torch.FloatTensor:


        vision_outputs = self.vision_model(
            pixel_values=pixel_values,
        )

        pooled_output = vision_outputs[1]  # pooled_output
        image_features = self.visual_projection(pooled_output)

        return image_features

        

    def forward(
        self,
        input_ids: Optional[torch.LongTensor] = None,
        pixel_values: Optional[torch.FloatTensor] = None,
        attention_mask: Optional[torch.Tensor] = None,
    ) :

        vision_outputs = self.vision_model(
            pixel_values=pixel_values,
        )
        text_outputs = self.text_model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            # position_ids=position_ids,
        )

        image_embeds = vision_outputs['pooler_output']
        image_embeds = self.visual_projection(image_embeds)

        text_embeds = text_outputs['pooler_output']
        text_embeds = self.text_projection(text_embeds)


        # normalized features
        image_embeds = image_embeds / image_embeds.norm(p=2, dim=-1, keepdim=True)
        text_embeds = text_embeds / text_embeds.norm(p=2, dim=-1, keepdim=True)
        print(text_embeds.shape)
        print(image_embeds.shape)

        # cosine similarity as logits
        logit_scale = self.logit_scale.exp()
        logits_per_text = torch.matmul(text_embeds, image_embeds.t().to(text_embeds.device)) * logit_scale.to(
            text_embeds.device
        )
        logits_per_image = logits_per_text.t()

        loss = None
        # if return_loss:

        # print(logits_per_text.shape)
        # print(logits_per_image.shape)
        loss = clip_loss(logits_per_text)

        # if not return_dict:
        #     output = (logits_per_image, logits_per_text, text_embeds, image_embeds, text_outputs, vision_outputs)
        #     return ((loss,) + output) if loss is not None else output

        return {
            'loss':loss,
            'logits_per_image':logits_per_image,
            'logits_per_text':logits_per_text,
            'text_embeds':text_embeds,
            'image_embeds':image_embeds,
            # 'text_model_output':text_outputs,
            # 'vision_model_output':vision_outputs,
        }

In [379]:
clip_model = CLIPModel(config)
# print(clip_model)

In [381]:
# print(text_src.shape)
text_src = torch.randint(low=0, high=text_config.vocab_size-1, 
                          size=(batch_size, seq_len),
                          dtype=torch.int)
x_img = torch.randn(batch_size, num_channels, image_size, image_size)
clip_output = clip_model(text_src, x_img)
print(clip_output['loss'])
print(clip_output['image_embeds'].shape)
print(clip_output['text_embeds'].shape)

# Pipeline

### CLIP Stage1: Pretrained

In [382]:
batch_size = 8
text_src = torch.randint(low=0, high=text_config.vocab_size-1, 
                          size=(batch_size, seq_len),
                          dtype=torch.int)
img_src = torch.randn(batch_size, num_channels, image_size, image_size)
clip_output = clip_model(text_src, img_src)
print(clip_output['loss'])
print(clip_output['image_embeds'].shape)
print(clip_output['text_embeds'].shape)

#### modify clip-vision-classifier

In [383]:
### CLIP zero
class CLIPForImageClassification(nn.Module):
    main_input_name = "pixel_values"

    def __init__(self, config: CLIPConfig) -> None:
        super().__init__()

        self.num_labels = config.num_labels
        self.vision_model = CLIPVisionModel(
            config.vision_config
        )

        # Classifier head
        self.classifier = (
            nn.Linear(config.vision_config.hidden_size, config.num_labels) if config.num_labels > 0 else nn.Identity()
        )
    
    def forward(
        self,
        pixel_values: Optional[torch.Tensor] = None,
        labels: Optional[torch.Tensor] = None,
    ):

        outputs = self.vision_model(
            pixel_values,
        )

        print(outputs.shape)
        # sequence_output = outputs[0]
        

        # average pool the patch tokens
        # sequence_output = torch.mean(sequence_output[:, 1:, :], dim=1)
        # apply classifier
        logits = self.classifier(outputs)

        loss = None
        if labels is not None:
            loss_function=nn.CrossEntropyLoss()
            loss = loss_function(logits, labels)
            print(loss)
            
        return {
            'loss':loss,
            'logits':logits,
        }

In [384]:
# CLIPForImageClassification()

In [385]:
batch_size = 8
# text_src = torch.randint(low=0, high=text_config.vocab_size-1, 
#                           size=(batch_size, seq_len),
#                           dtype=torch.int)
config.num_labels = 10
img_src = torch.randn(batch_size, num_channels, image_size, image_size)
clip_vision_classifier = CLIPForImageClassification(config)

# labels = torch.randint(low=0, high=config.num_labels, 
#                           size=(1,batch_size),
#                           dtype=torch.long)
# labels = labels.view(-1)
print(logits.shape)
print(logits.shape)
labels = torch.empty(batch_size, dtype=torch.long).random_(config.num_labels)


result = clip_vision_classifier(img_src, labels=labels)
# print(result)
print(result['logits'].shape)
print(result['loss'])

## CLIP stage2: ZeroShot prediction

In [386]:
clip_model.vision_model
clip_model.text_model


one_img_src = torch.randn(1, num_channels, image_size, image_size)


seq_len = 5
config.num_labels = 10

text_prompt_src = torch.randint(low=0, high=text_config.vocab_size-1, 
                          size=(1, seq_len),
                          dtype=torch.int)

text_prompt_src = text_prompt_src.expand(config.num_labels, -1)
print(text_prompt_src)

class_object_token = torch.arange(config.num_labels).unsqueeze(-1)
print(class_object_token)


prompt = torch.cat( (text_prompt_src,class_object_token), 1 )
print(prompt)


In [388]:

# clip_output = clip_model(prompt, one_img_src)
# print(clip_output['loss'])
print(clip_output['image_embeds'].shape)
print(clip_output['text_embeds'].shape)


vision_outputs = clip_model.vision_model(
    pixel_values=one_img_src,
)['pooler_output']
text_outputs = clip_model.text_model(
    input_ids=prompt,
    # attention_mask=attention_mask,
)['pooler_output']

image_embeds = clip_model.visual_projection(vision_outputs)
text_embeds = clip_model.text_projection(text_outputs)
print(image_embeds.shape)
print(text_embeds.shape)

# normalized features
image_embeds = image_embeds / image_embeds.norm(p=2, dim=-1, keepdim=True)
text_embeds = text_embeds / text_embeds.norm(p=2, dim=-1, keepdim=True)

logits = image_embeds@text_embeds.transpose(1,0)
print(logits)
torch.argmax(logits, dim=-1)

tensor([1])

## Pretrained data prepare