In [58]:
import os
from typing import Iterator

from datasets import load_dataset
from huggingface_hub import login

hf_token = os.getenv("HF_TOKEN")

login(hf_token)


def tokenizer_lite_dataset():
    total_length = 0

    # dataset
    no_robots_ds = load_dataset("HuggingFaceH4/no_robots", streaming=True, split="test").select_columns(
        ["messages"])
    total_length += no_robots_ds.dataset_size

    # iterator
    def iterator() -> Iterator[str]:
        for row in no_robots_ds:
            for msg in row["messages"]:
                yield msg["content"]

    return iterator, total_length


def tokenizer_real_dataset():
    total_length = 0

    # no robots
    no_robots_ds = load_dataset("HuggingFaceH4/no_robots", streaming=True, split="test").select_columns(["messages"])
    total_length += no_robots_ds.dataset_size

    # wiki
    wiki_ds = load_dataset("rahular/simple-wikipedia", streaming=True, split="train").select_columns(["text"])
    reduced_count = int(wiki_ds.dataset_size / 100)
    wiki_ds = wiki_ds.take(reduced_count)
    total_length += reduced_count

    # tiny stories
    tiny_stories_ds = load_dataset("roneneldan/TinyStories", streaming=True, split="validation").select_columns(
        ["text"])
    total_length += tiny_stories_ds.dataset_size

    # tiny textbooks
    tiny_textbooks_ds = load_dataset("nampdn-ai/tiny-textbooks", streaming=True, split="test").select_columns(
        ["textbook"])
    total_length += tiny_textbooks_ds.dataset_size

    # iterator
    def iterator() -> Iterator[str]:
        for row in no_robots_ds:
            for msg in row["messages"]:
                yield msg["content"]
        print("no robots completed")
        for row in wiki_ds:
            yield row["text"]
        print("wiki completed")
        for row in tiny_stories_ds:
            yield row["text"]
        print("tiny stories completed")
        for row in tiny_textbooks_ds:
            yield row["textbook"]
        print("tiny textbooks completed")

    return iterator, total_length

In [59]:
lite_dataset = False
dataset_iter, dataset_length = tokenizer_lite_dataset() if lite_dataset else tokenizer_real_dataset()

print(dataset_length)
example_message = next(dataset_iter())
print(example_message)

3798585513
Aster is a chatbot who answers questions with rhymes.


In [60]:
from tokenizers import Tokenizer, models, trainers, pre_tokenizers

tokenizer = Tokenizer(models.BPE(unk_token="<|unk|>"))

In [61]:
from tokenizers.normalizers import NFD, StripAccents, Lowercase, Sequence

tokenizer.normalizer = Sequence([
    NFD(),
    StripAccents(),
    Lowercase()
])
tokenizer.normalizer.normalize_str("Héllò hôw are ü?")

'hello how are u?'

In [62]:
tokenizer.pre_tokenizer = pre_tokenizers.ByteLevel(add_prefix_space=False)
tokenizer.pre_tokenizer.pre_tokenize_str(example_message)

[('Aster', (0, 5)),
 ('Ġis', (5, 8)),
 ('Ġa', (8, 10)),
 ('Ġchatbot', (10, 18)),
 ('Ġwho', (18, 22)),
 ('Ġanswers', (22, 30)),
 ('Ġquestions', (30, 40)),
 ('Ġwith', (40, 45)),
 ('Ġrhymes', (45, 52)),
 ('.', (52, 53))]

In [63]:
special_tokens: dict[str, str] = {
    "end_of_text": "<|endoftext|>",
    "end_of_turn": "<|endofturn|>",
    "bos": "<|bos|>",
    "eos": "<|eos|>",
    "user": "<|user|>",
    "assistant": "<|assistant|>",
    "system": "<|system|>",
    "pad": "<|pad|>",
    "unk": "<|unk|>"
}

special_token_list = list(special_tokens.values())
tokenizer.add_tokens(special_token_list)
tokenizer.get_vocab()

{'<|assistant|>': 5,
 '<|bos|>': 2,
 '<|eos|>': 3,
 '<|user|>': 4,
 '<|endofturn|>': 1,
 '<|system|>': 6,
 '<|unk|>': 8,
 '<|endoftext|>': 0,
 '<|pad|>': 7}

In [64]:
trainer = trainers.BpeTrainer(
    vocab_size=8000,
    special_tokens=special_token_list,
    show_progress=True,
    min_frequency=2,
)

In [65]:
tokenizer.train_from_iterator(dataset_iter(), trainer, length=dataset_length)
file_name = "./tokenizer.json"

no robots completed
wiki completed
tiny stories completed
tiny textbooks completed





In [66]:
from tokenizers import decoders, processors

tokenizer.post_processor = processors.ByteLevel(trim_offsets=False)
tokenizer.decoder = decoders.ByteLevel()

In [67]:
tokenizer.post_processor = processors.TemplateProcessing(
    single="<|bos|> $A <|eos|>",  # adds BOS before & EOS after each sequence
    pair="<|bos|> $A <|eos|> <|bos|> $B <|eos|>",  # for pairs (less common)
    special_tokens=[
        ("<|bos|>", tokenizer.token_to_id("<|bos|>")),
        ("<|eos|>", tokenizer.token_to_id("<|eos|>")),
    ],
)

In [68]:
tokenizer.save(file_name)

In [69]:
print("<|user|>" in tokenizer.get_vocab())
print("<|" in tokenizer.get_vocab())

True
False


In [70]:
tokenizer = Tokenizer.from_file(file_name)

encoding = tokenizer.encode(example_message)
print(encoding.tokens)
decoding = tokenizer.decode(encoding.ids)
print(decoding)

['<|bos|>', 'aster', 'Ġis', 'Ġa', 'Ġch', 'at', 'b', 'ot', 'Ġwho', 'Ġanswers', 'Ġquestions', 'Ġwith', 'Ġrhy', 'mes', '.', '<|eos|>']
aster is a chatbot who answers questions with rhymes.
