<a href="https://colab.research.google.com/github/rachit2005/Large-Language-Model/blob/main/transformer_from_complete_scratch.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
import regex as re

class Tokenization:
  def __init__(self):
    self.vocab_size = 0
    self.verbose = False
    self.merges = {} # (int,int) -> (int)
    self.vocab = {} # Maps token ID → bytes/string
    self.vocab_inv = {} # Maps bytes/string → token ID
    self.re_pattern = re.compile(r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""" , flags=re.IGNORECASE)

  def pair_freq(self , tokens):
    freq = {}
    for pair in zip(tokens , tokens[1:]):
      freq[pair] = freq.get(pair , 0) + 1
    return freq

  def merge_pairs(self , tokens,max_pair , idx_replaces_with):
    new_tokens = []
    i = 0
    while i < len(tokens):
      if i < len(tokens)-1 and tokens[i] == max_pair[0] and tokens[i+1] == max_pair[1]:
        new_tokens.append(idx_replaces_with)
        i += 2
      else:
        new_tokens.append(tokens[i])
        i += 1
    return new_tokens


  def train(self , text , vocab_size , verbose=False):
    self.vocab_size = vocab_size
    self.verbose = verbose
    words = re.findall(self.re_pattern , text)
    tokens = list(''.join(words).encode("utf-8"))


    # training loop
    num_merges = self.vocab_size - 256
    nxt_idx = 256
    self.vocab = {idx : bytes([idx]) for idx in range(nxt_idx)}

    for i in range(num_merges):
      stats = self.pair_freq(tokens)
      if not stats:
        break
      max_pair = max(stats , key=stats.get)
      nxt_idx += 1
      self.merges[max_pair] = nxt_idx

      tokens = self.merge_pairs(tokens , max_pair , nxt_idx)

      if verbose:
        print(f"merging {max_pair} --> {nxt_idx}")

      self.vocab[nxt_idx] = self.vocab[max_pair[0]] + self.vocab[max_pair[1]]
    self.vocab_inv = {v:k for k,v in self.vocab.items()}

  def encode(self , text):
    words = re.findall(self.re_pattern , text)
    tokens = list(''.join(words).encode("utf-8"))

    while True:
      stats = self.pair_freq(tokens)
      if not stats:
        break

      """From all pairs in stats, select the one with lowest merge index (i.e., which was merged earliest during training)."""
      pair = min(stats , key=lambda p : self.merges.get(p , float("inf")))
      if pair not in self.merges:
        break

      idx = self.merges[pair]
      tokens = self.merge_pairs(tokens , pair , idx)
    return tokens

  def decode(self , tokens):
    return b"".join(self.vocab[token] for token in tokens).decode("utf-8" , errors="ignore")

In [None]:
!wget "https://raw.githubusercontent.com/karpathy/minbpe/refs/heads/master/tests/taylorswift.txt"

--2025-06-13 04:36:40--  https://raw.githubusercontent.com/karpathy/minbpe/refs/heads/master/tests/taylorswift.txt
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.108.133, 185.199.109.133, 185.199.110.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.108.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 185768 (181K) [text/plain]
Saving to: ‘taylorswift.txt’


2025-06-13 04:36:40 (12.3 MB/s) - ‘taylorswift.txt’ saved [185768/185768]



In [None]:
tokenizer = Tokenization()
text = open("taylorswift.txt" , "r").read()

tokenizer.train(text , 50257)

In [None]:
# testing it

tokens = tokenizer.encode("hello world!!!? (안녕하세요!) lol123 😉")
print(tokens)

decoded_text = tokenizer.decode(tokens)
print(decoded_text)

[104, 532, 370, 1449, 33, 33, 33, 63, 401, 236, 149, 136, 235, 133, 149, 237, 149, 152, 236, 132, 184, 236, 154, 148, 33, 896, 108, 408, 49, 424, 32, 240, 159, 152, 137]
hello world!!!? (안녕하세요!) lol123 😉


In [20]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class SelfAttentionBlock(nn.Module):
    def __init__(self, emb_dim, num_heads):
        super().__init__()
        self.attn = nn.MultiheadAttention(emb_dim, num_heads, batch_first=True)
        self.ln1 = nn.LayerNorm(emb_dim)
        self.ff = nn.Sequential(
            nn.Linear(emb_dim, 4 * emb_dim),
            nn.GELU(),
            nn.Linear(4 * emb_dim, emb_dim)
        )
        self.ln2 = nn.LayerNorm(emb_dim)

    def forward(self, x, mask):
        attn_output, _ = self.attn(x, x, x, attn_mask=mask)
        x = self.ln1(x + attn_output)
        x = self.ln2(x + self.ff(x))
        return x

class TransformerLM(nn.Module):
    def __init__(self, vocab_size, emb_dim=128, num_heads=4, num_layers=4, block_size=128):
        super().__init__()
        self.token_embedding = nn.Embedding(vocab_size, emb_dim)
        self.pos_embedding = nn.Embedding(block_size, emb_dim)
        self.blocks = nn.ModuleList([
            SelfAttentionBlock(emb_dim, num_heads)
            for _ in range(num_layers)
        ])
        self.ln_f = nn.LayerNorm(emb_dim)
        self.fc_out = nn.Linear(emb_dim, vocab_size)
        self.block_size = block_size
        self.emb_dim = emb_dim

    def forward(self, x):
        B, T = x.size()
        tok_emb = self.token_embedding(x)
        pos = torch.arange(0, T, device=x.device)
        pos_emb = self.pos_embedding(pos)[None, :, :]

        x = tok_emb + pos_emb  # (B, T, emb_dim)

        # causal mask: (T, T) → upper triangular with -inf
        mask = torch.triu(torch.full((T, T), float('-inf'), device=x.device), diagonal=1)

        for block in self.blocks:
            x = block(x, mask)

        x = self.ln_f(x)
        logits = self.fc_out(x)
        return logits


In [None]:
from torch.utils.data import Dataset , DataLoader

class BPEDatasetLM(Dataset):
    def __init__(self, texts, tokenizer, seq_len=32):
        self.seq_len = seq_len
        tokens = []
        for text in texts:
            tokens.extend(tokenizer.encode(text))
        self.data = []
        for i in range(0, len(tokens) - seq_len):
            x = tokens[i:i + seq_len]
            y = tokens[i + 1:i + 1 + seq_len]
            self.data.append((x, y))

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        x, y = self.data[idx]
        return torch.tensor(x, dtype=torch.long), torch.tensor(y, dtype=torch.long)

dataset = BPEDatasetLM(text, tokenizer, seq_len=32)
dataloader = DataLoader(dataset, batch_size=32, shuffle=True)


In [None]:
x, y = next(iter(dataloader))
print(x.shape)
print(y.shape)

torch.Size([32, 32])
torch.Size([32, 32])


In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")


Using device: cuda


In [21]:
model = TransformerLM(vocab_size=tokenizer.vocab_size).to(device)
optimizer = torch.optim.Adam(model.parameters() , lr=3e-4)
loss_fn = torch.nn.CrossEntropyLoss()

In [None]:
for epoch in range(10):
    model.train()
    total_loss = 0
    for xb, yb in dataloader:
        xb, yb = xb.to(device), yb.to(device)
        logits = model(xb)
        loss = loss_fn(logits.view(-1, logits.size(-1)), yb.view(-1))

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    print(f"Epoch {epoch+1} Loss: {total_loss / len(dataloader):.4f}")

In [None]:
print(logits.shape)

torch.Size([32, 32, 50257])


In [None]:
def sample(model, tokenizer, prompt, max_new_tokens=100, temperature=0.8, top_k=40, device='cuda'):
    model.eval()
    tokens = tokenizer.encode(prompt)
    x = torch.tensor(tokens, dtype=torch.long, device=device)[None, :]  # (1, T)

    for _ in range(max_new_tokens):
        if x.size(1) > model.block_size:
            x = x[:, -model.block_size:]

        logits = model(x)[:, -1, :] / temperature  # (1, vocab)
        if top_k is not None:
            values, _ = torch.topk(logits, top_k)
            logits[logits < values[:, -1, None]] = -float('inf')

        probs = F.softmax(logits, dim=-1)
        next_token = torch.multinomial(probs, num_samples=1)
        x = torch.cat([x, next_token], dim=1)

    return tokenizer.decode(x[0].tolist())

print(sample(model, tokenizer, "hello" , 50))