In [29]:
import math
import requests

import torch
from torch import nn
import torch.nn.functional as F

from lightning.pytorch import LightningModule, Trainer, seed_everything
from lightning.pytorch.callbacks import ModelCheckpoint, LearningRateMonitor

In [2]:
url = "https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt"
text = requests.get(url).text

with open("../data/tiny_shakespeare.txt", "w") as f:
    f.write(text)

print(f"Saved tiny Shakespeare dataset ({len(text)} characters)")

Saved tiny Shakespeare dataset (1115394 characters)


In [8]:
chars = sorted(list(set(text)))
VOCAB_SIZE = len(chars)
stoi = {ch: i for i, ch in enumerate(chars)}
itos = {i: ch for i, ch in enumerate(chars)}

data_ids = encode(text)
split = int(0.9 * len(data_ids))
TRAIN_IDS, VAL_IDS = data_ids[:split], data_ids[split:]

In [20]:
seed_everything(1337)
device = "cuda" if torch.cuda.is_available() else "cpu"

# data / model hyperparams
BLOCK_SIZE = 128     # context length
N_LAYERS = 4
N_HEADS = 4
N_EMBED = 256     # must be divisible by N_HEADS
DROPOUT = 0.1

# training hyperparams
BATCH_SIZE = 64
LR = 3e-4
WEIGHT_DECAY = 0.01
MAX_STEPS = 800  # bump for better quality
WARMUP_STEPS = 50
EVAL_EVERY_N = 100
SAMPLE_LEN = 400

Seed set to 1337


In [9]:
def encode(s: str) -> torch.Tensor:
    return torch.tensor([stoi[c] for c in s], dtype=torch.long)


def decode(t: torch.Tensor) -> str:
    return "".join(itos[int(i)] for i in t)

In [27]:
def get_batch(ids: torch.Tensor, batch_size: int,
              block_size: int) -> tuple[torch.Tensor, torch.Tensor]:

    ix = torch.randint(0, len(ids) - block_size - 1, (batch_size,))
    x = torch.stack([ids[i:i+block_size] for i in ix])
    y = torch.stack([ids[i+1:i+block_size+1] for i in ix])
    return x, y

In [12]:
class CausalSelfAttention(nn.Module):
    """Multi-head masked self-attention implemented from scratch."""
    def __init__(self, n_embed: int, n_heads: int, dropout: float):
        super().__init__()
        assert n_embed % n_heads == 0
        self.n_heads = n_heads
        self.head_dim = n_embed // n_heads

        self.q_proj = nn.Linear(n_embed, n_embed, bias=False)
        self.k_proj = nn.Linear(n_embed, n_embed, bias=False)
        self.v_proj = nn.Linear(n_embed, n_embed, bias=False)
        self.o_proj = nn.Linear(n_embed, n_embed, bias=False)

        self.attn_drop = nn.Dropout(dropout)
        self.resid_drop = nn.Dropout(dropout)

        # precompute max-length causal mask
        self.register_buffer("mask", torch.triu(torch.ones(BLOCK_SIZE, BLOCK_SIZE), diagonal=1).bool())

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        B, T, C = x.shape
        H, D = self.n_heads, self.head_dim

        q = self.q_proj(x).view(B, T, H, D).transpose(1, 2)  # (B,H,T,D)
        k = self.k_proj(x).view(B, T, H, D).transpose(1, 2)
        v = self.v_proj(x).view(B, T, H, D).transpose(1, 2)

        att = (q @ k.transpose(-2, -1)) / math.sqrt(D)       # (B,H,T,T)
        att = att.masked_fill(self.mask[:T, :T], float("-inf"))
        att = F.softmax(att, dim=-1)
        att = self.attn_drop(att)

        y = att @ v                                          # (B,H,T,D)
        y = y.transpose(1, 2).contiguous().view(B, T, H*D)   # (B,T,C)
        y = self.resid_drop(self.o_proj(y))
        return y

In [14]:
class MLP(nn.Module):
    def __init__(self, n_embed: int, dropout: float):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(n_embed, 4*n_embed),
            nn.GELU(),
            nn.Linear(4*n_embed, n_embed),
            nn.Dropout(dropout),
        )

    def forward(self, x): 
        return self.net(x)

In [15]:
class Block(nn.Module):
    def __init__(self, n_embed: int, n_heads: int, dropout: float):
        super().__init__()
        self.ln1 = nn.LayerNorm(n_embed)
        self.attn = CausalSelfAttention(n_embed, n_heads, dropout)
        self.ln2 = nn.LayerNorm(n_embed)
        self.mlp = MLP(n_embed, dropout)

    def forward(self, x):
        x = x + self.attn(self.ln1(x))
        x = x + self.mlp(self.ln2(x))
        return x

In [21]:
class TinyGPT(LightningModule):
    def __init__(
        self,
        vocab_size: int,
        n_embed: int = N_EMBED,
        n_layers: int = N_LAYERS,
        n_heads: int = N_HEADS,
        dropout: float = DROPOUT,
        block_size: int = BLOCK_SIZE,
        lr: float = LR,
        weight_decay: float = WEIGHT_DECAY,
        warmup_steps: int = WARMUP_STEPS,
        sample_len: int = SAMPLE_LEN,
    ):
        super().__init__()
        self.save_hyperparameters()

        self.token_emb = nn.Embedding(vocab_size, n_embed)
        self.pos_emb = nn.Embedding(block_size, n_embed)  # <-- learned positional encoding
        self.drop = nn.Dropout(dropout)
        self.blocks = nn.ModuleList([Block(n_embed, n_heads, dropout) for _ in range(n_layers)])
        self.ln_f = nn.LayerNorm(n_embed)
        self.lm_head = nn.Linear(n_embed, vocab_size, bias=False)

        # weight tying
        self.lm_head.weight = self.token_emb.weight

        self.apply(self._init_weights)

    def configure_optimizers(self):
        opt = torch.optim.AdamW(self.parameters(), lr=self.hparams.lr, weight_decay=self.hparams.weight_decay)

        # simple warmup + cosine decay scheduler tied to max_steps (set in Trainer)
        def lr_lambda(step):
            if step < self.hparams.warmup_steps:
                return max(1e-8, step / max(1, self.hparams.warmup_steps))
            # cosine to zero over remaining steps
            total = max(1, self.trainer.max_steps - self.hparams.warmup_steps)
            step2 = min(step - self.hparams.warmup_steps, total)
            return 0.5 * (1 + math.cos(math.pi * step2 / total))
        sch = torch.optim.lr_scheduler.LambdaLR(opt, lr_lambda)
        return {"optimizer": opt, "lr_scheduler": {"scheduler": sch, "interval": "step"}}

    def training_step(self, batch, batch_idx):
        x, y = batch
        _, loss = self(x, y)
        self.log("train_loss", loss, prog_bar=True, on_step=True, on_epoch=False)
        return loss

    def validation_step(self, batch, batch_idx):
        x, y = batch
        _, loss = self(x, y)
        self.log("val_loss", loss, prog_bar=True, on_step=False, on_epoch=True)

    def on_train_end(self):
        # print a sample when training finishes
        with torch.no_grad():
            start = torch.randint(low=0, high=self.hparams.vocab_size, size=(1,1), device=self.device)
            sample = self.generate(start, max_new_tokens=self.hparams.sample_len, temperature=1.0, top_k=0)[0].cpu()
            print("\n=== SAMPLE ===\n")
            print(decode(sample))

    def _init_weights(self, m):
        if isinstance(m, (nn.Linear, nn.Embedding)):
            nn.init.normal_(m.weight, mean=0.0, std=0.02)
        if isinstance(m, nn.Linear) and m.bias is not None:
            nn.init.zeros_(m.bias)

    def forward(self, idx: torch.Tensor, targets: torch.Tensor = None):
        B, T = idx.shape
        tok = self.token_emb(idx)                                # (B,T,C)
        pos_ids = torch.arange(T, device=idx.device)
        pos = self.pos_emb(pos_ids)[None, :, :].expand(B, T, -1) # (B,T,C)
        x = self.drop(tok + pos)
        for blk in self.blocks:
            x = blk(x)
        x = self.ln_f(x)
        logits = self.lm_head(x)                                 # (B,T,V)
        loss = None
        if targets is not None:
            loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1))
        return logits, loss

    @torch.no_grad()
    def generate(self, idx: torch.Tensor, max_new_tokens: int, temperature: float = 1.0, top_k: int = 0):
        self.eval()
        for _ in range(max_new_tokens):
            idx_cond = idx[:, -self.hparams.block_size:]
            logits, _ = self(idx_cond)
            logits = logits[:, -1, :] / max(temperature, 1e-8)
            if top_k > 0:
                v, _ = torch.topk(logits, top_k)
                thresh = v[:, [-1]]
                logits = torch.where(logits < thresh, torch.full_like(logits, -float("inf")), logits)
            probs = F.softmax(logits, dim=-1)
            next_id = torch.multinomial(probs, num_samples=1)
            idx = torch.cat([idx, next_id], dim=1)
        return idx

In [22]:
class CharDataModule:
    """Tiny helper (not a LightningDataModule to keep dependencies minimal)."""
    def __init__(self, train_ids, val_ids, batch_size=BATCH_SIZE, block_size=BLOCK_SIZE):
        self.train_ids = train_ids
        self.val_ids = val_ids
        self.batch_size = batch_size
        self.block_size = block_size

    def train_dataloader(self):
        # use a generator-like iterable to avoid holding big tensors on GPU
        def collate(_):
            x, y = get_batch(self.train_ids, self.batch_size, self.block_size)
            return x, y
        return torch.utils.data.DataLoader([0]*10_000_000, batch_size=None, collate_fn=collate)

    def val_dataloader(self):
        def collate(_):
            x, y = get_batch(self.val_ids, self.batch_size, self.block_size)
            return x, y
        return torch.utils.data.DataLoader([0]*10_000, batch_size=None, collate_fn=collate)

In [23]:
dm = CharDataModule(TRAIN_IDS, VAL_IDS, BATCH_SIZE, BLOCK_SIZE)

model = TinyGPT(
    vocab_size=VOCAB_SIZE,
    n_embed=N_EMBED,
    n_layers=N_LAYERS,
    n_heads=N_HEADS,
    dropout=DROPOUT,
    block_size=BLOCK_SIZE,
    lr=LR,
    weight_decay=WEIGHT_DECAY,
    warmup_steps=WARMUP_STEPS,
    sample_len=SAMPLE_LEN,
)

callbacks = [
    LearningRateMonitor(logging_interval="step"),
    ModelCheckpoint(save_top_k=1, monitor="val_loss", mode="min", filename="tinygpt-{step:06d}-{val_loss:.3f}")
]

trainer = Trainer(
    max_steps=MAX_STEPS,
    val_check_interval=EVAL_EVERY_N,   # run validation every N steps
    enable_checkpointing=True,
    callbacks=callbacks,
    gradient_clip_val=1.0,
    log_every_n_steps=10,
    accelerator="auto",
    devices=1,
    precision="16-mixed" if torch.cuda.is_available() else "32-true",
)

GPU available: False, used: False
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs


In [30]:
trainer.fit(model, dm.train_dataloader(), dm.val_dataloader())


  | Name      | Type       | Params | Mode
------------------------------------------------
0 | token_emb | Embedding  | 16.6 K | eval
1 | pos_emb   | Embedding  | 32.8 K | eval
2 | drop      | Dropout    | 0      | eval
3 | blocks    | ModuleList | 3.2 M  | eval
4 | ln_f      | LayerNorm  | 512    | eval
5 | lm_head   | Linear     | 16.6 K | eval
------------------------------------------------
3.2 M     Trainable params
0         Non-trainable params
3.2 M     Total params
12.819    Total estimated model params size (MB)
0         Modules in train mode
70        Modules in eval mode


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]


Detected KeyboardInterrupt, attempting graceful shutdown ...


SystemExit: 1

  warn("To exit: use 'exit', 'quit', or Ctrl-D.", stacklevel=1)
