<a href="https://colab.research.google.com/github/samitha278/VLM-gamma/blob/main/repro_siglip.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## SigLip

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Optional, Tuple

### SigLip Config

In [23]:
class SigLipConfig:

    def __init__(self,n_embd=768,n_hidden=3072,n_layer=12,n_head=12,n_channel=3,image_size=224,patch_size=16,ln_eps=1e-6,**kwargs):
        super().__init__()

        #model params
        self.n_embd= n_embd
        self.n_hidden = n_hidden
        self.n_layer = n_layer
        self.n_head = n_head
        self.n_channel = n_channel

        #data params
        self.image_size = image_size
        self.patch_size = patch_size
        self.n_patch = (image_size // patch_size)**2

        self.ln_eps = ln_eps


### SigLip ViT Embeddings

In [15]:
class Embeddings(nn.Module):

    def __init__(self,config: SigLipConfig):
        super().__init__()

        self.config = config

        self.patch_embds = nn.Conv2d(in_channels= config.n_channel, out_channels= config.n_embd, kernel_size=config.patch_size, stride= config.patch_size)

        self.pos_embds = nn.Embedding(config.n_patch,embedding_dim=config.n_embd)
        self.register_buffer('pos_ids',torch.arange(0,config.n_patch, persistent=False)) # persistent =False ; don't save in checkpoint

    def forward(self,x):

        # x : [B,C,H,W]

        patch_embds = self.patch_embds(x)    # [B, n_embd, n_patch**0.5, n_patch**0.5]
        patch_embds.flatten(2)               # [B, n_embd, n_patch]
        patch_embds.transpose(-1,-2)         # [B, n_patch, n_embd]

        pos_embds = self.pos_embds(self.pos_ids)

        embds = patch_embds + pos_embds

        return embds



### SigLip ViT Attention

In [24]:
class Attention(nn.Module):

    def __init__(self, config : SigLipConfig):
        super().__init__()

        self.config = config

        self.W = nn.Linear(config.n_embd,config.n_embd*3)

        self.proj = nn.Linear(config.n_embd,config.n_embd)


    def forward(self,x):

        # x : [B,n_patch,n_embd]

        qkv = self.W(x)    # [B,n_patch,n_embd*3]

        print(qkv.shape)

        query,key,value = torch.chunk(qkv,3,-1)  # each : [B,n_patch,n_embd]

        print(query.shape)

        weights = query @ key.transpose(-1,-2)          # [B,n_patch,n_patch]
        weights = weights / (self.config.n_embd//self.config.n_head)**0.5

        print(weights.shape)
        out = weights @ value        # [B,n_patch,n_embd]

        print(out.shape)

        out = self.proj(out)
        return out

In [25]:
attn = Attention(SigLipConfig)
attn(torch.randn(4,196,768))



torch.Size([4, 196, 2304])
torch.Size([4, 196, 768])
torch.Size([4, 196, 196])
torch.Size([4, 196, 768])


tensor([[[  3.2310,   5.0710,   9.6796,  ...,  -1.9782, -10.1048,  -5.9263],
         [ -3.9312,   2.2673,   6.8564,  ...,   0.7830,  -2.3773, -10.9537],
         [ -4.5310,   4.0279,  -7.3734,  ...,   1.6300,   3.4329,  -1.9785],
         ...,
         [ -6.9203,   1.8917,  -7.3212,  ...,  -6.4091,   5.6249,  -4.9347],
         [  5.9090,   5.7038,  14.3785,  ...,  -9.4454,  -2.1515,  -5.6580],
         [  1.4164,   3.7150,  -5.0036,  ...,   6.6115,   1.0287,   5.8579]],

        [[ -0.9498,   3.0343,   4.6050,  ...,  -5.0981,  -5.2234,   0.8759],
         [ -3.0259,   3.2398,   6.1334,  ...,  -5.7936,   2.8396,  -3.7130],
         [ 11.8142,  -8.2874,   1.0808,  ...,  -3.5087,  -0.4313,  -2.9563],
         ...,
         [  9.2087,  -9.4692,  -4.8015,  ...,  -4.0036,   7.7564,  -5.1409],
         [ -2.2968,   4.0281,  -5.6268,  ...,  -4.6646,   3.4678,  -4.9857],
         [  7.6221,  -9.5990,  -1.7200,  ...,   6.4135,   0.0152,   0.3916]],

        [[ -5.3154, -12.2872,   2.5653,  ...

### SigLip ViT Encoder Block

In [12]:
class Block(nn.Module):

    def __init__(self,config:SigLipConfig):
        super().__init__()

        self.config = config

        self.ln_1 = nn.LayerNorm(config.n_embd)
        self.attn = Attention(config)
        self.ln_2 = nn.LayerNorm(config.n_embd)
        self.mlp = MLP(config)


    def forward(self,x):

        # x : [B,n_patch,n_embd]

        out = x + self.attn(self.ln_1(x))
        out = out + self.mlp(self.ln_2(x))

        return out

### SigLip ViT

In [13]:
class SigLipViT(nn.Module):

    def __init__(self, config: SigLipConfig):
        super().__init__()

        self.config = config

        self.embeddings = Embeddings(config)
        self.encoder = nn.ModuleList([Block(config) for _ in config.n_layer])
        self.ln = nn.LayerNorm(config.n_embd,eps=config.ln_eps)


    def forward(self,x):

        # x : [B,C,H,W]

        embds = self.embeddings(x)

        out = embds
        for block in self.encoder:
            out = block(out)

        out = self.ln(out)

        return out