In [None]:
# Minimal GPT-style Transformer (decoder-only) in PyTorch
# ------------------------------------------------------
# - Learned token + positional embeddings
# - Masked multi-head self-attention (causal)
# - Feedforward (MLP) + residual connections + LayerNorm (pre-LN)
# - Tied output head (weights shared with token embeddings)
# - Greedy/top-k text generation helper
#
# NOTE: This is a compact demonstration We can extend for real training.

import math
from dataclasses import dataclass
from typing import Optional

import torch
import torch.nn as nn
import torch.nn.functional as F


# -------------------------
# Configuration
# -------------------------
@dataclass
class GPTConfig:
    vocab_size: int = 256        # for a byte-level toy tokenizer
    n_layer: int = 4
    n_head: int = 4
    n_embd: int = 256
    block_size: int = 128        # max context length
    dropout: float = 0.1


# -------------------------
# Building blocks
# -------------------------
class CausalSelfAttention(nn.Module):
    def __init__(self, config: GPTConfig):
        super().__init__()
        assert config.n_embd % config.n_head == 0, "n_embd must be divisible by n_head"
        self.n_head = config.n_head
        self.head_dim = config.n_embd // config.n_head
        self.scale = 1.0 / math.sqrt(self.head_dim)

        # Projections
        self.q_proj = nn.Linear(config.n_embd, config.n_embd, bias=False)
        self.k_proj = nn.Linear(config.n_embd, config.n_embd, bias=False)
        self.v_proj = nn.Linear(config.n_embd, config.n_embd, bias=False)
        self.out_proj = nn.Linear(config.n_embd, config.n_embd, bias=False)

        self.attn_drop = nn.Dropout(config.dropout)
        self.resid_drop = nn.Dropout(config.dropout)

        # Causal mask as a buffer (1 for allowed, 0 for masked)
        mask = torch.tril(torch.ones(config.block_size, config.block_size))
        # Shape to [1, 1, T, T] so it can broadcast across batch & heads
        self.register_buffer("causal_mask", mask.view(1, 1, config.block_size, config.block_size))

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        B, T, C = x.shape

        q = self.q_proj(x).view(B, T, self.n_head, self.head_dim).transpose(1, 2)  # (B, nh, T, hs)
        k = self.k_proj(x).view(B, T, self.n_head, self.head_dim).transpose(1, 2)  # (B, nh, T, hs)
        v = self.v_proj(x).view(B, T, self.n_head, self.head_dim).transpose(1, 2)  # (B, nh, T, hs)

        # Scaled dot-product attention
        att = (q @ k.transpose(-2, -1)) * self.scale  # (B, nh, T, T)

        # Apply causal mask: only attend to <= current position
        att = att.masked_fill(self.causal_mask[:, :, :T, :T] == 0, float("-inf"))
        att = F.softmax(att, dim=-1)
        att = self.attn_drop(att)

        y = att @ v  # (B, nh, T, hs)
        y = y.transpose(1, 2).contiguous().view(B, T, C)  # concat heads
        y = self.resid_drop(self.out_proj(y))
        return y


class MLP(nn.Module):
    def __init__(self, config: GPTConfig):
        super().__init__()
        self.fc1 = nn.Linear(config.n_embd, 4 * config.n_embd)
        self.fc2 = nn.Linear(4 * config.n_embd, config.n_embd)
        self.drop = nn.Dropout(config.dropout)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.fc1(x)
        x = F.gelu(x)
        x = self.drop(self.fc2(x))
        return x


class Block(nn.Module):
    def __init__(self, config: GPTConfig):
        super().__init__()
        self.ln1 = nn.LayerNorm(config.n_embd)
        self.attn = CausalSelfAttention(config)
        self.ln2 = nn.LayerNorm(config.n_embd)
        self.mlp = MLP(config)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # Pre-LN residual block
        x = x + self.attn(self.ln1(x))
        x = x + self.mlp(self.ln2(x))
        return x


# -------------------------
# The mini GPT model
# -------------------------
class MiniGPT(nn.Module):
    def __init__(self, config: GPTConfig):
        super().__init__()
        self.config = config

        self.tok_emb = nn.Embedding(config.vocab_size, config.n_embd)
        self.pos_emb = nn.Embedding(config.block_size, config.n_embd)
        self.drop = nn.Dropout(config.dropout)

        self.blocks = nn.ModuleList([Block(config) for _ in range(config.n_layer)])
        self.ln_f = nn.LayerNorm(config.n_embd)

        # Language modeling head (tied with token embeddings)
        self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
        self.lm_head.weight = self.tok_emb.weight  # weight tying

        self.apply(self._init_weights)

    def _init_weights(self, module):
        if isinstance(module, (nn.Linear, nn.Embedding)):
            nn.init.normal_(module.weight, mean=0.0, std=0.02)
            if isinstance(module, nn.Linear) and module.bias is not None:
                nn.init.zeros_(module.bias)
        if isinstance(module, nn.LayerNorm):
            nn.init.ones_(module.weight)
            nn.init.zeros_(module.bias)

    @property
    def device(self):
        return next(self.parameters()).device

    def forward(self, idx: torch.Tensor, targets: Optional[torch.Tensor] = None):
        """
        idx: (B, T) int token ids
        targets: (B, T) next-token ids for LM loss (optional)
        returns: logits (B, T, vocab_size), and optional loss
        """
        B, T = idx.shape
        if T > self.config.block_size:
            raise ValueError(f"Sequence length {T} > block_size {self.config.block_size}")

        # embeddings
        tok = self.tok_emb(idx)                                   # (B, T, C)
        pos = self.pos_emb(torch.arange(T, device=self.device))   # (T, C)
        x = self.drop(tok + pos.unsqueeze(0))                     # (B, T, C)

        # transformer blocks
        for block in self.blocks:
            x = block(x)

        x = self.ln_f(x)
        logits = self.lm_head(x)  # (B, T, vocab)

        loss = None
        if targets is not None:
            loss = F.cross_entropy(logits.view(B * T, -1), targets.view(B * T))
        return logits, loss

    @torch.no_grad()
    def generate(
        self,
        idx: torch.Tensor,
        max_new_tokens: int,
        temperature: float = 1.0,
        top_k: Optional[int] = None,
    ) -> torch.Tensor:
        """
        Autoregressive generation. idx is (B, T) with existing context.
        """
        for _ in range(max_new_tokens):
            # Crop to the last block_size tokens
            idx_cond = idx[:, -self.config.block_size:]

            logits, _ = self(idx_cond)                # (B, T, vocab)
            logits = logits[:, -1, :] / max(1e-6, temperature)  # (B, vocab)

            if top_k is not None:
                # Top-k filtering
                v, _ = torch.topk(logits, k=top_k, dim=-1)
                thresh = v[:, [-1]]
                logits = torch.where(logits < thresh, torch.full_like(logits, float("-inf")), logits)

            probs = F.softmax(logits, dim=-1)
            next_token = torch.multinomial(probs, num_samples=1)  # (B, 1)
            idx = torch.cat([idx, next_token], dim=1)             # (B, T+1)
        return idx


# -------------------------
# Tiny demo (toy char-level)
# -------------------------
if __name__ == "__main__":
    # Device
    device = "cuda" if torch.cuda.is_available() else "cpu"

    # Toy dataset (byte-level over a short string, for shape-check & sanity)
    text = (
        "To build a tiny Transformer, we need embeddings, attention, feedforward layers, "
        "residuals, and layer norms. This demo is small but faithful to the core ideas."
    )
    data = torch.tensor(list(text.encode("utf-8")), dtype=torch.long)

    # Train/val split
    n = int(0.9 * len(data))
    train_data, val_data = data[:n], data[n:]

    cfg = GPTConfig(
        vocab_size=256,  # byte-level
        n_layer=4,
        n_head=4,
        n_embd=256,
        block_size=64,
        dropout=0.1,
    )
    model = MiniGPT(cfg).to(device)
    optimizer = torch.optim.AdamW(model.parameters(), lr=3e-4)

    def get_batch(split: str, batch_size: int = 32):
        source = train_data if split == "train" else val_data
        ix = torch.randint(0, len(source) - cfg.block_size - 1, (batch_size,))
        x = torch.stack([source[i : i + cfg.block_size] for i in ix])
        y = torch.stack([source[i + 1 : i + cfg.block_size + 1] for i in ix])
        return x.to(device), y.to(device)

    # A few quick training steps (for demonstration only)
    model.train()
    for step in range(200):  # increase for better learning
        xb, yb = get_batch("train")
        logits, loss = model(xb, yb)
        optimizer.zero_grad(set_to_none=True)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()
        if (step + 1) % 50 == 0:
            with torch.no_grad():
                xval, yval = get_batch("val")
                _, vloss = model(xval, yval)
            print(f"step {step+1:4d} | train loss {loss.item():.4f} | val loss {vloss.item():.4f}")

    # Generate a few bytes of text from a short prompt
    model.eval()
    prompt = b"Transformers are"
    idx0 = torch.tensor([list(prompt)], dtype=torch.long, device=device)
    out = model.generate(idx0, max_new_tokens=100, temperature=0.9, top_k=50)
    generated = bytes(out[0].tolist()).decode("utf-8", errors="ignore")
    print("\n--- Generated sample ---")
    print(generated)
