### Imports

In [1]:
import torch
import math
from datasets import load_dataset

  from .autonotebook import tqdm as notebook_tqdm


### Data

In [2]:
dataset = load_dataset("alpindale/light-novels")

In [3]:
text_dataset = "\n".join(dataset["train"]["text"])

### Tokenizer

In [4]:
# char based tokenization and not BPE (for simplicity)
vocab = list(set(text_dataset))
encoder_map = {v: i for i, v in enumerate(vocab)}
decoder_map = {i: v for i, v in enumerate(vocab)}

encode = lambda sentence: [encoder_map[character] for character in sentence]
decode = lambda lofints: "".join([decoder_map[i] for i in lofints])

### Model

In [5]:
embedding = torch.randn((100, 32), dtype=torch.float32)

In [6]:
embedding[5]

tensor([-0.7096, -0.1055,  1.2257,  0.1381,  0.0286, -0.7578, -0.1929,  0.2747,
         0.2825, -0.5817,  0.2672,  0.7243, -0.2106, -1.9916,  1.7362,  1.8098,
        -1.4643,  0.6058,  1.2720,  0.0305, -0.0819, -1.4450,  0.6982, -1.7524,
        -0.3701, -0.7441,  1.8367, -0.9371, -0.1402,  0.6811, -0.8513,  1.9684])

In [175]:
class BasicTransformerNetwork:

    def __init__(self, vocab_size, model_dim, nb_heads):
        
        # parameters and activations tracker
        self.parameters = {"nb_heads": nb_heads}
        self.cache = {"dropout_p": 0.2, "model_dim": model_dim, "nb_heads": nb_heads}
        self.activations = {}

        # model vector sizes
        self.vocab_size = vocab_size
        self.model_dim = model_dim
        self.nb_heads = nb_heads

        # layers
        self.attn_projs = {}
        attn_size = model_dim // nb_heads
        for n in range(self.nb_heads):
            self.attn_projs[f"k_proj_{n}"] = self.init_weights((model_dim, attn_size)) #torch.randn((model_dim, attn_size), dtype=torch.float32) * 0.01
            self.attn_projs[f"q_proj_{n}"] = self.init_weights((model_dim, attn_size)) #torch.randn((model_dim, attn_size), dtype=torch.float32) * 0.01
            self.attn_projs[f"v_proj_{n}"] = self.init_weights((model_dim, attn_size)) #torch.randn((model_dim, attn_size), dtype=torch.float32) * 0.01
        self.embedding = torch.randn((vocab_size, model_dim), dtype=torch.float32) * 0.01
        self.parameters["embedding"] = self.embedding
        self.attn_proj_back = self.init_weights((attn_size * nb_heads, model_dim)) #torch.randn((attn_size * nb_heads, model_dim), dtype=torch.float32)
        self.ff_layer = self.init_weights((model_dim, model_dim)) #torch.randn((model_dim, model_dim), dtype=torch.float32) * 0.01
        self.parameters["W1"] = self.ff_layer
        self.ff_layer2 = self.init_weights((model_dim, model_dim)) #torch.randn((model_dim, model_dim), dtype=torch.float32) * 0.01
        self.parameters["W2"] = self.ff_layer2
        self.ff_bias = torch.randn((1, model_dim), dtype=torch.float32)
        self.parameters["b1"] = self.ff_bias
        self.ff_bias2 = torch.randn((1, model_dim), dtype=torch.float32)
        self.parameters["b2"] = self.ff_bias2
        self.output = torch.randn((model_dim, vocab_size), dtype=torch.float32)

    def positional_encoding(self, x_pos):
        """sinus non-trainable positional encoding to add with input"""
        batch_size, seq_len = x_pos.shape
        div_term = torch.exp(
            torch.arange(0, self.model_dim, 2, dtype=torch.float32) * (-torch.log(torch.tensor(10000.0)) / self.model_dim)
        )

        pe = torch.zeros(batch_size, seq_len, self.model_dim, dtype=torch.float32)

        pe[:, :, 0::2] = torch.sin(x_pos.unsqueeze(-1) * div_term)  # sin for even indices
        pe[:, :, 1::2] = torch.cos(x_pos.unsqueeze(-1) * div_term)  # cos for odd indices
        return pe * 0.005
    
    def residual_connexion(self, current_layer, dragged_layer):
        """generic residual connexion operation"""
        return current_layer + dragged_layer

    def input_block(self, x, x_pos):
        """embedding input and positional encoding input"""
        input_embed = self.embedding[x]
        position_embed = self.positional_encoding(x_pos)
        self.position_embed_norm = torch.norm(position_embed)
        # print(f"Positional Encoding Magnitude: {torch.norm(position_embed)}")
        total_input = input_embed + position_embed
        self.cache["input_block_input"] = total_input
        total_input, total_input_mean, total_input_std = self.layernorm(total_input)
        self.cache["input_block_normalized_input"] = total_input
        self.cache["input_block_mean"] = total_input_mean
        self.cache["input_block_std"] = total_input_std
        total_input = self.dropout(total_input, 0.2)
        return total_input
    
    def attention_block(self, x):
        """full attention block, batch-compatible"""
        all_attn_outputs = []
        batch_size, seq_len, model_dim = x.shape
        head_size = model_dim // self.nb_heads

        self.cache["attn_X"] = x

        # Create lower triangular mask once and broadcast across batch
        mask = torch.tril(torch.ones(seq_len, seq_len, dtype=torch.float32))
        mask = mask.unsqueeze(0).expand(batch_size, -1, -1)  # [batch_size, seq_len, seq_len]

        for n in range(self.nb_heads):
            # Project input to K, Q, V for this head
            K = torch.einsum("bsm,mh->bsh", x, self.attn_projs[f"k_proj_{n}"])
            Q = torch.einsum("bsm,mh->bsh", x, self.attn_projs[f"q_proj_{n}"])
            V = torch.einsum("bsm,mh->bsh", x, self.attn_projs[f"v_proj_{n}"])

            self.cache[f"K{n}"] = K
            self.cache[f"Q{n}"] = Q
            self.cache[f"V{n}"] = V

            # Compute attention scores: Q @ K.T (batched attention matmul)
            pre_attention = torch.einsum("bqh,bkh->bqk", Q, K)  # [batch, seq, seq]

            # Scale
            scale = pre_attention / torch.sqrt(torch.tensor(head_size, dtype=torch.float32))

            # Apply mask and handle -inf for softmax
            scale = scale.masked_fill(mask == 0, float("-inf"))

            # Softmax over the last dimension (attention scores over the sequence)
            attn_softm = torch.softmax(scale, dim=-1)

            attn_softm = self.dropout(attn_softm, 0.2)
            self.cache[f"attn_probs{n}"] = attn_softm

            # Compute attention output (weighted sum of values)
            attn_output = torch.einsum("bqk,bkh->bqh", attn_softm, V)  # [batch, seq, head_size]
            all_attn_outputs.append(attn_output)

        # Concatenate all head outputs along last dimension (feature dimension)
        all_attn_outputs = torch.cat(all_attn_outputs, dim=-1)  # [batch, seq, model_dim]

        # Final projection back to model dimension
        all_attn_outputs = torch.einsum("bsm,mh->bsh", all_attn_outputs, self.attn_proj_back)

        all_attn_outputs = self.dropout(all_attn_outputs, 0.2)
        return all_attn_outputs

    def linear_block(self, x):
        """linear layer block before output"""
        self.cache["linear_block_input"] = x
        x, x_mean, x_std = self.layernorm(x)
        self.cache["linear_block_normalized_input"] = x
        self.cache["linear_block_mean"] = x_mean
        self.cache["linear_block_std"] = x_std
        self.activations["A0"] = x
        ff_outout = x @ self.ff_layer + self.ff_bias # a0 * W1 + b1
        ff_outout = self.dropout(ff_outout, 0.2) # 
        ff_activation = self.relu_activation(ff_outout) # a1 = activation(W1 * x + b1)
        self.activations["A1"] = ff_activation
        ff_output2 = ff_activation @ self.ff_layer2 + self.ff_bias2 # a1 * W2  + b2
        logits = ff_output2 @ self.embedding.T
        return logits
    
    def dropout(self, x, p):
        mask = (torch.rand_like(x) > p).float()
        return x * mask
    
    def init_weights(self, shape):
        return torch.randn(shape, dtype=torch.float32) * 0.01

    def layernorm(self, x):
        """generic normalization layer"""
        x_mean = torch.mean(x, dim=-1, keepdim=True)
        x_std = torch.std(x, dim=-1, keepdim=True)
        x_std = torch.clamp(x_std, min=1e-6)
        return (x - x_mean) / x_std, x_mean, x_std
    
    def relu_activation(self, x):
        """generic rectified linear unit layer"""
        return torch.relu(x)

    def forward(self, x, x_idx):
        """full transformer forward pass"""
        total_input = self.input_block(x, x_idx)
        attention = self.attention_block(total_input)
        residual1 = self.residual_connexion(attention, total_input * (0.5**0.5))
        logits = self.linear_block(residual1)
        return logits

In [169]:
def compute_loss(logits, y):
    log_softmax = torch.log_softmax(logits, dim=-1)
    log_softmax_mean = log_softmax.mean()
    log_softmax -= log_softmax_mean
    target_log_probs = log_softmax.gather(2, y.unsqueeze(-1)).squeeze(-1)
    loss = -target_log_probs.mean()
    return log_softmax, loss

In [170]:
def layernorm_backprop(dout, x, mean, std):

    xmu = x - mean
    xsgma = 1 / (std + 1e-6)
    dnorm = dout * xsgma

    dx = dnorm
    dx -= dnorm.mean(dim=1, keepdims=True)
    dx -= (xmu / (std + 1e-6) ** 2) * (dnorm * xmu).mean(dim=1, keepdims=True)
    return dx

def full_backprop(x, y, output_softmax, parameters, activations, cache, attn_projs, attn_proj_back):
    grads = {}

    batch_size, seq_len = x.shape

    # ======= Linear Block Backprop (Logits to FFN) =======

    dZ2 = output_softmax.clone()
    dZ2[torch.arange(batch_size).unsqueeze(1), torch.arange(seq_len), y] -= 1
    dZ2 /= (batch_size * seq_len)  # Normalize across batch and sequence

    # Final projection (logits = ff_output2 @ embedding.T)
    grads["dEmbedding"] = torch.einsum("bsi,bsj->ij", dZ2, activations["A1"])

    # Backprop through FFN (W2, b2, ReLU, W1, b1)
    dA1 = torch.einsum("bsi,ij->bsj", dZ2, parameters["embedding"])

    grads["dW2"] = torch.einsum("bsi,bsj->ij", activations["A1"], dA1)
    grads["db2"] = dA1.sum(dim=(0, 1), keepdim=False).unsqueeze(0)

    dZ1 = dA1 * (activations["A0"] > 0).float()

    grads["dW1"] = torch.einsum("bsi,bsj->ij", activations["A0"], dZ1)
    grads["db1"] = dZ1.sum(dim=(0, 1), keepdim=False).unsqueeze(0)

    dA0 = torch.einsum("bsj,ij->bsi", dZ1, parameters["W1"])

    # Backprop through layernorm after attention block
    grads["dA0"] = layernorm_backprop(
        dA0,
        cache["linear_block_input"],
        cache["linear_block_mean"],
        cache["linear_block_std"],
    )

    # ======= Attention Block Backprop (Multi-head Attention) =======

    head_size = cache["model_dim"] // cache["nb_heads"]

    # Fix: Project back and reshape into heads directly
    combined_heads = torch.einsum("bsi,ij->bsj", grads["dA0"], attn_proj_back)
    dHeads = combined_heads.view(batch_size, seq_len, cache["nb_heads"], head_size).unbind(dim=2)

    dX = torch.zeros_like(cache["attn_X"])

    for h in range(cache["nb_heads"]):
        Q = cache[f"Q{h}"]
        K = cache[f"K{h}"]
        V = cache[f"V{h}"]
        attn_probs = cache[f"attn_probs{h}"]

        dHead = dHeads[h]  # [batch, seq, head_size]

        # Backprop into V
        dV = torch.einsum("bqk,bqh->bkh", attn_probs, dHead)
        grads[f"dW_v{h}"] = torch.einsum("bsm,bsh->mh", cache["attn_X"], dV)

        # Backprop into attention scores
        dAttn = torch.einsum("bqh,bkh->bqk", dHead, V)

        # Backprop through softmax (Jacobian trick)
        dScores = attn_probs * (dAttn - (dAttn * attn_probs).sum(dim=-1, keepdim=True))

        # Backprop into Q and K
        dQ = torch.einsum("bqk,bkh->bqh", dScores, K)
        dK = torch.einsum("bqk,bqh->bkh", dScores, Q)

        grads[f"dW_q{h}"] = torch.einsum("bsm,bsh->mh", cache["attn_X"], dQ)
        grads[f"dW_k{h}"] = torch.einsum("bsm,bsh->mh", cache["attn_X"], dK)

        # Backprop into input X
        dX += (
            torch.einsum("bsh,mh->bsm", dQ, attn_projs[f"q_proj_{h}"]) +
            torch.einsum("bsh,mh->bsm", dK, attn_projs[f"k_proj_{h}"]) +
            torch.einsum("bsh,mh->bsm", dV, attn_projs[f"v_proj_{h}"])
        )

    # ======= Residual Connection (Add attention & FFN gradients) =======
    dResidual = grads["dA0"] + dX

    # ======= Input Block Backprop (Embedding + Positional Encoding + Layernorm) =======
    dResidual /= (1 - cache["dropout_p"])  # Backprop through dropout

    assert dResidual.shape == dA0.shape
    
    dTotalInput = layernorm_backprop(
        dResidual,
        cache["input_block_input"],
        cache["input_block_mean"],
        cache["input_block_std"],
    )

    dInputEmbed = dTotalInput.clone()

    # Combine gradients into the embedding table
    if "dEmbedding" not in grads:
        grads["dEmbedding"] = torch.zeros_like(parameters["embedding"])

    for b in range(batch_size):
        for t in range(seq_len):
            token_id = x[b, t]
            grads["dEmbedding"][token_id] += dInputEmbed[b, t]
    
    # Optional positional encoding backprop (unused unless trainable)
    grads["dPositionalEncoding"] = dTotalInput.clone()

    return grads

In [171]:
def optimizer_step(grads, parameters, optim_states, lr=1e-3, beta1=0.9, beta2=0.999, eps=1e-8, weight_decay=0.001):
    
    m = optim_states["m"]
    v = optim_states["v"]
    t = optim_states["t"]

    # Map parameters to their corresponding gradient names
    param_grad_map = {
        "W1": "dW1",
        "W2": "dW2",
        "b1": "db1",
        "b2": "db2",
        "embedding": "dEmbedding"
    }

    for h in range(parameters["nb_heads"]):  # Use passed-in `nb_heads`, don't read from parameters
        param_grad_map[f"q_proj_{h}"] = f"dW_q{h}"
        param_grad_map[f"k_proj_{h}"] = f"dW_k{h}"
        param_grad_map[f"v_proj_{h}"] = f"dW_v{h}"

    for param_name, grad_name in param_grad_map.items():
        if param_name not in parameters or grad_name not in grads:
            continue

        param = parameters[param_name]
        grad = grads[grad_name]

        if param_name not in m:
            m[param_name] = torch.zeros_like(param)
            v[param_name] = torch.zeros_like(param)

        # AdamW moment updates
        m[param_name] = beta1 * m[param_name] + (1 - beta1) * grad
        v[param_name] = beta2 * v[param_name] + (1 - beta2) * (grad ** 2)

        # Bias correction
        m_hat = m[param_name] / (1 - beta1 ** t)
        v_hat = v[param_name] / (1 - beta2 ** t)

        # Weight update (decoupled weight decay for non-embedding params)
        param_update = lr * m_hat / (torch.sqrt(v_hat) + eps)

        if param_name in ["embedding", "b1", "b2"]:
            param -= param_update  # No weight decay on embeddings
        else:
            param -= param_update + weight_decay * param  # AdamW-style weight decay

        parameters[param_name] = param  # Update parameter in-place

    t += 1

    optim_states["m"] = m
    optim_states["v"] = v
    optim_states["t"] = t

    return parameters, optim_states

In [160]:
def global_grad_norm(grads):
    """Computes the global gradient norm across all parameters."""
    total_norm = 0.0
    for g in grads.values():
        if g is not None:
            total_norm += g.norm().item() ** 2
    return total_norm ** 0.5

In [161]:
def gradient_clipping(grads, max_norm=1.0):
    """Apply gradient clipping to all gradients."""
    total_norm = global_grad_norm(grads)

    if math.isnan(total_norm) or math.isinf(total_norm):  # <- Use math instead of torch
        raise ValueError(f"NaN or Inf detected in global grad norm!")

    clip_coef = max_norm / (total_norm + 1e-6)
    if clip_coef < 1.0:
        for g in grads.values():
            g *= clip_coef

    return grads

In [162]:
def check_for_nan_inf(grads):
    """Check for NaN/Inf gradients in any parameter."""
    for name, grad in grads.items():
        if torch.isnan(grad).any() or torch.isinf(grad).any():
            print(f"🚨 Gradient explosion detected in {name}")
            raise ValueError(f"Gradient in {name} contains NaN/Inf")

In [148]:
model_dim, nb_heads = 64, 1
transformer = BasicTransformerNetwork(vocab_size=len(vocab), model_dim=model_dim, nb_heads=nb_heads)

In [274]:
example = "This is an example sentence"
input_ids = encode(example)
pos_ids = list(range(len(input_ids)))

input_ids = torch.tensor(input_ids, dtype=torch.long)
pos_ids = torch.tensor(pos_ids, dtype=torch.long)

In [275]:
logits = transformer.forward(input_ids, pos_ids)

In [276]:
y = torch.randint(0, 63, (1, 27), dtype=torch.long)

In [277]:
softmax, loss = compute_loss(logits, y)

In [278]:
grads = full_backprop(
    input_ids,
    y,
    softmax,
    transformer.parameters,
    transformer.activations,
    transformer.cache,
    transformer.attn_projs,
    transformer.attn_proj_back,
)

In [None]:
transformer.parameters, _ = optimizer_step(grads, transformer.parameters)

## Training

In [None]:
max_window = 30
batch_size = 32
nb_epochs = 5
optim_states = {
    "m": {},
    "v": {},
    "t": 1,
}
model_dim, nb_heads = 512, 4
transformer = BasicTransformerNetwork(vocab_size=len(vocab), model_dim=model_dim, nb_heads=nb_heads)

for iter in range(len(text_dataset) // batch_size):
    
    batch_inputs = []
    batch_pos = []
    batch_outputs = []

    for _ in range(batch_size):

        start_idx = int(torch.randint(0, len(text_dataset) - max_window - 1, (1,)))
        input_text = text_dataset[start_idx:start_idx + max_window] # x
        output_text = text_dataset[start_idx+1:start_idx + max_window + 1] # y
        input_ids = encode(input_text)
        pos_ids = list(range(len(input_ids)))
        output_ids = encode(output_text)

        batch_inputs.append(input_ids)
        batch_pos.append(pos_ids)
        batch_outputs.append(output_ids)

    batch_inputs = torch.tensor(batch_inputs, dtype=torch.long)  # [batch_size, seq_len]
    batch_pos = torch.tensor(batch_pos, dtype=torch.long)
    batch_outputs = torch.tensor(batch_outputs, dtype=torch.long)

    logits = transformer.forward(batch_inputs, batch_pos)

    logits = torch.clamp(logits, -5, 5)
    softmax, loss = compute_loss(logits, batch_outputs)

    grads = full_backprop(
        batch_inputs,
        batch_outputs,
        softmax,
        transformer.parameters,
        transformer.activations,
        transformer.cache,
        transformer.attn_projs,
        transformer.attn_proj_back,
    )

    check_for_nan_inf(grads)

    # Clip gradients
    grads = gradient_clipping(grads, max_norm=1.0)

    # Log global gradient norm
    global_norm = global_grad_norm(grads)

    transformer.parameters, optim_states = optimizer_step(
        grads, transformer.parameters, optim_states, lr=5e-7
    )

    if not iter % 200:
        print(f"Positional Encoding Magnitude: {transformer.position_embed_norm}")
        print(f"Logits Range: {logits.min()} to {logits.max()}")
        print(f"Softmax Range: {softmax.min()} to {softmax.max()}")
        print(f"Iteration {iter}: Global Grad Norm = {global_norm:.4f}")
        print()
        print("training loss: ", float(loss))
        print()

Positional Encoding Magnitude: 2.479029655456543
Logits Range: -0.8865989446640015 to 0.8521111607551575
Softmax Range: -0.8887062072753906 to 0.8502516746520996
Iteration 0: Global Grad Norm = 1.0000

training loss:  0.06995103508234024

Positional Encoding Magnitude: 2.479029655456543
Logits Range: -0.8629919290542603 to 0.8388312458992004
Softmax Range: -0.8642644882202148 to 0.8376955986022949
Iteration 200: Global Grad Norm = 1.0000

training loss:  0.06415937840938568

Positional Encoding Magnitude: 2.479029655456543
Logits Range: -0.853061854839325 to 0.8305718302726746
Softmax Range: -0.8548002243041992 to 0.8290166854858398
Iteration 400: Global Grad Norm = 1.0000

training loss:  0.05233614146709442

Positional Encoding Magnitude: 2.479029655456543
Logits Range: -0.8379678130149841 to 0.8167802095413208
Softmax Range: -0.839900016784668 to 0.8152532577514648
Iteration 600: Global Grad Norm = 1.0000

training loss:  0.05499274656176567

Positional Encoding Magnitude: 2.4790296

KeyboardInterrupt: 