In [1]:
from datasets import load_from_disk
from datasets.formatting.formatting import LazyBatch

from sft_datasets import get_sft_ds
from special_tokens import special_tokens

In [2]:


test_ds = get_sft_ds([
    "ultra_test",
    "lmsys_test",
    "alpaca_test",
    "robots_test",
])
train_ds = get_sft_ds([
    "ultra_train",
    "lmsys_train",
    "alpaca_train",
    "robots_train",
])

batch_size = 2000
processes = 8

In [3]:
from tokenizers.tokenizers import Tokenizer

tokenizer = Tokenizer.from_file("tokenizer.json")
assistant_id = tokenizer.token_to_id(special_tokens["assistant"])
eot_id = tokenizer.token_to_id(special_tokens["end_of_turn"])
pad_id = tokenizer.token_to_id(special_tokens["pad"])

In [4]:
def inspect_dataset(name):
    ds = load_from_disk(name).take(1)
    row = next(iter(ds))
    token_ids = row["tokens"]
    assistant_mask = row["assistant_mask"]
    tokens = list(map(tokenizer.id_to_token, token_ids))
    for i, t in enumerate(tokens):
        mask = assistant_mask[i]
        print(mask, t)

In [5]:
def enrich_chat(batch: LazyBatch):
    assistant_masks = []
    chats = batch["tokens"]
    for b, token_ids in enumerate(chats):
        assistant_mask = [False] * len(token_ids)
        inside_assistant = False
        for i, t in enumerate(token_ids):
            if t == assistant_id:
                inside_assistant = True
                assistant_mask[i] = True
                continue
            if t == eot_id:
                assistant_mask[i] = inside_assistant
                inside_assistant = False
                continue
            assistant_mask[i] = inside_assistant
        assistant_masks.append(assistant_mask)
    return {
        "assistant_mask": assistant_masks
    }

In [6]:
output = "tokenized_data/test_chats"
(
    test_ds
    .map(enrich_chat, batched=True, batch_size=batch_size, num_proc=processes)
    .save_to_disk(output)
)

Map (num_proc=8):   0%|          | 0/21523 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/21523 [00:00<?, ? examples/s]

In [7]:
inspect_dataset(output)

False <|bos|>
False <|user|>
False R
False ear
False range
False Ġthe
False Ġfollowing
False Ġsentence
False Ġto
False Ġmake
False Ġthe
False Ġsentence
False Ġmore
False Ġinteresting
False .
False Ċ
False Ċ
False She
False Ġleft
False Ġthe
False Ġparty
False Ġearly
False <|endofturn|>
False Ċ
True <|assistant|>
True Early
True ,
True Ġshe
True Ġleft
True Ġthe
True Ġparty
True .
True <|endofturn|>
False Ċ
False <|eos|>


In [8]:
output = "tokenized_data/train_chats"
(
    train_ds
    .map(enrich_chat, batched=True, batch_size=batch_size, num_proc=processes)
    .save_to_disk(output)
)
inspect_dataset(output)

Map (num_proc=8):   0%|          | 0/210499 [00:00<?, ? examples/s]

Saving the dataset (0/2 shards):   0%|          | 0/210499 [00:00<?, ? examples/s]

False <|bos|>
False <|user|>
False These
False Ġinstructions
False Ġapply
False Ġto
False Ġsection
False -
False based
False Ġthemes
False Ġ(
False R
False esp
False ons
False ive
False Ġ6
False .
False 0
False +
False ,
False ĠRet
False ina
False Ġ4
False .
False 0
False +
False ,
False ĠPar
False allax
False Ġ3
False .
False 0
False +
False ĠTurb
False o
False Ġ2
False .
False 0
False +
False ,
False ĠMob
False ilia
False Ġ5
False .
False 0
False +
False ).
False ĠWhat
False Ġtheme
False Ġversion
False Ġam
False ĠI
False Ġusing
False ?
False Ċ
False On
False Ġyour
False ĠColle
False ctions
False Ġpages
False Ġ&
False ĠFeature
False d
False ĠColle
False ctions
False Ġsections
False ,
False Ġyou
False Ġcan
False Ġeasily
False Ġshow
False Ġthe
False Ġsecondary
False Ġimage
False Ġof
False Ġa
False Ġproduct
False Ġon
False Ġh
False over
False Ġby
False Ġenabling
False Ġone
False Ġof
False Ġthe
False Ġtheme
False 's
False Ġbuilt
False -
False in
False Ġsettings
False !
False Ċ
False Y
Fal

In [9]:
ds_name = "tokenized_data/test_chats"
ds = load_from_disk(ds_name).shuffle(seed=42)
batch_size = 2
context_length = 100
iterator = iter(ds)


def prepare_chat(chat, target_size, pad_element):
    token_ids = chat["tokens"]
    assistant_mask = chat["assistant_mask"]
    length = len(token_ids)
    # truncate to target size
    if length > target_size:
        trim = length - target_size
        return token_ids[trim:], assistant_mask[trim:]
    # pad to target size
    if length < target_size:
        padding = target_size - length
        return token_ids + [pad_element] * padding, assistant_mask + [False] * padding
    # unchanged
    return token_ids, assistant_mask


def get_batch():
    batch = [next(iterator) for _ in range(batch_size)]
    longest_chat = max([len(row["tokens"]) for row in batch])
    target_length = min(longest_chat, context_length + 1)
    chats = [prepare_chat(chat, target_length, pad_id) for chat in batch]
    x = [(tokens[0:target_length - 1], mask[0:target_length - 1]) for tokens, mask in chats]
    y = [(tokens[1:target_length], mask[1:target_length]) for tokens, mask in chats]
    return x, y

In [10]:
print([1, 2, 3, 4, 5][0:4])
print([1, 2, 3, 4, 5][2:])

[1, 2, 3, 4]
[3, 4, 5]


In [11]:
prepare_chat({"tokens": [1, 2, 3, 4], "assistant_mask": [True, True, False, False]}, 6, -1)

([1, 2, 3, 4, -1, -1], [True, True, False, False, False, False])

In [12]:
prepare_chat({"tokens": [1, 2, 3, 4], "assistant_mask": [True, True, False, False]}, 2, -1)

([3, 4], [False, False])

In [13]:
x, y = get_batch()
for group in [x, y]:
    chat = group[0]
    tokens = chat[0]
    mask = chat[1]
    print(len(tokens), len(mask))
    print(list(map(tokenizer.id_to_token, tokens)))
    print(mask)

100 100
['Ġlead', 'Ġto', 'Ġfinancial', 'Ġstrain', '.', 'Ċ', 'Ċ', '3', '.', 'Ġ', '_', 'Te', 'chn', 'ical', 'Ġissues', ':', '_', 'ĠWhen', 'Ġadopting', 'Ġnew', 'Ġtechnology', ',', 'Ġcompanies', 'Ġmay', 'Ġalso', 'Ġface', 'Ġvarious', 'Ġtechnical', 'Ġproblems', '.', 'ĠFor', 'Ġexample', ',', 'Ġthe', 'Ġnew', 'Ġtechnology', 'Ġmay', 'Ġnot', 'Ġbe', 'Ġcompatible', 'Ġwith', 'Ġthe', 'Ġexisting', 'Ġinfrastructure', 'Ġor', 'Ġmay', 'Ġnot', 'Ġwork', 'Ġas', 'Ġintended', '.', 'ĠThere', 'Ġcan', 'Ġalso', 'Ġbe', 'Ġdisrupt', 'ions', 'Ġduring', 'Ġimplementation', 'Ġand', 'Ġemployee', 'Ġproductivity', 'Ġmay', 'Ġsuffer', 'Ġwhile', 'Ġthe', 'Ġnew', 'Ġsystem', 'Ġis', 'Ġbeing', 'Ġput', 'Ġin', 'Ġplace', '.', 'ĠThis', 'Ġcan', 'Ġlead', 'Ġto', 'Ġwas', 'ted', 'Ġtime', 'Ġand', 'Ġresources', ',', 'Ġand', 'Ġcan', 'Ġbe', 'Ġfr', 'ust', 'rating', 'Ġfor', 'Ġboth', 'Ġthe', 'Ġemployees', 'Ġand', 'Ġthe', 'Ġcompany', '.', '<|endofturn|>', 'Ċ']
[True, True, True, True, True, True, True, True, True, True, True, True, True, True, True

In [14]:
ignore_index = -100


def apply_mask(tokens, assistant_mask):
    return [
        t if assistant_mask[i] else ignore_index
        for i, t in enumerate(tokens)
    ]

In [15]:
apply_mask(y[0][0], y[0][1])

[248,
 2756,
 16645,
 22,
 176,
 176,
 27,
 22,
 178,
 71,
 9420,
 1029,
 387,
 3171,
 34,
 71,
 2378,
 22796,
 742,
 2979,
 20,
 3327,
 881,
 444,
 3923,
 1451,
 5353,
 3151,
 22,
 1112,
 1599,
 20,
 219,
 742,
 2979,
 881,
 522,
 309,
 16162,
 326,
 219,
 6175,
 8141,
 336,
 881,
 522,
 595,
 294,
 4657,
 22,
 2043,
 452,
 444,
 309,
 19459,
 477,
 750,
 10788,
 239,
 12449,
 12885,
 881,
 21326,
 1106,
 219,
 742,
 840,
 263,
 989,
 2607,
 237,
 1596,
 22,
 801,
 452,
 934,
 248,
 273,
 793,
 687,
 239,
 3658,
 20,
 239,
 452,
 309,
 5555,
 434,
 9136,
 280,
 1021,
 219,
 6283,
 239,
 219,
 1255,
 22,
 1,
 -100,
 -100]