In [186]:
import torch
import torch.nn as nn
from torch.nn import functional as F
import math

In [5]:
device

'cuda'

In [170]:
batch_size = 32 # how many independent sequences will we process in parallel?
block_size = 64 # what is the maximum context length for predictions?
max_iters = 4000 # the number of training iterations
eval_interval = 200 # how often to evaluate the loss
learning_rate = 1e-3 # learning rate
device = 'cuda' if torch.cuda.is_available() else 'cpu'
eval_iters = 200 # number of iterations for evaluation, from random samples
embed_size = 64 # embedding size
split_ratio = 0.9 # training, val split ratio
num_heads = 4
num_blocks = 12
dropout = 0.0
# ------------

In [6]:
!wget https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt

--2026-01-15 17:49:43--  https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.108.133, 185.199.109.133, 185.199.110.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.108.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 1115394 (1.1M) [text/plain]
Saving to: ‘input.txt’


2026-01-15 17:49:43 (28.1 MB/s) - ‘input.txt’ saved [1115394/1115394]



In [7]:
with open('input.txt', 'r', encoding='utf-8') as f:
    text = f.read()

In [8]:
chars = sorted(list(set(text)))
vocab_size = len(chars)
# create a mapping from characters to integers
stoi = { ch:i for i,ch in enumerate(chars) }
itos = { i:ch for i,ch in enumerate(chars) }
encode = lambda s: [stoi[c] for c in s] # encoder: take a string, output a list of integers
decode = lambda l: ''.join([itos[i] for i in l]) # decoder: take a list of integers, output a string


In [136]:
len(text)

1115394

In [40]:
class modeldata:
    def __init__(self, data, tag, split_ratio):
        self.data =data
        self.tag = tag
        self.split_ratio = split_ratio
    
    def __getdata__(self):
        n = int(self.split_ratio * len(self.data))
        if self.tag == 'train':
            out = self.data[:n]
        if self.tag == 'val':
            out = self.data[n:]
        return out

In [42]:
# Train and test splits
data = torch.tensor(encode(text), dtype=torch.long)
train_data = modeldata(data, "train", split_ratio).__getdata__()
val_data = modeldata(data, "val", split_ratio).__getdata__()

In [43]:
train_data.device

device(type='cpu')

In [47]:
class dataloaderlite:
    def __init__(self, data, block_size, batch_size, device, shuffle=True, tag=None):
        self.data = data
        self.L = block_size
        self.B= batch_size
        self.current_pos = 0
        self.shuffle = shuffle
        self.device = device
        self.N = len(data)
        self.tag = tag

    def __iter__(self):
        return self

    def __next__(self):
        if self.shuffle:
            # shuffle mode: sample random batches
            idx = torch.randint(0,self.N - self.L - 1, (self.B,))
        else:
            # sequential mode: get batches in order, mainly for eval
            if self.current_pos + self.B * self.L >= self.N:
                self.current_pos = 0  # reset pointer if we reach the end
                raise StopIteration
            idx = torch.arange(self.current_pos, self.current_pos + self.B * self.L, self.L)
            self.current_pos += self.B * self.L
        # x is from i -> i + L token index
        # y is from 1+1 -> i+ 1 + L token index    
        x = torch.stack([self.data[i:i+self.L] for i in idx])
        y = torch.stack([self.data[i+1:i+self.L+1] for i in idx])    
        return x.to(self.device), y.to(self.device)

In [46]:
class BaselineModel(nn.Module):
    def __init__(self, vocab_size, embed_size):
        super().__init__()
        self.embed_table = nn.Embedding(vocab_size, embed_size)
        self.FFN = nn.Linear(embed_size, vocab_size)
    
    def forward(self, x_id_matrix, y_id_matrix=None):
        # id_matrix: (B, L) is the input token indices for B batches, each batch has sequence length L
        # d: the embed size
        # v: the vocab size
        B, L = x_id_matrix.shape
        token_emb = self.embed_table(x_id_matrix) # (B, L, d)
        logits = self.FFN(token_emb) # (B, L, vocab_size)
        if y_id_matrix is None:
            loss = None
        else:
            B, L, v = logits.shape
            # reshape the logits and targets to compute the cross-entropy loss
            logits = logits.view(B*L, v)
            targets = y_id_matrix.view(B*L)
            loss = F.cross_entropy(logits, targets)
        return logits, loss
    
    def generate(self, x_id_matrix, max_new_tokens):
        B, L = x_id_matrix.shape
        for _ in range(max_new_tokens):
            logits, _ = self(x_id_matrix)
            # focus only on the last time step
            logits = logits[:, -1, :] # (B, vocab_size)
            probs = F.softmax(logits, dim=-1) # (B, vocab_size)
            # sample from the distribution
            next_id = torch.multinomial(probs, num_samples=1) # (B, 1)
            # append sampled index to the running sequence
            x_id_matrix= torch.cat((x_id_matrix, next_id), dim=1) # (B, L+1)
        return x_id_matrix
       

In [153]:
train_loader = dataloaderlite(train_data, block_size, batch_size, device, shuffle=True, tag="train")
val_loader = dataloaderlite(val_data, block_size, batch_size, device, shuffle=False, tag="val")

In [116]:
model = BaselineModel(vocab_size, embed_size*10)
model = model.to(device)
# create a PyTorch optimizer
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)

In [128]:
@torch.no_grad()
def loss_estimation(model, grad_norm=False):
    output = {}
    model.eval()
    train_data_loader_estimation = dataloaderlite(train_data, block_size, batch_size, device, shuffle=True, tag="train")
    val_data_loader_estimation = dataloaderlite(val_data, block_size, batch_size, device, shuffle=True, tag="val")
    for data_obj in [train_data_loader_estimation, val_data_loader_estimation]:
        losses = torch.zeros(eval_iters)
        for k in range(eval_iters):
            x, y = data_obj.__next__()
            logits, loss = model(x, y)
            losses[k] = loss.item()
        output[data_obj.tag] = losses.mean().item()
    if grad_norm:
        for name, p in model.named_parameters():
            if p.grad is not None:
                 print(f"{name}: {p.grad.norm().item():.4f}")
    model.train()
    return output

In [118]:
for iter in range(max_iters):
    # every once in a while evaluate the loss on train and val sets
    if iter % eval_interval == 0:
        losses = loss_estimation(model)
        print(f"step {iter}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}")

    # sample a batch of data
    xb, yb = train_loader.__next__()

    # evaluate the loss
    logits, loss = model(xb, yb)
    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    optimizer.step()

step 0: train loss 4.2737, val loss 4.2735
step 200: train loss 2.5432, val loss 2.5694
step 400: train loss 2.5192, val loss 2.5557
step 600: train loss 2.5133, val loss 2.5589
step 800: train loss 2.5133, val loss 2.5401
step 1000: train loss 2.5094, val loss 2.5315
step 1200: train loss 2.5106, val loss 2.5374
step 1400: train loss 2.5149, val loss 2.5333
step 1600: train loss 2.5144, val loss 2.5490
step 1800: train loss 2.5067, val loss 2.5517
step 2000: train loss 2.5160, val loss 2.5408
step 2200: train loss 2.4994, val loss 2.5355
step 2400: train loss 2.5150, val loss 2.5325
step 2600: train loss 2.4999, val loss 2.5330
step 2800: train loss 2.4942, val loss 2.5299
step 3000: train loss 2.5082, val loss 2.5401
step 3200: train loss 2.4930, val loss 2.5244
step 3400: train loss 2.5052, val loss 2.5422
step 3600: train loss 2.4990, val loss 2.5308
step 3800: train loss 2.5013, val loss 2.5284


In [151]:
context = torch.zeros((1, 1), dtype=torch.long, device=device)
generated_context_ids = model.generate(context, max_new_tokens=1000)
print(decode(generated_context_ids[0].tolist()))



I'd payou ead me
BAThoweve,
Tof s, qurarore gerilyrsondghys R blere ar urss bld u d wachaigureap anuneilf an ppirear a g th nof nd wicowhy:
Byowindiveatirw'su quenglle V: wha pe ghat wsod fut'd f aphay:
Lowilichesor us s CO t:--lerand oof giearelos s'd alal me was
Thome n forind Talingnen at n ubu sse pr s whee tate's st fumy t s, se wahfongisow,
ARD:
Coomo, geind,
Fire
ARCImowir

ARAng;
Yond aleed y,
ARDo wnd by beake,
THangaithilond y, thenct t he, wize isard oures:
MEESerien
GLusel ty, rak wead IFOf hichilt plearondet ond we, ty,
Frre w! ar bar tis
Nowind nors w mifu blld.
War bl cis Tones en I cind ILAn hy, hory ly, mactr w blt wat ave te ETrs:
Forise.
ARKis-myhu ff nonouns f omath'l g CimbeEsos pule baurloyo muronderalite CHAHeril nonct t t, yot lodghermesesend ha VEvest ausalararse fe pe weld
Ind?
Andeararenofo G testes D:
m wis sd sers pe, g heptissof forsouglit:
Fovereden s bilshet izeis.
Why,'Fr, balowarard ink uourd
S: led im'd heas thonseso waran windonoflinn IRIUSTh s! ly.

In [200]:
class GPT(nn.Module):
    def __init__(self, 
                 vocab_size, 
                 embed_size, 
                 block_size, 
                 n_layers, 
                 num_heads, 
                 weight_init=False):
        super().__init__()
        self.token_emb_table = nn.Embedding(vocab_size, embed_size)
        self.pos_emb_table = nn.Embedding(block_size, embed_size)
        self.attention_blocks = nn.ModuleList(Block(num_heads, embed_size, block_size)
                                               for _ in range(n_layers))
        self.LM = nn.Linear(embed_size, vocab_size, bias=False)
        self.ln = nn.LayerNorm(embed_size)
        self.block_size = block_size
        self.n_layers = n_layers
        self.vocab_size = vocab_size
        self.embed_size = embed_size
        self.num_heads = num_heads
        if weight_init:
            self.apply(self._init_weights)
            self._scale_residual_projections()
    
    def _init_weights(self, module: nn.Module):
        if isinstance(module, nn.Linear):
            nn.init.normal_(module.weight, mean=0.0, std=0.02)
            if module.bias is not None:
                nn.init.zeros_(module.bias)
        elif isinstance(module, nn.Embedding):
            nn.init.normal_(module.weight, mean=0.0, std=0.02)
    
    def _scale_residual_projections(self):
        # scale output projections in residual branches (proj layers)
        # This is a common trick used in GPT implementations.
        for name, p in self.named_parameters():
            if name.endswith("proj.weight") or name.endswith("fc2.weight"):
                p.data /= math.sqrt(2.0 * self.n_layers)
    
    def forward(self, x_id_matrix, y_id_matrix=None):
        B, L = x_id_matrix.shape
        token_emb = self.token_emb_table(x_id_matrix)
        # it's very important to use the current sequence length L here
        pos = torch.arange(L, device=x_id_matrix.device)
        pos_emb = self.pos_emb_table(pos)
        x = token_emb + pos_emb
        for block in self.attention_blocks:
          x = block(x)
        # LM head
        logits = self.LM(self.ln(x))
        if y_id_matrix == None:
            loss = 0.0 
        else:      
            targets = y_id_matrix.view(B*L)
            logits = logits.view(B*L, self.vocab_size)
            loss = F.cross_entropy(logits, targets)
        return logits, loss
    
    def generate(self, x_id_matrix, max_new_tokens, top_k):
        B, L = x_id_matrix.shape
        out = x_id_matrix.clone()
        for _ in range(max_new_tokens):
            if x_id_matrix.shape[1] > self.block_size:
                x_id_matrix = x_id_matrix[:, -self.block_size:] 
            
            logits, _ = self(x_id_matrix)
            # focus only on the last time step
            logits = logits[:, -1, :] # (B, vocab_size)
            top_k_logits, top_k_indices = torch.topk(logits, k=top_k)
            top_k_probs = F.softmax(top_k_logits, dim=-1) # (B, vocab_size)
            # sample from the distribution
            sampled = torch.multinomial(top_k_probs, num_samples=1) # (B, 1)
            next_id = torch.gather(top_k_indices, -1, sampled)  # (B, 1)
            # append sampled index to the running sequence
            x_id_matrix= torch.cat((x_id_matrix, next_id), dim=1) # (B, L+1)
            out = torch.cat((out, next_id), dim=1)
        return out
    

In [172]:
class MultiHeadAttention(nn.Module):
    def __init__(self, num_heads, embed_size, block_size):
        super().__init__()
        self.num_heads = num_heads
        self.attention_heads = nn.ModuleList(MaskedAttention(num_heads, embed_size, block_size) 
                                             for _ in range(num_heads)) 
        self.proj = nn.Linear(embed_size, embed_size)
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, x):
        out = torch.cat([h(x) for h in self.attention_heads], dim=-1)
        out = self.dropout(self.proj(out))
        return out # (B, L, d)
        

In [173]:
class FFN(nn.Module):
    def __init__(self, embed_size):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(embed_size, 4*embed_size),
            nn.ReLU(),
            nn.Linear(4*embed_size, embed_size),
            nn.Dropout(dropout)
        )
    
    def forward(self, x):
        return self.net(x)

In [174]:
class Block(nn.Module):
    def __init__(self, num_heads, embed_size, block_size):
        super().__init__()
        self.mha = MultiHeadAttention(num_heads, embed_size, block_size)
        self.ln1 = nn.LayerNorm(embed_size)
        self.ffn = FFN(embed_size)
        self.ln2 = nn.LayerNorm(embed_size)
    
    def forward(self, x):
        x = x + self.mha(self.ln1(x))
        x = x + self.ffn(self.ln2(x))
        return x

In [175]:
class MaskedAttention(nn.Module):
    def __init__(self, num_heads, embed_size, block_size):
        super().__init__()
        self.head_size = embed_size // num_heads
        self.q_proj = nn.Linear(embed_size, self.head_size, bias=False)
        self.k_proj = nn.Linear(embed_size, self.head_size, bias=False)
        self.v_proj = nn.Linear(embed_size, self.head_size, bias=False)
        self.dropout = nn.Dropout(dropout)
        #self.register_buffer('mask',torch.tril(torch.ones(batch_size, block_size, block_size, device=device,requires_grad=False)))

    def forward(self, x):
        # (B, L, d/H)
        B, L, _ = x.shape
        q, k ,v = self.q_proj(x), self.k_proj(x), self.v_proj(x)
        kt = k.transpose(-1,-2)
        #attention_matrix = F.softmax(q@kt, dim=-1)
        # (B, L, L)
        #attention_matrix[self.mask == 0] = -torch.inf
        att = q@kt/ (self.head_size**0.5)
        # magic trick for double softmax      
        # att = F.softmax(att, dim=-1)
        mask = torch.tril(torch.ones(B, L, L, device=device,requires_grad=False))
        att = att.masked_fill(mask == 0, float('-inf'))
        weights = F.softmax(att, dim=-1)
        weights = self.dropout(weights)
        att_output = weights@v
        #residual connection
        return att_output 

In [201]:
model_gpt = GPT(vocab_size, embed_size, block_size, num_blocks, num_heads, weight_init=True )
model_gpt = model_gpt.to(device)
# create a PyTorch optimizer
optimizer_gpt = torch.optim.AdamW(model_gpt.parameters(), lr=learning_rate)

In [202]:
train_loader = dataloaderlite(train_data, block_size, batch_size, device, shuffle=True, tag="train")
val_loader = dataloaderlite(val_data, block_size, batch_size, device, shuffle=True, tag="val")

In [203]:
for iter in range(max_iters):
    # every once in a while evaluate the loss on train and val sets
    if iter % eval_interval == 0:
        losses = loss_estimation(model_gpt,grad_norm=False)
        print(f"step {iter}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}")

    # sample a batch of data
    xb, yb = train_loader.__next__()

    # evaluate the loss
    logits_gpt, loss_gpt = model_gpt(xb, yb)
    optimizer_gpt.zero_grad(set_to_none=True)
    loss_gpt.backward()
    optimizer_gpt.step()

step 0: train loss 4.1820, val loss 4.1835
step 200: train loss 2.3947, val loss 2.3976
step 400: train loss 2.2066, val loss 2.2304
step 600: train loss 2.0049, val loss 2.0582
step 800: train loss 1.8514, val loss 1.9431
step 1000: train loss 1.7509, val loss 1.8902
step 1200: train loss 1.6774, val loss 1.8422
step 1400: train loss 1.6167, val loss 1.7948
step 1600: train loss 1.5809, val loss 1.7576
step 1800: train loss 1.5534, val loss 1.7337
step 2000: train loss 1.5219, val loss 1.7062
step 2200: train loss 1.5000, val loss 1.7027
step 2400: train loss 1.4823, val loss 1.6830
step 2600: train loss 1.4722, val loss 1.6698
step 2800: train loss 1.4559, val loss 1.6659
step 3000: train loss 1.4487, val loss 1.6621
step 3200: train loss 1.4405, val loss 1.6497
step 3400: train loss 1.4341, val loss 1.6332
step 3600: train loss 1.4265, val loss 1.6299
step 3800: train loss 1.4132, val loss 1.6135


In [None]:
context = torch.zeros((1, 1), dtype=torch.long, device=device)
generated_context_ids = model_gpt.generate(context, max_new_tokens=500, top_k=10)
print(decode(generated_context_ids[0].tolist()))

In [182]:
num_trainable = sum(p.numel() for p in model_gpt.parameters() if p.requires_grad)
print(num_trainable)

639233


In [185]:
for name, p in model_gpt.named_parameters():
    if p.requires_grad:
        size = p.numel()
        print(f"{name}: {size}")

token_emb_table.weight: 4160
pos_emb_table.weight: 4096
attention_blocks.0.mha.attention_heads.0.q_proj.weight: 1024
attention_blocks.0.mha.attention_heads.0.k_proj.weight: 1024
attention_blocks.0.mha.attention_heads.0.v_proj.weight: 1024
attention_blocks.0.mha.attention_heads.1.q_proj.weight: 1024
attention_blocks.0.mha.attention_heads.1.k_proj.weight: 1024
attention_blocks.0.mha.attention_heads.1.v_proj.weight: 1024
attention_blocks.0.mha.attention_heads.2.q_proj.weight: 1024
attention_blocks.0.mha.attention_heads.2.k_proj.weight: 1024
attention_blocks.0.mha.attention_heads.2.v_proj.weight: 1024
attention_blocks.0.mha.attention_heads.3.q_proj.weight: 1024
attention_blocks.0.mha.attention_heads.3.k_proj.weight: 1024
attention_blocks.0.mha.attention_heads.3.v_proj.weight: 1024
attention_blocks.0.mha.proj.weight: 4096
attention_blocks.0.mha.proj.bias: 64
attention_blocks.0.ln1.weight: 64
attention_blocks.0.ln1.bias: 64
attention_blocks.0.ffn.net.0.weight: 16384
attention_blocks.0.ffn.ne