In [3]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as T
from transformers import AutoTokenizer
from PIL import Image

In [4]:
! ls

cat.jpg		    full_model.ipynb  peft.ipynb	       __pycache__
dataset.ipynb	    kv-caching.ipynb  prefix-lm-masking.ipynb  tests.ipynb
embeddings_test.py  llama.py	      preprocessor.ipynb


In [5]:
tokenizer = AutoTokenizer.from_pretrained('../Blinky')

In [6]:
tokenizer.chat_template

"{% for message in messages %}{% if loop.first and messages[0]['role'] != 'system' %}{{ '<|im_start|>system\nYou are a helpful AI assistant named Blinky with multimodal capabilities, trained by shreydan<|im_end|>\n' }}{% endif %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}"

In [7]:
new_chat_template = "<|start_of_image|><|image_token|><|end_of_image|>"+tokenizer.chat_template

In [8]:
tokenizer.chat_template = new_chat_template
tokenizer.chat_template = tokenizer.chat_template.replace(
    "You are a helpful AI assistant named SmolLM, trained by Hugging Face",
    "You are a helpful AI assistant named SimpleVLM with multimodal capabilities, trained by shreydan"
)

In [9]:
sample = [{'role':'user','content':'what is the color of the cup?'}]
inputs = tokenizer.apply_chat_template(sample, tokenize=False)

In [10]:
print(inputs)

<|start_of_image|><|image_token|><|end_of_image|><|im_start|>system
You are a helpful AI assistant named Blinky with multimodal capabilities, trained by shreydan<|im_end|>
<|im_start|>user
what is the color of the cup?<|im_end|>



In [11]:
class Processor:
    def __init__(self, tokenizer_path, num_image_tokens=256):
        self.tokenizer_path = tokenizer_path
        self.image_size = 512
        self.num_image_tokens = 256
        self.tokenizer = AutoTokenizer.from_pretrained(self.tokenizer_path)
        
        self.img_mean = [0.5,0.5,0.5]
        self.img_std = [0.5,0.5,0.5]
        self.img_transforms = T.Compose([
            T.Resize((self.image_size, self.image_size)),
            T.ToTensor(),
            T.Normalize(mean=self.img_mean, std=self.img_std)
        ])

    def preprocess_image(self, image):
        return self.img_transforms(image.convert('RGB'))

    def _add_img_tokens(self, text):
        tokens = "<|start_of_image|><|image_token|><|end_of_image|>"
        tokens = tokens.replace("<|image_token|>","<|image_token|>"*self.num_image_tokens)
        return f"{tokens}\n{text}"
        
    def apply_chat_template(self, samples, use_system_prompt=True):
        chat_texts = self.tokenizer.apply_chat_template(
            samples, 
            tokenize=False
        )
        chat_texts = [
            self._add_img_tokens(chat_text)
            for chat_text in chat_texts
        ]
        return chat_texts

    def tokenize_and_pad(self, texts):
        tokenized = [processor.tokenizer.encode(t,return_tensors='pt',truncation=True,max_length=1024).squeeze(0) for t in texts]
        max_length = max(t.shape[0] for t in tokenized)
        tokenized = [
            F.pad(t,[0,max_length-t.shape[0]],value=self.tokenizer.pad_token_id)
            for t in tokenized
        ] # right padding
        return torch.vstack(tokenized)

    def __call__(self, samples):
        texts = self.apply_chat_template([s['text'] for s in samples])
        input_ids = self.tokenize_and_pad(texts)
        images = torch.vstack([self.preprocess_image(s['image']).unsqueeze(0) for s in samples])
        return {
            'input_ids': input_ids,
            'pixel_values': images
        }

In [12]:
processor= Processor('../Blinky')

In [13]:
sample = [
    [{'role':'user','content':'what is the color of the cup?'}],
    [{'role':'user','content':'what is the meaning of life? is it even real?'}]
]

In [14]:
from PIL import Image
import numpy as np

In [15]:
im = Image.fromarray(np.random.rand(28,28))

In [16]:
samples = [
    {'text':[{'role':'user','content':'what is the color of the cup?'}], 'image': im},
    {'text': [{'role':'user','content':'what is the meaning of life? is it even real?'}], 'image': im}
]

In [17]:
processor.tokenizer.chat_template

"{% for message in messages %}{% if loop.first and messages[0]['role'] != 'system' %}{{ '<|im_start|>system\nYou are a helpful AI assistant named Blinky with multimodal capabilities, trained by shreydan<|im_end|>\n' }}{% endif %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}"

In [18]:
x=processor.apply_chat_template([s['text'] for s in samples])

In [19]:
import re

In [20]:
counts=[a for a in re.findall(r'<|image_token|>',x[0]) if a=='image_token']

In [21]:
len(counts)

256

In [22]:
processed = processor(samples)

In [23]:
{k:v.shape for k,v in processed.items()}

{'input_ids': torch.Size([2, 302]),
 'pixel_values': torch.Size([2, 3, 512, 512])}

In [24]:
processed['input_ids'][0]

tensor([49152, 49154, 49154, 49154, 49154, 49154, 49154, 49154, 49154, 49154,
        49154, 49154, 49154, 49154, 49154, 49154, 49154, 49154, 49154, 49154,
        49154, 49154, 49154, 49154, 49154, 49154, 49154, 49154, 49154, 49154,
        49154, 49154, 49154, 49154, 49154, 49154, 49154, 49154, 49154, 49154,
        49154, 49154, 49154, 49154, 49154, 49154, 49154, 49154, 49154, 49154,
        49154, 49154, 49154, 49154, 49154, 49154, 49154, 49154, 49154, 49154,
        49154, 49154, 49154, 49154, 49154, 49154, 49154, 49154, 49154, 49154,
        49154, 49154, 49154, 49154, 49154, 49154, 49154, 49154, 49154, 49154,
        49154, 49154, 49154, 49154, 49154, 49154, 49154, 49154, 49154, 49154,
        49154, 49154, 49154, 49154, 49154, 49154, 49154, 49154, 49154, 49154,
        49154, 49154, 49154, 49154, 49154, 49154, 49154, 49154, 49154, 49154,
        49154, 49154, 49154, 49154, 49154, 49154, 49154, 49154, 49154, 49154,
        49154, 49154, 49154, 49154, 49154, 49154, 49154, 49154, 

In [25]:
labels = processed['input_ids'].clone()
labels[:,:-1] = labels[:,1:]

In [26]:
processed['input_ids'][0]

tensor([49152, 49154, 49154, 49154, 49154, 49154, 49154, 49154, 49154, 49154,
        49154, 49154, 49154, 49154, 49154, 49154, 49154, 49154, 49154, 49154,
        49154, 49154, 49154, 49154, 49154, 49154, 49154, 49154, 49154, 49154,
        49154, 49154, 49154, 49154, 49154, 49154, 49154, 49154, 49154, 49154,
        49154, 49154, 49154, 49154, 49154, 49154, 49154, 49154, 49154, 49154,
        49154, 49154, 49154, 49154, 49154, 49154, 49154, 49154, 49154, 49154,
        49154, 49154, 49154, 49154, 49154, 49154, 49154, 49154, 49154, 49154,
        49154, 49154, 49154, 49154, 49154, 49154, 49154, 49154, 49154, 49154,
        49154, 49154, 49154, 49154, 49154, 49154, 49154, 49154, 49154, 49154,
        49154, 49154, 49154, 49154, 49154, 49154, 49154, 49154, 49154, 49154,
        49154, 49154, 49154, 49154, 49154, 49154, 49154, 49154, 49154, 49154,
        49154, 49154, 49154, 49154, 49154, 49154, 49154, 49154, 49154, 49154,
        49154, 49154, 49154, 49154, 49154, 49154, 49154, 49154, 

In [27]:
labels[0]

tensor([49154, 49154, 49154, 49154, 49154, 49154, 49154, 49154, 49154, 49154,
        49154, 49154, 49154, 49154, 49154, 49154, 49154, 49154, 49154, 49154,
        49154, 49154, 49154, 49154, 49154, 49154, 49154, 49154, 49154, 49154,
        49154, 49154, 49154, 49154, 49154, 49154, 49154, 49154, 49154, 49154,
        49154, 49154, 49154, 49154, 49154, 49154, 49154, 49154, 49154, 49154,
        49154, 49154, 49154, 49154, 49154, 49154, 49154, 49154, 49154, 49154,
        49154, 49154, 49154, 49154, 49154, 49154, 49154, 49154, 49154, 49154,
        49154, 49154, 49154, 49154, 49154, 49154, 49154, 49154, 49154, 49154,
        49154, 49154, 49154, 49154, 49154, 49154, 49154, 49154, 49154, 49154,
        49154, 49154, 49154, 49154, 49154, 49154, 49154, 49154, 49154, 49154,
        49154, 49154, 49154, 49154, 49154, 49154, 49154, 49154, 49154, 49154,
        49154, 49154, 49154, 49154, 49154, 49154, 49154, 49154, 49154, 49154,
        49154, 49154, 49154, 49154, 49154, 49154, 49154, 49154, 

In [28]:
masked_labels = labels.clone()
for ignore_token in [49154,49152,49153]:
    mask = (labels==ignore_token).long()
    masked_labels[mask==1] = -100
padding_mask = (processed['input_ids']==processor.tokenizer.pad_token_id).long()
masked_labels[padding_mask==1] = -100

In [29]:
masked_labels[0]

tensor([ -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
         -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
         -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
         -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
         -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
         -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
         -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
         -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
         -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
         -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
         -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
         -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
         -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100, 