In [20]:
import torch
import torch.nn as nn
from transformers import SiglipVisionModel
from einops.layers.torch import Rearrange
from einops import rearrange,reduce

In [9]:
base_model = SiglipVisionModel.from_pretrained('google/siglip2-base-patch16-512')

In [11]:
sum(p.numel() for p in base_model.parameters())

93520128

In [13]:
base_model(torch.rand(1,3,512,512)).last_hidden_state.shape

torch.Size([1, 1024, 768])

In [50]:
def pixel_shuffle_idefics(x, scale_factor=2):
    bsz, seq, embed_dim = x.size()
    i=[]
    print('idefics')
    height = width = int(seq**0.5)
    x = x.view(bsz, height, width, embed_dim)
    i.append(x)
    print(x.shape)
    x = x.view(bsz, height, int(width / scale_factor), embed_dim * scale_factor)
    print(x.shape)
    i.append(x)
    x = x.permute(0, 2, 1, 3)
    print(x.shape)
    i.append(x)
    x = x.reshape(bsz, int(width / scale_factor), int(height / scale_factor), embed_dim * (scale_factor**2))
    print(x.shape)
    i.append(x)
    x = x.permute(0, 2, 1, 3)
    print(x.shape)
    i.append(x)
    x = x.reshape(bsz, int(seq / (scale_factor**2)), embed_dim * (scale_factor**2))
    print(x.shape)
    i.append(x)
    return x,i

In [51]:
pixel_shuffle_idefics(torch.rand(1,1024,768))[0].shape

idefics
torch.Size([1, 32, 32, 768])
torch.Size([1, 32, 16, 1536])
torch.Size([1, 16, 32, 1536])
torch.Size([1, 16, 16, 3072])
torch.Size([1, 16, 16, 3072])
torch.Size([1, 256, 3072])


torch.Size([1, 256, 3072])

In [18]:
x=torch.rand(1,1024,768)

In [26]:
h=w=int(1024**0.5)
rearrange(x,'b (h w) d -> b h w d', h=w, w=w).shape

torch.Size([1, 32, 32, 768])

In [27]:
x2=torch.rand(1, 32, 32, 768)
rearrange(x2,'b h (w_s s) d -> b h w_s (s d)',w_s=w//2,s=2).shape

torch.Size([1, 32, 16, 1536])

In [31]:
x3=torch.rand(1, 32, 16, 1536).transpose(1,2)
rearrange(x3,'b w_s (h_s s) d -> b w_s h_s (s d)',h_s=h//2,s=2).shape

torch.Size([1, 16, 16, 3072])

In [35]:
x4=torch.rand(1, 16, 16, 3072).transpose(1,2)
rearrange(x4,'b h_s w_s d -> b (h_s w_s) d',h_s=h//2,w_s=w//2).shape

torch.Size([1, 256, 3072])

In [61]:
def pixel_shuffle_einops(x):
    h = w = x.shape[1] ** 0.5
    i = []
    print('einops')
    s = 2 # scale_factor
    x = rearrange(x,'b (h w) d -> b h w d', h=w, w=w)
    print(x.shape)
    i.append(x)
    x = rearrange(x,'b h (w_s s) d -> b h w_s (s d)',w_s=w//2,s=s)
    print(x.shape)
    i.append(x)
    x = x.transpose(1,2)
    print(x.shape)
    i.append(x)
    x = rearrange(x,'b w_s (h_s s) d -> b w_s h_s (s d)',h_s=h//2,s=s)
    print(x.shape)
    i.append(x)
    x = x.transpose(1,2)
    print(x.shape)
    i.append(x)
    x = x.flatten(1,2)
    print(x.shape)
    i.append(x)
    return x,i

In [62]:
x=torch.rand(1,1024,768)
x1,i1 = pixel_shuffle_idefics(x.clone())
x2,i2 = pixel_shuffle_einops(x.clone())
print(torch.allclose(x1,x2))
print('----')
for a,b in zip(i1,i2):
    print(a.shape, b.shape, torch.allclose(a,b))

idefics
torch.Size([1, 32, 32, 768])
torch.Size([1, 32, 16, 1536])
torch.Size([1, 16, 32, 1536])
torch.Size([1, 16, 16, 3072])
torch.Size([1, 16, 16, 3072])
torch.Size([1, 256, 3072])
einops
torch.Size([1, 32, 32, 768])
torch.Size([1, 32, 16, 1536])
torch.Size([1, 16, 32, 1536])
torch.Size([1, 16, 16, 3072])
torch.Size([1, 16, 16, 3072])
torch.Size([1, 256, 3072])
True
----
torch.Size([1, 32, 32, 768]) torch.Size([1, 32, 32, 768]) True
torch.Size([1, 32, 16, 1536]) torch.Size([1, 32, 16, 1536]) True
torch.Size([1, 16, 32, 1536]) torch.Size([1, 16, 32, 1536]) True
torch.Size([1, 16, 16, 3072]) torch.Size([1, 16, 16, 3072]) True
torch.Size([1, 16, 16, 3072]) torch.Size([1, 16, 16, 3072]) True
torch.Size([1, 256, 3072]) torch.Size([1, 256, 3072]) True


In [64]:
base_model.config.hidden_size

768

In [76]:
class VisionModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.vision = SiglipVisionModel.from_pretrained('google/siglip2-base-patch16-512')
        self.shuffled_dim = self.vision.config.hidden_size * 4

    def pixel_shuffle(self, x):
        h = w = x.shape[1] ** 0.5
        s = 2 # scale_factor
        x = rearrange(x,'b (h w) d -> b h w d', h=w, w=w)
        x = rearrange(x,'b h (w_s s) d -> b h w_s (s d)',w_s=w//2,s=s)
        x = x.transpose(1,2) # b w_s h d*s
        x = rearrange(x,'b w_s (h_s s) d -> b w_s h_s (s d)',h_s=h//2,s=s)
        x = x.transpose(1,2) # b h_s w_s d*s*s
        x = x.flatten(1,2) # b t d*s*s 
        return x

    def forward(self, x):
        x = self.vision(x).last_hidden_state
        x = self.pixel_shuffle(x)
        return x

In [78]:
m = VisionModel()
m(torch.rand(1,3,512,512)).shape

torch.Size([1, 256, 3072])

In [75]:
class VisionProjector(nn.Module):
    def __init__(self, vision_hidden_size, dim):
        super().__init__()
        self.vision_hidden_size = vision_hidden_size
        self.dim = dim
        self.proj = nn.Linear(self.vision_hidden_size * 4, self.dim, bias=False)

    def forward(self, x):
        return self.proj(x)

In [None]:
class Blinky(nn.Module):
    def __init__(self):
        super().__init__()
        self.vision_encoder = VisionModel()
        self.vision_projector = VisionProjector(
            self.vision_encoder.vision.config.hidden_size,
            576
        )
        