In [1]:
import os
import regex as re
from tqdm import tqdm
from concurrent.futures import ProcessPoolExecutor, as_completed
from typing import Tuple, List, Iterable, BinaryIO
from collections import Counter

In [2]:
data_path = "/scratch/shayan/Projects/LLMfromScratch/data/TinyStoriesV2-GPT4-valid.txt"

with open(data_path, "r") as f:
    for i, line in enumerate(f):
        if i < 1000:
            continue
        if i >= 1050:
            break
        print(f"Line {i+1}: {line.strip()}")

Line 1001: From then on, Tim always observed his surroundings and found many more treasures. He learned that being alert can lead to finding special things.
Line 1002: <|endoftext|>
Line 1003: Once upon a time, there was a little boat. The boat liked to go to the shore. One day, the boat saw a big load. The load was heavy and uncomfortable.
Line 1004: The boat wanted to help. So, the boat took the load to the shore. The load made the boat very uncomfortable. The boat felt slow and tired.
Line 1005: In the end, the boat could not carry the load anymore. The boat stopped moving and stayed on the shore. The boat was sad and uncomfortable forever.
Line 1006: <|endoftext|>
Line 1007: Once upon a time, there was a brave cow named Bessie. Bessie loved to skip and play in the big green field. She had many friends who liked to play with her. They would skip, run, and jump all day long.
Line 1008: One day, Bessie saw a big truck come to the farm. The truck was taking the cows to a new place. Bes

In [3]:
# loading the data
with open(data_path, "r") as f:
    data = f.read()

len(data)

22493387

In [4]:
# pre-tokenize the data regex-based GPT-2 style 
from tqdm import tqdm

PAT = r"""'(?:[sdmt]|ll|ve|re)| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+"""
TOKEN_BYTES = b"<|endoftext|>"

chunk_size = 1000000
tokens = []

for i in tqdm(range(0, len(data), chunk_size), desc="pre-tokenizing the vocabulary"):
    chunk = data[i:i+chunk_size]
    tokens.extend(re.findall(PAT, chunk))

pre-tokenizing the vocabulary:   0%|          | 0/23 [00:00<?, ?it/s]

pre-tokenizing the vocabulary: 100%|██████████| 23/23 [00:01<00:00, 11.89it/s]


In [5]:
import re as pyre

def find_chunk_boundaries(
    file: BinaryIO,
    desired_num_chunks: int,
    split_special_token: bytes,
) -> list[int]:
    """
    Chunk the file into parts that can be counted independently.
    May return fewer chunks if the boundaries end up overlapping.
    """
    assert isinstance(split_special_token, bytes), "Must represent special token as a bytestring"

    # Get total file size in bytes
    file.seek(0, os.SEEK_END)
    file_size = file.tell()
    file.seek(0)

    chunk_size = file_size // desired_num_chunks

    # Initial guesses for chunk boundary locations, uniformly spaced
    # Chunks start on previous index, don't include last index
    chunk_boundaries = [i * chunk_size for i in range(desired_num_chunks + 1)]
    chunk_boundaries[-1] = file_size

    mini_chunk_size = 4096  # Read ahead by 4k bytes at a time

    for bi in range(1, len(chunk_boundaries) - 1):
        initial_position = chunk_boundaries[bi]
        file.seek(initial_position)  # Start at boundary guess
        while True:
            mini_chunk = file.read(mini_chunk_size)  # Read a mini chunk

            # If EOF, this boundary should be at the end of the file
            if mini_chunk == b"":
                chunk_boundaries[bi] = file_size
                break

            # Find the special token in the mini chunk
            found_at = mini_chunk.find(split_special_token)
            if found_at != -1:
                chunk_boundaries[bi] = initial_position + found_at
                break
            initial_position += mini_chunk_size

    # Make sure all boundaries are unique, but might be fewer than desired_num_chunks
    return sorted(set(chunk_boundaries))

_COMP = None
special_tokens = ["<|endoftext|>"]
SEP_PAT = pyre.compile("|".join(map(pyre.escape, sorted(special_tokens, key=len, reverse=True))))

def _tokenize_slice(args: Tuple[str, int, int, str]) -> List[str]:
    """
    open file at a path, read bytes, decode, regex-tokenize, return tokens.
    """
    global _COMP
    path, start, end, pattern = args
    if _COMP is None:
        _COMP = re.compile(pattern)

    with open(path, "rb") as f:
        f.seek(start)
        chunk = f.read(end-start).decode("utf-8", errors="ignore")

    counts = Counter()
    for doc in (d for d in SEP_PAT.split(chunk) if d and d not in special_tokens):
        counts.update(_COMP.findall(doc))

    return counts


In [6]:
def parallelize_tokenize_file(data_path: str, desired_num_chunks: int = None, max_workers: int = None) -> List[str]:
    """
    splits the file on TOKEN_BYTES boundaries, then tokenizes chunks in parallel.
    Returns a single flat list of tokens. 
    """
    if max_workers is None:
        max_workers = max(1, (os.cpu_count() or 4) - 1)
    
    if desired_num_chunks is None:
        desired_num_chunks = max_workers * 3
    
    with open(data_path, "rb") as f:
        boundaries = find_chunk_boundaries(f, desired_num_chunks, TOKEN_BYTES)

    pairs = list(zip(boundaries[:-1], boundaries[1:]))
    tasks = ((data_path, s, e, PAT) for s, e in pairs)

    token_counts = Counter()

    with ProcessPoolExecutor(max_workers=max_workers) as ex:
        futures = [ex.submit(_tokenize_slice, t) for t in tasks]
        for fut in tqdm(as_completed(futures), total=len(futures), desc="tokenizing chunks"):
            token_counts.update(fut.result())

    return token_counts

In [7]:
token_counts = parallelize_tokenize_file(data_path, desired_num_chunks=24, max_workers=8)

total_tokens = sum(token_counts.values())
top20 = token_counts.most_common(20)

print(f"Total tokens: {total_tokens:,}")
print(top20[:5])

tokenizing chunks: 100%|██████████| 24/24 [00:00<00:00, 73.76it/s]

Total tokens: 5,419,001
[('.', 421616), (',', 235432), (' the', 211031), (' and', 196057), (' a', 152161)]





In [28]:
import unicodedata as ud

class BytePairEncodingTokenizer():
    def __init__(self, data_path):
        super().__init__()
        self.data_path = data_path
        self.merges = []
        self.merge_ranks = {}
        self.b2u, self.u2b = self._bytes_to_unicode()
        self.token_to_id = {}
        self.id_to_token = []

    def _bytes_to_unicode(self):
        # visible ranges (don’t collide with space or control chars)
        bs = list(range(ord('!'), ord('~')+1)) + \
            list(range(ord('¡'), ord('¬')+1)) + \
            list(range(ord('®'), ord('ÿ')+1))
        cs = bs[:]
        n = 0
        for b in range(256):
            if b not in bs:
                bs.append(b)
                cs.append(256 + n)  # map leftover bytes to safe code points
                n += 1
        b2u = {b: chr(c) for b, c in zip(bs, cs)}   # byte -> unicode char
        u2b = {v: k for k, v in b2u.items()}        # unicode char -> byte
        return b2u, u2b
    
    def initialize_vocabulary(self, special_tokens):
        self.token_to_id = {tok: i for i, tok in enumerate(special_tokens)}
        self.id_to_token = special_tokens[:]
        
        # base byte tokens (each is a single printable Unicode "byte-char")
        for b in range(256):
            tok = self.b2u[b]
            self.token_to_id[tok] = len(self.id_to_token)
            self.id_to_token.append(tok)

        self.merges = []

    def _add_merge(
            self,
            a: str, 
            b: str,
    ):
        """
        Add a new merge token a+b; update vocab; return (token, id)
        """
        new_token = a + b
        if new_token not in self.token_to_id:
            new_id = len(self.id_to_token)
            self.token_to_id[new_token] = new_id
            self.id_to_token.append(new_token)
            self.merges.append((a, b))
            self.merge_ranks[(a, b)] = len(self.merges) - 1
            return new_token, new_id
        
        # if merged token already present
        return new_token, self.token_to_id[new_token]

    @staticmethod
    def _find_all_pairs(s):
        return list(zip(s, s[1:]))

    
    def find_most_frequent_pair(self, token_counts: dict) -> list[tuple]:
        all_token_counts = Counter()
        for token, count in token_counts.items():
            if len(token) < 2:
                continue

            pairs = self._find_all_pairs(token)
            local_counts = Counter(pairs)

            all_token_counts.update({k: v * count for k, v in local_counts.items()})

        return all_token_counts.most_common(1)
    
    @staticmethod
    def _apply_merge_to_seq(seq, a, b, ab):
        out = []
        i = 0

        while i < len(seq):
            if i + 1 < len(seq) and seq[i] == a and seq[i+1] == b:
                out.append(ab)
                i += 2
            else:
                out.append(seq[i])
                i += 1
            
        return tuple(out)
        
    def train_bpe(
            self, input_path: str, vocab_size: int, special_tokens: list[str]
    ) -> tuple:
        # look at byte pairs
        
        
        # return merges: List[tuple[bytes, bytes]] a list of BPE merges
        self.initialize_vocabulary(special_tokens)
        print(f"Vocabulary Length: {len(self.token_to_id)}")
        token_counts = parallelize_tokenize_file(input_path, desired_num_chunks=24, max_workers=8)
        corpus = {
            tuple(self.b2u[b] for b in ud.normalize("NFC", s).encode("utf-8")): freq
            for s, freq in token_counts.items()
        }
        
        target_merges = max(0, vocab_size - len(self.token_to_id))
        with tqdm(total=target_merges, dynamic_ncols=True, desc="Training BPE...") as pbar:
            while len(self.token_to_id) < vocab_size:
                pair = self.find_most_frequent_pair(corpus)[0][0]
                if pair is None:
                    break # no more mergable pairs; stop early

                a, b = pair    
                ab, _ = self._add_merge(a, b)
                new_corpus = {}
                for seq, freq in corpus.items():
                    new_seq = self._apply_merge_to_seq(seq, a, b, ab)
                    new_corpus[new_seq] = new_corpus.get(new_seq, 0) + freq

                corpus = new_corpus

                pbar.update(1)
                pbar.set_postfix_str(f"last merge: {len(ab)} chars")
        
        return self.token_to_id, self.merges
        

In [29]:
import json
from pathlib import Path

def save_bpe(vocab, merges, output_dir):
    output_dir = Path(output_dir)
    output_dir.mkdir(parents=True, exist_ok=True)

    if isinstance(vocab, dict):
        vocab_json = vocab
    else:
        raise TypeError("Vocabulary must be a dict of token->id")
    
    with (output_dir/"vocab.json").open("w", encoding="utf-8") as f:
        json.dump(vocab_json, f, ensure_ascii=False, indent=2)

    merges_path = output_dir / "merges.txt"
    with merges_path.open("w", encoding="utf-8") as f:
        for a, b in merges:
            f.write(f"{a} {b}\n")

def train_bpe_tinystories(data_path, vocab_size=1000, special_tokens=["<|endoftext|>"], out_dir="tokenizer"):
    bpe = BytePairEncodingTokenizer(data_path)
    vocabulary, merges = bpe.train_bpe(data_path, vocab_size=vocab_size, special_tokens=["<|endoftext|>"])
    save_bpe(vocabulary, merges, out_dir)


In [31]:
train_bpe_tinystories(data_path)

Vocabulary Length: 257


tokenizing chunks: 100%|██████████| 24/24 [00:00<00:00, 74.87it/s]
Training BPE...: 100%|██████████| 743/743 [00:40<00:00, 18.35it/s, last merge: 10 chars]
