## Start

In [44]:
import os
import pickle
import time
import random
import torch
import math


import regex as re
import numpy as np
import pandas as pd
import torch.nn as nn
from typing import Optional
from collections.abc import Callable, Iterable

from tqdm import tqdm
from concurrent.futures import ProcessPoolExecutor
from typing import BinaryIO
from collections import defaultdict
from collections.abc import Iterable, Iterator
from einops import rearrange, einsum, repeat





In [45]:
PAT = r"""'(?:[sdmt]|ll|ve|re)| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+"""
TOKEN = tuple[bytes]
PAIR = tuple[bytes, bytes]
END_TOKEN = '<|endoftext|>'

In [46]:
import sys
import os
# 获取当前 notebook 所在目录的上两级目录（即项目根目录）
project_root = os.path.abspath(os.path.join(os.getcwd(), "../.."))
# 将项目根目录加入 sys.path
if project_root not in sys.path:
    sys.path.append(project_root)

## 1 Assignment Overview

## 2 Byte-Pair Encoding (BPE) Tokenizer

### 2.1 The Unicode Standard

In [3]:
[ord('牛'), chr(29275)]

[29275, '牛']

#### Problem (unicode1)

##### a

In [None]:
chr(0)

'\x00'

##### b

In [None]:
print(chr(0))

 


##### c

In [8]:
"this is a test" + chr(0) + "string"

'this is a test\x00string'

In [9]:
print("this is a test" + chr(0) + "string")

this is a test string


### 2.2 Unicode Encodings

In [10]:
test_string = "hello! こんにちは!"
utf8_encoded = test_string.encode("utf-8")
print(utf8_encoded)

b'hello! \xe3\x81\x93\xe3\x82\x93\xe3\x81\xab\xe3\x81\xa1\xe3\x81\xaf!'


In [11]:
print(type(utf8_encoded))

<class 'bytes'>


In [12]:
list(utf8_encoded)

[104,
 101,
 108,
 108,
 111,
 33,
 32,
 227,
 129,
 147,
 227,
 130,
 147,
 227,
 129,
 171,
 227,
 129,
 161,
 227,
 129,
 175,
 33]

In [13]:
[len(test_string),len(utf8_encoded)]

[13, 23]

In [14]:
print(utf8_encoded.decode("utf-8"))

hello! こんにちは!


#### Problem (unicode2)

##### a

##### b

In [15]:
def decode_utf8_bytes_to_str_wrong(bytestring: bytes):
    return "".join([bytes([b]).decode("utf-8") for b in bytestring])
decode_utf8_bytes_to_str_wrong("hello".encode("utf-8"))

'hello'

In [16]:
decode_utf8_bytes_to_str_wrong("hello, 你好".encode("utf-8"))

UnicodeDecodeError: 'utf-8' codec can't decode byte 0xe4 in position 0: unexpected end of data

##### c

In [18]:
sentence = "hello, 你好".encode("utf-8")
[bytes([b]) for b in sentence]

[b'h',
 b'e',
 b'l',
 b'l',
 b'o',
 b',',
 b' ',
 b'\xe4',
 b'\xbd',
 b'\xa0',
 b'\xe5',
 b'\xa5',
 b'\xbd']

In [24]:
(b'\xe4\xbd\xa0').decode("utf-8")

'你'

### 2.3 Subword Tokenization

### 2.4 BPE Tokenizer Training

In [27]:
re.findall(PAT, "some text that i'll pre-tokenize")

['some', ' text', ' that', ' i', "'ll", ' pre', '-', 'tokenize']

In [28]:
max([("A", "B"), ("A", "C"), ("B", "ZZ"), ("BA", "A")])

('BA', 'A')

In [31]:
test_string = """low low low low low lower lower widest widest widest newest newest newest newest newest newest"""

In [32]:
re.findall(PAT, test_string)

['low',
 ' low',
 ' low',
 ' low',
 ' low',
 ' lower',
 ' lower',
 ' widest',
 ' widest',
 ' widest',
 ' newest',
 ' newest',
 ' newest',
 ' newest',
 ' newest',
 ' newest']

### 2.5 Experimenting with BPE Tokenizer Training

In [3]:
def find_chunk_boundaries(
    file: BinaryIO,
    desired_num_chunks: int,
    split_special_token: bytes,
) -> list[int]:
    """
    Chunk the file into parts that can be counted independently.
    May return fewer chunks if the boundaries end up overlapping.
    """
    assert isinstance(split_special_token, bytes), "Must represent special token as a bytestring"

    # Get total file size in bytes
    file.seek(0, os.SEEK_END)
    file_size = file.tell()
    file.seek(0)

    chunk_size = file_size // desired_num_chunks

    # Initial guesses for chunk boundary locations, uniformly spaced
    # Chunks start on previous index, don't include last index
    chunk_boundaries = [i * chunk_size for i in range(desired_num_chunks + 1)]
    chunk_boundaries[-1] = file_size

    mini_chunk_size = 4096  # Read ahead by 4k bytes at a time

    for bi in range(1, len(chunk_boundaries) - 1):
        initial_position = chunk_boundaries[bi]
        file.seek(initial_position)  # Start at boundary guess
        while True:
            mini_chunk = file.read(mini_chunk_size)  # Read a mini chunk

            # If EOF, this boundary should be at the end of the file
            if mini_chunk == b"":
                chunk_boundaries[bi] = file_size
                break

            # Find the special token in the mini chunk
            found_at = mini_chunk.find(split_special_token)
            if found_at != -1:
                chunk_boundaries[bi] = initial_position + found_at
                break
            initial_position += mini_chunk_size

    # Make sure all boundaries are unique, but might be fewer than desired_num_chunks
    return sorted(set(chunk_boundaries))

def update_token_freqs(
    token_freqs: dict[TOKEN, int],
    text_segment: str,
):
    matches = re.finditer(PAT, text_segment)
    for m in matches:
        token = m.group()
        token_bytes = tuple(bytes([b]) for b in token.encode("utf-8"))
        token_freqs[token_bytes] += 1 

def process_chunk(args):
    input_path, start, end, special_tokens, chunk_id = args
    # Read chunk
    with open(input_path, 'rb') as f:
        f.seek(start)
        chunk_bytes = f.read(end - start)
    
    # Decode
    chunk_str = chunk_bytes.decode("utf-8", errors="ignore")
    
    # Logic from original train_bpe
    special_pat = "|".join(re.escape(st) for st in special_tokens)
    segments = re.split(special_pat, chunk_str)
    
    token_freqs = defaultdict(int)
    # Use position=chunk_id+1 so 0 is left for the main bar
    for seg in segments:
        update_token_freqs(token_freqs, seg)
    
    return token_freqs

def get_pairs(
    token: TOKEN
) -> list[PAIR]:
    if len(token) < 2:
        return []
    return [(token[i], token[i+1]) for i in range(len(token)-1)]

def update_pair_freqs(
    pair_freqs: dict[PAIR, int],
    token_freqs: dict[TOKEN, int],
):
    for token, freq in token_freqs.items():
        if len(token) < 2:
            continue
        pairs = get_pairs(token)
        for p in pairs:
            pair_freqs[p] += freq

def update_pair2idx(
    pair_to_idx: dict[PAIR, set[TOKEN]],
    token_freqs: dict[TOKEN, int],
    token_to_idx: dict[TOKEN, int],
):
    for token, freq in token_freqs.items():
        if len(token) < 2:
            continue
        pairs = get_pairs(token)
        for p in pairs:
            pair_to_idx[p].add(token_to_idx[token])


def get_most_frequent_pair(
    pair_freqs: dict[PAIR, int],
) -> PAIR:
    return max(pair_freqs.keys(), key=lambda k: (pair_freqs[k], k))

def update_vocab(
    vocab: dict[int, bytes],
    new_id: int,
    best_pair: PAIR,
):
    new_vocab = best_pair[0] + best_pair[1]
    vocab[new_id] = new_vocab

def update_vocab_inverse(
    vocab_inverse: dict[bytes, int],
    new_id: int,
    best_pair: PAIR,
):
    new_vocab = best_pair[0] + best_pair[1]
    vocab_inverse[new_vocab] = new_id

def update_merges(
    merges: list[PAIR],
    best_pair: PAIR,
):
    merges.append(best_pair)

def update_all(
    pair_to_idx: dict[PAIR, set[int]],
    pair_freqs: dict[PAIR, int],
    token_freqs: dict[TOKEN, int],
    idx_to_token: dict[int, TOKEN],
    best_pair: PAIR,
):
    affected_idxs = list(pair_to_idx[best_pair])
    merged_bytes = best_pair[0] + best_pair[1]

    for idx in affected_idxs:
        # get new token
        i=0
        token = idx_to_token[idx]
        new_token = []
        while(i<(len(token))):
            if (i < len(token) - 1) and (token[i] == best_pair[0]) and (token[i+1] == best_pair[1]):
                new_token.append(merged_bytes)
                i = i + 2
            else:
                new_token.append(token[i])
                i = i + 1
        new_token = tuple(new_token)


        ## update pair_to_idx
        new_pairs = get_pairs(new_token)
        affected_pairs = get_pairs(token)
        for p in set(new_pairs)-set(affected_pairs):
            pair_to_idx[p].add(idx)
        for p in set(affected_pairs)-set(new_pairs):
            pair_to_idx[p].discard(idx)


        ## update pair_freqs
        for pair in affected_pairs:
            pair_freqs[pair] -= token_freqs[token]
        for pair in new_pairs:
            pair_freqs[pair] += token_freqs[token]


        ## update token_freqs
        origin_freq = token_freqs.pop(token)
        token_freqs[new_token] = origin_freq

        ## update idx_to_token
        idx_to_token[idx] = new_token
        

def train_bpe(
    input_path: str,
    vocab_size: int,
    special_tokens: list[str],
    num_processes: int = 4
):
    print("Start Training BPE...")
    start_total = time.time()

    print("Initialize Vocab and Merges...")
    vocab = {i:bytes([i]) for i in range(256)}
    vocab_inverse = {v:k for k,v in vocab.items()}
    # Add special tokens
    for st in special_tokens:
        new_id = len(vocab)
        st_bytes = st.encode("utf-8")
        vocab[new_id] = st_bytes
        vocab_inverse[st_bytes] = new_id
    
    merges = []

    print("Calculating chunk boundaries...")
    t0 = time.time()
    with open(input_path, 'rb') as f:
        # Assuming first special token is the split token
        split_token = special_tokens[0].encode("utf-8") if special_tokens else b"<|endoftext|>"
        boundaries = find_chunk_boundaries(f, num_processes, split_token) # More chunks than processes for load balancing
    print(f"Boundaries calculated in {time.time() - t0:.2f}s")
    
    token_freqs = defaultdict(int)

    ## Initiate token_freqs
    print("Update Token Freq (Parallel)...")
    t0 = time.time()
    
    tasks = []
    for i in range(len(boundaries) - 1):
        start = boundaries[i]
        end = boundaries[i+1]
        tasks.append((input_path, start, end, special_tokens, i))
    
    with ProcessPoolExecutor(max_workers=num_processes) as executor:
        results = list(tqdm(executor.map(process_chunk, tasks), total=len(tasks), desc="Token Freqs (Chunks)", position=0))
        
        for local_freqs in results:
            for token, count in local_freqs.items():
                token_freqs[token] += count
                
    idx_to_token = {i: k for i, k in enumerate(token_freqs.keys())}
    token_to_idx = {k: i for i, k in enumerate(token_freqs.keys())}
    print(f"Update Token Freq took {time.time() - t0:.2f}s")


    ## Initiate pair_freqs, pair_to_token, and token_to_pair
    print("Update Pair Freq...")
    t0 = time.time()
    pair_freqs = defaultdict(int)
    update_pair_freqs(pair_freqs, token_freqs)
    print(f"Update Pair Freq took {time.time() - t0:.2f}s")

    print("Update Pair to idx...")
    t0 = time.time()
    pair_to_idx = defaultdict(set)
    update_pair2idx(pair_to_idx, token_freqs, token_to_idx)
    print(f"Update Pair to Token took {time.time() - t0:.2f}s")

    print("Start Merging...")
    t0 = time.time()
    num_merges = vocab_size - 256 - len(special_tokens)
    for i in tqdm(range(num_merges), desc="Merging"):
        ### find most frequent pair
        if not pair_freqs:
            break
        best_pair = get_most_frequent_pair(pair_freqs)

        new_id = len(vocab)

        ## update vocab
        update_vocab(vocab, new_id, best_pair)

        ## update vocab_inverse
        update_vocab_inverse(vocab_inverse, new_id, best_pair)

        ## update merges
        update_merges(merges, best_pair)

        update_all(pair_to_idx, pair_freqs, token_freqs, idx_to_token, best_pair)
    print(f"Merging took {time.time() - t0:.2f}s")
    print(f"Total Training took {time.time() - start_total:.2f}s")

    return vocab, merges

#### Prepare

In [None]:
test_string = "abere ererea<|endoftext|>When and where is not as important as who and what. Hi, I am the the Ivan.<|endoftext|> aaa"
# test_string = "abere ererea"
input_path = ""
vocab_size = 260
special_tokens = [END_TOKEN]

In [5]:
special_pat = "|".join(re.escape(st) for st in special_tokens)
segments = re.split(special_pat, test_string)

text_segment = segments[0]
text_segment

'abere ererea'

#### Initialize Vocab and Merges

In [6]:
vocab = {i:bytes([i]) for i in range(256)}
vocab_inverse = {v:k for k,v in vocab.items()}
# Add special tokens
for st in special_tokens:
    new_id = len(vocab)
    st_bytes = st.encode("utf-8")
    vocab[new_id] = st_bytes
    vocab_inverse[st_bytes] = new_id

merges = []

#### Update Token Freq

In [7]:
token_freqs = defaultdict(int)
update_token_freqs(token_freqs,text_segment)
token_freqs

defaultdict(int,
            {(b'a', b'b', b'e', b'r', b'e'): 1,
             (b' ', b'e', b'r', b'e', b'r', b'e', b'a'): 1})

#### Update idx freq

In [16]:
idx_to_token = {i: k for i, k in enumerate(token_freqs.keys())}
token_to_idx = {k: i for i, k in enumerate(token_freqs.keys())}
idx_freqs = {token_to_idx[token]: freq for token, freq in token_freqs.items()}

#### Update Pair Freq

In [9]:
pair_freqs = defaultdict(int)
update_pair_freqs(pair_freqs,token_freqs)
pair_freqs

defaultdict(int,
            {(b'a', b'b'): 1,
             (b'b', b'e'): 1,
             (b'e', b'r'): 3,
             (b'r', b'e'): 3,
             (b' ', b'e'): 1,
             (b'e', b'a'): 1})

#### Update Pair to idx

In [10]:
pair_to_idx = defaultdict(set)
update_pair2idx(pair_to_idx, token_freqs, token_to_idx)
pair_to_idx

defaultdict(set,
            {(b'a', b'b'): {0},
             (b'b', b'e'): {0},
             (b'e', b'r'): {0, 1},
             (b'r', b'e'): {0, 1},
             (b' ', b'e'): {1},
             (b'e', b'a'): {1}})

#### Merging

In [11]:
best_pair = get_most_frequent_pair(pair_freqs)
best_pair

(b'r', b'e')

In [12]:
new_id = len(vocab)

## update vocab
update_vocab(vocab, new_id, best_pair)

## update vocab_inverse
update_vocab_inverse(vocab_inverse, new_id, best_pair)

## update merges
update_merges(merges, best_pair)

In [17]:
affected_idxs = list(pair_to_idx[best_pair])
merged_bytes = best_pair[0] + best_pair[1]
for idx in affected_idxs:
    token = idx_to_token[idx]

    i=0
    new_token = []
    while(i<(len(token))):
        if (i < len(token) - 1) and (token[i] == best_pair[0]) and (token[i+1] == best_pair[1]):
            new_token.append(merged_bytes)
            i = i + 2
        else:
            new_token.append(token[i])
            i = i + 1
    new_token = tuple(new_token)

    new_pairs = get_pairs(new_token)
    affected_pairs = get_pairs(token)
    for p in set(new_pairs)-set(affected_pairs):
        pair_to_idx[p].add(idx)
    for p in set(affected_pairs)-set(new_pairs):
        pair_to_idx[p].discard(idx)


    for pair in affected_pairs:
        pair_freqs[pair] -= idx_freqs[idx]
    for pair in new_pairs:
        pair_freqs[pair] += idx_freqs[idx]

    # origin_freq = token_freqs.pop(token)
    # token_freqs[new_token] = origin_freq

    idx_to_token[idx] = new_token

In [18]:
pair_to_idx

defaultdict(set,
            {(b'a', b'b'): {0},
             (b'b', b'e'): {0},
             (b'e', b'r'): set(),
             (b'r', b'e'): set(),
             (b' ', b'e'): {1},
             (b'e', b'a'): set(),
             (b'e', b're'): {0, 1},
             (b're', b'a'): {1},
             (b're', b're'): {1}})

In [19]:
pair_freqs

defaultdict(int,
            {(b'a', b'b'): 1,
             (b'b', b'e'): 1,
             (b'e', b'r'): 0,
             (b'r', b'e'): 0,
             (b' ', b'e'): 1,
             (b'e', b'a'): 0,
             (b'e', b're'): 2,
             (b're', b're'): 1,
             (b're', b'a'): 1})

In [86]:
idx_freqs

{0: 1, 1: 1}

In [46]:
idx_to_token

{0: (b'a', b'b', b'e', b're'), 1: (b' ', b'e', b're', b're', b'a')}

### 2.6 BPE Tokenizer: Encoding and Decoding

##### 2.6.1 Encoding text

##### 2.6.2 Decoding text

In [None]:
class Tokenizer():
    def __init__(
        self, 
        vocab: dict[int, bytes], 
        merges: list[PAIR],
        special_tokens: list[str] | None = None,
    ):
        self.vocab = vocab
        self.merges = merges
        
        if special_tokens:
            new_index = len(vocab)
            for st in special_tokens:
                if st != END_TOKEN:
                    vocab[new_index] = st
                    new_index += 1
            self.special_tokens =  sorted(list(set([END_TOKEN]+special_tokens)), key=len, reverse=True)
        else:
            self.special_tokens = [END_TOKEN]

        self.vocab_inverse = {v: k for k, v in vocab.items()}
        self.merges_to_rank = {m: i for i, m in enumerate(merges)}

    @classmethod
    def from_files(
        cls, 
        vocab_filepath: str, 
        merges_filepath: str, 
        special_tokens: list[str] | None = None,
    ) :
        with open(vocab_filepath, "rb") as f:
            vocab = pickle.load(f)
        with open(merges_filepath, "rb") as f:
            merges = pickle.load(f)
        return cls(vocab, merges, special_tokens=special_tokens)

    def merge_tokens(self, token_bytes: list[bytes]) -> list[bytes]:
        if len(token_bytes) <= 1:
            return token_bytes
        if token_bytes in self.special_tokens:
            return token_bytes
        while 1:
            merge_position = -1
            smallest_rank = len(self.merges)
            for i in range(len(token_bytes)-1):
                current_pair = (token_bytes[i], token_bytes[i+1])
                rank = self.merges_to_rank.get(current_pair, -1)
                if rank == -1:
                    continue
                if rank < smallest_rank:
                    smallest_rank = rank
                    merge_position = i
            if merge_position == -1:
                break
            token_bytes = token_bytes[:(merge_position)] + [token_bytes[merge_position]+token_bytes[merge_position+1]]+token_bytes[(merge_position+2):]
        return token_bytes

    def encode(self, text: str) -> list[int]:
        special_pat = "("+"|".join(re.escape(st) for st in self.special_tokens)+")"
        segments = re.split(special_pat, text)
        ids = []
        for segment in segments:
            if not segment:
                continue
            if segment in self.special_tokens:
                token_bytes = self.transfer_text2bytes(segment)
                token_bytes = self.merge_tokens(token_bytes)
                encoded_token = []
                for i in token_bytes:
                    encoded_token.append(self.vocab_inverse[i])
                ids.extend(encoded_token)
                continue
            matches = re.finditer(PAT, segment)
            for m in matches:
                token = m.group()
                token_bytes = self.transfer_text2bytes(token)
                token_bytes = self.merge_tokens(token_bytes)
                encoded_token = []
                for i in token_bytes:
                    encoded_token.append(self.vocab_inverse[i])
                ids.extend(encoded_token)
        return ids

    def _encode_batch(self, texts: list[str]) -> list[list[int]]:
        return [self.encode(text) for text in texts]

    def encode_iterable(self, iterable: Iterable[str], num_processes: int = 4, batch_size: int = 1000) -> Iterator[int]:
        with ProcessPoolExecutor(max_workers=num_processes) as executor:
            def batch_generator():
                current_batch = []
                for text in iterable:
                    current_batch.append(text)
                    if len(current_batch) >= batch_size:
                        yield current_batch
                        current_batch = []
                if current_batch:
                    yield current_batch
            # 直接在 encode_iterable 这一层产出结果
            for batch_results in executor.map(self._encode_batch, batch_generator()):
                for seq in batch_results:
                    yield from seq

    def decode(self, ids: list[int]) -> str:
        unk_bytes = '\ufffd'.encode('utf-8')
        bytes_list = [self.vocab.get(i, unk_bytes) for i in ids]
        return b"".join(bytes_list).decode("utf-8", errors="replace")

    def transfer_text2bytes(self, segment: str) -> list[bytes]:
        if segment in self.special_tokens:
            return [segment.encode("utf-8")]
        token_bytes = list(bytes([b]) for b in segment.encode("utf-8"))
        return token_bytes

In [5]:
import tiktoken
from tests.adapters import get_tokenizer
from tests.common import FIXTURES_PATH, gpt2_bytes_to_unicode

VOCAB_PATH = FIXTURES_PATH / "gpt2_vocab.json"
MERGES_PATH = FIXTURES_PATH / "gpt2_merges.txt"

from tests.test_tokenizer import get_tokenizer_from_vocab_merges_path

### 2.7 Experiments

In [12]:
from cs336_basics.bpe_tokenize import Tokenizer,find_chunk_boundaries,END_TOKEN

In [None]:
input_path = f"{project_root}/data/TinyStoriesV2-GPT4-valid.txt"
special_tokens = [END_TOKEN]

data_group = "train"
vocab_size = 10000
file_name = f"TinyStoriesV2-GPT4-{data_group}.txt"
vocab_filepath = f"{project_root}/outputs/{file_name.split(".")[0]}-vocab-{vocab_size}.pkl"
merges_filepath = f"{project_root}/outputs/{file_name.split(".")[0]}-merge-{vocab_size}.pkl"


num_processes = 16


tokenizer = Tokenizer.from_files(vocab_filepath,merges_filepath,special_tokens)

with open(input_path, 'rb') as f:
    split_token = special_tokens[0].encode("utf-8") if special_tokens else b""
    boundaries = find_chunk_boundaries(f, num_processes, split_token) 

start = boundaries[0]
end = boundaries[1]

with open(input_path, 'rb') as f:
    f.seek(start)
    chunk_bytes = f.read(end - start)
chunk_str = chunk_bytes.decode("utf-8", errors="ignore")
special_pat = "|".join(re.escape(st) for st in special_tokens)
segments = re.split(special_pat, chunk_str)

ran_ints = random.sample(range(len(segments)), 10)
for i in ran_ints:
    ids = tokenizer.encode(segments[i])
    print(len(ids)/len(segments[ran_int]))

In [43]:
special_tokens = [END_TOKEN]

data_group = "train"
vocab_size = 10000
file_name = f"TinyStoriesV2-GPT4-{data_group}.txt"
input_path = f"{project_root}/data/{file_name}"
vocab_filepath = f"{project_root}/outputs/{file_name.split(".")[0]}-vocab-{vocab_size}.pkl"
merges_filepath = f"{project_root}/outputs/{file_name.split(".")[0]}-merge-{vocab_size}.pkl"


num_processes = 16


tokenizer = Tokenizer.from_files(vocab_filepath,merges_filepath,special_tokens)

with open(input_path, 'rb') as f:
    split_token = special_tokens[0].encode("utf-8") if special_tokens else b""
    boundaries = find_chunk_boundaries(f, num_processes, split_token) 

start = boundaries[0]
end = boundaries[1]

with open(input_path, 'rb') as f:
    f.seek(start)
    chunk_bytes = f.read(end - start)
chunk_str = chunk_bytes.decode("utf-8", errors="ignore")
special_pat = "|".join(re.escape(st) for st in special_tokens)
segments = re.split(special_pat, chunk_str)

ran_ints = random.sample(range(len(segments)), 10)
for i in ran_ints:
    ids = tokenizer.encode(segments[i])
    print(len(ids)/len(segments[ran_int]))

0.2624356775300172
0.2933104631217839
0.6792452830188679
0.3567753001715266
0.18181818181818182
0.32246998284734135
0.2692967409948542
0.274442538593482
0.3687821612349914
0.2109777015437393


In [45]:
special_tokens = [END_TOKEN]

data_group = "train"
vocab_size = 32000
file_name = f"owt_{data_group}.txt"
input_path = f"{project_root}/data/{file_name}"
vocab_filepath = f"{project_root}/outputs/{file_name.split(".")[0]}-vocab-{vocab_size}.pkl"
merges_filepath = f"{project_root}/outputs/{file_name.split(".")[0]}-merge-{vocab_size}.pkl"


num_processes = 16


tokenizer = Tokenizer.from_files(vocab_filepath,merges_filepath,special_tokens)

with open(input_path, 'rb') as f:
    split_token = special_tokens[0].encode("utf-8") if special_tokens else b""
    boundaries = find_chunk_boundaries(f, num_processes, split_token) 

start = boundaries[0]
end = boundaries[1]

with open(input_path, 'rb') as f:
    f.seek(start)
    chunk_bytes = f.read(end - start)
chunk_str = chunk_bytes.decode("utf-8", errors="ignore")
special_pat = "|".join(re.escape(st) for st in special_tokens)
segments = re.split(special_pat, chunk_str)

ran_ints = random.sample(range(len(segments)), 10)
for i in ran_ints:
    ids = tokenizer.encode(segments[i])
    print(len(ids)/len(segments[ran_int]))

1.5901759530791788
0.2631964809384164
0.748533724340176
1.0425219941348973
0.9919354838709677
0.6708211143695014
0.2749266862170088
1.5879765395894427
0.41642228739002934
0.39222873900293254


In [46]:
special_tokens = [END_TOKEN]

data_group = "train"
vocab_size = 32000
file_name = f"owt_{data_group}.txt"
input_path = f"{project_root}/data/TinyStoriesV2-GPT4-train.txt"
vocab_filepath = f"{project_root}/outputs/{file_name.split(".")[0]}-vocab-{vocab_size}.pkl"
merges_filepath = f"{project_root}/outputs/{file_name.split(".")[0]}-merge-{vocab_size}.pkl"


num_processes = 16


tokenizer = Tokenizer.from_files(vocab_filepath,merges_filepath,special_tokens)

with open(input_path, 'rb') as f:
    split_token = special_tokens[0].encode("utf-8") if special_tokens else b""
    boundaries = find_chunk_boundaries(f, num_processes, split_token) 

start = boundaries[0]
end = boundaries[1]

with open(input_path, 'rb') as f:
    f.seek(start)
    chunk_bytes = f.read(end - start)
chunk_str = chunk_bytes.decode("utf-8", errors="ignore")
special_pat = "|".join(re.escape(st) for st in special_tokens)
segments = re.split(special_pat, chunk_str)

ran_ints = random.sample(range(len(segments)), 10)
for i in ran_ints:
    ids = tokenizer.encode(segments[i])
    print(len(ids)/len(segments[ran_int]))

0.2641509433962264
0.27101200686106347
0.3584905660377358
0.3516295025728988
0.3584905660377358
0.3156089193825043
0.8576329331046312
0.5728987993138936
0.27958833619210977
0.4922813036020583


In [31]:
ran_int = ran_ints[0]

In [34]:
ids = tokenizer.encode(segments[ran_int])

0.24390243902439024

## 3 Transformer Language Model Architecture

### 3.1 Transformer LM

#### 3.1.1 Token Embeddings

#### 3.1.2 Pre-norm Transformer Block

### 3.2 Output Normalization and Embedding

### 3.3 Remark: Batching, Einsum and Eﬀicient Computation

In [6]:
images = torch.randn(64, 128, 128, 3) # (batch, height, width, channel)
dim_by = torch.linspace(start=0.0, end=1.0, steps=10)

In [13]:
dim_value = rearrange(dim_by, "dim_value -> 1 dim_value 1 1 1")
images_rearr = rearrange(images, "b height width channel -> b 1 height width channel")

In [14]:
dimmed_images = images_rearr * dim_value

In [15]:
dimmed_images = einsum(
    images, dim_by,
    "batch height width channel, dim_value -> batch dim_value height width channel"
)

In [17]:
channels_last = torch.randn(64, 32, 32, 3) # (batch, height, width, channel)
B = torch.randn(32*32, 32*32)

In [None]:
## Rearrange an image tensor for mixing across all pixels
channels_last_flat = channels_last.view(
    -1, channels_last.size(1) * channels_last.size(2), channels_last.size(3)
)
channels_first_flat = channels_last_flat.transpose(1, 2)
channels_first_flat_transformed = channels_first_flat @ B.T
channels_last_flat_transformed = channels_first_flat_transformed.transpose(1, 2)
channels_last_transformed = channels_last_flat_transformed.view(*channels_last.shape)

In [19]:
height = width = 32
## Rearrange replaces clunky torch view + transpose
channels_first = rearrange(
    channels_last,
    "batch height width channel -> batch channel (height width)"
)
channels_first_transformed = einsum(
    channels_first, B,
    "batch channel pixel_in, pixel_out pixel_in -> batch channel pixel_out"
)
channels_last_transformed = rearrange(
    channels_first_transformed,
    "batch channel (height width) -> batch height width channel",
    height=height, width=width
)

In [21]:
height = width = 32
channels_last_transformed = einx.dot(
    "batch row_in col_in channel, (row_out col_out) (row_in col_in)"
    "-> batch row_out col_out channel",
    channels_last, B,
    col_in=width, col_out=width
)

NameError: name 'einx' is not defined

#### 3.3.1 Mathematical Notation and Memory Ordering

### 3.4 Basic Building Blocks: Linear and Embedding Modules

#### 3.4.1 Parameter Initialization

#### 3.4.2 Linear Module

In [None]:
class Linear(nn.Module):
    def __init__(
        self, 
        in_features: int, 
        out_features: int, 
        device: torch.device | None = None, 
        dtype: torch.dtype | None = None
    ):
        super().__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.device = device
        self.dtype = dtype

        self.weights = nn.Parameter(
            torch.empty(out_features, in_features, device=self.device, dtype=self.dtype)
        )
        std = np.sqrt(2/(in_features+out_features))
        nn.init.trunc_normal_(
            self.weights,
            mean = 0,
            std=std,
            a=-3*std,
            b=3*std
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        '''
        Apply the linear transformation to the input
        '''
        output = einsum(
            x, self.weights,
            "... in_features, out_features in_features -> ... out_features"
        )
        return output
        

In [26]:
in_features = 4
out_features = 2

In [31]:
weights = nn.Parameter(
            torch.empty(out_features, in_features)
        )

In [None]:
std = np.sqrt(2/(in_features+out_features))


Parameter containing:
tensor([[-0.3163,  0.2574, -0.6433, -1.2062],
        [-0.2289,  0.3611,  0.4565, -0.2161]], requires_grad=True)

#### 3.4.3 Embedding Module

In [None]:
class Embedding(nn.Module):
    def __init__(
        self, 
        num_embeddings: int, 
        embedding_dim: int, 
        device: torch.device | None = None, 
        dtype: torch.dtype | None = None
    ):
        super().__init__()
        self.num_embeddings = num_embeddings
        self.embedding_dim = embedding_dim
        self.device = device
        self.dtype = dtype

        self.weights = nn.Parameter(
            torch.empty(num_embeddings, embedding_dim, device=self.device, dtype=self.dtype)
        )
        nn.init.trunc_normal_(
            self.weights,
            mean = 0,
            std=1,
            a=-3,
            b=3
        )

    def forward(self, token_ids: torch.Tensor) -> torch.Tensor:
        output = self.weights[token_ids]
        return output

In [6]:
num_embeddings = 4
embedding_dim = 2

In [None]:
weights = nn.Parameter(
            torch.empty(num_embeddings, embedding_dim)
        )
nn.init.trunc_normal_(
            weights,
            mean = 0,
            std=1,
            a=-3,
            b=3
        )
embedding

Parameter containing:
tensor([[ 1.1913,  0.2204],
        [ 0.8022, -0.8639],
        [-0.1520,  0.4158],
        [ 2.5936,  1.7208]], requires_grad=True)

In [None]:
token_ids = torch.tensor([1,3,2], dtype=torch.int)
weights[token_ids]

tensor([[ 0.8022, -0.8639],
        [ 2.5936,  1.7208],
        [-0.1520,  0.4158]], grad_fn=<IndexBackward0>)

### 3.5 Pre-Norm Transformer Block

#### 3.5.1 Root Mean Square Layer Normalization

In [None]:
class RMSNorm(nn.Module):
    def __init__(
        self, 
        d_model: int, 
        eps: float = 1e-5, 
        device: torch.device | None = None, 
        dtype: torch.dtype | None = None
    ):
        super().__init__()
        self.d_model = d_model
        self.eps = eps
        self.device = device
        self.dtype = dtype

        self.weights = nn.Parameter(
            torch.ones(d_model, device=self.device, dtype=self.dtype)
        )


    def forward(self, x: torch.Tensor) -> torch.Tensor:
        in_dtype = x.dtype
        x = x.to(torch.float32)
        ms = x.pow(2).mean(dim=-1, keepdim=True)
        rms = torch.sqrt(ms+self.eps)
        result = x/rms * self.weights
        return result.to(in_dtype)

#### 3.5.2 Position-Wise Feed-Forward Network

In [None]:
class SwiGLU(nn.Module):
    def __init__(
        self, 
        d_model: int, 
        d_ff: int, 
        device: torch.device | None = None, 
        dtype: torch.dtype | None = None
    ):
        super().__init__()
        self.d_model = d_model
        self.d_ff = d_ff
        self.device = device
        self.dtype = dtype


        self.w1_weight = Linear(self.d_model,self.d_ff,self.device,self.dtype)
        self.w2_weight = Linear(self.d_ff,self.d_model,self.device,self.dtype)
        self.w3_weight = Linear(self.d_model,self.d_ff,self.device,self.dtype)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        def SiLU(x: torch.Tensor) -> torch.Tensor:
            return x*torch.sigmoid(x)
        
        x1 = self.w1_weight(x)
        x1_silu = SiLU(x1)
        x3 = self.w3_weight(x)
        x1_silu_x3 = x1_silu*x3
        result = self.w2_weight(x1_silu_x3)
        return result

#### 3.5.3 Relative Positional Embeddings

In [None]:
class RotaryPositionalEmbedding(nn.Module):
    def __init__(
        self, 
        theta: float, 
        d_k: int, 
        max_seq_len: int,
        device: torch.device | None = None, 
    ):
        super().__init__()
        self.theta = theta
        self.d_k = d_k
        self.max_seq_len = max_seq_len
        self.device = device

        dim_index = torch.arange(self.d_k // 2, device=self.device, dtype=torch.float32)
        position_index = torch.arange(self.max_seq_len, device=self.device, dtype=torch.float32)
        theta_inv_index = self.theta**(-2*dim_index/d_k)
        theta_ik = einsum(
            position_index, theta_inv_index,
            "s, d -> s d"
        )


        sin = torch.sin(theta_ik)
        cos = torch.cos(theta_ik)
        
        self.register_buffer("sin", sin, persistent=False)
        self.register_buffer("cos", cos, persistent=False)
    
    def forward(
        self, x: torch.Tensor,
        toke_position: torch.Tensor,
    ) -> torch.Tensor:
        x_even = x[...,::2]
        x_odd = x[...,1::2]

        sin_expend = sin[position]
        cos_expend = cos[position]

        x_even_new = x_even*cos_expend-x_odd*sin_expend
        x_odd_new = x_even*sin_expend+x_odd*cos_expend

        x_rope = rearrange(
            torch.stack([x_even_new,x_odd_new], dim=-1),
            '... seq_len d_k two -> ... seq_len (d_k two)',
        )
        return x_rope

In [25]:
max_seq_len = 5
d_k = 4
theta = 2
device = 'mps'
dim_index = torch.arange(d_k // 2)
position_index = torch.arange(max_seq_len)
theta_inv_index = theta**(-2*dim_index/d_k)
theta_ik = einsum(
    position_index, theta_inv_index,
    "s, d -> s d"
)


sin = torch.sin(theta_ik)
cos = torch.cos(theta_ik)

In [26]:
dim1 = 1
seq_len = 3
d_k = 4
x = torch.arange(dim1*seq_len*d_k).reshape(dim1,seq_len,d_k)
position = torch.arange(seq_len).unsqueeze(0).expand(dim1,seq_len)
x

tensor([[[ 0,  1,  2,  3],
         [ 4,  5,  6,  7],
         [ 8,  9, 10, 11]]])

In [27]:
x_even = x[...,::2]
x_odd = x[...,1::2]
x_even

tensor([[[ 0,  2],
         [ 4,  6],
         [ 8, 10]]])

In [28]:
sin_expend = sin[position]
cos_expend = cos[position]
sin_expend

tensor([[[0.0000, 0.0000],
         [0.8415, 0.6496],
         [0.9093, 0.9878]]])

In [40]:
x_even_new = x_even*cos_expend-x_odd*sin_expend
x_odd_new = x_even*sin_expend+x_odd*cos_expend
x_even_new

tensor([[[  0.0000,   2.0000],
         [ -2.0461,   0.0140],
         [-11.5129,  -9.3060]]])

In [41]:
x_odd_new

tensor([[[ 1.0000,  3.0000],
         [ 6.0674,  9.2195],
         [ 3.5291, 11.5930]]])

In [38]:
temp = torch.stack([x_even_new,x_odd_new], dim=-1)
temp

tensor([[[[  0.0000,   1.0000],
          [  2.0000,   3.0000]],

         [[ -2.0461,   6.0674],
          [  0.0140,   9.2195]],

         [[-11.5129,   3.5291],
          [ -9.3060,  11.5930]]]])

In [None]:
rearrange(
    temp,
    '... seq_len d_k two -> ... seq_len (d_k two)',
)

tensor([[[  0.0000,   2.0000,   1.0000,   3.0000],
         [ -2.0461,   0.0140,   6.0674,   9.2195],
         [-11.5129,  -9.3060,   3.5291,  11.5930]]])

#### 3.5.4 Scaled Dot-Product Attention

##### softmax

In [54]:
dimension = 2
dim1 = 2
dim2 = 4
dim3 = 3
x = torch.arange(dim1*dim2*dim3).reshape(dim1,dim2,dim3)
x

tensor([[[ 0,  1,  2],
         [ 3,  4,  5],
         [ 6,  7,  8],
         [ 9, 10, 11]],

        [[12, 13, 14],
         [15, 16, 17],
         [18, 19, 20],
         [21, 22, 23]]])

In [55]:
values, indices = torch.max(x, dim=dimension, keepdim=True)
values.shape

torch.Size([2, 4, 1])

In [56]:
x-values

tensor([[[-2, -1,  0],
         [-2, -1,  0],
         [-2, -1,  0],
         [-2, -1,  0]],

        [[-2, -1,  0],
         [-2, -1,  0],
         [-2, -1,  0],
         [-2, -1,  0]]])

In [61]:
x_exp = torch.exp(x-values)
x_exp

tensor([[[0.1353, 0.3679, 1.0000],
         [0.1353, 0.3679, 1.0000],
         [0.1353, 0.3679, 1.0000],
         [0.1353, 0.3679, 1.0000]],

        [[0.1353, 0.3679, 1.0000],
         [0.1353, 0.3679, 1.0000],
         [0.1353, 0.3679, 1.0000],
         [0.1353, 0.3679, 1.0000]]])

In [62]:
x_exp/torch.sum(x_exp, dim=dimension, keepdim=True)

tensor([[[0.0900, 0.2447, 0.6652],
         [0.0900, 0.2447, 0.6652],
         [0.0900, 0.2447, 0.6652],
         [0.0900, 0.2447, 0.6652]],

        [[0.0900, 0.2447, 0.6652],
         [0.0900, 0.2447, 0.6652],
         [0.0900, 0.2447, 0.6652],
         [0.0900, 0.2447, 0.6652]]])

In [None]:
def softmax(
    x: torch.Tensor,
    dimension: int,    
):
    max_values, _ = torch.max(x, dim=dimension, keepdim=True)
    x_exp = torch.exp(x-max_values)
    x_rxp_sum = torch.sum(x_exp, dim=dimension, keepdim=True)
    return x_exp/x_rxp_sum

##### scaled_dot_product_attention

In [83]:
from cs336_basics.transformer import softmax

In [102]:
batch_size = 2
dim1 = 1
seq_len = 3
d_k = 4
d_v = 8
Q = torch.rand(batch_size*dim1*seq_len*d_k).reshape(batch_size,dim1,seq_len,d_k)
K = torch.rand(batch_size*dim1*seq_len*d_k).reshape(batch_size,dim1,seq_len,d_k)
V = torch.rand(batch_size*dim1*seq_len*d_v).reshape(batch_size,dim1,seq_len,d_v)
mask = (torch.rand(seq_len*seq_len) > 0.5).reshape(seq_len,seq_len)

In [103]:
QK= einsum(
    Q, K,
    "batch ... seq_n d_k,  batch ... seq_m d_k -> batch ... seq_n seq_m"
)
QK_scaled = QK/torch.tensor(d_k).sqrt()

In [109]:
mask

tensor([[False, False,  True],
        [ True, False,  True],
        [False, False,  True]])

In [110]:
M = torch.where(mask, torch.tensor(0.0), torch.tensor(float('-inf')))
M

tensor([[-inf, -inf, 0.],
        [0., -inf, 0.],
        [-inf, -inf, 0.]])

In [107]:
QK_soft_max = softmax(QK, -1)
result = einsum(
    QK_soft_max, V,
    "batch ... seq_n seq_m, batch ... seq_n d_v -> batch ... seq_m d_v"
)

In [None]:
def scaled_dot_product_attention(
    Q: torch.Tensor,
    K: torch.Tensor,
    V: torch.Tensor,
    mask: torch.Tensor
):
    d_k = Q.shape[-1]
    QK= einsum(
        Q, K,
        "batch ... seq_n d_k,  batch ... seq_m d_k -> batch ... seq_n seq_m"
    )/torch.tensor(Q.shape[-1])
    QK_soft_max = softmax(QK, -1)
    result = einsum(
        QK_soft_max, V,
        "batch ... seq_n seq_m, batch ... seq_n d_v -> batch ... seq_m d_v"
    )
    return result

In [97]:
torch.tensor(Q.shape[-1])

tensor(4)

In [None]:
Q.dim()-1

4

#### 3.5.5 Causal Multi-Head Self-Attention

In [None]:
class MultiheadSelfAttention(nn.Module):
    def __init__(
        self, 
        d_model: int, 
        num_heads: int, 
        theta: float|None=None,
        max_seq_len:int|None=None,
        device: torch.device | None = None, 
        dtype: torch.dtype | None = None
    ):
        super().__init__()
        self.d_model = d_model
        self.num_heads = num_heads
        self.theta = theta
        self.max_seq_len = max_seq_len
        self.device = device
        self.dtype = dtype

        self.rope = None
        self.d_k = int(d_model/num_heads)
        self.d_v = int(d_model/num_heads)

        W_Q = Linear(d_model, num_heads*self.d_k,device,dtype)
        W_K = Linear(d_model, num_heads*self.d_k,device,dtype)
        W_V = Linear(d_model, num_heads*self.d_v,device,dtype)
        W_O = Linear(num_heads*self.d_v, d_model,device,dtype)

        if (theta is not None) and (max_seq_len is not None):
            self.rope = RotaryPositionalEmbedding(theta, self.d_k, max_seq_len)

    def forward(
        self, x: torch.Tensor,
        token_position: torch.Tensor|None=None,
    ) -> torch.Tensor:
        Q = self.W_Q(x)
        K = self.W_K(x)
        V = self.W_V(x)

        Q = rearrange(
            Q,
            "batch seq (num_heads d_k) -> batch num_heads seq d_k",
            num_heads=self.num_heads
        )
        K = rearrange(
            K,
            "batch seq (num_heads d_k) -> batch num_heads seq d_k",
            num_heads=self.num_heads
        )
        V = rearrange(
            V,
            "batch seq (num_heads d_v) -> batch num_heads seq d_v",
            num_heads=self.num_heads
        )

        seq_len = Q.shape[-2]

        # token_position = repeat(
        #     torch.arange(seq_len),
        #     "seq -> batch num_heads seq",
        #     batch=batch_size,
        #     num_heads=num_heads
        # )
        
        if (self.rope is not None) and (token_position is not None):
            Q = self.rope(Q, token_position)
            K = self.rope(K, token_position)

        mask = torch.triu(torch.ones(seq_len, seq_len), diagonal=0).bool()
        QKV = scaled_dot_product_attention(Q,K,V,mask)
        QKV_reshape = rearrange(
            QKV,
            "batch num_heads seq d_v -> batch seq (num_heads d_v)"
        )
        result = self.W_O(QKV_reshape)
        return result


In [5]:
d_model = 16
num_heads = 4

theta = 2
max_seq_len = 100

In [6]:
from cs336_basics.transformer import Linear, RotaryPositionalEmbedding,scaled_dot_product_attention



batch_size = 2
seq_len = 3
d_k = d_v = int(d_model/num_heads)
x = torch.randn(batch_size, seq_len, d_model)
print(x)

tensor([[[ 0.4357,  0.5501,  0.3195, -0.7769,  0.0510, -0.4412, -0.2595,
          -1.5785,  1.4489,  0.0391, -0.2899,  0.1485,  1.4783, -0.8320,
           0.4811,  0.7672],
         [-0.4532, -1.0519, -2.1271, -0.9563, -0.5471, -1.0961, -1.0668,
          -0.6359, -2.0603, -0.1888,  0.9379, -1.0534, -0.2707,  0.7150,
           0.2585,  0.3702],
         [ 1.5414,  0.1439, -0.3604,  1.2360, -0.0041, -0.4724, -0.8149,
          -0.3018, -1.3162, -0.5365,  1.7381, -1.5125, -2.0367,  0.0171,
           0.4785, -0.0124]],

        [[ 0.8033,  0.2236, -0.1552, -0.1383, -1.8450,  0.8227, -2.4821,
           1.6230, -0.5547,  0.1371, -0.0919, -1.5927, -0.0959, -2.1357,
          -0.3561, -1.1158],
         [-2.3072, -0.1718,  1.5170, -0.3445,  0.8949,  0.0707,  1.0895,
           1.0088, -0.1453, -0.8405, -0.4352,  0.4360, -0.5356, -0.4802,
           0.0862,  0.3063],
         [ 2.6335,  0.8828, -0.6398, -0.3546, -0.4464,  1.1500,  0.5257,
           0.6860, -0.5160,  0.2141,  1.9333,  0.0

In [13]:
W_Q = Linear(d_model, num_heads*d_k)
W_K = Linear(d_model, num_heads*d_k)
W_V = Linear(d_model, num_heads*d_v)
W_O = Linear(num_heads*d_v, d_model)

Q = W_Q(x)
K = W_K(x)
V = W_V(x)

Q = rearrange(
    Q,
    "batch seq (num_heads d_k) -> batch num_heads seq d_k",
    num_heads=num_heads
)
K = rearrange(
    K,
    "batch seq (num_heads d_k) -> batch num_heads seq d_k",
    num_heads=num_heads
)
V = rearrange(
    V,
    "batch seq (num_heads d_v) -> batch num_heads seq d_v",
    num_heads=num_heads
)

rope = RotaryPositionalEmbedding(theta, d_k, max_seq_len)
token_position = repeat(
    torch.arange(seq_len),
    "seq -> batch num_heads seq",
    batch=batch_size,
    num_heads=num_heads
)
Q_rope = rope(Q, token_position)
K_rope = rope(K, token_position)

mask = torch.triu(torch.ones(seq_len, seq_len), diagonal=0).bool()
QKV = scaled_dot_product_attention(Q_rope,K_rope,V,mask)
QKV_reshape = rearrange(
    QKV,
    "batch num_heads seq d_v -> batch seq (num_heads d_v)"
)
result = W_O(QKV_reshape)

In [9]:
QKV.shape

torch.Size([2, 4, 3, 4])

In [11]:
W_O.shape

AttributeError: 'Linear' object has no attribute 'shape'

In [12]:
[num_heads*d_v, d_model]

[16, 16]

**有两个大问题。一是对于QK的softmax方向，二是对于上下三角。**

### 3.6 The Full Transformer LM

##### transformer_block

In [None]:
from cs336_basics.transformer import (
    MultiheadSelfAttention,
    RMSNorm,
    SwiGLU
)


In [None]:
class TransformerBlock(nn.Module):
    def __init__(
        self, 
        d_model: int, 
        num_heads: int, 
        d_ff: int,
        eps: float = 1e-5, 
        theta: float|None=None,
        max_seq_len:int|None=None,
        device: torch.device | None = None, 
        dtype: torch.dtype | None = None
    ):
        super().__init__()
        self.d_model = d_model
        self.num_heads = num_heads
        self.d_ff = d_ff
        self.eps = eps
        self.theta = theta
        self.max_seq_len = max_seq_len
        self.device = device
        self.dtype = dtype

        self.multihead_attention = MultiheadSelfAttention(
            d_model,num_heads,theta,max_seq_len,device,dtype
        )
        self.rms_norm1 = RMSNorm(d_model,eps,device,dtype)
        self.rms_norm2 = RMSNorm(d_model,eps,device,dtype)
        self.swi_glu = SwiGLU(d_model,d_ff,device,dtype)

    def forward(
        self, x: torch.Tensor,
        token_position: torch.Tensor|None=None,
    ) -> torch.Tensor:
        x_norm = self.rms_norm1(x)
        x_attention = self.multihead_attention(x_norm, token_position)
        x2 = x + x_attention
        x2_norm = self.rms_norm2(x2)
        x2_glu = self.swi_glu(x2_norm)
        x3 = x2 + x2_glu
        return x3



##### transformer_lm

In [6]:
from cs336_basics.transformer import (
    Embedding,
    TransformerBlock,
    RMSNorm,
    Linear,
    softmax
)


In [None]:
class TransformerLM(nn.Module):
    def __init__(
        self, 
        vocab_size: int,
        context_length: int,
        num_layers: int,
        d_model: int, 
        num_heads: int, 
        d_ff: int,
        eps: float = 1e-5, 
        theta: float|None=None,
        device: torch.device | None = None, 
        dtype: torch.dtype | None = None
    ):
        super().__init__()
        self.vocab_size = vocab_size
        self.context_length = context_length
        self.num_layers = num_layers
        self.d_model = d_model
        self.num_heads = num_heads
        self.d_ff = d_ff
        self.eps = eps
        self.theta = theta
        self.device = device
        self.dtype = dtype

        self.embedding = Embedding(vocab_size,d_ff,device,dtype)
        self.transformer_blocks = nn.ModuleList([
            TransformerBlock(d_model, num_heads, d_ff, rope_theta, context_length, device, dtype)
            for _ in range(num_layers)
        ])
        self.rms_norm = RMSNorm(d_model,eps,device,dtype)
        self.linear = Linear(d_model,vocab_size,device,dtype)

    def forward(
        self, x: torch.Tensor
    ) -> torch.Tensor:
        x_embedded = self.embedding(x)
        for block in self.transformer_blocks:
            x_embedded = block(x_embedded)
        x_norm = self.rms_norm(x_embedded)
        x_linear = self.linear(x_norm)
        result = softmax(x_linear, dim=-1)
        return result


Suppose:
1. vocab_size = v
2. context_length = s
3. number_layer = N
4. d_model = d
5. n_head = h
6. d_ff = f

| Layer1      | Layer2    | Layer3 | Input_dim | weights_dim | Flop | N_para |
| :-----      | :-----    | :----- | :-------- | :---------- | :--- | :---   |
| Input       |           |        | 0         | s           | 0    | 0      |
| Embedding   |           |        | s         | d*v         | 0    | dv     |
| Transformer | Attention | get QKV| d*s       | d*d         | 6d^2s| 3d^2   |
|             |           | ROPE   | d*s       |             | 2ds  | 0      |
|             |           | QK     | s\*d/h\*h | s\*d/h\*h   | 2ds^2| 0      |
|             |           | softmax|           |             | O(hs^2)| 0    |
|             |           | KV     | s\*d/h\*h | s\*d/h\*h   | 2ds^2| 0      |
|             |           | Output | d*s       | d*d         | 2d^2s| d^2    |
|             | SwiGLU    | W1     | d*s       | d*f         | 2dsf | df     |
|             |           | W2     | f*s       | d*f         | 2dsf | df     |
|             |           | W3     | d*s       | d*f         | 2dsf | df     |
|             | RMS       |        |           |             | O(ds)| 2d     |
| RMS         |           |        |           |             |      | d      |
| Linear      |           |        | d*s       | d*v         | 2dsv | dv     |

N*(8d^2s+4ds^2+6dsf)+2dsv

In [21]:
def analyze_transformer_to_df(v, s, N, d, f, h):
    # 定义每一行的数据逻辑 [Layer1, Layer2, Layer3, Input, Weights, Flop, N_para]
    data = [
        ["Input", "", "", "0", "s", 0, 0],
        ["Embedding", "", "", "s", "d*v", 0, d * v],
        
        # Transformer Block 内部 (这里计算单层)
        ["Transformer", "Attention", "get QKV", "d*s", "3*d*d", 6 * (d**2) * s, 3 * (d**2)],
        ["", "", "ROPE", "d*s", "0", 2 * d * s, 0],
        ["", "", "QK", "s*d", "s*d", 2 * d * (s**2), 0],
        ["", "", "softmax", "h*s*s", "0", 5 * h * (s**2), 0],
        ["", "", "KV", "s*s", "s*d", 2 * d * (s**2), 0],
        ["", "", "Output", "d*s", "d*d", 2 * (d**2) * s, d**2],
        
        ["", "SwiGLU", "W1", "d*s", "d*f", 2 * d * s * f, d * f],
        ["", "", "W3", "d*s", "d*f", 2 * d * s * f, d * f],
        ["", "", "W2", "f*s", "d*f", 2 * d * s * f, d * f],
        
        ["", "RMS", "x2", "d*s", "2*d", 10 * d * s, 2 * d],
        
        # Final layers
        ["RMS", "Final", "", "d*s", "d", 5 * d * s, d],
        ["Linear", "Output", "", "d*s", "d*v", 2 * d * s * v, d * v]
    ]

    columns = ["Layer1", "Layer2", "Layer3", "Input_dim", "Weights_dim", "Flop", "N_para"]
    df = pd.DataFrame(data, columns=columns)

    # 计算 Block 内部的合计数（用于乘以 N 层）
    block_mask = df["Layer1"].isin(["Transformer", ""]) & ~df["Layer2"].isin(["Output", "Final"])
    block_flops = df.loc[block_mask, "Flop"].sum()
    block_params = df.loc[block_mask, "N_para"].sum()

    # 计算总计 (注意：Embedding 和 Linear Output 不随 N 变化)
    total_params = (block_params * N) + (d * v * 2) + d # 词表双向+最后Norm
    total_flops = (block_flops * N) + (2 * d * s * v)

    # 打印 DataFrame (美化显示)
    pd.options.display.float_format = '{:,.0f}'.format
    print("\nDetailed Layer-by-Layer Analysis (Single Block/Layer Stats):")
    print(df)

    print(f"\n{'='*40}")
    print(f"Summary for N = {N} layers:")
    print(f"- Total Parameters: {total_params / 1e9:.3f} Billion")
    print(f"- Total FLOPs:      {total_flops / 1e12:.3f} TFLOPs")
    print(f"- Memory (FP32):    {total_params * 4 / (1024**3):.2f} GB")
    print(f"{'='*40}")

    return df

# 设置参数
params = {
    "v": 50257,
    "s": 1024,
    "N": 48,
    "d": 1600,
    "f": 6400,
    "h": 25
}

df_result = analyze_transformer_to_df(**params)


Detailed Layer-by-Layer Analysis (Single Block/Layer Stats):
         Layer1     Layer2   Layer3 Input_dim Weights_dim          Flop  \
0         Input                             0           s             0   
1     Embedding                             s         d*v             0   
2   Transformer  Attention  get QKV       d*s       3*d*d   15728640000   
3                              ROPE       d*s           0       3276800   
4                                QK       s*d         s*d    3355443200   
5                           softmax     h*s*s           0     131072000   
6                                KV       s*s         s*d    3355443200   
7                            Output       d*s         d*d    5242880000   
8                   SwiGLU       W1       d*s         d*f   20971520000   
9                                W3       d*s         d*f   20971520000   
10                               W2       f*s         d*f   20971520000   
11                     RMS       x2   

In [22]:
params = {
    "v": 50257,
    "s": 1024,
    "N": 12,
    "d": 768,
    "f": 6400,
    "h": 12
}

df_result = analyze_transformer_to_df(**params)


Detailed Layer-by-Layer Analysis (Single Block/Layer Stats):
         Layer1     Layer2   Layer3 Input_dim Weights_dim         Flop  \
0         Input                             0           s            0   
1     Embedding                             s         d*v            0   
2   Transformer  Attention  get QKV       d*s       3*d*d   3623878656   
3                              ROPE       d*s           0      1572864   
4                                QK       s*d         s*d   1610612736   
5                           softmax     h*s*s           0     62914560   
6                                KV       s*s         s*d   1610612736   
7                            Output       d*s         d*d   1207959552   
8                   SwiGLU       W1       d*s         d*f  10066329600   
9                                W3       d*s         d*f  10066329600   
10                               W2       f*s         d*f  10066329600   
11                     RMS       x2       d*s     

In [23]:
params = {
    "v": 50257,
    "s": 1024,
    "N": 24,
    "d": 1024,
    "f": 6400,
    "h": 16
}

df_result = analyze_transformer_to_df(**params)


Detailed Layer-by-Layer Analysis (Single Block/Layer Stats):
         Layer1     Layer2   Layer3 Input_dim Weights_dim          Flop  \
0         Input                             0           s             0   
1     Embedding                             s         d*v             0   
2   Transformer  Attention  get QKV       d*s       3*d*d    6442450944   
3                              ROPE       d*s           0       2097152   
4                                QK       s*d         s*d    2147483648   
5                           softmax     h*s*s           0      83886080   
6                                KV       s*s         s*d    2147483648   
7                            Output       d*s         d*d    2147483648   
8                   SwiGLU       W1       d*s         d*f   13421772800   
9                                W3       d*s         d*f   13421772800   
10                               W2       f*s         d*f   13421772800   
11                     RMS       x2   

In [24]:
params = {
    "v": 50257,
    "s": 1024,
    "N": 36,
    "d": 1280,
    "f": 6400,
    "h": 20
}

df_result = analyze_transformer_to_df(**params)


Detailed Layer-by-Layer Analysis (Single Block/Layer Stats):
         Layer1     Layer2   Layer3 Input_dim Weights_dim          Flop  \
0         Input                             0           s             0   
1     Embedding                             s         d*v             0   
2   Transformer  Attention  get QKV       d*s       3*d*d   10066329600   
3                              ROPE       d*s           0       2621440   
4                                QK       s*d         s*d    2684354560   
5                           softmax     h*s*s           0     104857600   
6                                KV       s*s         s*d    2684354560   
7                            Output       d*s         d*d    3355443200   
8                   SwiGLU       W1       d*s         d*f   16777216000   
9                                W3       d*s         d*f   16777216000   
10                               W2       f*s         d*f   16777216000   
11                     RMS       x2   

In [25]:
params = {
    "v": 50257,
    "s": 16384,
    "N": 48,
    "d": 1600,
    "f": 6400,
    "h": 25
}

df_result = analyze_transformer_to_df(**params)


Detailed Layer-by-Layer Analysis (Single Block/Layer Stats):
         Layer1     Layer2   Layer3 Input_dim Weights_dim           Flop  \
0         Input                             0           s              0   
1     Embedding                             s         d*v              0   
2   Transformer  Attention  get QKV       d*s       3*d*d   251658240000   
3                              ROPE       d*s           0       52428800   
4                                QK       s*d         s*d   858993459200   
5                           softmax     h*s*s           0    33554432000   
6                                KV       s*s         s*d   858993459200   
7                            Output       d*s         d*d    83886080000   
8                   SwiGLU       W1       d*s         d*f   335544320000   
9                                W3       d*s         d*f   335544320000   
10                               W2       f*s         d*f   335544320000   
11                     RMS

## 4 Training a Transformer LM

### 4.1 Cross-entropy loss

In [None]:
def cross_entropy(
    prediction: torch.Tensor,
    target: torch.Tensor
):
    max_values, _ = torch.max(prediction, dim = -1, keepdim=True)
    prediction_scaled = prediction - max_values
    prediction_scaled_exp = torch.exp(prediction_scaled)
    prediction_scaled_exp_sum = torch.sum(prediction_scaled_exp, dim = -1)
    prediction_scaled_exp_sum_log = torch.log(prediction_scaled_exp_sum)
    target_expend = rearrange(
        target,
        "... -> ... 1"
    )
    target_logits = torch.gather(prediction_scaled, -1, target_expend)
    result = - target_logits + prediction_scaled_exp_sum_log
    return result

In [29]:
batch = 2
vocab_size = 4

prediction = torch.arange(batch*vocab_size).reshape(batch,vocab_size)
prediction_sum = torch.sum(prediction, dim = -1, keepdim=True)
prediction = prediction/prediction_sum
target = torch.randint(0,vocab_size,(batch,))
prediction



tensor([[0.0000, 0.1667, 0.3333, 0.5000],
        [0.1818, 0.2273, 0.2727, 0.3182]])

In [30]:
max_values, _ = torch.max(prediction, dim = -1, keepdim=True)
prediction_scaled = prediction - max_values
prediction_scaled

tensor([[-0.5000, -0.3333, -0.1667,  0.0000],
        [-0.1364, -0.0909, -0.0455,  0.0000]])

In [31]:
prediction_scaled_exp = torch.exp(prediction_scaled)
prediction_scaled_exp_sum = torch.sum(prediction_scaled_exp, dim = -1)
prediction_scaled_exp_sum_log = torch.log(prediction_scaled_exp_sum)
prediction_scaled_exp_sum_log

tensor([1.1536, 1.3194])

In [38]:
target_expend = rearrange(
    target,
    "... -> ... 1"
)

target_logits = torch.gather(prediction_scaled, -1, target_expend)
target_logits = rearrange(
    target_logits,
    "... 1 -> ..."
)
target_logits

tensor([-0.5000,  0.0000])

In [39]:
result = - target_logits + prediction_scaled_exp_sum_log
result

tensor([1.6536, 1.3194])

### 4.2 The SGD Optimizer

#### 4.2.1 Implementing SGD in PyTorch

In [None]:
class SGD(torch.optim.Optimizer):
    def __init__(self, params, lr=1e-3):
        if lr < 0:
            raise ValueError(f"Invalid learning rate: {lr}")
        defaults = {"lr": lr}
        super().__init__(params, defaults)

    def step(self, closure: Optional[Callable] = None):
        loss = None if closure is None else closure()
        for group in self.param_groups:
            lr = group["lr"] # Get the learning rate.
            for p in group["params"]:
                if p.grad is None:
                    continue
                state = self.state[p] # Get state associated with p.
                t = state.get("t", 0) # Get iteration number from the state, or initial value.
                grad = p.grad.data # Get the gradient of loss with respect to p.
                p.data -= lr / math.sqrt(t + 1) * grad # Update weight tensor in-place.
                state["t"] = t + 1
        return loss

In [42]:
weights = torch.nn.Parameter(5 * torch.randn((10, 10)))
opt = SGD([weights], lr=1)
for t in range(100):
    opt.zero_grad() # Reset the gradients for all learnable parameters.
    loss = (weights**2).mean() # Compute a scalar loss value.
    print(loss.cpu().item())
    loss.backward() # Run backward pass, which computes gradients.
    opt.step() # Run optimizer step.

20.437528610229492
19.628202438354492
19.07695770263672
18.638938903808594
18.268024444580078
17.94269561767578
17.650890350341797
17.385042190551758
17.14004898071289
16.912277221679688
16.69902992248535
16.49823760986328
16.30828285217285
16.12786293029785
15.955907821655273
15.791542053222656
15.634020805358887
15.482717514038086
15.337088584899902
15.19666862487793
15.061049461364746
14.929871559143066
14.80282211303711
14.679614067077637
14.560001373291016
14.443753242492676
14.330670356750488
14.220563888549805
14.113271713256836
14.008635520935059
13.906517028808594
13.806788444519043
13.709332466125488
13.614039421081543
13.520807266235352
13.429544448852539
13.340163230895996
13.252582550048828
13.166728973388672
13.082529067993164
12.999919891357422
12.918835639953613
12.839221954345703
12.761022567749023
12.684186935424805
12.608667373657227
12.534415245056152
12.461386680603027
12.389545440673828
12.31885051727295
12.249261856079102
12.180747985839844
12.113275527954102
12.

### 4.3 AdamW

In [47]:
class AdamW(torch.optim.Optimizer):
    def __init__(
        self, params,
        lr=1e-3,
        eps=1e-8,
        betas=(0.9, 0.999),
        weight_decay=0
    ):
        if not 0.0 <= lr:
            raise ValueError(f"Invalid learning rate: {lr}")
        if not 0.0 <= eps:
            raise ValueError(f"Invalid epsilon value: {eps}")
        if not 0.0 <= betas[0] < 1.0:
            raise ValueError(f"Invalid beta parameter at index 0: {betas[0]}")
        if not 0.0 <= betas[1] < 1.0:
            raise ValueError(f"Invalid beta parameter at index 1: {betas[1]}")
        defaults = {
            'lr': lr,
            'betas': betas,
            'eps': eps,
            'weight_decay': weight_decay,
        }
        super().__init__(params, defaults)
    
    def step(self, closure: Optional[Callable] = None):
        loss = None if closure is None else closure()
        for group in self.param_groups:
            lr = group["lr"]
            beta1 = group["betas"][0]
            beta2 = group["betas"][1]
            eps = group["eps"] 
            weight_decay = group["weight_decay"] 
            for p in group["params"]:
                if p.grad is None:
                    continue
                state = self.state[p]
                g = p.grad.data


                m = state.get('m',torch.zeros_like(p.data))
                m = beta1*m+(1-beta1)*g
                state['m'] = m

                v = state.get('v',torch.zeros_like(p.data))
                v = beta2*v+(1-beta2)*g**2
                state['v'] = v

                t = state.get("t", 1)
                state['t'] = t+1

                lrt = lr*math.sqrt(1-beta2**t)/(1-beta1**t)

                p.data -= lrt*m/(torch.sqrt(v)+eps)
                if weight_decay != 0:
                    p.data = p.data-lr*weight_decay*p.data
        return loss

### 4.4 Learning rate scheduling

In [None]:
def learning_rate_schedule(
    t: int,
    alpha_max: float,
    alpha_min: float,
    T_w: int,
    T_c: int
):
    if t<T_w:
        alpha_t = t*alpha_max/T_w
    elif t<= T_c:
        alpha_t = alpha_min + (alpha_max-alpha_min)*(1+math.cos((t-T_w)*math.pi/(T_c-T_w)))/2 
    else:
        alpha_t = alpha_min
    return alpha_t 