In [9]:
from datasets import load_dataset, Dataset, load_from_disk
from datasets.formatting.formatting import LazyBatch

from assistant_mask import assistant_mask
from chat_template import encode_chat
from load_pre_trained import model, tokenizer

In [10]:
batch_size = 5_000
processes = 8
context_length = model.config.n_positions

In [11]:
def tokenize_and_enrich_messages(batch: LazyBatch):
    messages = zip(batch["input"], batch["output"], batch["instruction"])
    tokenized = [
        encode_chat(instruction=instruction, output=output, input=input)
        for input, output, instruction in messages
    ]
    masks = list(map(assistant_mask, tokenized))
    return {"tokens": tokenized, "assistant_mask": masks}


def tokenize_and_prepare_dataset(ds: Dataset):
    ds = ds.map(
        tokenize_and_enrich_messages,
        batched=True,
        batch_size=batch_size,
        num_proc=processes
    )
    ds = ds.select_columns(["tokens", "assistant_mask"])
    ds = ds.filter(
        lambda example: len(example["tokens"]) <= context_length,
        num_proc=processes,
    )
    return ds

In [12]:
ds = load_dataset("tatsu-lab/alpaca", split="train").select_columns(["input", "output", "instruction"])

In [13]:
ds = tokenize_and_prepare_dataset(ds)

In [14]:
splits = ds.train_test_split(
    test_size=0.08,
    shuffle=True,
    seed=42
)
ds_test = splits["test"].save_to_disk("tokenized_data/test")
ds_train = splits["train"].save_to_disk("tokenized_data/train")

Saving the dataset (0/1 shards):   0%|          | 0/4161 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/47840 [00:00<?, ? examples/s]

In [15]:
test = load_from_disk("tokenized_data/test")
train = load_from_disk("tokenized_data/train")

In [16]:
print(test.num_rows)
print(train.num_rows)

4161
47840


In [17]:
row = next(iter(test))
token_ids = row["tokens"]
mask = row["assistant_mask"]
mask_id = tokenizer.convert_tokens_to_ids("*")
masked = [token_id if token_mask else mask_id for token_id, token_mask in zip(token_ids, mask)]
tokenizer.decode(token_ids)

'### Instruction:\nMerge these two sentences.\n\n### Input:\nThe cat is playing. The dog is sleeping.\n\n### Response:\nThe cat is playing while the dog is sleeping.<|endoftext|>'

In [18]:
tokenizer.decode(masked)

'********************************The cat is playing while the dog is sleeping.<|endoftext|>'