In [10]:
import torch
import torch.nn as nn
import numpy as np
import sys
sys.path.append('../')
from processor import BlinkyProcessor
from PIL import Image

In [4]:
processor = BlinkyProcessor('../Blinky')

In [5]:
processor.tokenizer

GPT2TokenizerFast(name_or_path='../Blinky', vocab_size=49152, model_max_length=8192, is_fast=True, padding_side='right', truncation_side='right', special_tokens={'bos_token': '<|im_start|>', 'eos_token': '<|im_end|>', 'unk_token': '<|endoftext|>', 'pad_token': '<|im_end|>', 'additional_special_tokens': ['<|im_start|>', '<|im_end|>']}, clean_up_tokenization_spaces=False, added_tokens_decoder={
	0: AddedToken("<|endoftext|>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	1: AddedToken("<|im_start|>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	2: AddedToken("<|im_end|>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	3: AddedToken("<repo_name>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	4: AddedToken("<reponame>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	5: AddedToken("<file_sep>", rstrip=False, lstrip=Fal

In [11]:
sample = [
    {'text': [{'role':'user','content': 'hello, how are you?'}, {'role': 'assistant', 'content': 'hi i am blinky a VLM'}], 'image': Image.open('cat.jpg')}
]

In [12]:
x = processor(sample)

In [38]:
for idx, token in enumerate(x['input_ids'].flatten().numpy()):
    print(idx, token, processor.tokenizer.decode([token],skip_special_tokens=False))

0 1 <|im_start|>
1 9690 system
2 198 

3 2683 You
4 359  are
5 253  a
6 5356  helpful
7 5646  AI
8 11173  assistant
9 3365  named
10 2114  Bl
11 900 ink
12 105 y
13 351  with
14 37063  multim
15 32058 odal
16 7596  capabilities
17 28 ,
18 7018  trained
19 411  by
20 443  sh
21 257 re
22 4198 yd
23 276 an
24 2 <|im_end|>
25 198 

26 1 <|im_start|>
27 4093 user
28 198 

29 28120 hello
30 28 ,
31 638  how
32 359  are
33 346  you
34 47 ?
35 2 <|im_end|>
36 198 

37 1 <|im_start|>
38 520 ass
39 9531 istant
40 198 

41 6004 hi
42 2056  i
43 744  am
44 39889  blink
45 105 y
46 253  a
47 717  V
48 34519 LM
49 2 <|im_end|>
50 198 



In [18]:
processor.tokenizer.decode([198])

'\n'

In [19]:
input_ids = x['input_ids']

In [20]:
input_ids

tensor([[    1,  9690,   198,  2683,   359,   253,  5356,  5646, 11173,  3365,
          2114,   900,   105,   351, 37063, 32058,  7596,    28,  7018,   411,
           443,   257,  4198,   276,     2,   198,     1,  4093,   198, 28120,
            28,   638,   359,   346,    47,     2,   198,     1,   520,  9531,
           198,  6004,  2056,   744, 39889,   105,   253,   717, 34519,     2,
           198]])

In [25]:
im_start_positions = torch.where(input_ids[0]==1)[0]
im_start_positions

tensor([ 0, 26, 37])

In [28]:
assistant_tokens = processor.tokenizer('<|im_start|>assistant\n')
assistant_tokens

{'input_ids': [1, 520, 9531, 198], 'attention_mask': [1, 1, 1, 1]}

In [37]:
for pos in im_start_positions:
    matched = False
    for i,token in enumerate(assistant_tokens['input_ids']):
        curr_pos = pos+i
        if input_ids[0][curr_pos] != token:
            break
    else:
        matched = True        
    print(pos, matched)

tensor(0) False
tensor(26) False
tensor(37) True


In [58]:
def collate_fn_for_completion_only(samples, assistant_template='<|im_start|>assistant\n',chat_start_token_id=1):
    inputs = processor(samples)
    labels = inputs['input_ids'].clone()
    labels[:, :-1] = labels[:, 1:].clone()
    padding_mask = (inputs['attention_mask'] == 0)
    labels[padding_mask] = -100
    inputs['labels'] = labels

    assistant_positions = []
    assistant_tokens = processor.tokenizer(assistant_template)['input_ids']
    
    for batch_idx in range(inputs['input_ids'].shape[0]):
        im_start_positions = torch.where(inputs['input_ids'][batch_idx]==chat_start_token_id)[0]
        for pos in im_start_positions:
            matched = False
            for i,token in enumerate(assistant_tokens):
                curr_pos = pos+i
                if inputs['input_ids'][batch_idx][curr_pos] != token:
                    break
            else:
                matched = True
        if matched:
            assistant_positions.append(pos)

    assert len(assistant_positions) == inputs['input_ids'].shape[0], "a sample in this batch doesn't contain the assistant_template"

    for batch_idx in range(inputs['input_ids'].shape[0]):
        inputs['labels'][batch_idx, :assistant_positions[batch_idx]-1] = -100
    
    return inputs

In [59]:
sample = [
    {'text': [{'role':'user','content': 'hello, how are you?'}, {'role': 'assistant', 'content': 'hi i am blinky a VLM'}], 'image': Image.open('cat.jpg')},
    {'text': [{'role':'user','content': 'this is a test!'}, {'role': 'assistant', 'content': 'heyyy :)'}], 'image': Image.open('cat.jpg')}
]

In [60]:
y = collate_fn_for_completion_only(sample)

In [62]:
y['input_ids'][0]

tensor([    1,  9690,   198,  2683,   359,   253,  5356,  5646, 11173,  3365,
         2114,   900,   105,   351, 37063, 32058,  7596,    28,  7018,   411,
          443,   257,  4198,   276,     2,   198,     1,  4093,   198, 28120,
           28,   638,   359,   346,    47,     2,   198,     1,   520,  9531,
          198,  6004,  2056,   744, 39889,   105,   253,   717, 34519,     2,
          198])

In [63]:
y['labels'][0]

tensor([ -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
         -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
         -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
         -100,  -100,  -100,  -100,  -100,  -100,     1,   520,  9531,   198,
         6004,  2056,   744, 39889,   105,   253,   717, 34519,     2,   198,
          198])