# Tiny Transformer
This is a self-driven study note of building a tiny transformer that predicts next word given the text.

In [27]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math

## Part 1: Build a transformer model
### Step 1: Token Embedding + Sinusoidal Positional Encoding
- We map each word (token index) to a dense vector using an nn.Embedding.
- We add positional encoding to give the model information about word order.
- Each position `pos` in the sequence gets a deterministic embedding using sine and cosine functions of different frequencies.

In [28]:
class SinusoidalPositionalEncoding(nn.Module):
    def __init__(self, embed_dim, max_seq_len=512):
        super().__init__()

        # Create a matrix of shape (max_seq_len, embed_dim)
        pe = torch.zeros(max_seq_len, embed_dim)
        # Position indices: (max_seq_len, 1)
        position = torch.arange(0, max_seq_len).unsqueeze(1)

        # Compute the denominator term for each dimension (even indices only)
        div_term = torch.exp(torch.arange(0, embed_dim, 2) * (-math.log(10000.0) / embed_dim))

        # Apply sine to even indices (0, 2, 4, ...)
        pe[:, 0::2] = torch.sin(position * div_term)
        # Apply cosine to odd indices (1, 3, 5, ...)
        pe[:, 1::2] = torch.cos(position * div_term)
        # Add a batch dimension: (1, max_seq_len, embed_dim)
        pe = pe.unsqueeze(0)

        # Register as a buffer to exclude from model parameters (not trainable)
        self.register_buffer('pe', pe)

    def forward(self, x):
        # x shape: (batch_size, seq_len, embed_dim)
        seq_len = x.size(1)
        # Add the positional encoding (broadcasted over batch)
        return x + self.pe[:, :seq_len, :]

#### A common layer of token embedding + psitional encoding

In [29]:
class TokenEmbeddingWithPE(nn.Module):
    def __init__(self, vocab_size, embed_dim, max_seq_len):
        super().__init__()
        # Token embedding layer (maps word indices to dense vectors)
        self.token_embed = nn.Embedding(vocab_size, embed_dim)

        # Sinusoidal positional encoding (not learned)
        self.pos_encoder = SinusoidalPositionalEncoding(embed_dim, max_seq_len)

    def forward(self, x):
        # x shape: (batch_size, seq_len)
        token_embeddings = self.token_embed(x)  # -> (batch_size, seq_len, embed_dim)
        return self.pos_encoder(token_embeddings)  # Add positional info

### Step 2: Multi-Head Self-Attention layer
Multi-Head Self-Attention (MHSA) allows the model to focus on different parts of the sequence in parallel using multiple attention "heads". For each head, we compute attention using separate linear projections of the input:
  - Queries (Q), Keys (K), and Values (V) are all derived from the input `x`.
  - Attention(Q, K, V) = `softmax(QKᵀ / sqrt(d_k)) * V`
    - Q, K, V are matrices of shape (batch_size, num_heads, seq_len, head_dim)
    - d_k = head_dim, which is typically embed_dim / num_heads

After computing attention output for all heads, we:
  1. Concatenate the outputs from all heads.
  2. Apply a final linear projection to get back to the original embedding dimension.

Input/Output shape:
  - Input: (batch_size, seq_len, embed_dim)
  - Output: (batch_size, seq_len, embed_dim)

In [30]:
class MultiHeadSelfAttention(nn.Module):
    def __init__(self, embed_dim, num_heads):
        super().__init__()
        assert embed_dim % num_heads == 0, "Embedding dimension must be divisible by number of heads"
        
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.head_dim = embed_dim // num_heads

        # Linear layers for Q, K, V projections (all heads at once)
        self.q_proj = nn.Linear(embed_dim, embed_dim)
        self.k_proj = nn.Linear(embed_dim, embed_dim)
        self.v_proj = nn.Linear(embed_dim, embed_dim)
        # Output projection
        self.out_proj = nn.Linear(embed_dim, embed_dim)

    def forward(self, x):
        # x shape: (batch_size, seq_len, embed_dim)
        B, T, C = x.shape

        # Project to Q, K, V
        q = self.q_proj(x)  # (B, T, C)
        k = self.k_proj(x)
        v = self.v_proj(x)

        # Split heads: (B, T, num_heads, head_dim) → (B, num_heads, T, head_dim)
        q = q.view(B, T, self.num_heads, self.head_dim).transpose(1, 2)
        k = k.view(B, T, self.num_heads, self.head_dim).transpose(1, 2)
        v = v.view(B, T, self.num_heads, self.head_dim).transpose(1, 2)

        # Scaled dot-product attention
        attn_scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.head_dim)  # (B, num_heads, T, T)
        attn_weights = F.softmax(attn_scores, dim=-1)  # (B, num_heads, T, T)
        attn_output = torch.matmul(attn_weights, v)  # (B, num_heads, T, head_dim)

        # Concatenate heads: (B, num_heads, T, head_dim) → (B, T, C)
        attn_output = attn_output.transpose(1, 2).contiguous().view(B, T, C)

        # Final linear projection
        return self.out_proj(attn_output)  # (B, T, C)


### Step 3: Transformer block (LayerNorm + Residual + FeedForward)
A Transformer block contains two main sub-layers:
  1. **Multi-Head Self-Attention (MHSA)** with residual connection and LayerNorm.
  2. **FeedForward Network (FFN)** with residual connection and LayerNorm.

**LayerNorm** helps stabilize training and is applied **before** each sub-layer ("Pre-LN" style).

**Residual connections** (x + sublayer(x)) allow gradient flow and faster training.

**FeedForward Network** is applied position-wise and consists of:
  - Linear -> ReLU -> Linear

Shapes:
  - Input/Output: (batch_size, seq_len, embed_dim)

In [31]:
class FeedForward(nn.Module):
    def __init__(self, embed_dim, ff_dim):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(embed_dim, ff_dim),
            nn.ReLU(),  # GELU also commonly used
            nn.Linear(ff_dim, embed_dim)
        )

    def forward(self, x):
        return self.net(x)


class TransformerBlock(nn.Module):
    def __init__(self, embed_dim, num_heads, ff_dim):
        super().__init__()
        self.attn = MultiHeadSelfAttention(embed_dim, num_heads)
        self.ff = FeedForward(embed_dim, ff_dim)
        self.norm1 = nn.LayerNorm(embed_dim)
        self.norm2 = nn.LayerNorm(embed_dim)

    def forward(self, x):
        # Pre-norm and residual connection for MHSA
        x = x + self.attn(self.norm1(x))
        # Pre-norm and residual connection for FFN
        x = x + self.ff(self.norm2(x))
        return x

### Step 4: Full Transformer Model
This is a **decoder-only Transformer** for autoregressive next-word prediction. It consists of:
  1. **Embedding layer**: token embeddings + positional encoding
  2. **Stacked Transformer blocks**
  3. **Final linear layer**: maps hidden states to vocabulary logits

Shape:
- Input: token indices (batch_size, seq_len)  
- Output: logits (batch_size, seq_len, vocab_size)


In [32]:
class TransformerLanguageModel(nn.Module):
    def __init__(self, vocab_size, embed_dim, max_seq_len, num_heads, ff_dim, num_layers):
        super().__init__()
        self.max_seq_len = max_seq_len

        # Token embedding + sinusoidal positional encoding
        self.embed = TokenEmbeddingWithPE(vocab_size, embed_dim, max_seq_len)

        # Stack of Transformer blocks
        self.blocks = nn.ModuleList([
            TransformerBlock(embed_dim, num_heads, ff_dim)
            for _ in range(num_layers)
        ])

        # Final layer normalization before output
        self.norm = nn.LayerNorm(embed_dim)

        # Output projection to vocab size for predicting next word
        self.head = nn.Linear(embed_dim, vocab_size)

    def forward(self, x):
        """
        x: (batch_size, seq_len) - input token indices
        Returns: (batch_size, seq_len, vocab_size) - output logits for each position
        """
        # Embed tokens + add sinusoidal positional encoding
        x = self.embed(x)  # shape: (B, T, embed_dim)
        # Apply each Transformer block sequentially
        for block in self.blocks:
            x = block(x)  # shape remains (B, T, embed_dim)
        # Normalize final output
        x = self.norm(x)  # shape: (B, T, embed_dim)

        # Project to logits over vocabulary
        logits = self.head(x)  # shape: (B, T, vocab_size)

        return logits


## Part 2: Train model
### Step 1: Toy Corpus and Tokenization
We use a toy dataset with a minimal vocabulary for training

In [33]:
corpus = """
Once upon a time there was a small cat who loved to chase mice and sleep in the sun.
The cat lived in a cozy house with a kind old woman who gave it food and warm milk.
Sometimes the cat would jump onto the windowsill and watch the birds fly by.
At night, the cat curled up on a soft pillow and dreamed of magical forests.
One day, the cat met a clever fox who told stories about the world beyond the forest.
They became friends and went on small adventures together under the moonlight.
"""

In [34]:
def use_toy_corpus_data():
    # Split into words and build vocab
    words = corpus.split()
    vocab = sorted(set(words))

    # Mapping from word ↔ index
    stoi = {w: i for i, w in enumerate(vocab)}
    itos = {i: w for w, i in stoi.items()}

    # Encode full corpus as list of token indices
    data = [stoi[w] for w in words]

    return data, vocab, stoi, itos

This is an alternative dataset, tiny shakespeare, which is 1MB size and suitable for small size decoder model.

In [35]:
def use_tiny_shakespeare_data(level="char"):
    # Load the tiny Shakespeare dataset
    with open('data_tiny_shakespeare.txt', 'r', encoding='utf-8') as f:
        text = f.read()

    if level == "word":
        tokens = text.split()
    elif level == "char":
        tokens = list(text)
    else:
        raise ValueError("level must be 'word' or 'char'")

    # Create vocab
    vocab = sorted(set(tokens))
    stoi = {w: i for i, w in enumerate(vocab)}
    itos = {i: w for w, i in stoi.items()}

    # Convert words to token indices
    data = [stoi[w] for w in tokens]
    
    return data, vocab, stoi, itos

#### Batch generation function and define the sequence length
We’ll use a fixed context window (e.g., seq_len = 4) and generate training samples as:
- Input: first n tokens
- Target: next n tokens (shifted by 1)

In [36]:
def get_batch(data, batch_size, seq_len, device='cpu'):
    """
    Randomly sample batches of (input, target) pairs from the corpus.

    Returns:
        x: (batch_size, seq_len) - input token indices
        y: (batch_size, seq_len) - next-token targets
    """
    ix = torch.randint(0, len(data) - seq_len - 1, (batch_size,))
    x = torch.stack([torch.tensor(data[i:i+seq_len]) for i in ix])
    y = torch.stack([torch.tensor(data[i+1:i+1+seq_len]) for i in ix])
    return x.to(device), y.to(device)

### Step 2: Training loop
The basic training loop that includes
1. Define the hyper-parameters
2. Define the loss function and optimizer
3. Train model in a loop

In [37]:
# Hyperparameters
embed_dim = 128
ff_dim = 512
num_heads = 4
num_layers = 4
max_seq_len = 128
batch_size = 16
num_iters = 10000
eval_interval = 100
learning_rate = 1e-3

In [38]:
# Add MPS device support for MacOS M-series GPUs
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")

# Use data
# data, vocab, stoi, itos = use_toy_corpus_data()
data, vocab, stoi, itos = use_tiny_shakespeare_data()
vocab_size = len(vocab)

# Model
model = TransformerLanguageModel(
    vocab_size=vocab_size,
    embed_dim=embed_dim,
    max_seq_len=max_seq_len,
    num_heads=num_heads,
    ff_dim=ff_dim,
    num_layers=num_layers
).to(device)

optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
loss_fn = nn.CrossEntropyLoss()

# Training loop
for step in range(num_iters):
    model.train()

    x, y = get_batch(data, batch_size, max_seq_len, device=device)
    logits = model(x)  # (B, T, vocab_size)

    # Reshape for loss: flatten batch and sequence dimensions
    loss = loss_fn(logits.view(-1, vocab_size), y.view(-1))

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    if step % eval_interval == 0:
        print(f"Step {step}, loss: {loss.item():.4f}")


Step 0, loss: 4.3224
Step 100, loss: 2.4791
Step 200, loss: 2.2092
Step 300, loss: 0.4823
Step 400, loss: 0.0959
Step 500, loss: 0.0458
Step 600, loss: 0.0303
Step 700, loss: 0.0291
Step 800, loss: 0.0225
Step 900, loss: 0.0200
Step 1000, loss: 0.0206
Step 1100, loss: 0.0166
Step 1200, loss: 0.0173
Step 1300, loss: 0.0173
Step 1400, loss: 0.0224
Step 1500, loss: 0.0159
Step 1600, loss: 0.0229
Step 1700, loss: 0.0211
Step 1800, loss: 0.0142
Step 1900, loss: 0.0207
Step 2000, loss: 0.0159
Step 2100, loss: 0.0163
Step 2200, loss: 0.0243
Step 2300, loss: 0.0135
Step 2400, loss: 0.0172
Step 2500, loss: 0.0174
Step 2600, loss: 0.0236
Step 2700, loss: 0.0271
Step 2800, loss: 0.0256
Step 2900, loss: 0.0203
Step 3000, loss: 0.0183
Step 3100, loss: 0.0205
Step 3200, loss: 0.0152
Step 3300, loss: 0.0209
Step 3400, loss: 0.0197
Step 3500, loss: 0.0230
Step 3600, loss: 0.0194
Step 3700, loss: 0.0219
Step 3800, loss: 0.0159
Step 3900, loss: 0.0242
Step 4000, loss: 0.0168
Step 4100, loss: 0.0250
Step

### Step 3: generating text
TO BE COMPLETE

In [39]:
def generate_text(model, prompt, max_new_tokens, stoi, itos, prediction_type="char",
                  method="greedy", temperature=1.2, top_k=40):
    """
    Modular text generation function supporting greedy, temperature, and top-k sampling.

    Args:
        model: Trained language model.
        prompt: String, starting text to generate from.
        max_new_tokens: Number of tokens to generate.
        stoi: Dict[str, int], word to index mapping.
        itos: Dict[int, str], index to word mapping.
        method: "greedy", "temperature", or "top-k"
        temperature: Sampling temperature for 'temperature' and 'top-k' methods.
        top_k: Integer, how many top tokens to keep for 'top-k' sampling.

    Returns:
        List of words (including start prompt + generated words)
    """
    model.eval()
    words = prompt.split()
    if prediction_type == "char":
        start_tokens = [stoi.get(ch, stoi.get("<unk>", 0)) for ch in prompt]
    else:
        start_tokens = [stoi.get(w, stoi.get("<unk>", 0)) for w in words]
    tokens = start_tokens[:]

    for _ in range(max_new_tokens):
        x = torch.tensor([tokens[-max_seq_len:]], dtype=torch.long).to(device)

        logits = model(x)  # (1, seq_len, vocab_size)
        logits_last = logits[0, -1]  # get logits for last token

        if method == "greedy":
            next_token = torch.argmax(logits_last, dim=-1).item()

        elif method == "temperature":
            scaled_logits = logits_last / temperature
            probs = F.softmax(scaled_logits, dim=-1)
            next_token = torch.multinomial(probs, num_samples=1).item()

        elif method == "top-k":
            # Keep only top_k logits
            scaled_logits = logits_last / temperature
            top_k_logits, top_k_indices = torch.topk(scaled_logits, k=top_k)

            # Sample from top-k softmax
            top_k_probs = F.softmax(top_k_logits, dim=-1)
            sample_idx = torch.multinomial(top_k_probs, num_samples=1).item()
            next_token = top_k_indices[sample_idx].item()

        else:
            raise ValueError(f"Unknown sampling method: {method}")

        tokens.append(next_token)

    generated = [itos[t] for t in tokens[len(start_tokens):]]
    return ' '.join(generated) if prediction_type == "word" else ''.join(generated)



In [44]:
prompt = """
ROMEO: 
"""
method = "top-k"  # Change to "greedy" or "temperature" for different methods
generated_words = generate_text(model, prompt, max_new_tokens=200, stoi=stoi, itos=itos, method=method, temperature=1.2, top_k=40)
print("Prompt:", prompt)
print("Generated:", generated_words)

Prompt: 
ROMEO: 

Generated: 
















FFFFFF
FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFEEEEDDYJIIOOOBEEEOOxEEEEEOxEEEEEEEEEEEEEEEEEEEEEEEMEDENES:EEEEEETEO:
My hingrild lo owredeupespedred cind.

GORWAVEce.
Yomane: hoit
