<a href="https://colab.research.google.com/github/pdangi-web/Mini_transformer/blob/main/Mini_Transformer.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [6]:
#@title Install dependencies
!pip -q install torch torchvision pyarrow datasets

import torch
import torch.nn as nn
import torch.nn.functional as F
import pyarrow.parquet as pq
import math
import time
import os
import re
from collections import Counter
from typing import List, Tuple, Dict

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device


device(type='cuda')

In [7]:
#@title Upload parquet files (train/valid/test)
from google.colab import files

print("Upload train parquet:")
uploaded = files.upload()

print("Upload validation parquet:")
uploaded2 = files.upload()

print("Upload test parquet:")
uploaded3 = files.upload()


Upload train parquet:


Saving train-00000-of-00001.parquet to train-00000-of-00001 (1).parquet
Upload validation parquet:


Saving validation-00000-of-00001.parquet to validation-00000-of-00001 (1).parquet
Upload test parquet:


Saving test-00000-of-00001.parquet to test-00000-of-00001 (1).parquet


In [8]:
#@title Load parquet files and extract text column

def load_parquet_text(path):
    table = pq.read_table(path)
    # assume text column is named "text" or similar
    for col in table.column_names:
        if col.lower() in ["text", "content", "article"]:
            return "\n".join(table[col].to_pylist())
    # fallback: take first column
    return "\n".join(table[table.column_names[0]].to_pylist())

train_text = load_parquet_text("train-00000-of-00001.parquet")
valid_text = load_parquet_text("validation-00000-of-00001.parquet")
test_text  = load_parquet_text("test-00000-of-00001.parquet")

print("Train length:", len(train_text))
print("Valid length:", len(valid_text))
print("Test length:", len(test_text))


Train length: 10929707
Valid length: 1145909
Test length: 1289979


In [9]:
#@title BPE Tokenizer from scratch

def basic_clean(text):
    text = text.lower()
    text = re.sub(r"\s+", " ", text)
    return text.strip()

class BPETokenizer:
    def __init__(self, vocab_size=5000):
        self.vocab_size = vocab_size
        self.merges = []
        self.token_to_id = {}
        self.id_to_token = {}

    def train(self, corpus_text: str, max_merges=3000):
        corpus_text = basic_clean(corpus_text)
        words = corpus_text.split()

        vocab = Counter()
        for w in words:
            vocab[tuple(w) + ("</w>",)] += 1

        for i in range(max_merges):
            pairs = Counter()
            for word, freq in vocab.items():
                for j in range(len(word)-1):
                    pairs[(word[j], word[j+1])] += freq

            if not pairs:
                break

            best_pair = pairs.most_common(1)[0][0]
            self.merges.append(best_pair)

            new_vocab = Counter()
            for word, freq in vocab.items():
                j = 0
                new_word = []
                while j < len(word):
                    if j < len(word)-1 and (word[j], word[j+1]) == best_pair:
                        new_word.append(word[j] + word[j+1])
                        j += 2
                    else:
                        new_word.append(word[j])
                        j += 1
                new_vocab[tuple(new_word)] += freq
            vocab = new_vocab

            if len(vocab) > self.vocab_size:
                break

        tokens = set()
        for word in vocab:
            for t in word:
                tokens.add(t)

        token_list = sorted(list(tokens))
        self.token_to_id = {tok: i for i, tok in enumerate(token_list)}
        self.id_to_token = {i: tok for tok, i in self.token_to_id.items()}

    def encode(self, text: str):
        text = basic_clean(text)
        words = [tuple(w) + ("</w>",) for w in text.split()]
        for pair in self.merges:
            new_words = []
            for word in words:
                new_word = []
                j = 0
                while j < len(word):
                    if j < len(word)-1 and (word[j], word[j+1]) == pair:
                        new_word.append(word[j] + word[j+1])
                        j += 2
                    else:
                        new_word.append(word[j])
                        j += 1
                new_words.append(tuple(new_word))
            words = new_words

        ids = []
        for w in words:
            for t in w:
                ids.append(self.token_to_id.get(t, 0))
        return ids

    def decode(self, ids: List[int]):
        tokens = [self.id_to_token[i] for i in ids]
        text = []
        word = ""
        for t in tokens:
            if t.endswith("</w>"):
                text.append(t.replace("</w>", ""))
                word = ""
            else:
                word += t
        return " ".join(text)


In [10]:
#@title Train BPE tokenizer
tokenizer = BPETokenizer(vocab_size=5000)
tokenizer.train(train_text, max_merges=3000)

vocab_size = len(tokenizer.token_to_id)
vocab_size


940

In [11]:
#@title Encode dataset to token IDs

train_ids = torch.tensor(tokenizer.encode(train_text), dtype=torch.long)
valid_ids = torch.tensor(tokenizer.encode(valid_text), dtype=torch.long)
test_ids  = torch.tensor(tokenizer.encode(test_text),  dtype=torch.long)

len(train_ids), len(valid_ids), len(test_ids)


(10529044, 1103541, 1242403)

In [12]:
#@title Mini Transformer (decoder-only)

class MultiHeadSelfAttention(nn.Module):
    def __init__(self, embed_dim, num_heads):
        super().__init__()
        assert embed_dim % num_heads == 0
        self.num_heads = num_heads
        self.head_dim = embed_dim // num_heads
        self.scale = 1.0 / math.sqrt(self.head_dim)

        self.qkv = nn.Linear(embed_dim, embed_dim * 3, bias=False)
        self.out = nn.Linear(embed_dim, embed_dim)

    def forward(self, x):
        B, T, C = x.shape
        qkv = self.qkv(x)
        qkv = qkv.view(B, T, 3, self.num_heads, self.head_dim)
        qkv = qkv.permute(2,0,3,1,4)
        q, k, v = qkv

        attn = (q @ k.transpose(-2,-1)) * self.scale
        mask = torch.triu(torch.ones(T,T,device=x.device), diagonal=1).bool()
        attn = attn.masked_fill(mask, float("-inf"))
        attn = F.softmax(attn, dim=-1)
        out = attn @ v

        out = out.transpose(1,2).contiguous().view(B,T,C)
        return self.out(out)


class FeedForward(nn.Module):
    def __init__(self, embed_dim):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(embed_dim, 4*embed_dim),
            nn.GELU(),
            nn.Linear(4*embed_dim, embed_dim)
        )
    def forward(self, x):
        return self.net(x)


class TransformerBlock(nn.Module):
    def __init__(self, embed_dim, num_heads):
        super().__init__()
        self.ln1 = nn.LayerNorm(embed_dim)
        self.attn = MultiHeadSelfAttention(embed_dim, num_heads)
        self.ln2 = nn.LayerNorm(embed_dim)
        self.ff = FeedForward(embed_dim)

    def forward(self, x):
        x = x + self.attn(self.ln1(x))
        x = x + self.ff(self.ln2(x))
        return x


class MiniTransformer(nn.Module):
    def __init__(self, vocab_size, block_size=256, embed_dim=256, n_layers=6, n_heads=4):
        super().__init__()
        self.token_emb = nn.Embedding(vocab_size, embed_dim)
        self.pos_emb = nn.Embedding(block_size, embed_dim)
        self.blocks = nn.ModuleList([
            TransformerBlock(embed_dim, n_heads) for _ in range(n_layers)
        ])
        self.ln_f = nn.LayerNorm(embed_dim)
        self.head = nn.Linear(embed_dim, vocab_size)
        self.block_size = block_size

    def forward(self, idx):
        B, T = idx.shape
        tok = self.token_emb(idx)
        pos = self.pos_emb(torch.arange(T, device=idx.device))
        x = tok + pos

        for block in self.blocks:
            x = block(x)

        x = self.ln_f(x)
        logits = self.head(x)
        return logits

    @torch.no_grad()
    def generate(self, idx, max_new_tokens=100):
        for _ in range(max_new_tokens):
            idx_cond = idx[:, -self.block_size:]
            logits = self(idx_cond)
            probs = F.softmax(logits[:, -1], dim=-1)
            next_id = torch.multinomial(probs, 1)
            idx = torch.cat([idx, next_id], dim=1)
        return idx


In [13]:
#@title Training loop with checkpoints

model = MiniTransformer(vocab_size=vocab_size).to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=3e-4)
loss_fn = nn.CrossEntropyLoss()

block_size = 256
batch_size = 32
max_steps = 2000

def get_batch(data):
    ix = torch.randint(len(data)-block_size-1, (batch_size,))
    x = torch.stack([data[i:i+block_size] for i in ix]).to(device)
    y = torch.stack([data[i+1:i+1+block_size] for i in ix]).to(device)
    return x, y

save_path = "transformer_checkpoint.pth"

model.train()
for step in range(1, max_steps+1):
    xb, yb = get_batch(train_ids)
    logits = model(xb)
    loss = loss_fn(logits.view(-1, vocab_size), yb.view(-1))

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    if step % 100 == 0:
        print(f"[{step}] loss = {loss.item():.4f}")
        torch.save(model.state_dict(), save_path)
        print("Saved checkpoint:", save_path)

print("Training complete.")


[100] loss = 2.5443
Saved checkpoint: transformer_checkpoint.pth
[200] loss = 2.4756
Saved checkpoint: transformer_checkpoint.pth
[300] loss = 2.4644
Saved checkpoint: transformer_checkpoint.pth
[400] loss = 2.4004
Saved checkpoint: transformer_checkpoint.pth
[500] loss = 2.3712
Saved checkpoint: transformer_checkpoint.pth
[600] loss = 2.3152
Saved checkpoint: transformer_checkpoint.pth
[700] loss = 2.2902
Saved checkpoint: transformer_checkpoint.pth
[800] loss = 2.1567
Saved checkpoint: transformer_checkpoint.pth
[900] loss = 2.0330
Saved checkpoint: transformer_checkpoint.pth
[1000] loss = 1.9905
Saved checkpoint: transformer_checkpoint.pth
[1100] loss = 1.9035
Saved checkpoint: transformer_checkpoint.pth
[1200] loss = 1.8533
Saved checkpoint: transformer_checkpoint.pth
[1300] loss = 1.7802
Saved checkpoint: transformer_checkpoint.pth
[1400] loss = 1.7679
Saved checkpoint: transformer_checkpoint.pth
[1500] loss = 1.7122
Saved checkpoint: transformer_checkpoint.pth
[1600] loss = 1.709

In [14]:
#@title Generate text from the model

model.eval()
prompt = "the history of machine learning"
ids = tokenizer.encode(prompt)
x = torch.tensor([ids], dtype=torch.long).to(device)

out = model.generate(x, max_new_tokens=150)
print(tokenizer.decode(out[0].tolist()))


e   e    e      e     e e       e      
