In [None]:
import os
from pathlib import Path
import regex as re
import json
from tqdm import tqdm
from concurrent.futures import ProcessPoolExecutor, as_completed
from typing import Tuple, List, Iterable, BinaryIO
from collections import Counter
from bpe_tokenizer import BytePairEncodingTokenizer
import torch.nn as nn

In [2]:
data_path = "/scratch/shayan/Projects/LLMfromScratch/data/TinyStoriesV2-GPT4-train.txt"

with open(data_path, "r") as f:
    for i, line in enumerate(f):
        if i < 1000:
            continue
        if i >= 1050:
            break
        print(f"Line {i+1}: {line.strip()}")

Line 1001: One day, a big dog named Max saw a small cat named Lily on top of a tree. Lily was angry because she could not get down. Max wanted to help Lily, so he thought of a plan.
Line 1002: Max said, "Lily, I will join you on top of the tree and help you get down." Max climbed up the tree and slowly got closer to Lily. Lily was scared at first, but Max was kind and gentle.
Line 1003: Max said, "Hold on to me, Lily. I will take you down." Lily held on tight to Max, and they went down the tree together. Lily was happy and thanked Max for helping her. From that day on, Max and Lily became the best of friends.
Line 1004: <|endoftext|>
Line 1005: Once upon a time, there was a white shark. The white shark lived in the big sea. One day, the white shark saw a little boat. The little boat had a hole in it. The white shark wanted to help.
Line 1006: The white shark swam to the boat. The white shark said, "I can fix your boat." The man in the boat was scared. The man said, "No, go away!" The w

In [13]:
# loading the data
with open(data_path, "r") as f:
    data = f.read()

len(data)

2226845268

In [14]:
# pre-tokenize the data regex-based GPT-2 style 
from tqdm import tqdm

PAT = r"""'(?:[sdmt]|ll|ve|re)| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+"""
TOKEN_BYTES = b"<|endoftext|>"

# chunk_size = 1000000
# tokens = []

# for i in tqdm(range(0, len(data), chunk_size), desc="pre-tokenizing the vocabulary"):
#     chunk = data[i:i+chunk_size]
#     tokens.extend(re.findall(PAT, chunk))

In [None]:
def save_bpe(vocab, merges, output_dir, data_name):
    output_dir = Path(output_dir)
    output_dir.mkdir(parents=True, exist_ok=True)

    if isinstance(vocab, dict):
        vocab_json = vocab
    else:
        raise TypeError("Vocabulary must be a dict of token->id")
    
    with (output_dir/f"{data_name}_vocab.json").open("w", encoding="utf-8") as f:
        json.dump(vocab_json, f, ensure_ascii=False, indent=2)

    merges_path = output_dir / f"{data_name}_merges.txt"
    with merges_path.open("w", encoding="utf-8") as f:
        for a, b in merges:
            f.write(f"{a} {b}\n")

def train_bpe_tinystories(data_path, vocab_size=10000, special_tokens=["<|endoftext|>"], out_dir="tokenizer"):
    bpe = BytePairEncodingTokenizer(data_path)
    vocabulary, merges = bpe.train_bpe(data_path, vocab_size=vocab_size, special_tokens=["<|endoftext|>"])
    save_bpe(vocabulary, merges, out_dir)

In [10]:
train_bpe_tinystories(data_path, vocab_size=10000)

Vocabulary Length: 257


tokenizing chunks: 100%|██████████| 24/24 [00:26<00:00,  1.08s/it]
Training BPE...: 100%|██████████| 9743/9743 [34:15<00:00,  4.74it/s, last merge: 10 chars]


In [3]:
with open("tokenizer/vocab.json", "r") as f:
    vocab = json.load(f)
with open("tokenizer/merges.txt", "r") as f:
    merges = f.read()

tokenizer = BytePairEncodingTokenizer.from_files(vocab_path="tokenizer/vocab.json", merges_path="tokenizer/merges.txt")

text = "This is a test for an interesting implementation of a BPE tokenizer. it was very exciting to learn all the detials"

ids = tokenizer.encode(text)
print(ids)


[1531, 431, 259, 2569, 387, 420, 2330, 1003, 2020, 377, 1553, 370, 259, 374, 81, 70, 266, 1343, 940, 282, 47, 309, 283, 378, 2929, 266, 613, 432, 263, 7278, 844, 116]


In [4]:
tokenizer = BytePairEncodingTokenizer.from_files(vocab_path="tokenizer/vocab.json", merges_path="tokenizer/merges.txt")

roundtrip = tokenizer.decode(tokenizer.encode("hello world!"))
assert roundtrip == "hello world!"