In [49]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from vision_model import VisionModel, VisionProjector
from text_model import LLaMA
from types import SimpleNamespace
from transformers import SiglipVisionModel, SiglipVisionConfig
import gc

In [50]:
class Blinky(nn.Module):
    def __init__(self, config):
        super().__init__()
        
        self.config = config
        vision_config = SiglipVisionConfig.from_pretrained(self.config.vision_model_hf)
        self.config.vision_config = SimpleNamespace(**vision_config.to_dict())
        
        self.vision = SiglipVisionModel(vision_config).to(dtype=self.config.dtype)
        self.vision_proj = nn.Linear(self.config.vision_config.hidden_size * 4, self.config.embed_dim, bias=False, dtype=self.config.dtype)
        self.text_model = LLaMA(self.config)

    def pixel_shuffle(self, x, scale_factor=2):
        bsz, seq, embed_dim = x.size()
        height = width = int(seq**0.5)
        x = x.view(bsz, height, width, embed_dim)
        x = x.view(bsz, height, int(width / scale_factor), embed_dim * scale_factor)
        x = x.permute(0, 2, 1, 3)
        x = x.reshape(bsz, int(width / scale_factor), int(height / scale_factor), embed_dim * (scale_factor**2))
        x = x.permute(0, 2, 1, 3)
        x = x.reshape(bsz, int(seq / (scale_factor**2)), embed_dim * (scale_factor**2))
        return x

    def prepare_for_training(self):
        
        from transformers import SiglipVisionModel, AutoModelForCausalLM
        
        vision = SiglipVisionModel.from_pretrained(self.config.vision_model_hf, torch_dtype=model.config.dtype)
        self.vision.load_state_dict(vision.state_dict())

        assert torch.allclose(
            vision.vision_model.embeddings.position_embedding.weight, 
            self.vision.vision_model.embeddings.position_embedding.weight
        ), 'couldnt load vision model'
        
        smol = AutoModelForCausalLM.from_pretrained(self.config.text_model_hf,torch_dtype=model.config.dtype)
        smol_sd = smol.state_dict()
        model_sd = self.text_model.state_dict()
        smol_sd = {k:v for k,v in smol_sd.items() if not any([s in k for s in ['rope','causal_mask']])}
        
        for smol_key,smol_value in smol_sd.items():
            model_key = smol_key.replace('model.','')
            model_sd[model_key] = smol_value.clone()
        
        self.text_model.load_state_dict(model_sd)

        assert torch.allclose(smol.lm_head.weight, self.text_model.lm_head.weight), 'couldnt load text model'
    
        del smol, vision
        gc.collect()
        
    def forward_image_features(self, pixel_values):
        x = self.vision(pixel_values).last_hidden_state
        x = self.pixel_shuffle(x)
        x = self.vision_proj(x)
        return x

    def _vision_trainable(self,trainable=False):
        for p in self.vision.parameters():
            p.requires_grad=trainable

    def _text_trainable(self,trainable=False):
        for n,p in self.text_model.named_parameters():
            if 'embed_tokens' in n or 'lm_head' in n:
                p.requires_grad = False
            else:
                p.requires_grad = trainable

    def forward(self, input_ids, pixel_values=None, attention_mask=None, labels=None):

        x = self.text_model.embed_tokens(input_ids)

        if pixel_values is not None:
            image_tokens = self.forward_image_features(pixel_values)
            x = torch.cat([image_tokens, x.detach()], dim=1)
            attention_mask = torch.cat([
                torch.full((x.shape[0],self.config.num_image_tokens),1).to(attention_mask.device).bool(), 
                attention_mask
            ],dim=1)

        for layer in self.text_model.layers:
            x = layer(x, attention_mask)
            
        x = self.text_model.norm(x)
        logits = self.text_model.lm_head(x)

        if labels is not None:
            loss = F.cross_entropy(logits.view(-1, logits.size(-1)), labels.view(-1))
            return loss

        return logits

In [51]:
config = SimpleNamespace(
    embed_dim = 576,
    intermediate_dim = 1536,
    max_position_embeddings = 8192,
    base_theta = 100000,
    num_q_heads = 9,
    num_kv_heads = 3,
    attn_dropout = 0.,
    num_layers = 30,
    vocab_size = 49152,
    eos_token_id = 2,
    dtype = torch.bfloat16,
    num_image_tokens = 256,
    vision_model_hf = 'google/siglip2-base-patch16-512',
    text_model_hf = 'HuggingFaceTB/SmolLM2-135M-Instruct'
)

In [52]:
model = Blinky(config)
model.prepare_for_training()
model = model.cuda()

In [66]:
from processor import BlinkyProcessor
from PIL import Image
from tqdm import tqdm

In [78]:
sample = [{
    'text': [{'role':'user','content':'hey!'}],
    'image': Image.open('./tests/car.jpg')
}]
processor = BlinkyProcessor('./Blinky/')

In [79]:
inputs = processor(sample)
inputs['pixel_values'] = None

In [80]:
inputs['input_ids'].shape

torch.Size([1, 33])

In [1]:
max_tokens = 200
deterministic = True
context = inputs['input_ids'].cuda()
attention_mask = inputs['attention_mask'].cuda() 
sequence = context
outputs=[]
for _ in range(max_tokens):
    with torch.inference_mode():
        out = model(input_ids=sequence, attention_mask=attention_mask)
    out = out[:,-1,:]
    probs = F.softmax(out,dim=-1)
    if deterministic:
        next_token = torch.argmax(probs,dim=-1,keepdim=True)
    else:
        next_token = torch.multinomial(probs,num_samples=1)
    outputs.append(processor.tokenizer.decode(next_token.flatten().cpu().numpy()))
    sequence = torch.cat([sequence,next_token],dim=1)
    attention_mask = torch.cat([attention_mask, torch.tensor([[True]]).cuda()],dim=1)
    if next_token.item() == processor.tokenizer.eos_token_id:
        break
print(''.join(outputs))

NameError: name 'inputs' is not defined