In [1]:
from datasets import load_from_disk
from datasets.formatting.formatting import LazyBatch

from special_tokens import special_tokens

In [2]:


# test_ds = get_sft_ds([
#     "ultra_test",
#     "lmsys_test",
#     "alpaca_test",
#     "robots_test",
# ])
# train_ds = get_sft_ds([
#     "ultra_train",
#     "lmsys_train",
#     "alpaca_train",
#     "robots_train",
# ])

# test_ds = get_sft_ds([
#     "alpaca_test",
#     "robots_test",
# ])
# train_ds = get_sft_ds([
#     "alpaca_train",
#     "robots_train",
# ])

test_ds = load_from_disk("tokenized_data/alpaca_test").take(2_000)
train_ds = load_from_disk("tokenized_data/alpaca_train").take(20_000)

batch_size = 2000
processes = 8

In [3]:
from tokenizers.tokenizers import Tokenizer

tokenizer = Tokenizer.from_file("tokenizer.json")
assistant_id = tokenizer.token_to_id(special_tokens["assistant"])
eot_id = tokenizer.token_to_id(special_tokens["end_of_turn"])
pad_id = tokenizer.token_to_id(special_tokens["pad"])

In [4]:
def inspect_dataset(name):
    ds = load_from_disk(name).take(1)
    row = next(iter(ds))
    token_ids = row["tokens"]
    assistant_mask = row["assistant_mask"]
    tokens = list(map(tokenizer.id_to_token, token_ids))
    for i, t in enumerate(tokens):
        mask = assistant_mask[i]
        print(mask, t)

In [5]:
def enrich_chat(batch: LazyBatch):
    assistant_masks = []
    chats = batch["tokens"]
    for b, token_ids in enumerate(chats):
        assistant_mask = [False] * len(token_ids)
        inside_assistant = False
        for i, t in enumerate(token_ids):
            if t == assistant_id:
                inside_assistant = True
                assistant_mask[i] = True
                continue
            if t == eot_id:
                assistant_mask[i] = inside_assistant
                inside_assistant = False
                continue
            assistant_mask[i] = inside_assistant
        assistant_masks.append(assistant_mask)
    return {
        "assistant_mask": assistant_masks
    }

In [6]:
output = "tokenized_data/test_chats"
(
    test_ds
    .map(enrich_chat, batched=True, batch_size=batch_size, num_proc=processes)
    .save_to_disk(output)
)

Map (num_proc=8):   0%|          | 0/2000 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/2000 [00:00<?, ? examples/s]

In [7]:
inspect_dataset(output)

False <|bos|>
False <|user|>
False R
False ear
False range
False Ġthe
False Ġfollowing
False Ġsentence
False Ġto
False Ġmake
False Ġthe
False Ġsentence
False Ġmore
False Ġinteresting
False .
False Ċ
False Ċ
False She
False Ġleft
False Ġthe
False Ġparty
False Ġearly
False <|endofturn|>
False Ċ
True <|assistant|>
True Early
True ,
True Ġshe
True Ġleft
True Ġthe
True Ġparty
True .
True <|endofturn|>
False Ċ
False <|eos|>


In [8]:
output = "tokenized_data/train_chats"
(
    train_ds
    .map(enrich_chat, batched=True, batch_size=batch_size, num_proc=processes)
    .save_to_disk(output)
)
inspect_dataset(output)

Map (num_proc=8):   0%|          | 0/20000 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/20000 [00:00<?, ? examples/s]

False <|bos|>
False <|user|>
False An
False aly
False ze
False Ġthe
False Ġuser
False 's
False Ġwriting
False Ġstyle
False Ġin
False Ġthe
False Ġgiven
False Ġtext
False Ġsn
False ipp
False et
False Ġand
False Ġprovide
False Ġ3
False Ġsugg
False estions
False Ġto
False Ġimprove
False Ġit
False .
False Ċ
False Ċ
False We
False ather
False Ġnow
False Ġbad
False Ġoutside
False Ġalso
False Ġit
False Ġis
False Ġcold
False .
False ĠPeople
False Ġno
False Ġhave
False Ġfun
False Ġtime
False ,
False Ġalways
False Ġg
False rum
False ble
False Ġbecause
False Ġof
False Ġcold
False Ġweather
False .
False ĠUse
False Ġwarm
False Ġclothes
False Ġprotect
False Ġfrom
False Ġfree
False zing
False .
False <|endofturn|>
False Ċ
True <|assistant|>
True 1
True .
True ĠImpro
True ve
True Ġgrammar
True Ġand
True Ġsentence
True Ġstructure
True :
True ĠThe
True Ġsentences
True Ġare
True Ġfrag
True mented
True Ġand
True Ġlack
True Ġproper
True Ġgrammar
True .
True ĠConsider
True Ġre
True writing
True Ġthem
True Ġa

In [9]:
ds_name = "tokenized_data/test_chats"
ds = load_from_disk(ds_name).shuffle(seed=42)
batch_size = 2
context_length = 100
iterator = iter(ds)


def prepare_chat(chat, target_size, pad_element):
    token_ids = chat["tokens"]
    assistant_mask = chat["assistant_mask"]
    length = len(token_ids)
    # truncate to target size
    if length > target_size:
        trim = length - target_size
        return token_ids[trim:], assistant_mask[trim:]
    # pad to target size
    if length < target_size:
        padding = target_size - length
        return token_ids + [pad_element] * padding, assistant_mask + [False] * padding
    # unchanged
    return token_ids, assistant_mask


def get_batch():
    batch = [next(iterator) for _ in range(batch_size)]
    longest_chat = max([len(row["tokens"]) for row in batch])
    target_length = min(longest_chat, context_length + 1)
    chats = [prepare_chat(chat, target_length, pad_id) for chat in batch]
    x = [(tokens[0:target_length - 1], mask[0:target_length - 1]) for tokens, mask in chats]
    y = [(tokens[1:target_length], mask[1:target_length]) for tokens, mask in chats]
    return x, y

In [10]:
print([1, 2, 3, 4, 5][0:4])
print([1, 2, 3, 4, 5][2:])

[1, 2, 3, 4]
[3, 4, 5]


In [11]:
prepare_chat({"tokens": [1, 2, 3, 4], "assistant_mask": [True, True, False, False]}, 6, -1)

([1, 2, 3, 4, -1, -1], [True, True, False, False, False, False])

In [12]:
prepare_chat({"tokens": [1, 2, 3, 4], "assistant_mask": [True, True, False, False]}, 2, -1)

([3, 4], [False, False])

In [13]:
x, y = get_batch()
for group in [x, y]:
    chat = group[0]
    tokens = chat[0]
    mask = chat[1]
    print(len(tokens), len(mask))
    print(list(map(tokenizer.id_to_token, tokens)))
    print(mask)

88 88
['<|bos|>', '<|user|>', 'Gen', 'erate', 'Ġa', 'Ġlist', 'Ġof', 'Ġwords', 'Ġthat', 'Ġrh', 'yme', 'Ġwith', "Ġ'", 'cat', "'.", '<|endofturn|>', 'Ċ', '<|assistant|>', 'Here', 'Ġis', 'Ġa', 'Ġlist', 'Ġof', 'Ġwords', 'Ġthat', 'Ġrh', 'yme', 'Ġwith', "Ġ'", 'cat', "'", ':', 'Ċ', '-', 'Ġbat', 'Ċ', '-', 'Ġrat', 'Ċ', '-', 'Ġhat', 'Ċ', '-', 'Ġmat', 'Ċ', '-', 'Ġpat', 'Ċ', '-', 'Ġsat', 'Ċ', '-', 'Ġg', 'n', 'at', 'Ċ', '-', 'Ġthat', 'Ċ', '-', 'Ġchat', 'Ċ', '-', 'Ġflat', 'Ċ', '-', 'Ġsp', 'at', 'Ċ', '-', 'Ġv', 'at', 'Ċ', '-', 'Ġfat', 'Ċ', '-', 'Ġt', 'at', 'Ċ', '-', 'Ġb', 'rat', 'Ċ', '-', 'Ġat', '<|endofturn|>', 'Ċ']
[False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, T

In [14]:
ignore_index = -100


def apply_mask(tokens, assistant_mask):
    return [
        t if assistant_mask[i] else ignore_index
        for i, t in enumerate(tokens)
    ]

In [15]:
apply_mask(y[0][0], y[0][1])

[-100,
 -100,
 -100,
 -100,
 -100,
 -100,
 -100,
 -100,
 -100,
 -100,
 -100,
 -100,
 -100,
 -100,
 -100,
 -100,
 5,
 14030,
 263,
 214,
 1644,
 236,
 3874,
 349,
 9283,
 14019,
 326,
 1682,
 9840,
 15,
 34,
 176,
 21,
 5818,
 176,
 21,
 4277,
 176,
 21,
 8661,
 176,
 21,
 1235,
 176,
 21,
 1549,
 176,
 21,
 2699,
 176,
 21,
 295,
 86,
 224,
 176,
 21,
 349,
 176,
 21,
 19034,
 176,
 21,
 6329,
 176,
 21,
 604,
 224,
 176,
 21,
 356,
 224,
 176,
 21,
 6009,
 176,
 21,
 213,
 224,
 176,
 21,
 242,
 1236,
 176,
 21,
 367,
 1,
 -100,
 -100]