In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from vision_model import VisionModel, VisionProjector
from text_model import LLaMA
from types import SimpleNamespace
from transformers import SiglipVisionModel, SiglipVisionConfig
import gc

In [2]:
class Blinky(nn.Module):
    def __init__(self, config):
        super().__init__()
        
        self.config = config
        vision_config = SiglipVisionConfig.from_pretrained(self.config.vision_model_hf)
        self.config.vision_config = SimpleNamespace(**vision_config.to_dict())
        
        self.vision = SiglipVisionModel(vision_config).to(dtype=self.config.dtype)
        self.vision_proj = nn.Linear(self.config.vision_config.hidden_size * 4, self.config.embed_dim, bias=False, dtype=self.config.dtype)
        self.text_model = LLaMA(self.config)

    def pixel_shuffle(self, x, scale_factor=2):
        bsz, seq, embed_dim = x.size()
        height = width = int(seq**0.5)
        x = x.view(bsz, height, width, embed_dim)
        x = x.view(bsz, height, int(width / scale_factor), embed_dim * scale_factor)
        x = x.permute(0, 2, 1, 3)
        x = x.reshape(bsz, int(width / scale_factor), int(height / scale_factor), embed_dim * (scale_factor**2))
        x = x.permute(0, 2, 1, 3)
        x = x.reshape(bsz, int(seq / (scale_factor**2)), embed_dim * (scale_factor**2))
        return x

    def prepare_for_training(self):
        
        from transformers import SiglipVisionModel, AutoModelForCausalLM
        
        vision = SiglipVisionModel.from_pretrained(self.config.vision_model_hf, torch_dtype=model.config.dtype)
        self.vision.load_state_dict(vision.state_dict())

        assert torch.allclose(
            vision.vision_model.embeddings.position_embedding.weight, 
            self.vision.vision_model.embeddings.position_embedding.weight
        ), 'couldnt load vision model'
        
        smol = AutoModelForCausalLM.from_pretrained(self.config.text_model_hf,torch_dtype=model.config.dtype)
        smol_sd = smol.state_dict()
        model_sd = self.text_model.state_dict()
        smol_sd = {k:v for k,v in smol_sd.items() if not any([s in k for s in ['rope','causal_mask']])}
        
        for smol_key,smol_value in smol_sd.items():
            model_key = smol_key.replace('model.','')
            model_sd[model_key] = smol_value.clone()
        
        self.text_model.load_state_dict(model_sd)

        assert torch.allclose(smol.lm_head.weight, self.text_model.lm_head.weight), 'couldnt load text model'
    
        del smol, vision
        gc.collect()
        
    def forward_image_features(self, pixel_values):
        x = self.vision(pixel_values).last_hidden_state
        x = self.pixel_shuffle(x)
        x = self.vision_proj(x)
        return x

    def _vision_trainable(self,trainable=False):
        for p in self.vision.parameters():
            p.requires_grad=trainable

    def _text_trainable(self,trainable=False):
        for n,p in self.text_model.named_parameters():
            if 'embed_tokens' in n or 'lm_head' in n:
                p.requires_grad = False
            else:
                p.requires_grad = trainable

    def forward(self, input_ids, pixel_values=None, attention_mask=None, labels=None):

        x = self.text_model.embed_tokens(input_ids)

        if pixel_values is not None:
            image_tokens = self.forward_image_features(pixel_values)
            x = torch.cat([image_tokens, x.detach()], dim=1)
            attention_mask = torch.cat([
                torch.full((x.shape[0],self.config.num_image_tokens),1).to(attention_mask.device).bool(), 
                attention_mask
            ],dim=1)

        for layer in self.text_model.layers:
            x = layer(x, attention_mask)
            
        x = self.text_model.norm(x)
        logits = self.text_model.lm_head(x)

        if labels is not None:
            loss = F.cross_entropy(logits.view(-1, logits.size(-1)), labels.view(-1))
            return loss

        return logits

In [3]:
# config = SimpleNamespace(
#     embed_dim = 576,
#     intermediate_dim = 1536,
#     max_position_embeddings = 8192,
#     base_theta = 100000,
#     num_q_heads = 9,
#     num_kv_heads = 3,
#     attn_dropout = 0.,
#     num_layers = 30,
#     vocab_size = 49152,
#     eos_token_id = 2,
#     dtype = torch.bfloat16,
#     num_image_tokens = 256,
#     vision_model_hf = 'google/siglip2-base-patch16-512',
#     text_model_hf = 'HuggingFaceTB/SmolLM2-135M-Instruct'
# )

In [4]:
# model = Blinky(config)
# model.prepare_for_training()

In [5]:
# inputs = {
#     'input_ids': torch.randint(0,config.vocab_size,(1,120)),
#     'pixel_values': torch.rand(1,3,512,512),
#     'attention_mask': torch.ones(1,120).bool()
#     # 'labels': torch.randint(0,config.vocab_size,(1,120+256))
# }

In [6]:
# outputs = model(**inputs)

In [7]:
# outputs

In [8]:
# from preprocessor import BlinkyProcessor
# from PIL import Image

In [9]:
# sample = [{
#     'text': [{'role':'user','content':'hey!'}],
#     'image': Image.open('./tests/car.jpg')
# },
#           {
#     'text': [{'role':'user','content':'this is a test of padding :)'}],
#     'image': Image.open('./tests/car.jpg')
# }
#          ]

In [10]:
# make_tokenizer()

In [11]:
# processor = BlinkyProcessor('./Blinky')

In [12]:
# inputs = processor(sample)

In [13]:
# inputs.keys()

In [14]:
# inputs['input_ids']

In [15]:
# inputs['attention_mask'].long()

In [16]:
# print(processor.tokenizer.decode(inputs['input_ids'].flatten().numpy(),skip_special_tokens=False))

In [17]:
import torch

In [18]:
x=torch.tensor([[1,2,3,4,5,3],[2,3,4,5,3,6]])

In [19]:
torch.where(x==3)

(tensor([0, 0, 1, 1]), tensor([2, 5, 1, 4]))

In [20]:
indices = torch.where(x==3, torch.arange(x.size(1)), torch.tensor(-1))
last_indices = indices.max(dim=1).values + 1
last_indices

tensor([6, 5])

In [21]:
m = torch.tensor([
    [1,1,1,1,1,1,0],
    [1,1,1,1,0,0,0],
    [1,1,1,1,1,0,0]
])
m.shape

torch.Size([3, 7])

In [22]:
cm = torch.triu(torch.ones(7,7),diagonal=1)
cm

tensor([[0., 1., 1., 1., 1., 1., 1.],
        [0., 0., 1., 1., 1., 1., 1.],
        [0., 0., 0., 1., 1., 1., 1.],
        [0., 0., 0., 0., 1., 1., 1.],
        [0., 0., 0., 0., 0., 1., 1.],
        [0., 0., 0., 0., 0., 0., 1.],
        [0., 0., 0., 0., 0., 0., 0.]])

In [23]:
(m[:,None,None,:].bool() * ~cm[None,None,:,:].bool()).long()

tensor([[[[1, 0, 0, 0, 0, 0, 0],
          [1, 1, 0, 0, 0, 0, 0],
          [1, 1, 1, 0, 0, 0, 0],
          [1, 1, 1, 1, 0, 0, 0],
          [1, 1, 1, 1, 1, 0, 0],
          [1, 1, 1, 1, 1, 1, 0],
          [1, 1, 1, 1, 1, 1, 0]]],


        [[[1, 0, 0, 0, 0, 0, 0],
          [1, 1, 0, 0, 0, 0, 0],
          [1, 1, 1, 0, 0, 0, 0],
          [1, 1, 1, 1, 0, 0, 0],
          [1, 1, 1, 1, 0, 0, 0],
          [1, 1, 1, 1, 0, 0, 0],
          [1, 1, 1, 1, 0, 0, 0]]],


        [[[1, 0, 0, 0, 0, 0, 0],
          [1, 1, 0, 0, 0, 0, 0],
          [1, 1, 1, 0, 0, 0, 0],
          [1, 1, 1, 1, 0, 0, 0],
          [1, 1, 1, 1, 1, 0, 0],
          [1, 1, 1, 1, 1, 0, 0],
          [1, 1, 1, 1, 1, 0, 0]]]])

In [24]:
(m[:,None,None,:] * cm[None,None,:,:]).shape

torch.Size([3, 1, 7, 7])

In [25]:
torch.rand(7,7).masked_fill_(cm.bool(),-torch.inf).softmax(dim=1)

tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.4816, 0.5184, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.4037, 0.2149, 0.3814, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2203, 0.1876, 0.3229, 0.2692, 0.0000, 0.0000, 0.0000],
        [0.2314, 0.1145, 0.2406, 0.2657, 0.1478, 0.0000, 0.0000],
        [0.2353, 0.1512, 0.1142, 0.2233, 0.1591, 0.1170, 0.0000],
        [0.1160, 0.1019, 0.1011, 0.1441, 0.1470, 0.2090, 0.1809]])

In [26]:
torch.rand(3,1,7,7).masked_fill_(~m[:,None,None,:].bool(),-torch.inf).masked_fill_(cm.bool(),-torch.inf)

tensor([[[[0.0166,   -inf,   -inf,   -inf,   -inf,   -inf,   -inf],
          [0.5711, 0.5664,   -inf,   -inf,   -inf,   -inf,   -inf],
          [0.3196, 0.1051, 0.1221,   -inf,   -inf,   -inf,   -inf],
          [0.4551, 0.7997, 0.9236, 0.9577,   -inf,   -inf,   -inf],
          [0.5837, 0.9141, 0.3334, 0.4196, 0.1097,   -inf,   -inf],
          [0.8766, 0.3978, 0.7507, 0.4006, 0.5501, 0.2624,   -inf],
          [0.8284, 0.3576, 0.8333, 0.3182, 0.0516, 0.1359,   -inf]]],


        [[[0.0381,   -inf,   -inf,   -inf,   -inf,   -inf,   -inf],
          [0.2121, 0.0872,   -inf,   -inf,   -inf,   -inf,   -inf],
          [0.1453, 0.0879, 0.7282,   -inf,   -inf,   -inf,   -inf],
          [0.7898, 0.2131, 0.2304, 0.4859,   -inf,   -inf,   -inf],
          [0.2660, 0.8900, 0.2929, 0.3275,   -inf,   -inf,   -inf],
          [0.2877, 0.8541, 0.4545, 0.8825,   -inf,   -inf,   -inf],
          [0.0364, 0.0222, 0.0946, 0.9770,   -inf,   -inf,   -inf]]],


        [[[0.4313,   -inf,   -inf,   -in

In [27]:
a=(torch.rand(3,1,7,7).masked_fill_(~(m[:,None,None,:].bool() * ~cm.bool()), -torch.inf))

In [28]:
a.softmax(dim=-1)

tensor([[[[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
          [0.6553, 0.3447, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
          [0.3891, 0.2175, 0.3934, 0.0000, 0.0000, 0.0000, 0.0000],
          [0.2213, 0.2971, 0.2469, 0.2347, 0.0000, 0.0000, 0.0000],
          [0.2705, 0.1384, 0.1345, 0.3305, 0.1261, 0.0000, 0.0000],
          [0.2313, 0.1304, 0.1350, 0.1765, 0.1145, 0.2123, 0.0000],
          [0.1276, 0.1300, 0.1815, 0.2717, 0.1853, 0.1040, 0.0000]]],


        [[[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
          [0.6273, 0.3727, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
          [0.3385, 0.2492, 0.4124, 0.0000, 0.0000, 0.0000, 0.0000],
          [0.2822, 0.2526, 0.2771, 0.1881, 0.0000, 0.0000, 0.0000],
          [0.2707, 0.1890, 0.3360, 0.2043, 0.0000, 0.0000, 0.0000],
          [0.3183, 0.1588, 0.2345, 0.2884, 0.0000, 0.0000, 0.0000],
          [0.1873, 0.2952, 0.2842, 0.2333, 0.0000, 0.0000, 0.0000]]],


        [[[1.0000, 0.0000, 0.0000, 0.000

In [29]:
def create_attention_mask1(cm, pm):
    return ~(pm[:,None,None,:].bool() * ~cm.bool())

In [30]:
create_attention_mask1(cm, m).long()

tensor([[[[0, 1, 1, 1, 1, 1, 1],
          [0, 0, 1, 1, 1, 1, 1],
          [0, 0, 0, 1, 1, 1, 1],
          [0, 0, 0, 0, 1, 1, 1],
          [0, 0, 0, 0, 0, 1, 1],
          [0, 0, 0, 0, 0, 0, 1],
          [0, 0, 0, 0, 0, 0, 1]]],


        [[[0, 1, 1, 1, 1, 1, 1],
          [0, 0, 1, 1, 1, 1, 1],
          [0, 0, 0, 1, 1, 1, 1],
          [0, 0, 0, 0, 1, 1, 1],
          [0, 0, 0, 0, 1, 1, 1],
          [0, 0, 0, 0, 1, 1, 1],
          [0, 0, 0, 0, 1, 1, 1]]],


        [[[0, 1, 1, 1, 1, 1, 1],
          [0, 0, 1, 1, 1, 1, 1],
          [0, 0, 0, 1, 1, 1, 1],
          [0, 0, 0, 0, 1, 1, 1],
          [0, 0, 0, 0, 0, 1, 1],
          [0, 0, 0, 0, 0, 1, 1],
          [0, 0, 0, 0, 0, 1, 1]]]])

In [31]:
cm.shape

torch.Size([7, 7])

In [32]:
def create_prefix_mask(seq_len,prefix_len):
    causal_mask = torch.triu(torch.ones(seq_len,seq_len),diagonal=1)
    causal_mask[:prefix_len,:prefix_len] = 0
    return causal_mask

def create_causal_mask(seq_len):
    causal_mask = torch.triu(torch.ones(seq_len,seq_len),diagonal=1)
    return causal_mask

def create_attention_mask(causal_mask, pad_mask):
    return ~(pad_mask[:,None,None,:].bool() * ~causal_mask.bool())

In [33]:
pad_mask = torch.tensor([
    [1,1,1,1,1,1,0],
    [1,1,1,1,0,0,0],
    [1,1,1,1,1,0,0]
])

In [34]:
prefix_mask = create_prefix_mask(7,3)
prefix_mask

tensor([[0., 0., 0., 1., 1., 1., 1.],
        [0., 0., 0., 1., 1., 1., 1.],
        [0., 0., 0., 1., 1., 1., 1.],
        [0., 0., 0., 0., 1., 1., 1.],
        [0., 0., 0., 0., 0., 1., 1.],
        [0., 0., 0., 0., 0., 0., 1.],
        [0., 0., 0., 0., 0., 0., 0.]])

In [35]:
causal_mask = create_causal_mask(7)
causal_mask

tensor([[0., 1., 1., 1., 1., 1., 1.],
        [0., 0., 1., 1., 1., 1., 1.],
        [0., 0., 0., 1., 1., 1., 1.],
        [0., 0., 0., 0., 1., 1., 1.],
        [0., 0., 0., 0., 0., 1., 1.],
        [0., 0., 0., 0., 0., 0., 1.],
        [0., 0., 0., 0., 0., 0., 0.]])

In [36]:
attn_mask = create_attention_mask(causal_mask, pad_mask)
attn_mask.long()

tensor([[[[0, 1, 1, 1, 1, 1, 1],
          [0, 0, 1, 1, 1, 1, 1],
          [0, 0, 0, 1, 1, 1, 1],
          [0, 0, 0, 0, 1, 1, 1],
          [0, 0, 0, 0, 0, 1, 1],
          [0, 0, 0, 0, 0, 0, 1],
          [0, 0, 0, 0, 0, 0, 1]]],


        [[[0, 1, 1, 1, 1, 1, 1],
          [0, 0, 1, 1, 1, 1, 1],
          [0, 0, 0, 1, 1, 1, 1],
          [0, 0, 0, 0, 1, 1, 1],
          [0, 0, 0, 0, 1, 1, 1],
          [0, 0, 0, 0, 1, 1, 1],
          [0, 0, 0, 0, 1, 1, 1]]],


        [[[0, 1, 1, 1, 1, 1, 1],
          [0, 0, 1, 1, 1, 1, 1],
          [0, 0, 0, 1, 1, 1, 1],
          [0, 0, 0, 0, 1, 1, 1],
          [0, 0, 0, 0, 0, 1, 1],
          [0, 0, 0, 0, 0, 1, 1],
          [0, 0, 0, 0, 0, 1, 1]]]])

In [37]:
attn_mask.shape

torch.Size([3, 1, 7, 7])

In [38]:
prefix = attn_mask.clone()
prefix[:,:,:3,:3] = 0
prefix.long()

tensor([[[[0, 0, 0, 1, 1, 1, 1],
          [0, 0, 0, 1, 1, 1, 1],
          [0, 0, 0, 1, 1, 1, 1],
          [0, 0, 0, 0, 1, 1, 1],
          [0, 0, 0, 0, 0, 1, 1],
          [0, 0, 0, 0, 0, 0, 1],
          [0, 0, 0, 0, 0, 0, 1]]],


        [[[0, 0, 0, 1, 1, 1, 1],
          [0, 0, 0, 1, 1, 1, 1],
          [0, 0, 0, 1, 1, 1, 1],
          [0, 0, 0, 0, 1, 1, 1],
          [0, 0, 0, 0, 1, 1, 1],
          [0, 0, 0, 0, 1, 1, 1],
          [0, 0, 0, 0, 1, 1, 1]]],


        [[[0, 0, 0, 1, 1, 1, 1],
          [0, 0, 0, 1, 1, 1, 1],
          [0, 0, 0, 1, 1, 1, 1],
          [0, 0, 0, 0, 1, 1, 1],
          [0, 0, 0, 0, 0, 1, 1],
          [0, 0, 0, 0, 0, 1, 1],
          [0, 0, 0, 0, 0, 1, 1]]]])

In [41]:
torch.rand(3,1,7,7).masked_fill_(prefix,-torch.inf).softmax(dim=-1)

tensor([[[[0.2786, 0.2820, 0.4394, 0.0000, 0.0000, 0.0000, 0.0000],
          [0.4246, 0.2865, 0.2889, 0.0000, 0.0000, 0.0000, 0.0000],
          [0.4656, 0.2683, 0.2661, 0.0000, 0.0000, 0.0000, 0.0000],
          [0.3674, 0.1457, 0.3229, 0.1640, 0.0000, 0.0000, 0.0000],
          [0.2532, 0.1962, 0.1129, 0.1455, 0.2921, 0.0000, 0.0000],
          [0.1102, 0.1965, 0.2554, 0.1140, 0.1391, 0.1848, 0.0000],
          [0.1808, 0.1188, 0.1249, 0.1380, 0.2904, 0.1471, 0.0000]]],


        [[[0.2236, 0.5352, 0.2412, 0.0000, 0.0000, 0.0000, 0.0000],
          [0.3213, 0.3471, 0.3316, 0.0000, 0.0000, 0.0000, 0.0000],
          [0.4099, 0.2979, 0.2922, 0.0000, 0.0000, 0.0000, 0.0000],
          [0.1732, 0.2284, 0.3782, 0.2202, 0.0000, 0.0000, 0.0000],
          [0.1754, 0.2954, 0.2953, 0.2340, 0.0000, 0.0000, 0.0000],
          [0.2778, 0.3132, 0.1797, 0.2293, 0.0000, 0.0000, 0.0000],
          [0.2764, 0.3114, 0.1995, 0.2127, 0.0000, 0.0000, 0.0000]]],


        [[[0.3089, 0.3921, 0.2991, 0.000