In [None]:
# !pip install tiktoken

In [None]:
import torch
import torch.nn as nn
import numpy as np
from torch.utils.data import Dataset, DataLoader
import torch.optim as optim
from typing import Any, Dict, Iterable, Optional
import math
import os
import regex as re
import tiktoken
import warnings
warnings.filterwarnings('ignore')

print(torch.cuda.is_available())
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")


True
Using device: cuda


In [None]:
gpt2pat = re.compile(r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""")

list_of_tokens = re.findall(gpt2pat, "hello world! how are you? I'm fine, thank you!😊")
print("List of tokens:", list_of_tokens)

List of tokens: ['hello', ' world', '!', ' how', ' are', ' you', '?', ' I', "'m", ' fine', ',', ' thank', ' you', '!😊']


In [None]:
# Get the GPT-2 encoding --> GPT-2 does not merge spaces, so it's bad for coding
enc_gpt2 = tiktoken.get_encoding("gpt2")
enc_gpt4 = tiktoken.get_encoding("cl100k_base")

list_of_tokens = enc_gpt2.encode("hello world! how are you? I'm fine, thank you!😊")
print("Encoded tokens:", list_of_tokens)

list_of_tokens = enc_gpt4.encode("hello world! how are you? I'm fine, thank you!😊")
print("Encoded tokens:", list_of_tokens)

# gpt-4 merges spaces, so it's better for coding
list_of_tokens = enc_gpt4.encode("how are you?")
print("Encoded tokens from example:", list_of_tokens)
print(len(list_of_tokens))


Encoded tokens: [31373, 995, 0, 703, 389, 345, 30, 314, 1101, 3734, 11, 5875, 345, 0, 47249, 232]
Encoded tokens: [15339, 1917, 0, 1268, 527, 499, 30, 358, 2846, 7060, 11, 9901, 499, 0, 76460, 232]
Encoded tokens from example: [5269, 527, 499, 30]
4


In [None]:
# # special tokens
# special_tokens = {'<|endoftext|>', 50256}
# # minbpe - GPT-4
# # SentencePiece tokenizer

In [None]:
with open(r"tinystories_sample_5M.txt", 'r', encoding='utf-8') as file:
    sample_text_data = file.read()

In [None]:
# sample_text_data[1:1000]
# text_split = sample_text_data.split("<|endoftext|>")
# len(text_split)
# print(text_split)
print(sample_text_data[:1000])

u don't have to be scared of the loud dog, I'll protect you". The mole felt so safe with the little girl. She was very kind and the mole soon came to trust her. He leaned against her and she kept him safe. The mole had found his best friend.
<|endoftext|>
Once upon a time, in a warm and sunny place, there was a big pit. A little boy named Tom liked to play near the pit. One day, Tom lost his red ball. He was very sad.
Tom asked his friend, Sam, to help him search for the ball. They looked high and low, but they could not find the ball. Tom said, "I think my ball fell into the pit."
Sam and Tom went close to the pit. They were scared, but they wanted to find the red ball. They looked into the pit, but it was too dark to see. Tom said, "We must go in and search for my ball."
They went into the pit to search. It was dark and scary. They could not find the ball. They tried to get out, but the pit was too deep. Tom and Sam were stuck in the pit. They called for help, but no one could hear t

In [None]:
# train a BPE
# pre-tokenising sample_text_data
# before pre-tokenising we will split the text around all the special tokens
text_split = sample_text_data.split("<|endoftext|>")
print(len(text_split))

gpt2pre_tok = r"""'(?:[sdmt]|ll|ve|re)| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+"""

tokens = []

for i in range(len(text_split)):

    text_data = text_split[i]
    pre_tokens = re.findall(gpt2pre_tok, text_data)

    for l in range(len(pre_tokens)):
        pre_tokens[l] = pre_tokens[l].encode('utf-8')

# # for i in range(len(pre_tok)):
# #     print(list(pre_tok[i]))

    for l in range(len(pre_tokens)):
        for j in range(len(pre_tokens[l])):
            tokens.append(pre_tokens[l][j])

    # print(len(tokens))

# text_to_bytes = sample_text_data.encode('utf-8')
# list(text_to_bytes)
# print(len(text_to_bytes))
# print(len(sample_text_data))

# len(text_to_bytes)
# list(text_to_bytes)

len(tokens)


6458


5158939

In [None]:
# list(tokens)
max_token = -1
list_of_tokens = list(map(int, tokens))

for i in range(len(list_of_tokens)):
    if list_of_tokens[i] > max_token:
        max_token = list_of_tokens[i]

max_token

226

In [None]:
# BPE implementation

def get_frequency_of_pairs(tokens):

    counts = {}

    for i in range(len(tokens) - 1):
        pair = (tokens[i], tokens[i + 1])
        counts[pair] = counts.get(pair, 0) + 1

    return counts

def get_max_frequent_pair(counts):

    sorted_pairs = sorted(counts.items(), key=lambda x: x[1], reverse=True)
    ch1, ch2 = sorted_pairs[0][0]

    if ch2 == None:
        for i in range(len(counts)):
            if sorted_pairs[i][0][1] != None:
                return sorted_pairs[i][0], sorted_pairs[i][1]
    else:
        get_pair = sorted_pairs[0][0]
        get_the_frequency = sorted_pairs[0][1]

        return get_pair, get_the_frequency


counts = get_frequency_of_pairs(list_of_tokens)
get_pair, cnts = get_max_frequent_pair(counts)
x1, x2 = get_pair

def merge_tokens(ids, pair, idx):

    new_tokens = []
    i = 0
    while i < len(ids):
        if i < len(ids) - 1 and ids[i] == pair[0] and ids[i + 1] == pair[1]:
            new_tokens.append(idx)  # Merging the pair into a single token
            i += 2
        else:
            new_tokens.append(ids[i])
            i += 1

    return new_tokens

merges = {}

for i in range(1000):

    counts = get_frequency_of_pairs(list_of_tokens)
    get_pair, cnts = get_max_frequent_pair(counts)

    if cnts == 1:
        print(f"iter = {i}")
        break
    else:
        # print(f"Iteration {i + 1}: Pair {get_pair} with frequency {cnts}")
        list_of_tokens = merge_tokens(list_of_tokens, get_pair, i + 256)  # Using 256 as the starting index for new tokens
        merges[get_pair] = i + 256
        # print(f"Iteration {i + 1}: Pair {get_pair} with merged token = {i + 256}")


print("Number of tokens:", len(list_of_tokens))
# print("Final tokens:", tokens)



Number of tokens: 1389977


In [None]:
compression_ratio = len(tokens) / len(list_of_tokens)
compression_ratio

3.711528320252781

In [None]:
# list(tokens)
max_token = -1
# list_of_tokens = list(map(int, tokens))

for i in range(len(list_of_tokens)):
    if list_of_tokens[i] > max_token:
        max_token = list_of_tokens[i]

# max_token
# for <|endoftext|>
# list_of_tokens.append(415)
print(merges)

{(101, 32): 256, (100, 32): 257, (116, 104): 258, (32, 97): 259, (46, 32): 260, (116, 32): 261, (121, 32): 262, (115, 32): 263, (110, 257): 264, (116, 111): 265, (101, 114): 266, (101, 257): 267, (258, 256): 268, (44, 32): 269, (119, 97): 270, (105, 110): 271, (104, 256): 272, (265, 32): 273, (111, 117): 274, (259, 264): 275, (101, 110): 276, (104, 97): 277, (260, 84): 278, (259, 32): 279, (111, 109): 280, (115, 97): 281, (97, 114): 282, (32, 268): 283, (111, 110): 284, (104, 101): 285, (46, 10): 286, (105, 109): 287, (108, 108): 288, (103, 32): 289, (270, 263): 290, (97, 110): 291, (111, 114): 292, (266, 32): 293, (105, 116): 294, (97, 121): 295, (105, 100): 296, (105, 114): 297, (105, 263): 298, (114, 101): 299, (112, 108): 300, (105, 108): 301, (267, 273): 302, (119, 105): 303, (97, 109): 304, (258, 101): 305, (108, 111): 306, (115, 116): 307, (114, 105): 308, (97, 32): 309, (97, 264): 310, (260, 72): 311, (260, 83): 312, (111, 32): 313, (285, 262): 314, (311, 256): 315, (32, 104): 

In [None]:
len(merges)

1000

In [None]:
def encode(text):

    # pre-tokenising sample_text_data
    # before pre-tokenising we will split the text around all the special tokens
    text_split = text.split("<|endoftext|>")
    # print(f"len of text split = {len(text_split)}")

    # gpt2pre_tok = r"""'(?:[sdmt]|ll|ve|re)| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+"""

    tokens = []

    for i in range(len(text_split)):

        text_data = text_split[i]
        pre_tokens = re.findall(gpt2pre_tok, text_data)

        for l in range(len(pre_tokens)):
            pre_tokens[l] = pre_tokens[l].encode('utf-8')

    # # for i in range(len(pre_tok)):
    # #     print(list(pre_tok[i]))

        for l in range(len(pre_tokens)):
            for j in range(len(pre_tokens[l])):
                tokens.append(pre_tokens[l][j])

    while len(tokens) >= 2:

        counts = get_frequency_of_pairs(tokens)
        pair, cnts = get_max_frequent_pair(counts)

        if pair not in merges:
            break

        tokens = merge_tokens(tokens, pair, merges[pair])

    return tokens




In [None]:
vocab_size = max_token + 1
print("Vocabulary size = ", vocab_size)

Vocabulary size =  1256


In [None]:
# encoder = encode("My name is <|endoftext|> Yash Kumar")
# len(encoder)
# # list(bytes([255]))
s = "My name is <|endoftext|> Yash Kumar"
words = re.findall(gpt2pre_tok, s)
print(words)

['My', ' name', ' is', ' <|', 'endoftext', '|>', ' Yash', ' Kumar']


In [None]:
# decoding part
vocab = {idx: bytes([idx]) for idx in range(256)}
# vocab[2]
for (p1, p2), idx in merges.items():
    # print(f"Pair {p1}, {p2} is encoded as token {idx}")
    vocab[idx] = vocab[p1] + vocab[p2]

# print("Vocabulary size:", len(vocab))

def decode(ids):
    tokens = b"".join(vocab[idx] for idx in ids)
    text = tokens.decode('utf-8', errors='replace')
    return text

# print("Decoded text:", decode(tokens[:10]))
decode(encode("My name is <|endoftext|> Yash Kumar"))

'My name is  Yash Kumar'

In [None]:
batch_size = 8
context_length = 64
num_heads = 8
d_model = 128 # embedding dimension
n_layers = 6
vocab_size = vocab_size # 356
max_steps = 20000
warmup_steps = 100
total_cycle_steps = 800
max_lr = 9e-4
min_lr = 9e-5
weight_decay = 0.1
grad_clip = 1.0
eps = 1e-5

In [None]:
print(f"vocab size = {vocab_size}")

vocab size = 1256


In [None]:
class Linear(nn.Module):

    def __init__(self, features_in: int, features_out: int, bias = False):

        super().__init__()
        self.weight = nn.Parameter((torch.randn((features_in, features_out)) / (features_in ** 0.5)).to(device))

    def forward(self, x):
        self.out = x @ self.weight.transpose(-2, -1)
        return self.out

    def parameters(self):
        return [self.weight]



In [None]:
class Embedding(nn.Module):

    def __init__(self, num_embedding, embedding_dim):
        super().__init__()
        self.weights = nn.Parameter(torch.randn(num_embedding, embedding_dim).to(device))

    def forward(self, idx):
        self.out = self.weights[idx]
        return self.out

    def parameters(self):
        return [self.weights]


In [None]:
class RMSNorm(nn.Module):

    def __init__(self, d_model: int, eps: float = 1e-5):

        super().__init__()
        self.gamma = nn.Parameter(torch.ones(d_model).to(device))
        self.eps = eps

    def forward(self, x):

        in_dtype = x.dtype
        x = x.to(torch.float32)
        rms = torch.sqrt(torch.mean(x**2, dim=-1, keepdim=True) + self.eps)
        normalized_rmsx = x / rms
        result = normalized_rmsx * self.gamma
        return result.to(in_dtype)


In [None]:
class SwiGLU_FeedForward(nn.Module):

    def __init__(self, d_model: int):
        super().__init__()

        d_fff = (8 / 3) * d_model
        self.d_ff = int(d_fff)
        self.d_model = d_model

        self.w1 = Linear(self.d_ff, self.d_model)
        self.w3 = Linear(self.d_ff, self.d_model)
        self.w2 = Linear(self.d_model, self.d_ff)

    def forward(self, x):

        # x --> [d_model] --> [1, d_model]
        # w1.T -------------> [d_model, d_ff]
        silu = self.w1(x) * torch.sigmoid(self.w1(x))
        # print(f"silu shape = {silu.shape}")
        intermidiate = silu * self.w3(x)
        # print(f"intermidiate shape = {intermidiate.shape}")
        # inter --> [4, 32, d_ff]
        # w2.T ---> [d_ff, d_model]
        result = self.w2(intermidiate)
        # print(f"result shape = {result.shape}")

        return result




In [None]:
class RotaryPositionalEmbedding(nn.Module):
    """
    Implements Rotary Positional Embedding (RoPE) as described in https://arxiv.org/abs/2104.09864.

    This module pre-computes the cosine and sine frequency tables and applies the rotation
    to the input tensor during the forward pass.
    """
    def __init__(self, theta: float, d_k: int, max_seq_len: int):
        """
        Constructs the RoPE module.

        Args:
            theta (float): The base period for the sinusoidal embeddings. A common value is 10000.0.
            d_k (int): The dimension of the query and key vectors. Must be an even number.
            max_seq_len (int): The maximum sequence length that this module will process.
            device (torch.device | None, optional): The device to store the buffers on. Defaults to None.
        """
        super().__init__()

        # Validate that the dimension d_k is an even number, as RoPE works on pairs of dimensions.
        if d_k % 2 != 0:
            raise ValueError(f"Dimension d_k ({d_k}) must be an even number for RoPE.")

        self.theta = theta
        self.d_k = d_k
        self.max_seq_len = max_seq_len
        self.device = device

        # Pre-compute the frequency terms (theta_j in the paper)
        # The shape of freqs will be (d_k / 2)
        freqs = 1.0 / (self.theta ** (torch.arange(0, self.d_k, 2).float() / self.d_k))

        # Pre-compute the positions 'm'
        # The shape of m will be (max_seq_len)
        m = torch.arange(self.max_seq_len)

        # Create the full frequency map by taking the outer product of m and freqs
        # This results in a tensor of shape (max_seq_len, d_k / 2)
        # Each element (i, j) corresponds to the angle m_i * theta_j
        freqs_cis = torch.outer(m, freqs)

        # Compute the cosine and sine values for all positions and frequencies
        cos_cached = torch.cos(freqs_cis)
        sin_cached = torch.sin(freqs_cis)

        # Register the computed values as buffers. Buffers are part of the module's state
        # but are not considered parameters to be trained. They are moved to the correct
        # device when the model is moved (e.g., model.to(device)).
        self.register_buffer("cos_cached", cos_cached.to(device))
        self.register_buffer("sin_cached", sin_cached.to(device))

    def forward(self, x: torch.Tensor, token_positions: torch.Tensor) -> torch.Tensor:
        """
        Applies RoPE to an input tensor.

        Args:
            x (torch.Tensor): The input tensor of queries or keys.
                              Shape: (..., seq_len, d_k)
            token_positions (torch.Tensor): A tensor specifying the absolute positions
                                             of tokens in x along the sequence dimension.
                                             Shape: (..., seq_len)

        Returns:
            torch.Tensor: The tensor with rotary positional embeddings applied.
                          Shape: (..., seq_len, d_k)
        """
        # Retrieve the pre-computed cosine and sine values using the token positions.
        # The indexing operation will fetch the correct (cos, sin) pair for each token's position.
        # The resulting shape will be (..., seq_len, d_k / 2)
        cos = self.cos_cached[token_positions]
        sin = self.sin_cached[token_positions]

        # Reshape x to split the last dimension into pairs for rotation.
        # x_orig shape: (..., seq_len, d_k)
        # x_reshaped shape: (..., seq_len, d_k / 2, 2)
        x_reshaped = x.float().reshape(*x.shape[:-1], -1, 2)

        # Split into the real (x_r) and imaginary (x_i) parts, corresponding to
        # even and odd dimensions.
        x_r, x_i = x_reshaped.unbind(-1)

        # The RoPE transformation is equivalent to complex number multiplication:
        # (x_r + i*x_i) * (cos + i*sin) = (x_r*cos - x_i*sin) + i*(x_r*sin + x_i*cos)

        # Apply the rotation to the even-indexed dimensions (real part)
        # Shape: (..., seq_len, d_k / 2)
        rotated_r = x_r * cos - x_i * sin

        # Apply the rotation to the odd-indexed dimensions (imaginary part)
        # Shape: (..., seq_len, d_k / 2)
        rotated_i = x_r * sin + x_i * cos

        # Stack the rotated parts back together
        # Shape: (..., seq_len, d_k / 2, 2)
        rotated_pairs = torch.stack((rotated_r, rotated_i), dim=-1)

        # Reshape back to the original input shape
        # Shape: (..., seq_len, d_k)
        rotated_x = rotated_pairs.flatten(start_dim=-2)

        return rotated_x.type_as(x)

In [None]:
# m = nn.Softmax(dim=1)
# input = torch.randn(2, 3)
# print(input)
# output = m(input)
# print(f"output = {output}")

In [None]:
def Softmax(dim: int, input: torch.Tensor):

    max_values, _ = torch.max(input, dim=dim, keepdim=True)
    # print(f"max value = {max_values}")
    final_inp = input - max_values
    # print(f"final input = {final_inp}")
    sum_val = torch.sum(torch.exp(final_inp), dim=dim, keepdim=True)
    # print(f"sum value = {sum_val}")
    result = torch.exp(final_inp) / sum_val

    return result



In [None]:
x = torch.tensor([[10, 2, 8],
                  [5, 15, 9],
                  [1, 6, 12]], dtype=torch.float32)

y = Softmax(-1, x)
y

# m = Softmax(-2, x)
# m


tensor([[8.8054e-01, 2.9539e-04, 1.1917e-01],
        [4.5286e-05, 9.9748e-01, 2.4725e-03],
        [1.6660e-05, 2.4726e-03, 9.9751e-01]])

In [None]:
class Head(nn.Module):

    def __init__(self, d_model: int, head_size: int):

        super().__init__()

        self.d_model = d_model
        self.head_size = head_size

        self.W_q = Linear(head_size, d_model)
        self.W_k = Linear(head_size, d_model)
        self.W_v = Linear(head_size, d_model)

    def forward(self, x: torch.Tensor):

        _, seq_len, _ = x.shape
        #W -->            (d_model, head_size)
        #x --> (batch_size, seq_len, d_model)
        # print(f"x shape = {x.shape}")
        # print(f"W_q shape = {self.W_q.weight.shape}")
        q = self.W_q(x) # (batch_size, seq_len, d_k)
        # print(f"q shape = {q.shape}")
        k = self.W_k(x)
        # print(f"k shape = {k.shape}")
        v = self.W_v(x)
        # print(f"v shape = {v.shape}")

        # to apply RoPE Embeddings to --> q and k
        self.rope = RotaryPositionalEmbedding(
            theta = 10000.0,
            d_k = self.head_size,
            max_seq_len = seq_len
        )

        token_positions = torch.arange(seq_len, device=device)

        query = self.rope(q, token_positions)
        # print(f"query shape after RoPE = {query.shape}")
        keys = self.rope(k, token_positions)
        # print(f"keys shape after RoPE = {keys.shape}")
        values = v

        return self.scaled_dot_product_attention(query, keys, values, attn_mask=None, scale=None, is_causal=True).to(device)

    def scaled_dot_product_attention(self, query: torch.Tensor, keys: torch.Tensor, values: torch.Tensor,
                        attn_mask: torch.Tensor, scale = float, is_causal = bool) -> torch.Tensor:
        # """
        # Given key (K), query (Q), and value (V) tensors, return
        # the output of your scaled dot product attention implementation.

        # Args:
        #     Q (Float[Tensor, " ... queries d_k"]): Query tensor
        #     K (Float[Tensor, " ... keys d_k"]): Key tensor
        #     V (Float[Tensor, " ... values d_v"]): Values tensor
        #     mask (Float[Tensor, " ... queries keys"] | None): Mask tensor
        # Returns:
        #     Float[Tensor, " ... queries d_v"]: Output of SDPA
        # """
        # print(f"query shape = {query.shape}")
        # print(f"keys shape = {keys.shape}")
        # print(f"values shape = {values.shape}")
        scale_factor = 1 / math.sqrt(query.size(-1)) if scale is None else scale
        L = query.size(-2)
        S = keys.size(-2)
        attn_bias = torch.zeros(L, S, dtype = query.dtype).to(device)

        if attn_mask is not None:
            attn_mask.to(device)
            if attn_mask.dtype == torch.bool:
                attn_bias.masked_fill(attn_mask.logical_not(), float("-inf"))
            else:
                attn_bias += attn_mask

        if is_causal:
            assert attn_mask is None
            temp_mask = torch.ones(L, S, dtype = torch.bool).tril(diagonal=0).to(device)
            attn_bias.masked_fill(temp_mask.logical_not(), float("-inf"))
            attn_bias.to(query.dtype)

        attn_wei = ((query @ keys.transpose(-2, -1)) * scale_factor).to(device)
        # print(f"attn_wei shape = {attn_wei.shape}")
        # print(f"attn_bias shape = {attn_bias.shape}")
        attn_wei += attn_bias.to(attn_wei.device)
        # print(f"attn_wei after adding bias shape = {attn_wei.shape}")
        softmax_attn_wei = Softmax(dim=-1, input=attn_wei).to(attn_wei.device)
        # print(f"softmax_attn_wei shape = {softmax_attn_wei.shape}")
        res = softmax_attn_wei @ values
        # print(f"res shape = {res.shape}")
        return res.to(device)


In [None]:
class MultiheadAttention(nn.Module):

    def __init__(self, d_model: int, num_heads: int):
        super().__init__()

        self.d_model = d_model
        self.num_heads = num_heads
        self.head_size = d_model // num_heads # --> d_k or d_v

        self.heads = nn.ModuleList([Head(d_model, self.head_size) for _ in range(num_heads)])
        self.wo = Linear(self.d_model, self.d_model)

    def forward(self, x: torch.Tensor):

        out = torch.cat([h(x) for h in self.heads], dim=-1)
        # print(f"Concatenated heads shape = {out.shape}")
        res = self.wo(out)
        # print(f"MultiheadAttention output shape = {res.shape}")
        return res


In [None]:
class transformer_block(nn.Module):

    def __init__(self, d_model: int, num_heads: int):
        super().__init__()

        self.mha = MultiheadAttention(d_model, num_heads)
        self.ffn = SwiGLU_FeedForward(d_model)
        self.rmsnorm = RMSNorm(d_model, eps=1e-5)

    def forward(self, x: torch.Tensor):

        y = x + self.mha(self.rmsnorm(x))
        # print(f"y shape after MHA = {y.shape}")
        y = y + self.ffn(self.rmsnorm(y))
        return y


In [None]:
class transformer_lm(nn.Module):

    def __init__(self, context_length: int, d_model: int, num_layers: int):
        super().__init__()

        self.num_layers = num_layers
        self.context_length = context_length
        self.d_model = d_model

        self.token_embedding_table = Embedding(vocab_size, d_model)
        self.position_embedding_table = Embedding(context_length, d_model)

        self.blocks = nn.Sequential(*[transformer_block(d_model, num_heads) for _ in range(num_layers)])
        self.ln_f = RMSNorm(d_model, eps=1e-5)
        self.lm_head = Linear(vocab_size, d_model)

    def forward(self, predicts: torch.Tensor):

        B, T = predicts.shape

        token_emb = self.token_embedding_table(predicts)
        pos_emb = self.position_embedding_table(torch.arange(T, device=device))

        x = token_emb + pos_emb #(B, T, d_model)
        x = self.blocks(x)
        x = self.ln_f(x)
        logits = self.lm_head(x) #(B, T, vocab_size)

        return logits

    def generate(self, index_vectors, max_tokens):

        for _ in range(max_tokens):
            index_vectors = index_vectors[:, -self.context_length:]
            logits = self(index_vectors)
            logits = logits[:, -1, :]  # Focus on the last time step
            probs = Softmax(dim=-1, input=logits)
            next_token = torch.multinomial(probs, num_samples=1)  # Sample the next token
            index_vectors = torch.cat((index_vectors, next_token), dim=1)

        return index_vectors


In [None]:
def cross_entropy(logits: torch.Tensor, targets: torch.Tensor):

    # logits --> (batch_size, num_classes)
    # targets -> (batch_size,)
    batch_size, _ = logits.shape
    # 1. We are doing this step to stabilize to not tend to inf
    max_values, _ = torch.max(logits, dim=-1, keepdim=True)
    stabilized_values = logits - max_values

    # 2. Now we will calculate log-softmax
    stabilized_exp = torch.exp(stabilized_values)
    stabilized_exp_sum = torch.sum(stabilized_exp, dim=-1)

    #. taking log
    log_stabilized = torch.log(stabilized_exp_sum) + max_values.squeeze(-1)

    # 3. Extract values for true classes
    row_indices = torch.arange(batch_size) #batch_size
    true_logits = logits[row_indices, targets]

    loss_per_sample = log_stabilized - true_logits

    loss = torch.mean(loss_per_sample)

    return loss


In [None]:
class AdamW(optim.Optimizer):
    """
    Implements the AdamW algorithm (Decoupled Weight Decay Regularization).

    Parameters:
        params (iterable): An iterable of torch.Tensor or dicts.
        lr (float): Learning rate (default: 1e-3).
        betas (Tuple[float, float]): Coefficients for running averages (default: (0.9, 0.999)).
        eps (float): Term added to denominator for numerical stability (default: 1e-8).
        weight_decay (float): Weight decay (L2 penalty) coefficient (default: 1e-2).
    """

    def __init__(
        self,
        params: Iterable[torch.nn.Parameter],
        lr: float = 1e-3,
        betas: tuple[float, float] = (0.9, 0.999),
        eps: float = 1e-8,
        weight_decay: float = 1e-2
    ):
        # Validate inputs
        if lr < 0.0:
            raise ValueError(f"Learning rate should be >= 0: {lr}")
        if not (0.0 <= betas[0] < 1.0):
            raise ValueError(f"beta1 should be in [0, 1): {betas[0]}")
        if not (0.0 <= betas[1] < 1.0):
            raise ValueError(f"beta2 should be in [0, 1): {betas[1]}")
        if eps < 0.0:
            raise ValueError(f"Epsilon should be > 0: {eps}")
        if weight_decay < 0.0:
            raise ValueError(f"Weight decay should be >= 0: {weight_decay}")

        defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay)
        super(AdamW, self).__init__(params, defaults)

    def __setstate__(self, state: Any) -> None:
        """Support for pickling."""
        super(AdamW, self).__setstate__(state)
        # Ensure all parameter groups have the 'weight_decay' key in defaults
        for group in self.param_groups:
            group.setdefault('weight_decay', 1e-2)

    @torch.no_grad()
    def step(self, closure: Optional[callable] = None) -> Optional[torch.Tensor]:
        """
        Performs a single optimization step.

        Args:
            closure (callable, optional): A closure that reevaluates the model and returns the loss.

        Returns:
            Optional loss if closure is provided.
        """
        loss = None
        if closure is not None:
            with torch.enable_grad():
                loss = closure()

        for group in self.param_groups:
            beta1, beta2 = group['betas']
            weight_decay = group['weight_decay']
            eps = group['eps']
            lr = group['lr']

            for p in group['params']:
                if p.grad is None:
                    continue

                grad = p.grad
                if grad.is_sparse:
                    raise RuntimeError('AdamW does not support sparse gradients')

                # Get or initialize state
                state = self.state[p]
                if len(state) == 0:
                    state['step'] = 0
                    state['exp_avg'] = torch.zeros_like(p, memory_format=torch.preserve_format)
                    state['exp_avg_sq'] = torch.zeros_like(p, memory_format=torch.preserve_format)

                state['step'] += 1
                step = state['step']

                exp_avg = state['exp_avg']
                exp_avg_sq = state['exp_avg_sq']

                # Decay the first and second moment running averages
                exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1)
                exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)

                # Inside step()
                bias_correction1 = 1 - beta1 ** step
                bias_correction2 = 1 - beta2 ** step

                # Bias-corrected moments
                hat_exp_avg = exp_avg / bias_correction1
                hat_exp_avg_sq = exp_avg_sq / bias_correction2

                # Denominator: sqrt(v_hat) + eps
                denom = hat_exp_avg_sq.sqrt().add_(eps)

                # Apply weight decay (AdamW: decoupled)
                if weight_decay != 0:
                    p.data.mul_(1 - lr * weight_decay)

                # Update: θ ← θ - η * m_hat / sqrt(v_hat)
                p.addcdiv_(hat_exp_avg, denom, value=-lr)

        return loss

In [None]:
def learning_rate_scheduling(current_step: int, alpha_min: float, alpha_max: float, Tw: int, Tc: int):

    t = current_step
    final_alpha = 0.0

    if t < Tw:
        final_alpha = (t / Tw) * alpha_max
    elif t >= Tw and t <= Tc:
        x = (t - Tw) / (Tc - Tw)
        final_alpha = alpha_min + (0.5 * (1 + math.cos(x * math.pi)) * (alpha_max - alpha_min))
    else:
        final_alpha = alpha_min

    return final_alpha


In [None]:
def gradient_clipping(parameters, M):

    l2_norm = 0.0

    for p in parameters:
        if p.grad is not None:
            p_grad = p.grad
            l2_norm += (p_grad ** 2).sum()

    l2_norm = l2_norm.sqrt().item()
    fact = M / (l2_norm + 1e-6)

    if l2_norm > M:
        for p in parameters:
            if p.grad is not None:
                p.grad.data.mul_(fact)


In [None]:
# def data_loading(x: torch.Tensor, batch_size: int, context_len: int, device: torch.device):

#     total_tokens_needed = batch_size * context_len + 1

#     x_trimmed = x[:total_tokens_needed]

#     input_seq = x_trimmed[:-1].view(batch_size, context_len)
#     target_seq = x_trimmed[1:].view(batch_size, context_len)

#     return input_seq.to(device), target_seq.to(device)


class MemMapDataSet(Dataset):

    def __init__(self, file_name: str, context_length: int):

        self.data = np.memmap(file_name, dtype=np.int32, mode='r')
        self.context_length = context_length

    def __len__(self):
        return max(1, len(self.data) - self.context_length)

    def __getitem__(self, idx):

        x_trimmed = self.data[idx : idx + self.context_length + 1]

        input_seq = torch.from_numpy(x_trimmed[:-1]).long()
        target_seq = torch.from_numpy(x_trimmed[1:]).long()

        return input_seq, target_seq


In [None]:
def save_checkpoint(model: torch.nn.Module, optimizer: torch.optim.Optimizer, loss: float, step: int, filepath: str):
    checkpoint = {
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'step': step,
        'loss': loss
    }
    torch.save(checkpoint, filepath)
    print(f"Checkpoint saved at step {step} with loss {loss:.4f} to {filepath}")

In [None]:
def load_checkpoint(model: torch.nn.Module, optimizer: torch.optim.Optimizer, filepath: str):
    checkpoint = torch.load(filepath)
    model.load_state_dict(checkpoint['model_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    step = checkpoint['step']
    loss = checkpoint['loss']
    print(f"Checkpoint loaded from {filepath} at step {step} with loss {loss:.4f}")
    return step, loss

In [None]:
with open("tinystories_sample.txt", 'r', encoding='utf-8') as f:
    text_data = f.read()

In [None]:
text_data[:100]

'\nOnce upon a time there was a little boy named Ben. Ben loved to explore the world around him. He sa'

In [None]:
tokens = encode(text_data)

In [None]:
len_tokens = len(tokens)
# tokens[:100]
train_tokens_len = int(0.9 * len_tokens)
train_tokens = tokens[:train_tokens_len]
val_tokens = tokens[train_tokens_len:]

In [None]:
# val_tokens
len(train_tokens), len(val_tokens)

(3240, 361)

In [None]:
train_data = np.array(train_tokens, dtype=np.int32)
val_data = np.array(val_tokens, dtype=np.int32)
train_data[:10], val_data[:10]

(array([ 10,  79, 110,  99, 256, 117, 112, 111, 110,  32], dtype=int32),
 array([ 98, 111, 117, 116,  32, 116, 104, 256, 115, 112], dtype=int32))

In [None]:
train_data.tofile("train.bin")
print("train_dat is written on train.bin file successfully")

train_dat is written on train.bin file successfully


In [None]:
val_data.tofile("val.bin")
print("val_data is written on val.bin file successfully")

val_data is written on val.bin file successfully


In [None]:
val_dataset = MemMapDataSet("val.bin", context_length=context_length)
val_dataloader = DataLoader(val_dataset, batch_size=batch_size, num_workers = 2, shuffle=True)

In [None]:
def evaluate_model(model: torch.nn.Module, dataloader: DataLoader, device: torch.device):

    model.eval()
    total_loss = 0.0
    counts = 0

    with torch.no_grad():
        for x, y in dataloader:
            xb, yb = x.to(device), y.to(device)
            logits = model(xb)
            B, T, C = logits.shape
            loss = cross_entropy(logits.view(B * T, C), yb.view(B * T))
            total_loss += loss.item()
            counts += 1

            if counts >= 1000:
              break

    # print(f"Total evaluation batches: {total_batches}")
    avg_loss = total_loss / counts if counts > 0 else float('inf')
    model.train()

    return avg_loss

In [None]:
# data = np.memmap('train.bin', dtype=np.int32, mode="r")
# idx = torch.from_numpy(data[:100])  # input + target
# x = idx[:-1]  # input
# y = idx[1:]  # target

In [None]:
# print("x:", x)
# print("y:", y)

In [None]:
lm = transformer_lm(context_length, d_model, n_layers)

print(sum(p.numel() for p in lm.parameters()) / 1e6, "M parameters")

optimizer = AdamW(lm.parameters(), lr=max_lr, betas=(0.9, 0.95), eps=eps, weight_decay=weight_decay)

train_dataset = MemMapDataSet("train.bin", context_length)
train_data_loader = DataLoader(train_dataset, batch_size=batch_size, num_workers=2, shuffle=True)

# for x, y in train_data_loader:
#     print(f"x shape = {x.shape}, y shape = {y.shape}")
#     print(x)
#     print(y)
#     break

lm.train()
steps = 0

for x, y in train_data_loader:

    if steps >= max_steps:
        break

    steps += 1

    x = x.to(device)
    y = y.to(device)

    # forward pass
    logits = lm(x)
    loss = cross_entropy(logits.view(-1, vocab_size), y.view(-1))

    optimizer.zero_grad()

    # backward pass
    loss.backward()

    # gradient clipping and learning rate scheduling
    gradient_clipping(lm.parameters(), grad_clip)

    lr = learning_rate_scheduling(steps, min_lr, max_lr, warmup_steps, total_cycle_steps)
    for param_group in optimizer.param_groups:
        param_group['lr'] = lr

    # update parameters
    optimizer.step()

    if steps % 1000 == 0:
        print(f"Step {steps}, Loss: {loss.item():.4f}")
        save_checkpoint(lm, optimizer, loss.item(), steps, "model_checkpoint.txt")
        val_loss = evaluate_model(lm, val_dataloader, device)
        print(f"Validation Loss after step {steps}: {val_loss:.4f}")



1.509504 M parameters


In [None]:
story = lm.generate(index_vectors=torch.tensor([encode("Once upon a time,")], device=device), max_tokens=1000)
print(decode(story[0].tolist()))

reom that day on, e e wbcwrrame 's onece bunnnenoeeencp. neennenzrrruennnnnneenneeunnn
