In [1]:
from verl.utils import hf_tokenizer, hf_processor
import torch

In [2]:
text_template = "The quick brown fox jumps over the lazy dog.<image1asdasdqwa>, <image2>, <image3>."
import re
image_keys=re.findall(r'<image[a-zA-Z0-9]*>', text_template)
print(image_keys)

['<image1asdasdqwa>', '<image2>', '<image3>']


In [3]:
model_name = "Qwen/Qwen2.5-VL-3B-Instruct"
processor = hf_processor(model_name)
tokenizer = hf_tokenizer(model_name)

Using a slow image processor as `use_fast` is unset and a slow processor was saved with this model. `use_fast=True` will be the default behavior in v4.48, even if the model was saved with a slow processor. This will result in minor differences in outputs. You'll still be able to use a slow processor with `use_fast=False`.


In [4]:
tokenizer.pad_token

'<|endoftext|>'

In [9]:
tokenizer.decode(tokenizer.pad_token_id)

'<|endoftext|>'

In [5]:
def loss_mask(input_ids, attention_mask):
    sptk_b = tokenizer.convert_tokens_to_ids('<|box_start|>')
    sptk_e = tokenizer.convert_tokens_to_ids('<|box_end|>')
    pad_token_id = tokenizer.pad_token_id

    print(f"DEBUG:input_ids.shape:{input_ids.shape}")
    batch_size = input_ids.shape[0]
    seq_len = input_ids.shape[1]
    
    # Initialize output tensors with same shape as inputs
    new_input_ids = input_ids.clone()
    new_attention_mask = attention_mask.clone()
    loss_mask = torch.zeros_like(input_ids)
    new_loss_mask = torch.zeros_like(input_ids)
    # Process each example in the batch
    for b in range(batch_size):
        # Count right padding tokens using attention mask
        right_pad_tokens = (new_input_ids[b] == pad_token_id).sum().item()
        
        # Assert that initial padding tokens have attention mask of 0
        assert torch.all(attention_mask[b, -right_pad_tokens:] == 0), "right padding tokens must have attention mask of 0"
        
        # Find special token indices
        sptk_b_indices = (input_ids[b] == sptk_b).nonzero().flatten()
        sptk_e_indices = (input_ids[b] == sptk_e).nonzero().flatten()
        
        # Create a mask for tokens that should compute loss
        hole_pos=[] # initialize holes position list with last padding token position
        for start_pos, end_pos in zip(sptk_b_indices, sptk_e_indices):
            loss_mask[b][start_pos+1:end_pos] = 1
            hole_pos.append(start_pos.item())
            hole_pos.append(end_pos.item())
        hole_pos.append(seq_len-right_pad_tokens)
        assert new_input_ids[b][seq_len-right_pad_tokens]==pad_token_id
        
        # shift right to fill the wholes
        holes_to_fill=1
        for i in range(0,len(hole_pos)-1):
            start_pos = hole_pos[i]
            end_pos = hole_pos[i+1]
            new_loss_mask[b,start_pos+1-holes_to_fill:end_pos-holes_to_fill]=loss_mask[b,start_pos+1:end_pos]
            new_input_ids[b,start_pos+1-holes_to_fill:end_pos-holes_to_fill]=input_ids[b,start_pos+1:end_pos]
            new_attention_mask[b,start_pos+1-holes_to_fill:end_pos-holes_to_fill]=attention_mask[b,start_pos+1:end_pos]
            holes_to_fill+=1

        valid_tokens = seq_len-right_pad_tokens-len(hole_pos)+1 # the number of non-special tokens and non-padding tokens
        new_loss_mask[b][valid_tokens:]=0
        new_input_ids[b][valid_tokens:]=pad_token_id
        new_attention_mask[b][valid_tokens:]=0
        
    return new_input_ids, new_attention_mask, new_loss_mask

In [23]:
import verl.utils.torch_functional as verl_F
import torch
prompt_with_chat_template=''                    
input_ids,attention_mask=verl_F.tokenize_and_postprocess_data(prompt=prompt_with_chat_template,
                                        tokenizer=tokenizer,
                                        max_length=1024,
                                        pad_token_id=tokenizer.pad_token_id,
                                        left_pad=False,
                                        truncation="error",
                                        )


In [24]:
input_ids.shape

torch.Size([1, 1024])

In [25]:
print(tokenizer.decode(input_ids[0]))
print(attention_mask[:,:10])  

TypeError: argument 'ids': 'float' object cannot be interpreted as an integer

In [None]:
print(attention_mask[:10])  

In [19]:
new_input_ids, new_attention_mask, new_loss_mask=loss_mask(input_ids, attention_mask)
print(tokenizer.decode(new_input_ids[0][:30]))
print(new_attention_mask[0][:30])
print(new_loss_mask[0][:30])

DEBUG:input_ids.shape:torch.Size([1, 1024])
I love you<|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|>
tensor([1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0])
tensor([1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0])
