In [2]:
import torch
import torch.nn as nn
from torch.nn import functional as F

In [154]:
batch_size = 32 # how many independent sequences will we process in parallel?
block_size = 8 # what is the maximum context length for predictions?
max_iters = 4000 # the number of training iterations
eval_interval = 200 # how often to evaluate the loss
learning_rate = 1e-3 # learning rate
device = 'cuda' if torch.cuda.is_available() else 'cpu'
eval_iters = 200 # number of iterations for evaluation, from random samples
embed_size = 64 # embedding size
split_ratio = 0.9 # training, val split ratio
num_heads = 4
num_blocks = 2
# ------------

In [4]:
!wget https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt

--2026-01-15 03:32:21--  https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.110.133, 185.199.111.133, 185.199.109.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.110.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 1115394 (1.1M) [text/plain]
Saving to: ‘input.txt.1’


2026-01-15 03:32:21 (19.0 MB/s) - ‘input.txt.1’ saved [1115394/1115394]



In [5]:
with open('input.txt', 'r', encoding='utf-8') as f:
    text = f.read()

In [6]:
chars = sorted(list(set(text)))
vocab_size = len(chars)
# create a mapping from characters to integers
stoi = { ch:i for i,ch in enumerate(chars) }
itos = { i:ch for i,ch in enumerate(chars) }
encode = lambda s: [stoi[c] for c in s] # encoder: take a string, output a list of integers
decode = lambda l: ''.join([itos[i] for i in l]) # decoder: take a list of integers, output a string


In [7]:
class modeldata:
    def __init__(self, data, tag):
        self.data =data
        self.tag = tag

In [8]:
# Train and test splits
data = torch.tensor(encode(text), dtype=torch.long)
n = int(split_ratio*len(data)) # first 90% will be train, rest val
train_data = modeldata(data[:n], "train")
val_data = modeldata(data[n:], "val")

In [9]:
class dataloaderlite:
    def __init__(self, data, block_size, batch_size, shuffle=True, tag=None):
        self.data = data
        self.L = block_size
        self.B= batch_size
        self.current_pos = 0
        self.shuffle = shuffle
        self.N = len(data)
        self.tag = tag

    def __iter__(self):
        return self

    def __next__(self):
        if self.shuffle:
            # shuffle mode: sample random batches
            idx = torch.randint(0,self.N - self.L - 1, (self.B,))
        else:
            # sequential mode: get batches in order, mainly for eval
            if self.current_pos + self.B * self.L >= self.N:
                self.current_pos = 0  # reset pointer if we reach the end
                raise StopIteration
            idx = torch.arange(self.current_pos, self.current_pos + self.B * self.L, self.L)
            self.current_pos += self.B * self.L
        # x is from i -> i + L token index
        # y is from 1+1 -> i+ 1 + L token index    
        x = torch.stack([self.data[i:i+self.L] for i in idx])
        y = torch.stack([self.data[i+1:i+self.L+1] for i in idx])    
        return x.to(device), y.to(device)

In [134]:
class BaselineModel(nn.Module):
    def __init__(self, vocab_size, embed_size):
        super().__init__()
        self.embed_table = nn.Embedding(vocab_size, embed_size)
        self.FFN = nn.Linear(embed_size, vocab_size)
    
    def forward(self, x_id_matrix, y_id_matrix=None):
        # id_matrix: (B, L) is the input token indices for B batches, each batch has sequence length L
        # d: the embed size
        # v: the vocab size
        B, L = x_id_matrix.shape
        token_emb = self.embed_table(x_id_matrix) # (B, L, d)
        logits = self.FFN(token_emb) # (B, L, vocab_size)
        if y_id_matrix is None:
            loss = None
        else:
            B, L, v = logits.shape
            # reshape the logits and targets to compute the cross-entropy loss
            logits = logits.view(B*L, v)
            targets = y_id_matrix.view(B*L)
            loss = F.cross_entropy(logits, targets)
        return logits, loss
    
    def generate(self, x_id_matrix, max_new_tokens):
        B, L = x_id_matrix.shape
        for _ in range(max_new_tokens):
            logits, _ = self(x_id_matrix)
            # focus only on the last time step
            logits = logits[:, -1, :] # (B, vocab_size)
            probs = F.softmax(logits, dim=-1) # (B, vocab_size)
            # sample from the distribution
            next_id = torch.multinomial(probs, num_samples=1) # (B, 1)
            # append sampled index to the running sequence
            x_id_matrix= torch.cat((x_id_matrix, next_id), dim=1) # (B, L+1)
        return x_id_matrix
       

In [65]:
train_loader = dataloaderlite(train_data.data, block_size, batch_size, shuffle=True, tag="train")
val_loader = dataloaderlite(val_data.data, block_size, batch_size, shuffle=False, tag="val")

In [148]:
model = BaselineModel(vocab_size, embed_size*10)
model = model.to(device)
# create a PyTorch optimizer
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)

In [149]:
@torch.no_grad()
def loss_estimation(model):
    output = {}
    model.eval()
    train_data_loader_estimation = dataloaderlite(train_data.data, block_size, batch_size, shuffle=True, tag="train")
    val_data_loader_estimation = dataloaderlite(val_data.data, block_size, batch_size, shuffle=True, tag="val")
    for data_obj in [train_data_loader_estimation, val_data_loader_estimation]:
        losses = torch.zeros(eval_iters)
        for k in range(eval_iters):
            x, y = data_obj.__next__()
            logits, loss = model(x, y)
            losses[k] = loss.item()
        output[data_obj.tag] = losses.mean().item()
    model.train()
    return output

In [150]:
for iter in range(max_iters):
    # every once in a while evaluate the loss on train and val sets
    if iter % eval_interval == 0:
        losses = loss_estimation(model)
        print(f"step {iter}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}")

    # sample a batch of data
    xb, yb = train_loader.__next__()

    # evaluate the loss
    logits, loss = model(xb, yb)
    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    optimizer.step()

step 0: train loss 4.2918, val loss 4.2892
step 200: train loss 2.5473, val loss 2.5620
step 400: train loss 2.5312, val loss 2.5513
step 600: train loss 2.5156, val loss 2.5495
step 800: train loss 2.5242, val loss 2.5537
step 1000: train loss 2.5208, val loss 2.5606
step 1200: train loss 2.5205, val loss 2.5494
step 1400: train loss 2.5030, val loss 2.5575
step 1600: train loss 2.5156, val loss 2.5444
step 1800: train loss 2.5074, val loss 2.5353
step 2000: train loss 2.5052, val loss 2.5242
step 2200: train loss 2.5157, val loss 2.5446
step 2400: train loss 2.5185, val loss 2.5273
step 2600: train loss 2.5041, val loss 2.5381
step 2800: train loss 2.5068, val loss 2.5328
step 3000: train loss 2.5003, val loss 2.5486
step 3200: train loss 2.5117, val loss 2.5363
step 3400: train loss 2.5032, val loss 2.5392
step 3600: train loss 2.4977, val loss 2.5387
step 3800: train loss 2.4933, val loss 2.5294


In [151]:
context = torch.zeros((1, 1), dtype=torch.long, device=device)
generated_context_ids = model.generate(context, max_new_tokens=1000)
print(decode(generated_context_ids[0].tolist()))



I'd payou ead me
BAThoweve,
Tof s, qurarore gerilyrsondghys R blere ar urss bld u d wachaigureap anuneilf an ppirear a g th nof nd wicowhy:
Byowindiveatirw'su quenglle V: wha pe ghat wsod fut'd f aphay:
Lowilichesor us s CO t:--lerand oof giearelos s'd alal me was
Thome n forind Talingnen at n ubu sse pr s whee tate's st fumy t s, se wahfongisow,
ARD:
Coomo, geind,
Fire
ARCImowir

ARAng;
Yond aleed y,
ARDo wnd by beake,
THangaithilond y, thenct t he, wize isard oures:
MEESerien
GLusel ty, rak wead IFOf hichilt plearondet ond we, ty,
Frre w! ar bar tis
Nowind nors w mifu blld.
War bl cis Tones en I cind ILAn hy, hory ly, mactr w blt wat ave te ETrs:
Forise.
ARKis-myhu ff nonouns f omath'l g CimbeEsos pule baurloyo muronderalite CHAHeril nonct t t, yot lodghermesesend ha VEvest ausalararse fe pe weld
Ind?
Andeararenofo G testes D:
m wis sd sers pe, g heptissof forsouglit:
Fovereden s bilshet izeis.
Why,'Fr, balowarard ink uourd
S: led im'd heas thonseso waran windonoflinn IRIUSTh s! ly.

In [155]:
class GPT(nn.Module):
    def __init__(self, vocab_size, embed_size, block_size, n_layers, num_heads):
        super().__init__()
        self.token_emb_table = nn.Embedding(vocab_size, embed_size)
        self.pos_emb_table = nn.Embedding(block_size, embed_size)
        self.attention_blocks = nn.ModuleList(MultiHeadAttention(num_heads, embed_size, block_size, batch_size)
                                               for _ in range(n_layers))
        self.FFN = nn.Linear(embed_size, 4*embed_size)
        self.LM = nn.Linear(4*embed_size, vocab_size)
        self.ln = nn.LayerNorm(embed_size)
        self.block_size = block_size
        self.n_layers = n_layers
        self.vocab_size = vocab_size
        self.embed_size = embed_size
        self.num_heads = num_heads
    
    def forward(self, x_id_matrix, y_id_matrix=None):
        B, L = x_id_matrix.shape
        token_emb = self.token_emb_table(x_id_matrix)
        # it's very important to use the current sequence length L here
        pos_emb = self.pos_emb_table(torch.arange(L))
        x = token_emb + pos_emb
        for block in self.attention_blocks:
          x = x + block(x)
        # LM head
        logits = self.LM(self.FFN(self.ln(x)))
        if y_id_matrix == None:
            loss = 0.0 
        else:      
            targets = y_id_matrix.view(B*L)
            logits = logits.view(B*L, self.vocab_size)
            loss = F.cross_entropy(logits, targets)
        return logits, loss
    
    def generate(self, x_id_matrix, max_new_tokens):
        B, L = x_id_matrix.shape
        out = x_id_matrix.clone()
        for _ in range(max_new_tokens):
            if x_id_matrix.shape[1] > self.block_size:
                x_id_matrix = x_id_matrix[:, -self.block_size:] 
            
            logits, _ = self(x_id_matrix)
            # focus only on the last time step
            logits = logits[:, -1, :] # (B, vocab_size)
            top_k_logits, top_k_indices = torch.topk(logits, k=5)
            top_k_probs = F.softmax(top_k_logits, dim=-1) # (B, vocab_size)
            # sample from the distribution
            sampled = torch.multinomial(top_k_probs, num_samples=1) # (B, 1)
            next_id = torch.gather(top_k_indices, -1, sampled)  # (B, 1)
            # append sampled index to the running sequence
            x_id_matrix= torch.cat((x_id_matrix, next_id), dim=1) # (B, L+1)
            out = torch.cat((out, next_id), dim=1)
        return out
    

In [156]:
class MultiHeadAttention(nn.Module):
    def __init__(self, num_heads, embed_size, block_size, batch_size):
        super().__init__()
        self.num_heads = num_heads
        self.attention_heads = nn.ModuleList(MaskedAttention(num_heads, embed_size, block_size, batch_size) 
                                             for _ in range(num_heads)) 
        
    
    def forward(self, x):
        chuncks = []
        for attention_head in self.attention_heads: 
            out = attention_head(x) # (B, L, d/H)
            chuncks.append(out)
        final_out = torch.cat(chuncks,dim=-1)
        return final_out # (B, L, d)
        

In [157]:
class MaskedAttention(nn.Module):
    def __init__(self, num_heads, embed_size, block_size, batch_size):
        super().__init__()
        self.head_size = int(embed_size/num_heads)
        self.q_proj = nn.Linear(embed_size, self.head_size, bias=False)
        self.k_proj = nn.Linear(embed_size, self.head_size, bias=False)
        self.v_proj = nn.Linear(embed_size, self.head_size, bias=False)
        self.ff1 = nn.Linear(self.head_size, 4*self.head_size)
        self.ff2 = nn.Linear(4*self.head_size, self.head_size)
        self.ln1 = nn.LayerNorm(embed_size)
        self.ln2 = nn.LayerNorm(self.head_size)
        self.relu = nn.ReLU()
        #self.register_buffer('mask',torch.tril(torch.ones(batch_size, block_size, block_size, device=device,requires_grad=False)))

    def forward(self, x):
        # (B, L, d/H)
        B, L, _ = x.shape
        q, k ,v = self.q_proj(self.ln1(x)), self.k_proj(self.ln1(x)), self.v_proj(self.ln1(x))
        kt = k.transpose(-1,-2)
        attention_matrix = F.softmax(q@kt, dim=-1)
        # (B, L, L)
        #attention_matrix[self.mask == 0] = -torch.inf
        mask = torch.tril(torch.ones(B, L, L, device=device,requires_grad=False))
        attention_matrix = attention_matrix.masked_fill(mask == 0, float('-inf'))
        weights = F.softmax(attention_matrix, dim=-1)
        attention_output = weights@v
        #residual connection
        attention_output = q + attention_output
        # second layer norm
        attention_output = self.ln2(attention_output)
        output = self.ff2(self.relu(self.ff1(attention_output)))
        return output 

In [158]:
model_gpt = GPT(vocab_size, embed_size, block_size, num_blocks, num_heads)
model_gpt = model_gpt.to(device)
# create a PyTorch optimizer
optimizer_gpt = torch.optim.AdamW(model_gpt.parameters(), lr=learning_rate)

In [159]:
train_loader = dataloaderlite(train_data.data, block_size, batch_size, shuffle=True, tag="train")
val_loader = dataloaderlite(val_data.data, block_size, batch_size, shuffle=True, tag="val")

In [161]:
for iter in range(max_iters):
    # every once in a while evaluate the loss on train and val sets
    if iter % eval_interval == 0:
        losses = loss_estimation(model_gpt)
        print(f"step {iter}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}")

    # sample a batch of data
    xb, yb = train_loader.__next__()

    # evaluate the loss
    logits_gpt, loss_gpt = model_gpt(xb, yb)
    optimizer_gpt.zero_grad(set_to_none=True)
    loss_gpt.backward()
    optimizer_gpt.step()

step 0: train loss 4.2207, val loss 4.2222
step 200: train loss 2.4596, val loss 2.4739
step 400: train loss 2.2751, val loss 2.2951
step 600: train loss 2.1291, val loss 2.1488
step 800: train loss 1.8383, val loss 1.8673
step 1000: train loss 1.4783, val loss 1.5308
step 1200: train loss 1.1388, val loss 1.1748
step 1400: train loss 0.9344, val loss 0.9632
step 1600: train loss 0.8468, val loss 0.8788
step 1800: train loss 0.7988, val loss 0.8303
step 2000: train loss 0.7571, val loss 0.7814
step 2200: train loss 0.7285, val loss 0.7551
step 2400: train loss 0.6971, val loss 0.7278
step 2600: train loss 0.6968, val loss 0.7159
step 2800: train loss 0.6797, val loss 0.7001
step 3000: train loss 0.6655, val loss 0.6845
step 3200: train loss 0.6562, val loss 0.6682
step 3400: train loss 0.6411, val loss 0.6580
step 3600: train loss 0.6424, val loss 0.6664
step 3800: train loss 0.6531, val loss 0.6824


In [162]:
context = torch.zeros((1, 1), dtype=torch.long, device=device)
generated_context_ids = model_gpt.generate(context, max_new_tokens=1000)
print(decode(generated_context_ids[0].tolist()))


Acqppppithise buth so but wow wand ser whor to wing wird han, bus me mers shave ta shy, the tath sor sto meant shatt sand so to wno what shin wir son myour then mus thy sat my som whattingh o that mon bund what bun bes, wart she ward the then thatth wily with wer tord thand shot must me sardont what wir,
Mas thert this shas stis start at the thas,

LAN LLAR:
Thor sove we sors munt we me sus, thy sord werd, whath wer thear bur therd,
To thy sutt ma sto wis sours

The mats shats my my hard,
Hur gurt the warts mant thers the that thester, the wald mas sang thes beastst of gruth mang mond tand by mous mans tand,
Att mentste,
Me wom bat, thand buth matt wirs stand thes she may sart berst ward hill band stast bes my wond war mast stoust will st win will, whe wors thengs sove wort som to be shand.
Sest thand but the thir bus wird than will me gret town well wh wond that whor bunth,

Mut sprand ther stise tour bert, thand mur thy the mat the thas this somert to sum mant, ses, wird thars, and 

In [147]:
num_trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(num_trainable)

8385
