## Start

In [22]:
import os
import pickle
import time
import random
import torch


import regex as re
import numpy as np
import torch.nn as nn

from tqdm import tqdm
from concurrent.futures import ProcessPoolExecutor
from typing import BinaryIO
from collections import defaultdict
from collections.abc import Iterable, Iterator
from einops import rearrange, einsum

In [2]:
PAT = r"""'(?:[sdmt]|ll|ve|re)| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+"""
TOKEN = tuple[bytes]
PAIR = tuple[bytes, bytes]
END_TOKEN = '<|endoftext|>'

In [3]:
import sys
import os
# 获取当前 notebook 所在目录的上两级目录（即项目根目录）
project_root = os.path.abspath(os.path.join(os.getcwd(), "../.."))
# 将项目根目录加入 sys.path
if project_root not in sys.path:
    sys.path.append(project_root)

## 1 Assignment Overview

## 2 Byte-Pair Encoding (BPE) Tokenizer

### 2.1 The Unicode Standard

In [3]:
[ord('牛'), chr(29275)]

[29275, '牛']

#### Problem (unicode1)

##### a

In [None]:
chr(0)

'\x00'

##### b

In [None]:
print(chr(0))

 


##### c

In [8]:
"this is a test" + chr(0) + "string"

'this is a test\x00string'

In [9]:
print("this is a test" + chr(0) + "string")

this is a test string


### 2.2 Unicode Encodings

In [10]:
test_string = "hello! こんにちは!"
utf8_encoded = test_string.encode("utf-8")
print(utf8_encoded)

b'hello! \xe3\x81\x93\xe3\x82\x93\xe3\x81\xab\xe3\x81\xa1\xe3\x81\xaf!'


In [11]:
print(type(utf8_encoded))

<class 'bytes'>


In [12]:
list(utf8_encoded)

[104,
 101,
 108,
 108,
 111,
 33,
 32,
 227,
 129,
 147,
 227,
 130,
 147,
 227,
 129,
 171,
 227,
 129,
 161,
 227,
 129,
 175,
 33]

In [13]:
[len(test_string),len(utf8_encoded)]

[13, 23]

In [14]:
print(utf8_encoded.decode("utf-8"))

hello! こんにちは!


#### Problem (unicode2)

##### a

##### b

In [15]:
def decode_utf8_bytes_to_str_wrong(bytestring: bytes):
    return "".join([bytes([b]).decode("utf-8") for b in bytestring])
decode_utf8_bytes_to_str_wrong("hello".encode("utf-8"))

'hello'

In [16]:
decode_utf8_bytes_to_str_wrong("hello, 你好".encode("utf-8"))

UnicodeDecodeError: 'utf-8' codec can't decode byte 0xe4 in position 0: unexpected end of data

##### c

In [18]:
sentence = "hello, 你好".encode("utf-8")
[bytes([b]) for b in sentence]

[b'h',
 b'e',
 b'l',
 b'l',
 b'o',
 b',',
 b' ',
 b'\xe4',
 b'\xbd',
 b'\xa0',
 b'\xe5',
 b'\xa5',
 b'\xbd']

In [24]:
(b'\xe4\xbd\xa0').decode("utf-8")

'你'

### 2.3 Subword Tokenization

### 2.4 BPE Tokenizer Training

In [27]:
re.findall(PAT, "some text that i'll pre-tokenize")

['some', ' text', ' that', ' i', "'ll", ' pre', '-', 'tokenize']

In [28]:
max([("A", "B"), ("A", "C"), ("B", "ZZ"), ("BA", "A")])

('BA', 'A')

In [31]:
test_string = """low low low low low lower lower widest widest widest newest newest newest newest newest newest"""

In [32]:
re.findall(PAT, test_string)

['low',
 ' low',
 ' low',
 ' low',
 ' low',
 ' lower',
 ' lower',
 ' widest',
 ' widest',
 ' widest',
 ' newest',
 ' newest',
 ' newest',
 ' newest',
 ' newest',
 ' newest']

### 2.5 Experimenting with BPE Tokenizer Training

In [3]:
def find_chunk_boundaries(
    file: BinaryIO,
    desired_num_chunks: int,
    split_special_token: bytes,
) -> list[int]:
    """
    Chunk the file into parts that can be counted independently.
    May return fewer chunks if the boundaries end up overlapping.
    """
    assert isinstance(split_special_token, bytes), "Must represent special token as a bytestring"

    # Get total file size in bytes
    file.seek(0, os.SEEK_END)
    file_size = file.tell()
    file.seek(0)

    chunk_size = file_size // desired_num_chunks

    # Initial guesses for chunk boundary locations, uniformly spaced
    # Chunks start on previous index, don't include last index
    chunk_boundaries = [i * chunk_size for i in range(desired_num_chunks + 1)]
    chunk_boundaries[-1] = file_size

    mini_chunk_size = 4096  # Read ahead by 4k bytes at a time

    for bi in range(1, len(chunk_boundaries) - 1):
        initial_position = chunk_boundaries[bi]
        file.seek(initial_position)  # Start at boundary guess
        while True:
            mini_chunk = file.read(mini_chunk_size)  # Read a mini chunk

            # If EOF, this boundary should be at the end of the file
            if mini_chunk == b"":
                chunk_boundaries[bi] = file_size
                break

            # Find the special token in the mini chunk
            found_at = mini_chunk.find(split_special_token)
            if found_at != -1:
                chunk_boundaries[bi] = initial_position + found_at
                break
            initial_position += mini_chunk_size

    # Make sure all boundaries are unique, but might be fewer than desired_num_chunks
    return sorted(set(chunk_boundaries))

def update_token_freqs(
    token_freqs: dict[TOKEN, int],
    text_segment: str,
):
    matches = re.finditer(PAT, text_segment)
    for m in matches:
        token = m.group()
        token_bytes = tuple(bytes([b]) for b in token.encode("utf-8"))
        token_freqs[token_bytes] += 1 

def process_chunk(args):
    input_path, start, end, special_tokens, chunk_id = args
    # Read chunk
    with open(input_path, 'rb') as f:
        f.seek(start)
        chunk_bytes = f.read(end - start)
    
    # Decode
    chunk_str = chunk_bytes.decode("utf-8", errors="ignore")
    
    # Logic from original train_bpe
    special_pat = "|".join(re.escape(st) for st in special_tokens)
    segments = re.split(special_pat, chunk_str)
    
    token_freqs = defaultdict(int)
    # Use position=chunk_id+1 so 0 is left for the main bar
    for seg in segments:
        update_token_freqs(token_freqs, seg)
    
    return token_freqs

def get_pairs(
    token: TOKEN
) -> list[PAIR]:
    if len(token) < 2:
        return []
    return [(token[i], token[i+1]) for i in range(len(token)-1)]

def update_pair_freqs(
    pair_freqs: dict[PAIR, int],
    token_freqs: dict[TOKEN, int],
):
    for token, freq in token_freqs.items():
        if len(token) < 2:
            continue
        pairs = get_pairs(token)
        for p in pairs:
            pair_freqs[p] += freq

def update_pair2idx(
    pair_to_idx: dict[PAIR, set[TOKEN]],
    token_freqs: dict[TOKEN, int],
    token_to_idx: dict[TOKEN, int],
):
    for token, freq in token_freqs.items():
        if len(token) < 2:
            continue
        pairs = get_pairs(token)
        for p in pairs:
            pair_to_idx[p].add(token_to_idx[token])


def get_most_frequent_pair(
    pair_freqs: dict[PAIR, int],
) -> PAIR:
    return max(pair_freqs.keys(), key=lambda k: (pair_freqs[k], k))

def update_vocab(
    vocab: dict[int, bytes],
    new_id: int,
    best_pair: PAIR,
):
    new_vocab = best_pair[0] + best_pair[1]
    vocab[new_id] = new_vocab

def update_vocab_inverse(
    vocab_inverse: dict[bytes, int],
    new_id: int,
    best_pair: PAIR,
):
    new_vocab = best_pair[0] + best_pair[1]
    vocab_inverse[new_vocab] = new_id

def update_merges(
    merges: list[PAIR],
    best_pair: PAIR,
):
    merges.append(best_pair)

def update_all(
    pair_to_idx: dict[PAIR, set[int]],
    pair_freqs: dict[PAIR, int],
    token_freqs: dict[TOKEN, int],
    idx_to_token: dict[int, TOKEN],
    best_pair: PAIR,
):
    affected_idxs = list(pair_to_idx[best_pair])
    merged_bytes = best_pair[0] + best_pair[1]

    for idx in affected_idxs:
        # get new token
        i=0
        token = idx_to_token[idx]
        new_token = []
        while(i<(len(token))):
            if (i < len(token) - 1) and (token[i] == best_pair[0]) and (token[i+1] == best_pair[1]):
                new_token.append(merged_bytes)
                i = i + 2
            else:
                new_token.append(token[i])
                i = i + 1
        new_token = tuple(new_token)


        ## update pair_to_idx
        new_pairs = get_pairs(new_token)
        affected_pairs = get_pairs(token)
        for p in set(new_pairs)-set(affected_pairs):
            pair_to_idx[p].add(idx)
        for p in set(affected_pairs)-set(new_pairs):
            pair_to_idx[p].discard(idx)


        ## update pair_freqs
        for pair in affected_pairs:
            pair_freqs[pair] -= token_freqs[token]
        for pair in new_pairs:
            pair_freqs[pair] += token_freqs[token]


        ## update token_freqs
        origin_freq = token_freqs.pop(token)
        token_freqs[new_token] = origin_freq

        ## update idx_to_token
        idx_to_token[idx] = new_token
        

def train_bpe(
    input_path: str,
    vocab_size: int,
    special_tokens: list[str],
    num_processes: int = 4
):
    print("Start Training BPE...")
    start_total = time.time()

    print("Initialize Vocab and Merges...")
    vocab = {i:bytes([i]) for i in range(256)}
    vocab_inverse = {v:k for k,v in vocab.items()}
    # Add special tokens
    for st in special_tokens:
        new_id = len(vocab)
        st_bytes = st.encode("utf-8")
        vocab[new_id] = st_bytes
        vocab_inverse[st_bytes] = new_id
    
    merges = []

    print("Calculating chunk boundaries...")
    t0 = time.time()
    with open(input_path, 'rb') as f:
        # Assuming first special token is the split token
        split_token = special_tokens[0].encode("utf-8") if special_tokens else b"<|endoftext|>"
        boundaries = find_chunk_boundaries(f, num_processes, split_token) # More chunks than processes for load balancing
    print(f"Boundaries calculated in {time.time() - t0:.2f}s")
    
    token_freqs = defaultdict(int)

    ## Initiate token_freqs
    print("Update Token Freq (Parallel)...")
    t0 = time.time()
    
    tasks = []
    for i in range(len(boundaries) - 1):
        start = boundaries[i]
        end = boundaries[i+1]
        tasks.append((input_path, start, end, special_tokens, i))
    
    with ProcessPoolExecutor(max_workers=num_processes) as executor:
        results = list(tqdm(executor.map(process_chunk, tasks), total=len(tasks), desc="Token Freqs (Chunks)", position=0))
        
        for local_freqs in results:
            for token, count in local_freqs.items():
                token_freqs[token] += count
                
    idx_to_token = {i: k for i, k in enumerate(token_freqs.keys())}
    token_to_idx = {k: i for i, k in enumerate(token_freqs.keys())}
    print(f"Update Token Freq took {time.time() - t0:.2f}s")


    ## Initiate pair_freqs, pair_to_token, and token_to_pair
    print("Update Pair Freq...")
    t0 = time.time()
    pair_freqs = defaultdict(int)
    update_pair_freqs(pair_freqs, token_freqs)
    print(f"Update Pair Freq took {time.time() - t0:.2f}s")

    print("Update Pair to idx...")
    t0 = time.time()
    pair_to_idx = defaultdict(set)
    update_pair2idx(pair_to_idx, token_freqs, token_to_idx)
    print(f"Update Pair to Token took {time.time() - t0:.2f}s")

    print("Start Merging...")
    t0 = time.time()
    num_merges = vocab_size - 256 - len(special_tokens)
    for i in tqdm(range(num_merges), desc="Merging"):
        ### find most frequent pair
        if not pair_freqs:
            break
        best_pair = get_most_frequent_pair(pair_freqs)

        new_id = len(vocab)

        ## update vocab
        update_vocab(vocab, new_id, best_pair)

        ## update vocab_inverse
        update_vocab_inverse(vocab_inverse, new_id, best_pair)

        ## update merges
        update_merges(merges, best_pair)

        update_all(pair_to_idx, pair_freqs, token_freqs, idx_to_token, best_pair)
    print(f"Merging took {time.time() - t0:.2f}s")
    print(f"Total Training took {time.time() - start_total:.2f}s")

    return vocab, merges

#### Prepare

In [None]:
test_string = "abere ererea<|endoftext|>When and where is not as important as who and what. Hi, I am the the Ivan.<|endoftext|> aaa"
# test_string = "abere ererea"
input_path = ""
vocab_size = 260
special_tokens = [END_TOKEN]

In [5]:
special_pat = "|".join(re.escape(st) for st in special_tokens)
segments = re.split(special_pat, test_string)

text_segment = segments[0]
text_segment

'abere ererea'

#### Initialize Vocab and Merges

In [6]:
vocab = {i:bytes([i]) for i in range(256)}
vocab_inverse = {v:k for k,v in vocab.items()}
# Add special tokens
for st in special_tokens:
    new_id = len(vocab)
    st_bytes = st.encode("utf-8")
    vocab[new_id] = st_bytes
    vocab_inverse[st_bytes] = new_id

merges = []

#### Update Token Freq

In [7]:
token_freqs = defaultdict(int)
update_token_freqs(token_freqs,text_segment)
token_freqs

defaultdict(int,
            {(b'a', b'b', b'e', b'r', b'e'): 1,
             (b' ', b'e', b'r', b'e', b'r', b'e', b'a'): 1})

#### Update idx freq

In [16]:
idx_to_token = {i: k for i, k in enumerate(token_freqs.keys())}
token_to_idx = {k: i for i, k in enumerate(token_freqs.keys())}
idx_freqs = {token_to_idx[token]: freq for token, freq in token_freqs.items()}

#### Update Pair Freq

In [9]:
pair_freqs = defaultdict(int)
update_pair_freqs(pair_freqs,token_freqs)
pair_freqs

defaultdict(int,
            {(b'a', b'b'): 1,
             (b'b', b'e'): 1,
             (b'e', b'r'): 3,
             (b'r', b'e'): 3,
             (b' ', b'e'): 1,
             (b'e', b'a'): 1})

#### Update Pair to idx

In [10]:
pair_to_idx = defaultdict(set)
update_pair2idx(pair_to_idx, token_freqs, token_to_idx)
pair_to_idx

defaultdict(set,
            {(b'a', b'b'): {0},
             (b'b', b'e'): {0},
             (b'e', b'r'): {0, 1},
             (b'r', b'e'): {0, 1},
             (b' ', b'e'): {1},
             (b'e', b'a'): {1}})

#### Merging

In [11]:
best_pair = get_most_frequent_pair(pair_freqs)
best_pair

(b'r', b'e')

In [12]:
new_id = len(vocab)

## update vocab
update_vocab(vocab, new_id, best_pair)

## update vocab_inverse
update_vocab_inverse(vocab_inverse, new_id, best_pair)

## update merges
update_merges(merges, best_pair)

In [17]:
affected_idxs = list(pair_to_idx[best_pair])
merged_bytes = best_pair[0] + best_pair[1]
for idx in affected_idxs:
    token = idx_to_token[idx]

    i=0
    new_token = []
    while(i<(len(token))):
        if (i < len(token) - 1) and (token[i] == best_pair[0]) and (token[i+1] == best_pair[1]):
            new_token.append(merged_bytes)
            i = i + 2
        else:
            new_token.append(token[i])
            i = i + 1
    new_token = tuple(new_token)

    new_pairs = get_pairs(new_token)
    affected_pairs = get_pairs(token)
    for p in set(new_pairs)-set(affected_pairs):
        pair_to_idx[p].add(idx)
    for p in set(affected_pairs)-set(new_pairs):
        pair_to_idx[p].discard(idx)


    for pair in affected_pairs:
        pair_freqs[pair] -= idx_freqs[idx]
    for pair in new_pairs:
        pair_freqs[pair] += idx_freqs[idx]

    # origin_freq = token_freqs.pop(token)
    # token_freqs[new_token] = origin_freq

    idx_to_token[idx] = new_token

In [18]:
pair_to_idx

defaultdict(set,
            {(b'a', b'b'): {0},
             (b'b', b'e'): {0},
             (b'e', b'r'): set(),
             (b'r', b'e'): set(),
             (b' ', b'e'): {1},
             (b'e', b'a'): set(),
             (b'e', b're'): {0, 1},
             (b're', b'a'): {1},
             (b're', b're'): {1}})

In [19]:
pair_freqs

defaultdict(int,
            {(b'a', b'b'): 1,
             (b'b', b'e'): 1,
             (b'e', b'r'): 0,
             (b'r', b'e'): 0,
             (b' ', b'e'): 1,
             (b'e', b'a'): 0,
             (b'e', b're'): 2,
             (b're', b're'): 1,
             (b're', b'a'): 1})

In [86]:
idx_freqs

{0: 1, 1: 1}

In [46]:
idx_to_token

{0: (b'a', b'b', b'e', b're'), 1: (b' ', b'e', b're', b're', b'a')}

### 2.6 BPE Tokenizer: Encoding and Decoding

##### 2.6.1 Encoding text

##### 2.6.2 Decoding text

In [None]:
class Tokenizer():
    def __init__(
        self, 
        vocab: dict[int, bytes], 
        merges: list[PAIR],
        special_tokens: list[str] | None = None,
    ):
        self.vocab = vocab
        self.merges = merges
        
        if special_tokens:
            new_index = len(vocab)
            for st in special_tokens:
                if st != END_TOKEN:
                    vocab[new_index] = st
                    new_index += 1
            self.special_tokens =  sorted(list(set([END_TOKEN]+special_tokens)), key=len, reverse=True)
        else:
            self.special_tokens = [END_TOKEN]

        self.vocab_inverse = {v: k for k, v in vocab.items()}
        self.merges_to_rank = {m: i for i, m in enumerate(merges)}

    @classmethod
    def from_files(
        cls, 
        vocab_filepath: str, 
        merges_filepath: str, 
        special_tokens: list[str] | None = None,
    ) :
        with open(vocab_filepath, "rb") as f:
            vocab = pickle.load(f)
        with open(merges_filepath, "rb") as f:
            merges = pickle.load(f)
        return cls(vocab, merges, special_tokens=special_tokens)

    def merge_tokens(self, token_bytes: list[bytes]) -> list[bytes]:
        if len(token_bytes) <= 1:
            return token_bytes
        if token_bytes in self.special_tokens:
            return token_bytes
        while 1:
            merge_position = -1
            smallest_rank = len(self.merges)
            for i in range(len(token_bytes)-1):
                current_pair = (token_bytes[i], token_bytes[i+1])
                rank = self.merges_to_rank.get(current_pair, -1)
                if rank == -1:
                    continue
                if rank < smallest_rank:
                    smallest_rank = rank
                    merge_position = i
            if merge_position == -1:
                break
            token_bytes = token_bytes[:(merge_position)] + [token_bytes[merge_position]+token_bytes[merge_position+1]]+token_bytes[(merge_position+2):]
        return token_bytes

    def encode(self, text: str) -> list[int]:
        special_pat = "("+"|".join(re.escape(st) for st in self.special_tokens)+")"
        segments = re.split(special_pat, text)
        ids = []
        for segment in segments:
            if not segment:
                continue
            if segment in self.special_tokens:
                token_bytes = self.transfer_text2bytes(segment)
                token_bytes = self.merge_tokens(token_bytes)
                encoded_token = []
                for i in token_bytes:
                    encoded_token.append(self.vocab_inverse[i])
                ids.extend(encoded_token)
                continue
            matches = re.finditer(PAT, segment)
            for m in matches:
                token = m.group()
                token_bytes = self.transfer_text2bytes(token)
                token_bytes = self.merge_tokens(token_bytes)
                encoded_token = []
                for i in token_bytes:
                    encoded_token.append(self.vocab_inverse[i])
                ids.extend(encoded_token)
        return ids

    def _encode_batch(self, texts: list[str]) -> list[list[int]]:
        return [self.encode(text) for text in texts]

    def encode_iterable(self, iterable: Iterable[str], num_processes: int = 4, batch_size: int = 1000) -> Iterator[int]:
        with ProcessPoolExecutor(max_workers=num_processes) as executor:
            def batch_generator():
                current_batch = []
                for text in iterable:
                    current_batch.append(text)
                    if len(current_batch) >= batch_size:
                        yield current_batch
                        current_batch = []
                if current_batch:
                    yield current_batch
            # 直接在 encode_iterable 这一层产出结果
            for batch_results in executor.map(self._encode_batch, batch_generator()):
                for seq in batch_results:
                    yield from seq

    def decode(self, ids: list[int]) -> str:
        unk_bytes = '\ufffd'.encode('utf-8')
        bytes_list = [self.vocab.get(i, unk_bytes) for i in ids]
        return b"".join(bytes_list).decode("utf-8", errors="replace")

    def transfer_text2bytes(self, segment: str) -> list[bytes]:
        if segment in self.special_tokens:
            return [segment.encode("utf-8")]
        token_bytes = list(bytes([b]) for b in segment.encode("utf-8"))
        return token_bytes

In [5]:
import tiktoken
from tests.adapters import get_tokenizer
from tests.common import FIXTURES_PATH, gpt2_bytes_to_unicode

VOCAB_PATH = FIXTURES_PATH / "gpt2_vocab.json"
MERGES_PATH = FIXTURES_PATH / "gpt2_merges.txt"

from tests.test_tokenizer import get_tokenizer_from_vocab_merges_path

### 2.7 Experiments

In [12]:
from cs336_basics.bpe_tokenize import Tokenizer,find_chunk_boundaries,END_TOKEN

In [None]:
input_path = f"{project_root}/data/TinyStoriesV2-GPT4-valid.txt"
special_tokens = [END_TOKEN]

data_group = "train"
vocab_size = 10000
file_name = f"TinyStoriesV2-GPT4-{data_group}.txt"
vocab_filepath = f"{project_root}/outputs/{file_name.split(".")[0]}-vocab-{vocab_size}.pkl"
merges_filepath = f"{project_root}/outputs/{file_name.split(".")[0]}-merge-{vocab_size}.pkl"


num_processes = 16


tokenizer = Tokenizer.from_files(vocab_filepath,merges_filepath,special_tokens)

with open(input_path, 'rb') as f:
    split_token = special_tokens[0].encode("utf-8") if special_tokens else b""
    boundaries = find_chunk_boundaries(f, num_processes, split_token) 

start = boundaries[0]
end = boundaries[1]

with open(input_path, 'rb') as f:
    f.seek(start)
    chunk_bytes = f.read(end - start)
chunk_str = chunk_bytes.decode("utf-8", errors="ignore")
special_pat = "|".join(re.escape(st) for st in special_tokens)
segments = re.split(special_pat, chunk_str)

ran_ints = random.sample(range(len(segments)), 10)
for i in ran_ints:
    ids = tokenizer.encode(segments[i])
    print(len(ids)/len(segments[ran_int]))

In [43]:
special_tokens = [END_TOKEN]

data_group = "train"
vocab_size = 10000
file_name = f"TinyStoriesV2-GPT4-{data_group}.txt"
input_path = f"{project_root}/data/{file_name}"
vocab_filepath = f"{project_root}/outputs/{file_name.split(".")[0]}-vocab-{vocab_size}.pkl"
merges_filepath = f"{project_root}/outputs/{file_name.split(".")[0]}-merge-{vocab_size}.pkl"


num_processes = 16


tokenizer = Tokenizer.from_files(vocab_filepath,merges_filepath,special_tokens)

with open(input_path, 'rb') as f:
    split_token = special_tokens[0].encode("utf-8") if special_tokens else b""
    boundaries = find_chunk_boundaries(f, num_processes, split_token) 

start = boundaries[0]
end = boundaries[1]

with open(input_path, 'rb') as f:
    f.seek(start)
    chunk_bytes = f.read(end - start)
chunk_str = chunk_bytes.decode("utf-8", errors="ignore")
special_pat = "|".join(re.escape(st) for st in special_tokens)
segments = re.split(special_pat, chunk_str)

ran_ints = random.sample(range(len(segments)), 10)
for i in ran_ints:
    ids = tokenizer.encode(segments[i])
    print(len(ids)/len(segments[ran_int]))

0.2624356775300172
0.2933104631217839
0.6792452830188679
0.3567753001715266
0.18181818181818182
0.32246998284734135
0.2692967409948542
0.274442538593482
0.3687821612349914
0.2109777015437393


In [45]:
special_tokens = [END_TOKEN]

data_group = "train"
vocab_size = 32000
file_name = f"owt_{data_group}.txt"
input_path = f"{project_root}/data/{file_name}"
vocab_filepath = f"{project_root}/outputs/{file_name.split(".")[0]}-vocab-{vocab_size}.pkl"
merges_filepath = f"{project_root}/outputs/{file_name.split(".")[0]}-merge-{vocab_size}.pkl"


num_processes = 16


tokenizer = Tokenizer.from_files(vocab_filepath,merges_filepath,special_tokens)

with open(input_path, 'rb') as f:
    split_token = special_tokens[0].encode("utf-8") if special_tokens else b""
    boundaries = find_chunk_boundaries(f, num_processes, split_token) 

start = boundaries[0]
end = boundaries[1]

with open(input_path, 'rb') as f:
    f.seek(start)
    chunk_bytes = f.read(end - start)
chunk_str = chunk_bytes.decode("utf-8", errors="ignore")
special_pat = "|".join(re.escape(st) for st in special_tokens)
segments = re.split(special_pat, chunk_str)

ran_ints = random.sample(range(len(segments)), 10)
for i in ran_ints:
    ids = tokenizer.encode(segments[i])
    print(len(ids)/len(segments[ran_int]))

1.5901759530791788
0.2631964809384164
0.748533724340176
1.0425219941348973
0.9919354838709677
0.6708211143695014
0.2749266862170088
1.5879765395894427
0.41642228739002934
0.39222873900293254


In [46]:
special_tokens = [END_TOKEN]

data_group = "train"
vocab_size = 32000
file_name = f"owt_{data_group}.txt"
input_path = f"{project_root}/data/TinyStoriesV2-GPT4-train.txt"
vocab_filepath = f"{project_root}/outputs/{file_name.split(".")[0]}-vocab-{vocab_size}.pkl"
merges_filepath = f"{project_root}/outputs/{file_name.split(".")[0]}-merge-{vocab_size}.pkl"


num_processes = 16


tokenizer = Tokenizer.from_files(vocab_filepath,merges_filepath,special_tokens)

with open(input_path, 'rb') as f:
    split_token = special_tokens[0].encode("utf-8") if special_tokens else b""
    boundaries = find_chunk_boundaries(f, num_processes, split_token) 

start = boundaries[0]
end = boundaries[1]

with open(input_path, 'rb') as f:
    f.seek(start)
    chunk_bytes = f.read(end - start)
chunk_str = chunk_bytes.decode("utf-8", errors="ignore")
special_pat = "|".join(re.escape(st) for st in special_tokens)
segments = re.split(special_pat, chunk_str)

ran_ints = random.sample(range(len(segments)), 10)
for i in ran_ints:
    ids = tokenizer.encode(segments[i])
    print(len(ids)/len(segments[ran_int]))

0.2641509433962264
0.27101200686106347
0.3584905660377358
0.3516295025728988
0.3584905660377358
0.3156089193825043
0.8576329331046312
0.5728987993138936
0.27958833619210977
0.4922813036020583


In [31]:
ran_int = ran_ints[0]

In [34]:
ids = tokenizer.encode(segments[ran_int])

0.24390243902439024

## 3 Transformer Language Model Architecture

### 3.1 Transformer LM

#### 3.1.1 Token Embeddings

#### 3.1.2 Pre-norm Transformer Block

### 3.2 Output Normalization and Embedding

### 3.3 Remark: Batching, Einsum and Eﬀicient Computation

In [6]:
images = torch.randn(64, 128, 128, 3) # (batch, height, width, channel)
dim_by = torch.linspace(start=0.0, end=1.0, steps=10)

In [13]:
dim_value = rearrange(dim_by, "dim_value -> 1 dim_value 1 1 1")
images_rearr = rearrange(images, "b height width channel -> b 1 height width channel")

In [14]:
dimmed_images = images_rearr * dim_value

In [15]:
dimmed_images = einsum(
    images, dim_by,
    "batch height width channel, dim_value -> batch dim_value height width channel"
)

In [17]:
channels_last = torch.randn(64, 32, 32, 3) # (batch, height, width, channel)
B = torch.randn(32*32, 32*32)

In [None]:
## Rearrange an image tensor for mixing across all pixels
channels_last_flat = channels_last.view(
    -1, channels_last.size(1) * channels_last.size(2), channels_last.size(3)
)
channels_first_flat = channels_last_flat.transpose(1, 2)
channels_first_flat_transformed = channels_first_flat @ B.T
channels_last_flat_transformed = channels_first_flat_transformed.transpose(1, 2)
channels_last_transformed = channels_last_flat_transformed.view(*channels_last.shape)

In [19]:
height = width = 32
## Rearrange replaces clunky torch view + transpose
channels_first = rearrange(
    channels_last,
    "batch height width channel -> batch channel (height width)"
)
channels_first_transformed = einsum(
    channels_first, B,
    "batch channel pixel_in, pixel_out pixel_in -> batch channel pixel_out"
)
channels_last_transformed = rearrange(
    channels_first_transformed,
    "batch channel (height width) -> batch height width channel",
    height=height, width=width
)

In [21]:
height = width = 32
channels_last_transformed = einx.dot(
    "batch row_in col_in channel, (row_out col_out) (row_in col_in)"
    "-> batch row_out col_out channel",
    channels_last, B,
    col_in=width, col_out=width
)

NameError: name 'einx' is not defined

#### 3.3.1 Mathematical Notation and Memory Ordering

### 3.4 Basic Building Blocks: Linear and Embedding Modules

#### 3.4.1 Parameter Initialization

#### 3.4.2 Linear Module

In [None]:
class Linear(nn.Module):
    def __init__(
        self, 
        in_features: int, 
        out_features: int, 
        device: torch.device | None = None, 
        dtype: torch.dtype | None = None
    ):
        super().__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.device = device
        self.dtype = dtype

        self.weights = nn.Parameter(
            torch.empty(out_features, in_features, device=self.device, dtype=self.dtype)
        )
        std = np.sqrt(2/(in_features+out_features))
        nn.init.trunc_normal_(
            self.weights,
            mean = 0,
            std=std,
            a=-3*std,
            b=3*std
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        '''
        Apply the linear transformation to the input
        '''
        output = insum(
            x, self.weights,
            "... in_features, out_features in_features -> ... out_features"
        )
        return output
        

In [26]:
in_features = 4
out_features = 2

In [31]:
weights = nn.Parameter(
            torch.empty(out_features, in_features)
        )

In [None]:
std = np.sqrt(2/(in_features+out_features))


Parameter containing:
tensor([[-0.3163,  0.2574, -0.6433, -1.2062],
        [-0.2289,  0.3611,  0.4565, -0.2161]], requires_grad=True)