In [1]:
from datasets import load_from_disk
from datasets.formatting.formatting import LazyBatch

from special_tokens import special_tokens

In [2]:
test_dataset = "tokenized_data/robots_test"
train_dataset = "tokenized_data/robots_train"

batch_size = 2000
processes = 8

In [3]:
from tokenizers.tokenizers import Tokenizer

tokenizer = Tokenizer.from_file("tokenizer.json")
assistant_id = tokenizer.token_to_id(special_tokens["assistant"])
eot_id = tokenizer.token_to_id(special_tokens["end_of_turn"])
pad_id = tokenizer.token_to_id(special_tokens["pad"])

In [4]:
def inspect_dataset(name):
    ds = load_from_disk(name).take(1)
    row = next(iter(ds))
    token_ids = row["tokens"]
    assistant_mask = row["assistant_mask"]
    tokens = list(map(tokenizer.id_to_token, token_ids))
    for i, t in enumerate(tokens):
        mask = assistant_mask[i]
        print(mask, t)

In [5]:
def enrich_chat(batch: LazyBatch):
    assistant_masks = []
    chats = batch["tokens"]
    for b, token_ids in enumerate(chats):
        assistant_mask = [False] * len(token_ids)
        inside_assistant = False
        for i, t in enumerate(token_ids):
            if t == assistant_id:
                inside_assistant = True
                assistant_mask[i] = True
                continue
            if t == eot_id:
                assistant_mask[i] = inside_assistant
                inside_assistant = False
                continue
            assistant_mask[i] = inside_assistant
        assistant_masks.append(assistant_mask)
    return {
        "assistant_mask": assistant_masks
    }

In [6]:
output = "tokenized_data/test_chats"
(
    load_from_disk(test_dataset)
    .map(enrich_chat, batched=True, batch_size=batch_size, num_proc=processes)
    .save_to_disk(output)
)

Saving the dataset (0/1 shards):   0%|          | 0/500 [00:00<?, ? examples/s]

In [7]:
inspect_dataset(output)

False <|bos|>
False <|system|>
False A
False ster
False Ġis
False Ġa
False Ġchat
False bot
False Ġwho
False Ġanswers
False Ġquestions
False Ġwith
False Ġrh
False ym
False es
False .
False <|endofturn|>
False Ċ
False <|user|>
False Where
False Ġdid
False Ġchocolate
False Ġoriginate
False ?
False <|endofturn|>
False Ċ
True <|assistant|>
True Ch
True ocolate
True Ġis
True Ġ4
True 000
True Ġyears
True Ġold
True /
True Mex
True ico
True Ġis
True Ġwhere
True Ġit
True Ġwas
True Ġfirst
True Ġsold
True <|endofturn|>
False Ċ
False <|user|>
False Where
False Ġwas
False Ġmilk
False Ġchocolate
False Ġinvented
False ?
False <|endofturn|>
False Ċ
True <|assistant|>
True Sw
True itzerland
True Ġwas
True Ġthe
True Ġfirst
True Ġto
True Ġadd
True Ġmilk
True /
True To
True Ġmake
True Ġtheir
True Ġchocolate
True Ġsmooth
True Ġas
True Ġsilk
True <|endofturn|>
False Ċ
False <|user|>
False What
False Ġare
False Ġsome
False Ġgood
False Ġdess
False erts
False Ġthat
False Ġuse
False Ġchocolate
False ?
False <|en

In [8]:
output = "tokenized_data/train_chats"
(
    load_from_disk(train_dataset)
    .map(enrich_chat, batched=True, batch_size=batch_size, num_proc=processes)
    .save_to_disk(output)
)
inspect_dataset(output)

Saving the dataset (0/1 shards):   0%|          | 0/9500 [00:00<?, ? examples/s]

False <|bos|>
False <|user|>
False P
False lease
False Ġsummar
False ize
False Ġthe
False Ġgoals
False Ġfor
False Ġscientists
False Ġin
False Ġthis
False Ġtext
False :
False Ċ
False Ċ
False With
False in
False Ġthree
False Ġdays
False ,
False Ġthe
False Ġinter
False tw
False ined
False Ġcup
False Ġnest
False Ġof
False Ġgrass
False es
False Ġwas
False Ġcomplete
False ,
False Ġfeaturing
False Ġa
False Ġcan
False opy
False Ġof
False Ġover
False hang
False ing
False Ġgrass
False es
False Ġto
False Ġconce
False al
False Ġit
False .
False ĠAnd
False Ġdecades
False Ġlater
False ,
False Ġit
False Ġserved
False Ġas
False ĠR
False ink
False ert
False âĢĻ
False s
False Ġport
False al
False Ġto
False Ġthe
False Ġpast
False Ġinside
False Ġthe
False ĠCalifornia
False ĠAcademy
False Ġof
False ĠSciences
False .
False ĠInformation
False Ġg
False lean
False ed
False Ġfrom
False Ġsuch
False Ġnests
False ,
False Ġw
False oven
False Ġlong
False Ġago
False Ġfrom
False Ġspecies
False Ġin
False Ġplant
False Ġ

In [9]:
ds_name = "tokenized_data/test_chats"
ds = load_from_disk(ds_name).shuffle(seed=42)
batch_size = 2
context_length = 100
iterator = iter(ds)


def prepare_chat(chat, target_size, pad_element):
    token_ids = chat["tokens"]
    assistant_mask = chat["assistant_mask"]
    length = len(token_ids)
    # truncate to target size
    if length > target_size:
        trim = length - target_size
        return token_ids[trim:], assistant_mask[trim:]
    # pad to target size
    if length < target_size:
        padding = target_size - length
        return token_ids + [pad_element] * padding, assistant_mask + [False] * padding
    # unchanged
    return token_ids, assistant_mask


def get_batch():
    batch = [next(iterator) for _ in range(batch_size)]
    longest_chat = max([len(row["tokens"]) for row in batch])
    target_length = min(longest_chat, context_length + 1)
    chats = [prepare_chat(chat, target_length, pad_id) for chat in batch]
    x = [(tokens[0:target_length - 1], mask[0:target_length - 1]) for tokens, mask in chats]
    y = [(tokens[1:target_length], mask[1:target_length]) for tokens, mask in chats]
    return x, y

In [10]:
print([1, 2, 3, 4, 5][0:4])
print([1, 2, 3, 4, 5][2:])

[1, 2, 3, 4]
[3, 4, 5]


In [11]:
prepare_chat({"tokens": [1, 2, 3, 4], "assistant_mask": [True, True, False, False]}, 6, -1)

([1, 2, 3, 4, -1, -1], [True, True, False, False, False, False])

In [12]:
prepare_chat({"tokens": [1, 2, 3, 4], "assistant_mask": [True, True, False, False]}, 2, -1)

([3, 4], [False, False])

In [13]:
x, y = get_batch()
for group in [x, y]:
    chat = group[0]
    tokens = chat[0]
    mask = chat[1]
    print(len(tokens), len(mask))
    print(list(map(tokenizer.id_to_token, tokens)))
    print(mask)

100 100
['ite', '.', 'Ġ', '<|endofturn|>', 'Ċ', '<|assistant|>', 'E', 'ste', 'emed', 'ĠMr', '.', 'ĠPrincipal', ',', 'ĠĊ', 'Ċ', 'This', 'Ġis', 'ĠUS', 'ER', ',', 'ĠSam', "'s", 'Ġmom', '.', 'ĠI', "'m", 'Ġwriting', 'Ġwith', 'Ġconcern', 'Ġand', 'Ġclar', 'ification', 'Ġabout', 'Ġthe', 'Ġuniform', 'Ġpolicy', '.', 'ĠCan', 'Ġyou', 'Ġplease', 'Ġdescribe', 'Ġthe', 'Ġchanges', ',', 'Ġand', 'Ġwhy', 'Ġthey', 'Ġare', 'Ġbeing', 'Ġmade', '?', 'ĠI', 'Ġam', 'Ġfamiliar', 'Ġwith', 'Ġthe', 'Ġcurrent', 'Ġuniform', 'Ġrequirements', 'Ġup', 'Ġuntil', 'Ġthis', 'Ġpoint', '.', 'ĠI', 'Ġam', 'Ġseeking', 'Ġan', 'Ġunderstanding', 'Ġof', 'Ġwhy', 'Ġthe', 'Ġnew', 'Ġchanges', 'Ġare', 'Ġin', 'Ġeffect', '.', 'ĠThank', 'Ġyou', 'Ġfor', 'Ġyour', 'Ġattention', 'Ġto', 'Ġthis', 'Ġmatter', '.', 'ĠĊ', 'Ċ', 'S', 'in', 'cer', 'ely', ',', 'Ġ', 'Ċ', 'US', 'ER', '<|endofturn|>', 'Ċ']
[False, False, False, False, False, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True,

In [14]:
ignore_index=-100

def apply_mask(tokens,assistant_mask):
    return [
        t if assistant_mask[i] else ignore_index
        for i,t in enumerate(tokens)
    ]

In [15]:
apply_mask(y[0][0], y[0][1])

[-100,
 -100,
 -100,
 -100,
 5,
 45,
 3768,
 8271,
 5469,
 22,
 19944,
 20,
 1860,
 176,
 3319,
 263,
 1541,
 6358,
 20,
 2520,
 348,
 1519,
 22,
 274,
 5703,
 3050,
 326,
 4543,
 239,
 11281,
 2370,
 744,
 219,
 8989,
 4044,
 22,
 792,
 613,
 13671,
 5333,
 219,
 3212,
 20,
 239,
 5030,
 641,
 358,
 989,
 929,
 39,
 274,
 873,
 11938,
 326,
 219,
 1330,
 8989,
 6788,
 661,
 943,
 602,
 2470,
 22,
 274,
 873,
 10473,
 316,
 2101,
 236,
 5030,
 219,
 742,
 3212,
 358,
 237,
 1359,
 22,
 21575,
 613,
 280,
 895,
 4531,
 248,
 602,
 5397,
 22,
 1860,
 176,
 59,
 216,
 1157,
 592,
 20,
 178,
 176,
 5260,
 6358,
 1,
 -100,
 -100]