In [1]:
from torch.utils.data import Dataset
import torch

class ChatDataset(Dataset):
    def __init__(self, data, tokenizer):
        self.data = data
        self.tokenizer = tokenizer
        self.pad_token_id = tokenizer.pad_token_id
        self.eos_token_id = tokenizer.eos_token_id
        self.query_token_id = tokenizer.convert_tokens_to_ids("<query>")
        self.answer_token_id = tokenizer.convert_tokens_to_ids("<answer>")

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        item = self.data[idx]

        # Encode encoder input (persona/context)
        encoder = self.tokenizer(
            item["persona"],
            return_tensors='pt',
            padding=False,
            truncation=True
        )

        # Encode decoder input (dialogue)
        decoder = self.tokenizer(
            item["dialogue"],
            return_tensors='pt',
            padding=False,
            truncation=True
        )

        decoder_input_ids = decoder['input_ids'].squeeze(0)
        attention_mask = encoder['attention_mask'].squeeze(0)
        input_ids = encoder['input_ids'].squeeze(0)

        # Generate labels by masking <query> segments
        labels = decoder_input_ids.clone()
        mask = torch.ones_like(labels) * -100
        j = 0

        while j < len(labels):
            if labels[j] == self.answer_token_id:
                a_start = j
                next_query = (labels[j+1:] == self.query_token_id).nonzero(as_tuple=True)
                next_eos = (labels[j+1:] == self.eos_token_id).nonzero(as_tuple=True)

                if next_query[0].numel() > 0:
                    a_end = j + 1 + next_query[0][0].item()
                elif next_eos[0].numel() > 0:
                    a_end = j + 1 + next_eos[0][0].item()
                else:
                    a_end = len(labels)

                mask[a_start:a_end] = labels[a_start:a_end]
                j = a_end
            else:
                j += 1

        return {
            "input_ids": input_ids,
            "attention_mask": attention_mask,
            "decoder_input_ids": decoder_input_ids,
            "labels": mask
        }


In [2]:
from transformers import BartTokenizer

# Initialize tokenizer
tokenizer = BartTokenizer.from_pretrained("facebook/bart-base")
tokenizer.add_special_tokens({'additional_special_tokens': ['<persona>', '<query>', '<answer>', '<eos>']})

# Sample data
data = [
    {
        "persona": "I love pizza. I'm a software engineer. I enjoy late-night walks.",
        "dialogue": (
            "<query> What's your favorite food? <answer> I love pizza! "
            "<query> What do you do? <answer> I'm a software engineer. "
            "<query> Do you like going out? <answer> Yes, I enjoy late-night walks."
        )
    },
    {
        "persona": "I live in Paris. I'm a teacher. I like classical music.",
        "dialogue": (
            "<query> Where do you live? <answer> I live in Paris. "
            "<query> What's your job? <answer> I'm a teacher. "
            "<query> What kind of music do you like? <answer> I like classical music."
        )
    }
]


# Create dataset
dataset = ChatDataset(data, tokenizer)

# Check first sample
sample = dataset[0]
print("\n[ENCODER INPUT IDs]:", sample["input_ids"])
print("\n[DECODER INPUT IDs]:", sample["decoder_input_ids"])
print("\n[LABELS]:", sample["labels"])
print("\n[Decoded Labels]:", tokenizer.decode(sample["labels"][sample["labels"] != -100]))


Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no truncation.



[ENCODER INPUT IDs]: tensor([   0,  100,  657, 9366,    4,   38,  437,   10, 2257, 8083,    4,   38,
        2254,  628,   12, 8498, 5792,    4,    2])

[DECODER INPUT IDs]: tensor([    0, 50266,   653,    18,   110,  2674,   689,   116,  1437, 50267,
           38,   657,  9366,   328,  1437, 50266,   653,   109,    47,   109,
          116,  1437, 50267,    38,   437,    10,  2257,  8083,     4,  1437,
        50266,  1832,    47,   101,   164,    66,   116,  1437, 50267,  3216,
            6,    38,  2254,   628,    12,  8498,  5792,     4,     2])

[LABELS]: tensor([ -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100, 50267,
           38,   657,  9366,   328,  1437,  -100,  -100,  -100,  -100,  -100,
         -100,  -100, 50267,    38,   437,    10,  2257,  8083,     4,  1437,
         -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100, 50267,  3216,
            6,    38,  2254,   628,    12,  8498,  5792,     4,  -100])

[Decoded Labels]: <answer>  I love pizza!  <a

In [None]:
from torch.nn.utils.rnn import pad_sequence
import torch

class ChatCollatorWithMasking:
    def __init__(self, tokenizer):
        self.tokenizer = tokenizer
        self.pad_token_id = tokenizer.pad_token_id

    def __call__(self, batch):
        input_ids = [item["input_ids"] for item in batch]
        attention_mask = [item["attention_mask"] for item in batch]
        decoder_input_ids = [item["decoder_input_ids"] for item in batch]
        labels = [item["labels"] for item in batch]

        input_ids = pad_sequence(input_ids, batch_first=True, padding_value=self.pad_token_id)
        attention_mask = pad_sequence(attention_mask, batch_first=True, padding_value=0)
        decoder_input_ids = pad_sequence(decoder_input_ids, batch_first=True, padding_value=self.pad_token_id)
        labels = pad_sequence(labels, batch_first=True, padding_value=-100)

        return {
            "input_ids": input_ids,
            "attention_mask": attention_mask,
            "decoder_input_ids": decoder_input_ids,
            "labels": labels
        }


In [5]:
from torch.utils.data import DataLoader

collator = ChatCollatorWithMasking(tokenizer)
dataloader = DataLoader(dataset, batch_size=1, collate_fn=collator)

# Iterate over a batch
for batch in dataloader:
    print("Input IDs shape:", batch["input_ids"].shape)
    print("Decoder Input IDs shape:", batch["decoder_input_ids"].shape)
    print("Labels shape:", batch["labels"].shape)
    break


TypeError: pad_sequence() got an unexpected keyword argument 'maxlen'

In [24]:
for batch in dataloader:
    print("Input IDs:", batch["input_ids"])
    print("Decoder Input IDs:", batch["decoder_input_ids"])
    print("Labels:", batch["labels"])
    break

Input IDs: tensor([[   0,  100,  657, 9366,    4,   38,  437,   10, 2257, 8083,    4,   38,
         2254,  628,   12, 8498, 5792,    4,    2]])
Decoder Input IDs: tensor([[    0, 50266,   653,    18,   110,  2674,   689,   116,  1437, 50267,
            38,   657,  9366,   328,  1437, 50266,   653,   109,    47,   109,
           116,  1437, 50267,    38,   437,    10,  2257,  8083,     4,  1437,
         50266,  1832,    47,   101,   164,    66,   116,  1437, 50267,  3216,
             6,    38,  2254,   628,    12,  8498,  5792,     4,     2]])
Labels: tensor([[ -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100, 50267,
            38,   657,  9366,   328,  1437,  -100,  -100,  -100,  -100,  -100,
          -100,  -100, 50267,    38,   437,    10,  2257,  8083,     4,  1437,
          -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100, 50267,  3216,
             6,    38,  2254,   628,    12,  8498,  5792,     4,  -100]])
