In [23]:
import os
import regex as re
from tqdm import tqdm
from concurrent.futures import ProcessPoolExecutor, as_completed
from typing import Tuple, List, Iterable, BinaryIO
from collections import Counter

In [2]:
data_path = "/scratch/shayan/Projects/LLMfromScratch/data/TinyStoriesV2-GPT4-train.txt"

with open(data_path, "r") as f:
    for i, line in enumerate(f):
        if i < 1000:
            continue
        if i >= 1050:
            break
        print(f"Line {i+1}: {line.strip()}")

Line 1001: One day, a big dog named Max saw a small cat named Lily on top of a tree. Lily was angry because she could not get down. Max wanted to help Lily, so he thought of a plan.
Line 1002: Max said, "Lily, I will join you on top of the tree and help you get down." Max climbed up the tree and slowly got closer to Lily. Lily was scared at first, but Max was kind and gentle.
Line 1003: Max said, "Hold on to me, Lily. I will take you down." Lily held on tight to Max, and they went down the tree together. Lily was happy and thanked Max for helping her. From that day on, Max and Lily became the best of friends.
Line 1004: <|endoftext|>
Line 1005: Once upon a time, there was a white shark. The white shark lived in the big sea. One day, the white shark saw a little boat. The little boat had a hole in it. The white shark wanted to help.
Line 1006: The white shark swam to the boat. The white shark said, "I can fix your boat." The man in the boat was scared. The man said, "No, go away!" The w

In [3]:
# loading the data
with open(data_path, "r") as f:
    data = f.read()

len(data)

2226845268

In [None]:
# pre-tokenize the data regex-based GPT-2 style 
from tqdm import tqdm

PAT = r"""'(?:[sdmt]|ll|ve|re)| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+"""
TOKEN_BYTES = b"<|endoftext|>"

chunk_size = 1000000
tokens = []

for i in tqdm(range(0, len(data), chunk_size), desc="pre-tokenizing the vocabulary"):
    chunk = data[i:i+chunk_size]
    tokens.extend(re.findall(PAT, chunk))



In [29]:
import re as pyre

def find_chunk_boundaries(
    file: BinaryIO,
    desired_num_chunks: int,
    split_special_token: bytes,
) -> list[int]:
    """
    Chunk the file into parts that can be counted independently.
    May return fewer chunks if the boundaries end up overlapping.
    """
    assert isinstance(split_special_token, bytes), "Must represent special token as a bytestring"

    # Get total file size in bytes
    file.seek(0, os.SEEK_END)
    file_size = file.tell()
    file.seek(0)

    chunk_size = file_size // desired_num_chunks

    # Initial guesses for chunk boundary locations, uniformly spaced
    # Chunks start on previous index, don't include last index
    chunk_boundaries = [i * chunk_size for i in range(desired_num_chunks + 1)]
    chunk_boundaries[-1] = file_size

    mini_chunk_size = 4096  # Read ahead by 4k bytes at a time

    for bi in range(1, len(chunk_boundaries) - 1):
        initial_position = chunk_boundaries[bi]
        file.seek(initial_position)  # Start at boundary guess
        while True:
            mini_chunk = file.read(mini_chunk_size)  # Read a mini chunk

            # If EOF, this boundary should be at the end of the file
            if mini_chunk == b"":
                chunk_boundaries[bi] = file_size
                break

            # Find the special token in the mini chunk
            found_at = mini_chunk.find(split_special_token)
            if found_at != -1:
                chunk_boundaries[bi] = initial_position + found_at
                break
            initial_position += mini_chunk_size

    # Make sure all boundaries are unique, but might be fewer than desired_num_chunks
    return sorted(set(chunk_boundaries))

_COMP = None
special_tokens = ["<|endoftext|>"]
SEP_PAT = pyre.compile("|".join(map(pyre.escape, sorted(special_tokens, key=len, reverse=True))))

def _tokenize_slice(args: Tuple[str, int, int, str]) -> List[str]:
    """
    open file at a path, read bytes, decode, regex-tokenize, return tokens.
    """
    global _COMP
    path, start, end, pattern = args
    if _COMP is None:
        _COMP = re.compile(pattern)

    with open(path, "rb") as f:
        f.seek(start)
        chunk = f.read(end-start).decode("utf-8", errors="ignore")

    counts = Counter()
    for doc in (d for d in SEP_PAT.split(chunk) if d and d not in special_tokens):
        counts.update(_COMP.findall(doc))

    return counts


In [27]:
def parallelize_tokenize_file(data_path: str, desired_num_chunks: int = None, max_workers: int = None) -> List[str]:
    """
    splits the file on TOKEN_BYTES boundaries, then tokenizes chunks in parallel.
    Returns a single flat list of tokens. 
    """
    if max_workers is None:
        max_workers = max(1, (os.cpu_count() or 4) - 1)
    
    if desired_num_chunks is None:
        desired_num_chunks = max_workers * 3
    
    with open(data_path, "rb") as f:
        boundaries = find_chunk_boundaries(f, desired_num_chunks, TOKEN_BYTES)

    pairs = list(zip(boundaries[:-1], boundaries[1:]))
    tasks = ((data_path, s, e, PAT) for s, e in pairs)

    token_counts = Counter()

    with ProcessPoolExecutor(max_workers=max_workers) as ex:
        futures = [ex.submit(_tokenize_slice, t) for t in tasks]
        for fut in tqdm(as_completed(futures), total=len(futures), desc="tokenizing chunks"):
            token_counts.update(fut.result())

    return token_counts

In [None]:
token_counts = parallelize_tokenize_file(data_path, desired_num_chunks=24, max_workers=8)

total_tokens = sum(token_counts.values())
top20 = token_counts.most_common(20)

print(f"Total tokens: {total_tokens:,}")
print(top20[:5])

tokenizing chunks: 100%|██████████| 24/24 [00:25<00:00,  1.08s/it]


Total tokens: 536,592,168
[('.', 41764510), (',', 23284330), (' the', 20828576), (' and', 19475966), (' a', 15063529)]
