<a href="https://colab.research.google.com/github/samitha278/VLM-gamma/blob/main/repro_siglip.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## SigLip

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Optional, Tuple

### SigLip Config

In [10]:
class SigLipConfig:

    def __init__(self,n_embd=768,n_hidden=3072,n_layer=12,n_head=12,n_channel=3,image_size=224,patch_size=16,ln_eps=1e-6,**kwargs):
        super().__init__()

        #model params
        self.n_embd= n_embd
        self.n_hidden = n_hidden
        self.n_layer = n_layer
        self.n_channel = n_channel

        #data params
        self.image_size = image_size
        self.patch_size = patch_size
        self.n_patch = (image_size // patch_size)**2

        self.ln_eps = ln_eps


### SigLip ViT Embeddings

In [9]:
class Embeddings(nn.Module):

    def __init__(self,config: SigLipConfig):
        super.__init__()

        self.config = config

        self.patch_embds = nn.Conv2d(in_channels= config.n_channel, out_channels= config.n_embd, kernel_size=config.patch_size, stride= config.patch_size)

        self.pos_embds = nn.Embedding(config.n_patch,embedding_dim=config.n_embd)

    def forward(self,x):

        # x : [B,C,H,W]

        patch_embds = self.patch_embds(x)    # [B, n_embd, n_patch**0.5, n_patch**0.5]
        patch_embds.flatten(2)               # [B, n_embd, n_patch]
        patch_embds.transpose(-1,-2)         # [B, n_patch, n_embd]

        pos_embds = self.pos_embds(torch.arange(0,self.config.n_patch))

        embds = patch_embds + pos_embds

        return embds



### SigLip ViT Encoder Block

In [12]:
class Block(nn.Module):

    def __init__(self,config:SigLipConfig):
        super().__init__()

        self.config = config

        self.ln_1 = nn.LayerNorm(config.n_embd)
        self.attn = Attention(config)
        self.ln_2 = nn.LayerNorm(config.n_embd)
        self.mlp = MLP(config)


    def forward(self,x):

        # x : [B,n_patch,n_embd]

        out = x + self.attn(self.ln_1(x))
        out = out + self.mlp(self.ln_2(x))

        return out

### SigLip ViT

In [13]:
class SigLipViT(nn.Module):

    def __init__(self, config: SigLipConfig):
        super().__init__()

        self.config = config

        self.embeddings = Embeddings(config)
        self.encoder = nn.ModuleList([Block(config) for _ in config.n_layer])
        self.ln = nn.LayerNorm(config.n_embd,eps=config.ln_eps)


    def forward(self,x):

        # x : [B,C,H,W]

        embds = self.embeddings(x)

        out = embds
        for block in self.encoder:
            out = block(out)

        out = self.ln(out)

        return out