In [140]:
import tiktoken
import os
from pathlib import Path
import torch
from typing import Literal
import json
from time import perf_counter

In [3]:
enc = tiktoken.encoding_for_model("gpt2")

In [60]:
vocab_size = 50257

In [65]:
decoded_tokens = [enc.decode([idx]) for idx in range(vocab_size)]

## Token lengths

In [72]:
lengths = torch.tensor([len(t) for t in decoded_tokens], dtype=torch.float)

In [81]:
lengths.mean(), lengths.max()

(tensor(6.3544), tensor(66.))

Apparently, the longest tokens has 66 bytes. That's way too much; if possible, I'll just use the first / last n bytes and hope that the tokenization itself will take care of this. In the future, just don't design stupid tokenizers like this.

For now, I'll try to find out a sensible cutoff length.

First off, define some helpers:

In [90]:
def with_gt_n_bytes(n: int) -> list[int]:
    return torch.where(lengths >= n)[0].tolist()

def decode_tokens(tokens: list[int]) -> list[str]:
    return [enc.decode([t]) for t in tokens]

### Count tokens

How many are there with more than n bytes, for n in [16, 66]

In [93]:
hist = dict()
for num in range(16, 67):
    hist[num] = len(with_gt_n_bytes(num))
hist

{16: 121,
 17: 61,
 18: 45,
 19: 34,
 20: 29,
 21: 28,
 22: 25,
 23: 25,
 24: 24,
 25: 21,
 26: 21,
 27: 21,
 28: 21,
 29: 21,
 30: 21,
 31: 21,
 32: 21,
 33: 12,
 34: 10,
 35: 9,
 36: 9,
 37: 9,
 38: 9,
 39: 9,
 40: 9,
 41: 9,
 42: 9,
 43: 9,
 44: 9,
 45: 9,
 46: 9,
 47: 9,
 48: 9,
 49: 8,
 50: 8,
 51: 8,
 52: 8,
 53: 8,
 54: 8,
 55: 8,
 56: 8,
 57: 7,
 58: 7,
 59: 7,
 60: 7,
 61: 7,
 62: 7,
 63: 7,
 64: 7,
 65: 2,
 66: 1}

### Where to make the cutoff

Are threre any relevant tokens with more than n bytes?

#### 16 bytes

In [92]:
decode_tokens(with_gt_n_bytes(16)[-8:])

[' incomprehensible',
 ' technologically',
 ' Telecommunications',
 '..................',
 'oooooooooooooooo',
 ' Congratulations',
 ' inappropriately',
 '////////////////////////////////']

**CONCLUSION** &mdash; ' incomprehensible' is an important token &rarr; I cannot go below 16 

#### 18 bytes

In [94]:
decode_tokens(with_gt_n_bytes(18))

['--------------------------------',
 '................................',
 '----------------------------------------------------------------',
 '________________________________',
 ' ----------------------------------------------------------------',
 '********************************',
 '--------------------',
 ' --------------------------------',
 '------------------------',
 'ÃÂÃÂÃÂÃÂÃÂÃÂÃÂÃÂÃÂÃÂÃÂÃÂÃÂÃÂÃÂÃÂ',
 '................................................................',
 '________________________________________________________________',
 ' telecommunications',
 '........................',
 ' disproportionately',
 '################################',
 ' guiActiveUnfocused',
 ' externalToEVAOnly',
 'cloneembedreportprint',
 'rawdownloadcloneembedreportprint',
 'externalActionCode',
 '________________________',
 'ÃÂÃÂÃÂÃÂÃÂÃÂÃÂÃÂÃÂÃÂÃÂÃÂÃÂÃÂÃÂÃÂÃÂÃÂÃÂÃÂÃÂÃÂÃÂÃÂÃÂÃÂÃÂÃÂÃÂÃÂÃÂÃÂ',
 ' RandomRedditorWithNo',
 'ItemThumbnailImage',
 'quickShipAvailable',
 'isSpecialOrderable',
 'chan

**CONCLUSION** &mdash; some potentially relevant tokens in here.

#### 20

In [95]:
decode_tokens(with_gt_n_bytes(20))

['--------------------------------',
 '................................',
 '----------------------------------------------------------------',
 '________________________________',
 ' ----------------------------------------------------------------',
 '********************************',
 '--------------------',
 ' --------------------------------',
 '------------------------',
 'ÃÂÃÂÃÂÃÂÃÂÃÂÃÂÃÂÃÂÃÂÃÂÃÂÃÂÃÂÃÂÃÂ',
 '................................................................',
 '________________________________________________________________',
 '........................',
 '################################',
 'cloneembedreportprint',
 'rawdownloadcloneembedreportprint',
 '________________________',
 'ÃÂÃÂÃÂÃÂÃÂÃÂÃÂÃÂÃÂÃÂÃÂÃÂÃÂÃÂÃÂÃÂÃÂÃÂÃÂÃÂÃÂÃÂÃÂÃÂÃÂÃÂÃÂÃÂÃÂÃÂÃÂÃÂ',
 ' RandomRedditorWithNo',
 'BuyableInstoreAndOnline',
 ' --------------------',
 ' ********************************',
 '--------------------------------------------------------',
 '--------------------------------------

**CONCLUSION** &mdash; 20 tokens seems like a fine cutoff, only bullshit beyond that.

## Convert tokens to bytes

Convert tokens to bytes, depending on how many bytes I want to represent a single token.

In [115]:
allbytes = set()
for token in decoded_tokens:
    for char in token:
        allbytes.add(str(char))
allbytes = sorted(list(allbytes))

In [110]:
len(allbytes)

456

In [118]:
int_to_byte = {i: char for i, char in enumerate(allbytes)}
byte_to_int = {v:k for k, v in int_to_byte.items()}
byte_to_int["pad"] = len(allbytes)
byte_to_int["endoftext"] = len(allbytes)+1

In [136]:
def _token_to_bytes_right_aligned(token: str, num_bytes: int) -> list[int]:
    if token == enc.decode([vocab_size-1]):
        return [byte_to_int["endoftext"]] * num_bytes
    if len(token) > num_bytes:
        return [byte_to_int[char] for char in token[-num_bytes:]]

    padded = [byte_to_int["pad"]] * num_bytes
    decoded = [byte_to_int[char] for char in token]
    for i, num in enumerate(reversed(decoded)):
        padded[-i-1] = num

    return padded


def _token_to_bytes_left_aligned(token: str, num_bytes: int) -> list[int]:
    if token == enc.decode([vocab_size-1]):
        return [byte_to_int["endoftext"]] * num_bytes
    if len(token) > num_bytes:
        return [byte_to_int[char] for char in token[:num_bytes]]

    padded = [byte_to_int["pad"]] * num_bytes
    for i, char in enumerate(token):
        padded[i] = byte_to_int[char]

    return padded


def token_to_bytes(token: str, num_bytes: int, alignment: Literal["left", "right"]) -> list[int]:
    assert alignment in ("left", "right")
    if alignment == "left":
        return _token_to_bytes_left_aligned(token, num_bytes)
    else:
        return _token_to_bytes_right_aligned(token, num_bytes)

In [130]:
token_to_bytes("abcd", 6, "left"), token_to_bytes("abcd", 6, "right")

([97, 98, 99, 100, 456, 456], [456, 456, 97, 98, 99, 100])

### Do the actual conversion

In [137]:
for num_bytes in (16, 18, 20):
    for alignment in ("left", "right"):
        ttb = dict()
        for idx in range(vocab_size):
            ttb[idx] = token_to_bytes(token=enc.decode([idx]), num_bytes=num_bytes, alignment=alignment)
        with open(f"ttb_{num_bytes}_{alignment}.json", "w") as f:
            f.write(json.dumps(ttb))

## How do I change the dataloader with this in mind?

### Baseline

Take the dataloader from modded-nanogpt as a baseline

In [138]:
def _load_data_shard(file: Path):
    header = torch.from_file(str(file), False, 256, dtype=torch.int32) # header is 256 int32
    assert header[0] == 20240520, "magic number mismatch in the data .bin file"
    assert header[1] == 1, "unsupported version"
    num_tokens = int(header[2]) # number of tokens (claimed)
    with file.open("rb", buffering=0) as f:
        tokens = torch.empty(num_tokens, dtype=torch.uint16, pin_memory=True) # avoid pin_memory copy by @YouJiacheng
        f.seek(256 * 4)
        nbytes = f.readinto(tokens.numpy()) # avoid bytes->array copy by @YouJiacheng
        assert nbytes == 2 * num_tokens, "number of tokens read does not match header"
    return tokens

def distributed_data_generator(filename_pattern: str, batch_size: int, rank : int, world_size : int):
    files = sorted(Path.cwd().glob(filename_pattern))
    assert batch_size % world_size == 0
    local_batch_size = batch_size // world_size
    file_iter = iter(files) # use itertools.cycle(files) instead if you want to do multi-epoch training
    tokens, pos = _load_data_shard(next(file_iter)), 0
    while True:
        if pos + batch_size + 1 >= len(tokens):
            tokens, pos = _load_data_shard(next(file_iter)), 0
        buf = tokens[pos + rank * local_batch_size:][:local_batch_size + 1]
        inputs = buf[:-1].to(device="cuda", dtype=torch.int32, non_blocking=True) # no sync on host side;
        targets = buf[1:].to(device="cuda", dtype=torch.int64, non_blocking=True) # H2D in another stream isn't helpful.
        pos += batch_size
        yield inputs, targets

In [144]:
num_tries = 10
times = list()
for _ in range(num_tries):
    t0 = perf_counter()
    ddg = distributed_data_generator("data/fineweb10B/fineweb_train_*.bin", 1024, 1, 1)
    try:
        for inputs, targets in ddg:
            pass
    except Exception:
        pass
    times.append(perf_counter() - t0)

t = sum(times) / len(times)
t

54.016402913599904

### Can I just do the conversion online?