In [1]:
# Transformer From Scratch — Full Pipeline (Toy Dataset)

import torch
import torch.nn as nn
import torch.nn.functional as F
import math
from torch.utils.data import DataLoader
import random

# --- Building Blocks ---

class ScaledDotProductAttention(nn.Module):
    def __init__(self, dk):
        super().__init__()
        self.dk = dk

    def forward(self, Q, K, V, mask=None):
        scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.dk)
        if mask is not None:
            scores = scores.masked_fill(mask == 0, float('-inf'))
        attn = F.softmax(scores, dim=-1)
        output = torch.matmul(attn, V)
        return output, attn

class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, num_heads):
        super().__init__()
        self.dk = d_model // num_heads
        self.num_heads = num_heads

        self.W_q = nn.Linear(d_model, d_model)
        self.W_k = nn.Linear(d_model, d_model)
        self.W_v = nn.Linear(d_model, d_model)
        self.W_o = nn.Linear(d_model, d_model)

        self.attn = ScaledDotProductAttention(self.dk)

    def forward(self, Q, K, V, mask=None):
        B, tgt_len, _ = Q.size()
        B, src_len, _ = K.size()

        Q = self.W_q(Q).view(B, tgt_len, self.num_heads, self.dk).transpose(1, 2)
        K = self.W_k(K).view(B, src_len, self.num_heads, self.dk).transpose(1, 2)
        V = self.W_v(V).view(B, src_len, self.num_heads, self.dk).transpose(1, 2)

        output, attn = self.attn(Q, K, V, mask)
        output = output.transpose(1, 2).contiguous().view(B, tgt_len, self.num_heads * self.dk)
        return self.W_o(output)

class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=500):
        super().__init__()
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2) * -(math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0)
        self.register_buffer('pe', pe)

    def forward(self, x):
        return x + self.pe[:, :x.size(1), :]

class FeedForward(nn.Module):
    def __init__(self, d_model, d_ff):
        super().__init__()
        self.linear1 = nn.Linear(d_model, d_ff)
        self.linear2 = nn.Linear(d_ff, d_model)

    def forward(self, x):
        return self.linear2(F.relu(self.linear1(x)))

class EncoderLayer(nn.Module):
    def __init__(self, d_model, num_heads, d_ff, dropout=0.1):
        super().__init__()
        self.attn = MultiHeadAttention(d_model, num_heads)
        self.ff = FeedForward(d_model, d_ff)
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, mask=None):
        attn_out = self.attn(x, x, x, mask)
        x = self.norm1(x + self.dropout(attn_out))
        ff_out = self.ff(x)
        x = self.norm2(x + self.dropout(ff_out))
        return x

class DecoderLayer(nn.Module):
    def __init__(self, d_model, num_heads, d_ff, dropout=0.1):
        super().__init__()
        self.self_attn = MultiHeadAttention(d_model, num_heads)
        self.enc_attn = MultiHeadAttention(d_model, num_heads)
        self.ff = FeedForward(d_model, d_ff)
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.norm3 = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, enc_out, tgt_mask=None, src_mask=None):
        x = self.norm1(x + self.dropout(self.self_attn(x, x, x, tgt_mask)))
        x = self.norm2(x + self.dropout(self.enc_attn(x, enc_out, enc_out, src_mask)))
        x = self.norm3(x + self.dropout(self.ff(x)))
        return x

class Encoder(nn.Module):
    def __init__(self, vocab_size, d_model, N, num_heads, d_ff, max_len=512, dropout=0.1):
        super().__init__()
        self.embed = nn.Embedding(vocab_size, d_model)
        self.pos_enc = PositionalEncoding(d_model, max_len)
        self.layers = nn.ModuleList([EncoderLayer(d_model, num_heads, d_ff, dropout) for _ in range(N)])
        self.dropout = nn.Dropout(dropout)

    def forward(self, src, mask=None):
        x = self.dropout(self.pos_enc(self.embed(src)))
        for layer in self.layers:
            x = layer(x, mask)
        return x

class Decoder(nn.Module):
    def __init__(self, vocab_size, d_model, N, num_heads, d_ff, max_len=512, dropout=0.1):
        super().__init__()
        self.embed = nn.Embedding(vocab_size, d_model)
        self.pos_enc = PositionalEncoding(d_model, max_len)
        self.layers = nn.ModuleList([DecoderLayer(d_model, num_heads, d_ff, dropout) for _ in range(N)])
        self.dropout = nn.Dropout(dropout)

    def forward(self, tgt, enc_out, tgt_mask=None, src_mask=None):
        x = self.dropout(self.pos_enc(self.embed(tgt)))
        for layer in self.layers:
            x = layer(x, enc_out, tgt_mask, src_mask)
        return x

class Transformer(nn.Module):
    def __init__(self, src_vocab_size, tgt_vocab_size, d_model=128, N=2, num_heads=4, d_ff=512, max_len=100):
        super().__init__()
        self.encoder = Encoder(src_vocab_size, d_model, N, num_heads, d_ff, max_len)
        self.decoder = Decoder(tgt_vocab_size, d_model, N, num_heads, d_ff, max_len)
        self.out_proj = nn.Linear(d_model, tgt_vocab_size)

    def forward(self, src, tgt, src_mask=None, tgt_mask=None):
        enc_out = self.encoder(src, src_mask)
        dec_out = self.decoder(tgt, enc_out, tgt_mask, src_mask)
        return self.out_proj(dec_out)

# --- Utils ---

def generate_subsequent_mask(size):
    return torch.tril(torch.ones(size, size)).unsqueeze(0).unsqueeze(0)

# --- Toy Dataset ---

word2idx = {"<pad>": 0, "<sos>": 1, "<eos>": 2, "i": 3, "am": 4, "happy": 5, "sad": 6, "you": 7,
            "are": 8, "amazing": 9, "love": 10, "pizza": 11, "angry": 12, "good": 13, "night": 14,
            "😊": 15, "😢": 16, "💯": 17, "🍕": 18, "😡": 19, "🌙": 20}
idx2word = {i: w for w, i in word2idx.items()}
vocab_size = len(word2idx)

toy_data = [
    ("i am happy", "😊"),
    ("i am sad", "😢"),
    ("you are amazing", "💯"),
    ("i love pizza", "🍕"),
    ("i am angry", "😡"),
    ("good night", "🌙")
]

def encode(sentence):
    return [word2idx["<sos>"]] + [word2idx[word] for word in sentence.split()] + [word2idx["<eos>"]]

dataset = [(encode(src), encode(tgt)) for src, tgt in toy_data]

def pad_sequence(seq, max_len):
    return seq + [word2idx["<pad>"]] * (max_len - len(seq))

def collate_batch(batch):
    src_seqs, tgt_seqs = zip(*batch)
    max_src = max(len(s) for s in src_seqs)
    max_tgt = max(len(t) for t in tgt_seqs)
    src_batch = torch.tensor([pad_sequence(s, max_src) for s in src_seqs])
    tgt_batch = torch.tensor([pad_sequence(t, max_tgt) for t in tgt_seqs])
    return src_batch, tgt_batch

# --- Train Loop ---

dataloader = DataLoader(dataset, batch_size=2, shuffle=True, collate_fn=collate_batch)
model = Transformer(src_vocab_size=vocab_size, tgt_vocab_size=vocab_size)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
criterion = nn.CrossEntropyLoss(ignore_index=word2idx["<pad>"])

for epoch in range(20):
    model.train()
    total_loss = 0
    for src, tgt in dataloader:
        src, tgt = src.to(device), tgt.to(device)
        tgt_input = tgt[:, :-1]
        tgt_output = tgt[:, 1:]
        tgt_mask = generate_subsequent_mask(tgt_input.size(1)).to(device)
        logits = model(src, tgt_input, tgt_mask=tgt_mask)
        loss = criterion(logits.view(-1, logits.size(-1)), tgt_output.reshape(-1))
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    print(f"Epoch {epoch + 1} | Loss: {total_loss:.4f}")


Epoch 1 | Loss: 8.3728
Epoch 2 | Loss: 5.1485
Epoch 3 | Loss: 4.2367
Epoch 4 | Loss: 3.6103
Epoch 5 | Loss: 2.6878
Epoch 6 | Loss: 2.0315
Epoch 7 | Loss: 1.6475
Epoch 8 | Loss: 1.3120
Epoch 9 | Loss: 0.8595
Epoch 10 | Loss: 0.5693
Epoch 11 | Loss: 0.4440
Epoch 12 | Loss: 0.2120
Epoch 13 | Loss: 0.2166
Epoch 14 | Loss: 0.1516
Epoch 15 | Loss: 0.1431
Epoch 16 | Loss: 0.1253
Epoch 17 | Loss: 0.1062
Epoch 18 | Loss: 0.0792
Epoch 19 | Loss: 0.0738
Epoch 20 | Loss: 0.0634
