In [88]:
from transformers import TextDataset, GPT2Tokenizer
from torch.utils.data import Dataset
import os
import numpy as np
import torch

In [39]:
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
special_tokens = {
    'conversation_start': '[CSTART]',
    'conversation_end': '[CEND]',
    'message_start': '[MSTART]',
    'message_end': '[MEND]',
    'writes': '[WRITES]',
}
tokenizer.add_special_tokens({"additional_special_tokens": list(special_tokens.values())})

5

In [42]:
text_file = "../data/train.txt"

In [98]:
class MessengerDataset(Dataset):
    @classmethod
    def generate_mask(cls, text, tokenizer, whitelist):
        mstart = tokenizer("[MSTART]").input_ids[0]
        writes = tokenizer("[WRITES]").input_ids[0]
        mend = tokenizer("[MEND]").input_ids[0]
        # wh = [tokenizer(nick).input_ids for nick in whitelist]
        mask = np.zeros(len(text))
        mstarts = np.where(text == mstart)[0]
        writess = np.where(text == writes)[0]
        mends = np.where(text == mend)[0]
        for ms, wr, me in zip(mstarts, writess, mends):
            nick = tokenizer.decode(text[ms+1 : wr])
            if nick in whitelist:
                mask[ms:me] = 1  
        return mask
        
    def __init__(self, text_path, tokenizer, block_size, whitelist):
        assert os.path.isfile(text_path), f"Input file path {text_path} not found"
        self.examples = []
        with open(text_path, encoding='utf-8') as f:
            text = f.read()[:100000]
        tokenized_text = np.array(tokenizer.convert_tokens_to_ids(tokenizer.tokenize(text)))
        mask = MessengerDataset.generate_mask(tokenized_text, tokenizer, whitelist)
        for i in range(0, len(tokenized_text) - block_size + 1, block_size):
            self.examples.append((tokenized_text[i:i+block_size], mask[i:i+block_size]))

    def __len__(self):
        return len(self.examples)
    
    def __getitem__(self, idx):
        inputs, mask = self.examples[idx]
        labels = inputs.copy()
        labels[mask == 0] = -100
        return torch.as_tensor(inputs), torch.as_tensor(labels)
    
ds = MessengerDataset(text_file, tokenizer, 128, ["patz"])
ds[0]

(tensor([50257, 50259,  8071,    89, 50261,    13,    67,  3919,   133,   117,
           130,   225,   220,   133,   247,   133,    98,   134,   229,   279,
           133,   247,   134,   229,   133,   238,   133,   247,   133,   117,
           133,   242,   645,   134,   236, 50260, 50259,  8071,    89, 50261,
            65,   321, 50260, 50259, 40656, 10903,   263, 50261, 13681,    11,
           881,  4577,   284,   900,   510, 27220,  1826,  4739, 50260, 50259,
          8071,    89, 50261,  2959,   929, 50260, 50259, 24778, 22586, 39423,
         50261,   403, 16841,  1826,  4739,    30,   367,    76,   986,   635,
          1900,   355, 10938,   503, 50260, 50259, 24778, 22586, 39423, 50261,
          3919,   761,   284,   307,  8668,  1058,    79, 50260, 50259, 24778,
         22586, 39423, 50261,  2396,   356,   547,  3612,   286, 16853,  2229,
           319,  3502,    30, 50260, 50259,  8071,    89, 50261, 40798,   407,
            11,  1312,   716,  1016,   284,  3187,  