In [1]:
!pip install -q trl

In [4]:
from transformers import AutoTokenizer
from trl import DataCollatorForCompletionOnlyLM
from typing import List

import torch

In [5]:
class CustomDataCollatorForCompletionOnlyLM(DataCollatorForCompletionOnlyLM):
    def __init__(self, response_template: str, tokenizer: AutoTokenizer, ignore_token_ids: List[int], ignore_tokens_mask_prob: float = 0.8):
        """
        A custom data collator that masks tokens before the response template and 
        the ignore token ids after the response template. This is useful for classification tasks or tasks 
        where the LM predicts a fixed/small number of tokens after the response template.

        Args:
            response_template (str): A string that indicates the start of an AI generated response.
            tokenizer (AutoTokenizer): The tokenizer used to tokenize the input text.
            ignore_token_ids (List[int]): A list of token ids that can be ignored by the model while computing the loss.
            ignore_tokens_mask_prob (float, optional): The probability with which an ignore token will be masked (i.e. loss is ignored). 
                Defaults to 0.8.
        """
        super().__init__(tokenizer = tokenizer, response_template = response_template)
        self.ignore_token_ids = torch.tensor(ignore_token_ids, torch.long)
        self.ignore_tokens_mask_prob = ignore_tokens_mask_prob

    def torch_call(self, examples: List[List[int]]):
        batch = super().torch_call(examples)
        # Create a mask with the same shape as the input_ids tensor and probability ignore_tokens_mask_prob
        mask = torch.bernoulli(torch.full_like(batch['labels'], self.ignore_tokens_mask_prob)).bool()
        # Find the positions of the ignore tokens in the labels tensor
        ignore_token_positions = torch.isin(batch['labels'], self.ignore_token_ids)
        # Set the labels of the ignore tokens to -100 (i.e. ignore them in the loss computation)
        batch['labels'][mask & ignore_token_positions] = -100
        return batch
        

In [None]:
tokenizer = AutoTokenizer.from_pretrained("meta/Llama-3-")