# Mini GPT Transformer

A character-level GPT language model trained on text data.

In [None]:
import logging
import math
import os
import time
from collections import Counter
from datetime import datetime

import torch
import torch.nn as nn
from torch.nn import functional as F

## Configuration & Hyperparameters

In [None]:
# Dataset and model paths
DATASET_NAME = "input_childSpeech_trainingSet.txt"
MODEL_NAME = os.path.join("models", "model_checkpoint.pt")
LOAD_MODEL = False

# Hyperparameters
batch_size = 64
block_size = 128
max_iters = 5000
eval_interval = 500
learning_rate = 3e-4
device = "cuda" if torch.cuda.is_available() else "cpu"
eval_iters = 200
n_embd = 128
n_head = 4
n_layer = 4
dropout = 0.2

USE_BIAS_IN_ATTENTION = False
USE_SKIP_CONNECTIONS = True

seed = 1337
torch.manual_seed(seed)

print(f"Device: {device}")

## Setup Logging & Directories

In [None]:
os.makedirs("models", exist_ok=True)
os.makedirs("logs", exist_ok=True)

model_base_name = os.path.splitext(os.path.basename(MODEL_NAME))[0]
log_filename = os.path.join("logs", f"{model_base_name}_{DATASET_NAME}.log")

logging.basicConfig(
    level=logging.INFO,
    format="%(message)s",
    handlers=[
        logging.FileHandler(log_filename, mode="w", encoding="utf-8"),
        logging.StreamHandler(),
    ],
)

logger = logging.getLogger(__name__)
logger.info(f"Log file: {log_filename}\n")

start_time = time.time()

## Load & Explore Dataset

In [None]:
logger.info(f"Loading dataset: {DATASET_NAME}")

with open(DATASET_NAME, "r", encoding="utf-8") as f:
    text = f.read()

logger.info(f"Total characters in dataset: {len(text):,}")
logger.info(f"First 200 characters preview:\n{text[:200]}\n")

# Unique characters
chars = sorted(list(set(text)))
vocab_size = len(chars)
logger.info(f"Vocabulary size: {vocab_size}")
logger.info(f"Unique characters: {''.join(chars)}\n")

In [None]:
# Character frequencies
char_counts = Counter(text)
logger.info("Top 5 Character frequencies:")
for char, count in char_counts.most_common(5):
    logger.info(f"'{char}': {count}")

# Most used words
words = text.split()
word_counts = Counter(words)
logger.info("Most common 5 words:")
for word, count in word_counts.most_common(5):
    logger.info(f"'{word}': {count}")

unique_words = len(word_counts)
logger.info(f"Total unique words in dataset: {unique_words}\n")

## Encoding & Data Splits

In [None]:
# Character-to-integer mapping
stoi = {ch: i for i, ch in enumerate(chars)}
itos = {i: ch for i, ch in enumerate(chars)}
encode = lambda s: [stoi[c] for c in s]
decode = lambda l: "".join([itos[i] for i in l])

# Train/val split
data = torch.tensor(encode(text), dtype=torch.long)
n = int(0.9 * len(data))
train_data = data[:n]
val_data = data[n:]

logger.info(f"Training set size: {len(train_data):,} characters")
logger.info(f"Validation set size: {len(val_data):,} characters")

## Optionally Load Checkpoint

In [None]:
checkpoint = None
if LOAD_MODEL:
    logger.info(f"\nLoading existing model from: {MODEL_NAME}")
    try:
        checkpoint = torch.load(MODEL_NAME, map_location=device)

        vocab_size = checkpoint["vocab_size"]
        n_embd = checkpoint["n_embd"]
        n_head = checkpoint["n_head"]
        n_layer = checkpoint["n_layer"]
        block_size = checkpoint["block_size"]
        dropout = checkpoint["dropout"]
        USE_BIAS_IN_ATTENTION = checkpoint["use_bias"]
        USE_SKIP_CONNECTIONS = checkpoint["use_skip"]
        stoi = checkpoint["stoi"]
        itos = checkpoint["itos"]

        logger.info(f"Loaded model configuration:")
        logger.info(f"  Vocabulary size: {vocab_size}")
        logger.info(f"  Embedding dimension: {n_embd}")
        logger.info(f"  Attention heads: {n_head}")
        logger.info(f"  Layers: {n_layer}")
        logger.info(f"  Block size: {block_size}")
        logger.info(f"  Previous train loss: {checkpoint['train_loss']:.4f}")
        logger.info(f"  Previous val loss: {checkpoint['val_loss']:.4f}")

    except FileNotFoundError:
        logger.info(f"Error: Model file '{MODEL_NAME}' not found!")
        raise
    except Exception as e:
        logger.info(f"Error loading model: {e}")
        raise

## Helper Functions

In [None]:
def get_batch(split):
    """Generate a small batch of data of inputs x and targets y."""
    data = train_data if split == "train" else val_data
    ix = torch.randint(len(data) - block_size, (batch_size,))
    x = torch.stack([data[i : i + block_size] for i in ix])
    y = torch.stack([data[i + 1 : i + block_size + 1] for i in ix])
    x, y = x.to(device), y.to(device)
    return x, y


@torch.no_grad()
def estimate_loss():
    out = {}
    model.eval()
    for split in ["train", "val"]:
        losses = torch.zeros(eval_iters)
        for k in range(eval_iters):
            X, Y = get_batch(split)
            logits, loss = model(X, Y)
            losses[k] = loss.item()
        out[split] = losses.mean()
    model.train()
    return out

## Model Definition

In [None]:
class Head(nn.Module):
    """One head of self-attention."""

    def __init__(self, head_size):
        super().__init__()
        self.key = nn.Linear(n_embd, head_size, bias=USE_BIAS_IN_ATTENTION)
        self.query = nn.Linear(n_embd, head_size, bias=USE_BIAS_IN_ATTENTION)
        self.value = nn.Linear(n_embd, head_size, bias=USE_BIAS_IN_ATTENTION)
        self.register_buffer("tril", torch.tril(torch.ones(block_size, block_size)))
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        B, T, C = x.shape
        k = self.key(x)    # (B,T,hs)
        q = self.query(x)  # (B,T,hs)
        # Compute attention scores
        wei = q @ k.transpose(-2, -1) * k.shape[-1] ** -0.5  # (B, T, T)
        wei = wei.masked_fill(self.tril[:T, :T] == 0, float("-inf"))
        wei = F.softmax(wei, dim=-1)
        wei = self.dropout(wei)
        # Weighted aggregation of values
        v = self.value(x)  # (B,T,hs)
        out = wei @ v      # (B, T, hs)
        return out


class MultiHeadAttention(nn.Module):
    """Multiple heads of self-attention in parallel."""

    def __init__(self, num_heads, head_size):
        super().__init__()
        self.heads = nn.ModuleList([Head(head_size) for _ in range(num_heads)])
        self.proj = nn.Linear(head_size * num_heads, n_embd)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        out = torch.cat([h(x) for h in self.heads], dim=-1)
        out = self.dropout(self.proj(out))
        return out


class FeedFoward(nn.Module):
    """A simple linear layer followed by a non-linearity."""

    def __init__(self, n_embd):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(n_embd, 4 * n_embd),
            nn.ReLU(),
            nn.Linear(4 * n_embd, n_embd),
            nn.Dropout(dropout),
        )

    def forward(self, x):
        return self.net(x)


class Block(nn.Module):
    """Transformer block: communication followed by computation."""

    def __init__(self, n_embd, n_head):
        super().__init__()
        head_size = n_embd // n_head
        self.sa = MultiHeadAttention(n_head, head_size)
        self.ffwd = FeedFoward(n_embd)
        self.ln1 = nn.LayerNorm(n_embd)
        self.ln2 = nn.LayerNorm(n_embd)

    def forward(self, x):
        if USE_SKIP_CONNECTIONS:
            x = x + self.sa(self.ln1(x))
            x = x + self.ffwd(self.ln2(x))
        else:
            x = self.sa(self.ln1(x))
            x = self.ffwd(self.ln2(x))
        return x


class GPTLanguageModel(nn.Module):

    def __init__(self):
        super().__init__()
        self.token_embedding_table = nn.Embedding(vocab_size, n_embd)
        self.position_embedding_table = nn.Embedding(block_size, n_embd)
        self.blocks = nn.Sequential(
            *[Block(n_embd, n_head=n_head) for _ in range(n_layer)]
        )
        self.ln_f = nn.LayerNorm(n_embd)
        self.lm_head = nn.Linear(n_embd, vocab_size)
        self.apply(self._init_weights)

    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
            if module.bias is not None:
                torch.nn.init.zeros_(module.bias)
        elif isinstance(module, nn.Embedding):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)

    def forward(self, idx, targets=None):
        B, T = idx.shape
        tok_emb = self.token_embedding_table(idx)                          # (B,T,C)
        pos_emb = self.position_embedding_table(torch.arange(T, device=device))  # (T,C)
        x = tok_emb + pos_emb  # (B,T,C)
        x = self.blocks(x)     # (B,T,C)
        x = self.ln_f(x)       # (B,T,C)
        logits = self.lm_head(x)  # (B,T,vocab_size)

        if targets is None:
            loss = None
        else:
            B, T, C = logits.shape
            logits = logits.view(B * T, C)
            targets = targets.view(B * T)
            loss = F.cross_entropy(logits, targets)

        return logits, loss

    def generate(self, idx, max_new_tokens):
        for _ in range(max_new_tokens):
            idx_cond = idx[:, -block_size:]
            logits, loss = self(idx_cond)
            logits = logits[:, -1, :]       # (B, C)
            probs = F.softmax(logits, dim=-1)  # (B, C)
            idx_next = torch.multinomial(probs, num_samples=1)  # (B, 1)
            idx = torch.cat((idx, idx_next), dim=1)  # (B, T+1)
        return idx

## Instantiate Model

In [None]:
logger.info("Model Configuration:")
logger.info(f"Embedding dimension (n_embd): {n_embd}")
logger.info(f"Number of attention heads (n_head): {n_head}")
logger.info(f"Number of layers (n_layer): {n_layer}")
logger.info(f"Block size (context length): {block_size}")
logger.info(f"Batch size: {batch_size}")
logger.info(f"Dropout: {dropout}")
logger.info(f"Use bias in attention: {USE_BIAS_IN_ATTENTION}")
logger.info(f"Use skip connections: {USE_SKIP_CONNECTIONS}")
logger.info(f"Device: {device}")

model = GPTLanguageModel()
m = model.to(device)

if LOAD_MODEL and checkpoint is not None:
    m.load_state_dict(checkpoint["model_state_dict"])
    logger.info("Model weights loaded successfully!")

num_params = sum(p.numel() for p in m.parameters())
logger.info(f"\nTotal parameters: {num_params:,} ({num_params/1e6:.2f}M)")

## Training

In [None]:
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)

if LOAD_MODEL:
    logger.info("\nSkipping training (LOAD_MODEL is set)")
    logger.info("Proceeding directly to evaluation...\n")
    losses = (
        {"train": checkpoint["train_loss"], "val": checkpoint["val_loss"]}
        if checkpoint
        else {}
    )
else:
    logger.info("Starting training...")
    training_start_time = time.time()

    losses = {}
    for iter in range(max_iters):
        if iter % eval_interval == 0 or iter == max_iters - 1:
            losses = estimate_loss()
            elapsed = time.time() - training_start_time
            logger.info(
                f"step {iter}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f} "
                f"(elapsed: {elapsed:.1f}s)"
            )

        xb, yb = get_batch("train")
        logits, loss = model(xb, yb)
        optimizer.zero_grad(set_to_none=True)
        loss.backward()
        optimizer.step()

    training_end_time = time.time()
    training_time = training_end_time - training_start_time
    logger.info(
        f"Training completed in {training_time:.2f} seconds ({training_time/60:.2f} minutes)!"
    )

## Save Model

In [None]:
if not LOAD_MODEL:
    torch.save(
        {
            "model_state_dict": model.state_dict(),
            "optimizer_state_dict": optimizer.state_dict(),
            "vocab_size": vocab_size,
            "n_embd": n_embd,
            "n_head": n_head,
            "n_layer": n_layer,
            "block_size": block_size,
            "dropout": dropout,
            "use_bias": USE_BIAS_IN_ATTENTION,
            "use_skip": USE_SKIP_CONNECTIONS,
            "stoi": stoi,
            "itos": itos,
            "train_loss": losses["train"].item(),
            "val_loss": losses["val"].item(),
        },
        MODEL_NAME,
    )
    logger.info(f"Model saved to {MODEL_NAME}\n")

## Generate Text Sample

In [None]:
logger.info("Generated text sample:")

context = torch.zeros((1, 1), dtype=torch.long, device=device)
generated_text = decode(m.generate(context, max_new_tokens=500)[0].tolist())
logger.info(generated_text)

## Evaluate on Test Sets

In [None]:
@torch.no_grad()
def evaluate_on_test_file(
    model, filename, stoi, block_size, device, batch_size, eval_iters=200
):
    """
    Load a test file and calculate the loss.
    Handles characters not in vocabulary by skipping them.
    """
    try:
        logger.info(f"\nEvaluating on: {filename}")
        with open(filename, "r", encoding="utf-8") as f:
            test_text = f.read()

        logger.info(f"Test file length: {len(test_text):,} characters")

        test_encoded = []
        skipped_chars = set()
        for c in test_text:
            if c in stoi:
                test_encoded.append(stoi[c])
            else:
                skipped_chars.add(c)

        if skipped_chars:
            logger.info(
                f"Warning: Skipped {len(skipped_chars)} characters not in vocabulary: {skipped_chars}"
            )

        test_data = torch.tensor(test_encoded, dtype=torch.long)
        logger.info(f"Encoded test data length: {len(test_data):,} tokens")

        if len(test_data) <= block_size:
            logger.info("Warning: Test data too short for evaluation")
            return None

        model.eval()
        losses = []

        num_batches = min(eval_iters, len(test_data) // block_size)
        for _ in range(num_batches):
            ix = torch.randint(len(test_data) - block_size, (batch_size,))
            x = torch.stack([test_data[i : i + block_size] for i in ix])
            y = torch.stack([test_data[i + 1 : i + block_size + 1] for i in ix])
            x, y = x.to(device), y.to(device)

            logits, loss = model(x, y)
            losses.append(loss.item())

        model.train()

        if losses:
            avg_loss = sum(losses) / len(losses)
            logger.info(f"Average loss on {filename}: {avg_loss:.4f}")
            return avg_loss
        else:
            return None

    except FileNotFoundError:
        logger.info(f"Error: File {filename} not found")
        return None
    except Exception as e:
        logger.info(f"Error evaluating on {filename}: {e}")
        return None

In [None]:
logger.info("Evaluating on Test Sets")

test_loss_child = evaluate_on_test_file(
    model, "input_childSpeech_testSet.txt", stoi, block_size, device, batch_size, eval_iters=200
)

## Baseline Comparison

In [None]:
def calculate_baseline_loss(text, stoi):
    """
    Calculate baseline loss using uniform character distribution.
    This is the loss if the model just guessed randomly.
    """
    char_counts = {}
    total_chars = 0
    for c in text:
        if c in stoi:
            char_counts[c] = char_counts.get(c, 0) + 1
            total_chars += 1

    entropy = 0
    for count in char_counts.values():
        prob = count / total_chars
        entropy -= prob * math.log(prob)

    return entropy


logger.info("Baseline Comparison:")
baseline_train = calculate_baseline_loss(text, stoi)
logger.info(f"Baseline loss (uniform distribution): {baseline_train:.4f}")
logger.info(f"Training loss: {losses['train']:.4f}")
logger.info(f"Validation loss: {losses['val']:.4f}")
if test_loss_child:
    logger.info(f"Test loss (Child Speech): {test_loss_child:.4f}")
if test_loss_shakespeare:
    logger.info(f"Test loss (Shakespeare): {test_loss_shakespeare:.4f}")

end_time = time.time()
total_time = end_time - start_time
logger.info(
    f"Total execution time: {total_time:.2f} seconds ({total_time/60:.2f} minutes)"
)