In [None]:
import torch
import torch.nn as nn
from transformers import ViTConfig, ViTModel

In [None]:
class CLIPImageEncoder(nn.Module):
    def __init__(
        self,
        image_size=224,
        patch_size=32,
        hidden_size=512,
        num_hidden_layers=12,
        num_attention_heads=8,
        projection_dim=512
    ):
        super().__init__()
        
        # Configure ViT
        self.config = ViTConfig(
            image_size=image_size,
            patch_size=patch_size,
            hidden_size=hidden_size,
            num_hidden_layers=num_hidden_layers,
            num_attention_heads=num_attention_heads,
            num_channels=3,
            qkv_bias=True,
            layer_norm_eps=1e-6
        )
        
        # Initialize ViT backbone
        self.vit = ViTModel(self.config)
        
        # Projection layer
        self.projection = nn.Sequential(
            nn.LayerNorm(hidden_size),
            nn.Linear(hidden_size, projection_dim)
        )

    def forward(self, pixel_values):
        # Get ViT outputs
        outputs = self.vit(pixel_values)
        pooled_output = outputs.pooler_output
        
        # Project to final dimension
        projected = self.projection(pooled_output)
        
        # Normalize embeddings
        image_features = projected / projected.norm(dim=-1, keepdim=True)
        
        return image_features

In [None]:
test_model = CLIPImageEncoder()

batch_size = 32
dummy_input = torch.randn(batch_size, 3, 224, 224)

output = test_model(dummy_input)
output.size()