In [50]:
import torch
import random
from dataclasses import dataclass

# ============ HYPERPARAMS ============
DEVICE = "mps"  # for macbook training
BATCH_SIZE = 64
BLOCK_SIZE = 256  # max sequence length for training
# =====================================

# ============ MODEL HYPERPARAMETERS ============
N_EMBD = 384  # embedding dimension
N_HEADS = 6  # number of attention heads
N_LAYERS = 6  # number of transformer blocks
DROPOUT = 0.2  # dropout rate
# ================================================

# ============ TRAINING HYPERPARAMETERS ============
LEARNING_RATE = 3e-4
TRAINING_STEPS = 10_000
EVAL_ITER_PERIOD = 500
EVAL_ITERS = 100  # How many batches to average for stable loss estimate
# ==================================================

In [51]:
# ============ TOKENIZER (character-level) ============
with open("input.txt", "r") as f:
    text = f.read()

# Build vocabulary from characters
chars = sorted(set(text))
vocab_size = len(chars) + 3  # +3 for special tokens: PAD, BOS, EOS

# Special token indices
PAD = 0
BOS = 1  # Beginning of sequence (used as initial state for insertion)
EOS = 2  # End of sequence

# Character mappings (offset by 3 for special tokens)
char_to_idx = {ch: i + 3 for i, ch in enumerate(chars)}
idx_to_char = {i + 3: ch for i, ch in enumerate(chars)}
idx_to_char[PAD] = "<PAD>"
idx_to_char[BOS] = "<BOS>"
idx_to_char[EOS] = "<EOS>"


def encode(s: str) -> list[int]:
    """Encode string to list of token indices"""
    return [char_to_idx[c] for c in s]


def decode(tokens: list[int]) -> str:
    """Decode token indices to string, filtering special tokens"""
    return "".join(idx_to_char.get(t, "?") for t in tokens if t >= 3)


print(f"Vocabulary size: {vocab_size}")
print(f"Sample encoding: 'Hello' -> {encode('Hello')}")
print(f"Sample decoding: {encode('Hello')} -> '{decode(encode('Hello'))}'")

Vocabulary size: 68
Sample encoding: 'Hello' -> [23, 46, 53, 53, 56]
Sample decoding: [23, 46, 53, 53, 56] -> 'Hello'


In [52]:
# ============ INSERTION ORACLE ============
# Core algorithm: given a partial sequence (candidate) and target (reference),
# compute which tokens can be validly inserted at each position.


def get_optimal_inserts(cand: list[int], ref: list[int]) -> list[set[int]]:
    """
    For a candidate sequence that is a subsequence of reference,
    compute valid insertions at each position.

    Returns: list of sets, one per position in [0, len(cand)+1)
             Each set contains tokens that can be inserted at that position
             while keeping cand as a subsequence of ref.

    Example:
        ref  = [A, B, C, D]
        cand = [B, D]

        Position 0 (before B): can insert A (to get [A,B,D] which is subseq of ref)
        Position 1 (between B and D): can insert C
        Position 2 (after D): nothing (already at end)

        Returns: [{A}, {C}, {}]
    """
    # Find where each cand token appears in ref (leftmost match)
    # starts[i] = position in ref after matching cand[:i]
    starts = [0]
    ref_iter = iter(enumerate(ref))
    for cand_item in cand:
        for ref_pos, ref_item in ref_iter:
            if ref_item == cand_item:
                starts.append(ref_pos + 1)
                break
        else:
            raise ValueError("cand must be a subsequence of ref")

    # Find rightmost matches going backwards
    # ends[i] = position in ref before matching cand[i:]
    ends = [len(ref)]
    reverse_ref_iter = iter(reversed(list(enumerate(ref))))
    for cand_item in reversed(cand):
        for ref_pos, ref_item in reverse_ref_iter:
            if ref_item == cand_item:
                ends.append(ref_pos)
                break
        else:
            raise ValueError("cand must be a subsequence of ref")
    ends = ends[::-1]

    # Valid inserts at position i are tokens in ref[starts[i]:ends[i]]
    inserts = []
    for i, j in zip(starts, ends):
        inserts.append(set(ref[i:j]))
    return inserts


# Test the oracle
ref = encode("ABCD")
cand = encode("BD")
print(f"Reference: {decode(ref)}")
print(f"Candidate: {decode(cand)}")
inserts = get_optimal_inserts(cand, ref)
print(f"Valid inserts per position:")
for i, s in enumerate(inserts):
    print(f"  Position {i}: {{{', '.join(decode([t]) for t in s)}}}")

Reference: ABCD
Candidate: BD
Valid inserts per position:
  Position 0: {A}
  Position 1: {C}
  Position 2: {}


In [53]:
# ============ TRAJECTORY GENERATION ============
# Generate training samples by simulating insertion trajectories from empty -> reference


@dataclass
class InsertionSample:
    """A single training sample for the insertion transformer"""

    hypo: list[int]  # Current partial sequence (input)
    ref_inserts: list[set[int]]  # Valid insertions at each position (for loss)
    chosen_pos: int  # Position where we inserted (target)
    chosen_token: int  # Token we inserted (target)


def generate_trajectory(ref: list[int], mode: str = "random") -> list[InsertionSample]:
    """
    Generate a full trajectory from empty sequence to reference.

    Args:
        ref: Target sequence (list of token indices)
        mode: "random" for random order, "l2r" for left-to-right

    Returns:
        List of InsertionSample, one per insertion step
    """
    samples = []
    hypo = []

    while True:
        inserts = get_optimal_inserts(hypo, ref)

        # Flatten to list of (position, token) pairs
        flat_inserts = [
            (pos, tok) for pos, tokens in enumerate(inserts) for tok in tokens
        ]

        if not flat_inserts:
            # Trajectory complete - hypo == ref
            break

        # Choose next insertion
        if mode == "random":
            chosen_pos, chosen_tok = random.choice(flat_inserts)
        elif mode == "l2r":
            # Left-to-right: always insert at position len(hypo), token ref[len(hypo)]
            chosen_pos, chosen_tok = len(hypo), ref[len(hypo)]
        else:
            raise ValueError(f"Unknown mode: {mode}")

        samples.append(
            InsertionSample(
                hypo=list(hypo),
                ref_inserts=inserts,
                chosen_pos=chosen_pos,
                chosen_token=chosen_tok,
            )
        )

        # Apply the insertion
        hypo.insert(chosen_pos, chosen_tok)

    return samples


# Demonstrate trajectory generation
ref = encode("Hello")
print(f"Generating trajectory for: '{decode(ref)}'")
print(f"Reference tokens: {ref}\n")

trajectory = generate_trajectory(ref, mode="random")
for i, sample in enumerate(trajectory):
    print(f"Step {i}:")
    print(f"  Hypothesis: '{decode(sample.hypo)}' (len={len(sample.hypo)})")
    print(f"  Insert '{decode([sample.chosen_token])}' at position {sample.chosen_pos}")
    valid_at_pos = sample.ref_inserts[sample.chosen_pos]
    print(
        f"  Valid tokens at that position: {{{', '.join(decode([t]) for t in valid_at_pos)}}}"
    )
    print()

Generating trajectory for: 'Hello'
Reference tokens: [23, 46, 53, 53, 56]

Step 0:
  Hypothesis: '' (len=0)
  Insert 'l' at position 0
  Valid tokens at that position: {o, l, e, H}

Step 1:
  Hypothesis: 'l' (len=1)
  Insert 'e' at position 0
  Valid tokens at that position: {l, e, H}

Step 2:
  Hypothesis: 'el' (len=2)
  Insert 'o' at position 2
  Valid tokens at that position: {o, l}

Step 3:
  Hypothesis: 'elo' (len=3)
  Insert 'H' at position 0
  Valid tokens at that position: {H}

Step 4:
  Hypothesis: 'Helo' (len=4)
  Insert 'l' at position 2
  Valid tokens at that position: {l}



In [54]:
# ============ TRAINING DATA PREPARATION ============
# Convert text into training samples for the insertion transformer

# Split data
data = torch.tensor(encode(text), dtype=torch.long)
train_size = int(len(data) * 0.9)
train_data = data[:train_size]
val_data = data[train_size:]

print(f"Total tokens: {len(data):,}")
print(f"Train tokens: {len(train_data):,}")
print(f"Val tokens: {len(val_data):,}")

Total tokens: 1,115,394
Train tokens: 1,003,854
Val tokens: 111,540


In [55]:
# ============ EOS HANDLING ============
# When the hypothesis equals the reference, we need to predict "FINISH"
# The model outputs logits for (position, token) pairs PLUS a "finish" logit


def generate_trajectory_with_eos(
    ref: list[int], mode: str = "random"
) -> list[InsertionSample]:
    """Generate trajectory including the final EOS step."""
    samples = generate_trajectory(ref, mode=mode)

    # Add final step where hypo == ref and we should predict EOS
    if ref:
        eos_inserts = [{EOS} for _ in range(len(ref) + 1)]
        chosen_pos = random.randint(0, len(ref))
        samples.append(
            InsertionSample(
                hypo=list(ref),
                ref_inserts=eos_inserts,
                chosen_pos=chosen_pos,
                chosen_token=EOS,
            )
        )

    return samples


# Test
ref = encode("Hi")
trajectory = generate_trajectory_with_eos(ref)
print(f"Trajectory for 'Hi' (with EOS):")
for i, s in enumerate(trajectory):
    tok = "<EOS>" if s.chosen_token == EOS else decode([s.chosen_token])
    print(f"  Step {i}: '{decode(s.hypo)}' -> insert '{tok}' at pos {s.chosen_pos}")

Trajectory for 'Hi' (with EOS):
  Step 0: '' -> insert 'H' at pos 0
  Step 1: 'H' -> insert 'i' at pos 1
  Step 2: 'Hi' -> insert '<EOS>' at pos 1


# Summary: Training Data Interface

We now have the core training data generation for the Insertion Transformer:

## Key Components:

1. **`get_optimal_inserts(cand, ref)`** - Given a partial sequence (candidate) and target (reference), computes which tokens can be validly inserted at each position.

2. **`generate_trajectory(ref)`** - Simulates the full insertion process from empty â†’ reference, returning a list of training samples.

3. **`InsertionSample`** - A single training example containing:
   - `hypo`: Current partial sequence (model input)
   - `ref_inserts`: Valid (position, token) pairs for loss
   - `chosen_pos`, `chosen_token`: The action taken (target)

4. **`get_batch(split)`** - Returns a batched `InsertionBatch` ready for training.

## Model Requirements (Next Steps):

The decoder-only insertion transformer needs to:
1. Take a partial sequence as input
2. Output logits for `(position, token)` pairs: `[batch, num_positions, vocab_size]`
3. Plus a "finish" logit for EOS prediction

The loss will use `valid_mask` to allow credit for ANY valid insertion, not just the one we sampled.

In [56]:
# ============ BATCH GENERATION ============
# For the insertion transformer, each training example is:
#   - Input: partial sequence (hypothesis)
#   - Target: (position, token) to insert
#   - For loss: all valid (position, token) pairs at this step


@dataclass
class InsertionBatch:
    """A batch of training samples for the insertion transformer"""

    hypo: torch.Tensor  # [B, max_hypo_len] - padded hypotheses
    hypo_len: torch.Tensor  # [B] - actual lengths
    target_pos: torch.Tensor  # [B] - position to insert
    target_token: torch.Tensor  # [B] - token to insert
    # For loss computation: valid insertions as sparse tensor
    valid_mask: torch.Tensor  # [B, max_hypo_len+1, vocab_size] - 1 where valid


def collate_samples(samples: list[InsertionSample]) -> InsertionBatch:
    """Convert list of samples into a padded batch"""
    B = len(samples)
    max_len = max(len(s.hypo) for s in samples) if samples else 0

    # Pad hypotheses
    hypo = torch.full((B, max_len), PAD, dtype=torch.long)
    hypo_len = torch.zeros(B, dtype=torch.long)
    target_pos = torch.zeros(B, dtype=torch.long)
    target_token = torch.zeros(B, dtype=torch.long)

    # Valid mask: [B, max_len+1, vocab_size]
    # +1 because we can insert at positions 0 to len(hypo) inclusive
    valid_mask = torch.zeros((B, max_len + 1, vocab_size), dtype=torch.bool)

    for i, s in enumerate(samples):
        L = len(s.hypo)
        if L > 0:
            hypo[i, :L] = torch.tensor(s.hypo)
        hypo_len[i] = L
        target_pos[i] = s.chosen_pos
        target_token[i] = s.chosen_token

        # Fill valid mask
        for pos, valid_tokens in enumerate(s.ref_inserts):
            for tok in valid_tokens:
                valid_mask[i, pos, tok] = True

    return InsertionBatch(
        hypo=hypo,
        hypo_len=hypo_len,
        target_pos=target_pos,
        target_token=target_token,
        valid_mask=valid_mask,
    )


def get_batch(
    split: str,
    batch_size: int = BATCH_SIZE,
    block_size: int = BLOCK_SIZE,
    include_eos: bool = True,
) -> InsertionBatch:
    """
    Get a random batch of training samples.

    Strategy:
    1. Sample random chunks of text as target sequences
    2. For each target, sample ONE random step from its insertion trajectory
    3. Include EOS samples so model learns when to stop
    """
    data_source = train_data if split == "train" else val_data

    samples = []
    for _ in range(batch_size):
        # Random starting position
        start = random.randint(0, len(data_source) - block_size - 1)
        ref = [int(t) for t in data_source[start : start + block_size].tolist()]

        # Generate trajectory (with EOS) and pick one random sample
        if include_eos:
            trajectory = generate_trajectory_with_eos(ref, mode="random")
        else:
            trajectory = generate_trajectory(ref, mode="random")

        if trajectory:
            sample = random.choice(trajectory)
            samples.append(sample)

    return collate_samples(samples)


# Test batch generation
batch = get_batch("train", batch_size=4, block_size=16)
print(f"Batch shapes:")
print(f"  hypo: {batch.hypo.shape}")
print(f"  hypo_len: {batch.hypo_len}")
print(f"  target_pos: {batch.target_pos}")
print(f"  target_token: {batch.target_token}")
print(f"  valid_mask: {batch.valid_mask.shape}")

print(f"\nExample from batch:")
idx = 0
hypo_tokens = [int(t) for t in batch.hypo[idx].tolist()]
print(f"  Hypothesis: '{decode(hypo_tokens)}'")
print(
    f"  Insert '{decode([int(batch.target_token[idx].item())])}' at position {batch.target_pos[idx].item()}"
)

Batch shapes:
  hypo: torch.Size([4, 15])
  hypo_len: tensor([ 2, 11, 12, 15])
  target_pos: tensor([ 0,  0, 10,  1])
  target_token: tensor([60, 60, 57, 46])
  valid_mask: torch.Size([4, 16, 68])

Example from batch:
  Hypothesis: 'nI'
  Insert 's' at position 0


In [57]:
decode(batch.hypo[0].tolist())

'nI'

In [58]:
import torch.nn as nn
import torch.nn.functional as F
import math


class MultiHeadAttention(nn.Module):
    """Multi-head self-attention with BIDIRECTIONAL attention (no causal mask).

    Unlike GPT which uses causal masking, the Insertion Transformer uses
    bidirectional attention since we need to see all tokens to decide
    where to insert.
    """

    def __init__(self, n_embd: int, n_heads: int, dropout: float = 0.1):
        super().__init__()
        assert n_embd % n_heads == 0
        self.n_heads = n_heads
        self.head_dim = n_embd // n_heads

        # Combined QKV projection for efficiency
        self.qkv = nn.Linear(n_embd, 3 * n_embd, bias=False)
        self.proj = nn.Linear(n_embd, n_embd, bias=False)
        self.dropout = nn.Dropout(dropout)

    def forward(
        self, x: torch.Tensor, mask: torch.Tensor | None = None
    ) -> torch.Tensor:
        """
        Args:
            x: [B, T, C] input tensor
            mask: [B, T] boolean mask, True for valid positions, False for padding
        Returns:
            [B, T, C] output tensor
        """
        B, T, C = x.shape

        # Compute Q, K, V
        qkv = self.qkv(x)  # [B, T, 3*C]
        qkv = qkv.reshape(B, T, 3, self.n_heads, self.head_dim)
        qkv = qkv.permute(2, 0, 3, 1, 4)  # [3, B, n_heads, T, head_dim]
        q, k, v = qkv[0], qkv[1], qkv[2]

        # Attention scores
        scale = 1.0 / math.sqrt(self.head_dim)
        attn = (q @ k.transpose(-2, -1)) * scale  # [B, n_heads, T, T]

        # Apply padding mask if provided
        if mask is not None:
            # mask: [B, T] -> [B, 1, 1, T] for broadcasting
            attn_mask = mask[:, None, None, :]  # attend TO these positions
            attn = attn.masked_fill(~attn_mask, float("-inf"))

        attn = F.softmax(attn, dim=-1)
        attn = self.dropout(attn)

        # Weighted sum
        out = attn @ v  # [B, n_heads, T, head_dim]
        out = out.transpose(1, 2).reshape(B, T, C)  # [B, T, C]
        out = self.proj(out)

        return out

In [59]:
class MLP(nn.Module):
    """Simple feed-forward network applied position-wise."""

    def __init__(self, n_embd: int, dropout: float = 0.1):
        super().__init__()
        self.fc1 = nn.Linear(n_embd, 4 * n_embd)
        self.fc2 = nn.Linear(4 * n_embd, n_embd)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.fc1(x)
        x = F.gelu(x)
        x = self.fc2(x)
        x = self.dropout(x)
        return x


class TransformerBlock(nn.Module):
    """Single transformer block with pre-norm architecture."""

    def __init__(self, n_embd: int, n_heads: int, dropout: float = 0.1):
        super().__init__()
        self.ln1 = nn.LayerNorm(n_embd)
        self.attn = MultiHeadAttention(n_embd, n_heads, dropout)
        self.ln2 = nn.LayerNorm(n_embd)
        self.mlp = MLP(n_embd, dropout)

    def forward(
        self, x: torch.Tensor, mask: torch.Tensor | None = None
    ) -> torch.Tensor:
        x = x + self.attn(self.ln1(x), mask)
        x = x + self.mlp(self.ln2(x))
        return x

In [60]:
class InsertionTransformer(nn.Module):
    """
    Decoder-only Insertion Transformer for language modeling.

    Key differences from standard GPT:
    1. Bidirectional attention (no causal mask) - we see all tokens to decide where to insert
    2. Outputs logits for (position, token) pairs
    3. Prepends a special "slot" token to allow insertion at position 0

    Architecture:
    - Input: partial sequence [t1, t2, ..., tn]
    - Prepend slot: [SLOT, t1, t2, ..., tn] -> n+1 positions
    - Each position i outputs:
        - position_logit: how likely to insert HERE (before token i)
        - token_logits: what token to insert

    Output interpretation:
    - Position i in output corresponds to inserting BEFORE position i in input
    - Position 0 = insert at the beginning
    - Position n = insert at the end (after last token)
    """

    def __init__(
        self,
        vocab_size: int,
        n_embd: int = N_EMBD,
        n_heads: int = N_HEADS,
        n_layers: int = N_LAYERS,
        max_seq_len: int = BLOCK_SIZE + 1,
        dropout: float = DROPOUT,
    ):
        super().__init__()
        self.vocab_size = vocab_size
        self.n_embd = n_embd
        self.max_seq_len = max_seq_len

        # Token embedding (includes special tokens: PAD=0, BOS=1, EOS=2)
        self.tok_emb = nn.Embedding(vocab_size, n_embd)

        # Learnable "slot" embedding prepended to represent insertion at position 0
        self.slot_emb = nn.Parameter(torch.randn(1, 1, n_embd) * 0.02)

        # Positional embedding
        self.pos_emb = nn.Embedding(max_seq_len, n_embd)

        # Transformer blocks
        self.blocks = nn.ModuleList(
            [TransformerBlock(n_embd, n_heads, dropout) for _ in range(n_layers)]
        )

        # Final layer norm
        self.ln_f = nn.LayerNorm(n_embd)

        # Output heads:
        # - token_head: predicts which token to insert [n_embd -> vocab_size]
        # - position_head: predicts insertion position weight [n_embd -> 1]
        self.token_head = nn.Linear(n_embd, vocab_size, bias=False)
        self.position_head = nn.Linear(n_embd, 1, bias=False)

        # Weight tying: share token embeddings with output
        self.token_head.weight = self.tok_emb.weight

        # Initialize weights
        self.apply(self._init_weights)

    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
            if module.bias is not None:
                torch.nn.init.zeros_(module.bias)
        elif isinstance(module, nn.Embedding):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)

    def forward(
        self,
        hypo: torch.Tensor,  # [B, T] token indices
        hypo_len: torch.Tensor | None = None,  # [B] actual lengths
    ) -> dict[str, torch.Tensor]:
        """
        Forward pass.

        Args:
            hypo: [B, T] partial sequence (padded with PAD token)
            hypo_len: [B] actual lengths (optional, inferred if not provided)

        Returns:
            dict with:
                - 'position_logits': [B, T+1] logits for each insertion position
                - 'token_logits': [B, T+1, vocab_size] logits for each token at each position
                - 'insert_logp': [B, T+1, vocab_size] log P(insert token t at position p)
                - 'finish_logp': [B] log P(finish/EOS)
        """
        B, T = hypo.shape
        device = hypo.device

        # Infer lengths from padding if not provided
        if hypo_len is None:
            hypo_len = (hypo != PAD).sum(dim=1)

        # Create attention mask [B, T+1] - True for valid positions (including slot)
        # The slot (position 0) is always valid
        positions = torch.arange(T, device=device).unsqueeze(0)  # [1, T]
        token_mask = positions < hypo_len.unsqueeze(1)  # [B, T]
        # Prepend True for slot position
        mask = torch.cat(
            [torch.ones(B, 1, dtype=torch.bool, device=device), token_mask], dim=1
        )  # [B, T+1]

        # Token embeddings
        tok_emb = self.tok_emb(hypo)  # [B, T, n_embd]

        # Prepend slot embedding
        slot = self.slot_emb.expand(B, -1, -1)  # [B, 1, n_embd]
        x = torch.cat([slot, tok_emb], dim=1)  # [B, T+1, n_embd]

        # Add positional embeddings
        pos = torch.arange(T + 1, device=device)
        x = x + self.pos_emb(pos)

        # Apply transformer blocks
        for block in self.blocks:
            x = block(x, mask)

        x = self.ln_f(x)  # [B, T+1, n_embd]

        # Compute logits
        position_logits = self.position_head(x).squeeze(-1)  # [B, T+1]
        token_logits = self.token_head(x)  # [B, T+1, vocab_size]

        # Mask invalid positions (after sequence end + 1 for final insertion point)
        # Valid positions: 0 to hypo_len (inclusive), so T+1 positions total for full sequence
        pos_mask = torch.arange(T + 1, device=device).unsqueeze(
            0
        ) <= hypo_len.unsqueeze(1)  # [B, T+1]
        position_logits = position_logits.masked_fill(~pos_mask, float("-inf"))

        # Compute log probabilities
        # P(insert at pos p) comes from position_logits (softmax over positions)
        # P(insert token t | pos p) comes from token_logits (softmax over vocab)
        position_logp = F.log_softmax(position_logits, dim=-1)  # [B, T+1]
        token_logp = F.log_softmax(token_logits, dim=-1)  # [B, T+1, vocab_size]

        # Combined: log P(insert token t at position p) = log P(pos) + log P(token|pos)
        insert_logp = position_logp.unsqueeze(-1) + token_logp  # [B, T+1, vocab_size]

        # Finish probability: inserting EOS at the last valid position
        # Actually, in the paper, "finish" is predicted when we choose to insert EOS anywhere
        # We'll define finish_logp as the logsumexp of inserting EOS at any position
        eos_logp = insert_logp[
            :, :, EOS
        ]  # [B, T+1] - log prob of inserting EOS at each position
        finish_logp = torch.logsumexp(
            eos_logp.masked_fill(~pos_mask, float("-inf")), dim=1
        )  # [B]

        return {
            "position_logits": position_logits,
            "token_logits": token_logits,
            "insert_logp": insert_logp,
            "finish_logp": finish_logp,
        }

    def get_num_params(self) -> int:
        return sum(p.numel() for p in self.parameters())


# Test the model
model = InsertionTransformer(vocab_size=vocab_size)
model = model.to(DEVICE)
print(f"InsertionTransformer: {model.get_num_params() / 1e6:.2f}M parameters")

# Test forward pass
batch = get_batch("train", batch_size=4, block_size=16)
batch_hypo = batch.hypo.to(DEVICE)
batch_hypo_len = batch.hypo_len.to(DEVICE)

with torch.no_grad():
    out = model(batch_hypo, batch_hypo_len)

print(f"\nOutput shapes:")
print(f"  position_logits: {out['position_logits'].shape}")
print(f"  token_logits: {out['token_logits'].shape}")
print(f"  insert_logp: {out['insert_logp'].shape}")
print(f"  finish_logp: {out['finish_logp'].shape}")

InsertionTransformer: 10.76M parameters

Output shapes:
  position_logits: torch.Size([4, 17])
  token_logits: torch.Size([4, 17, 68])
  insert_logp: torch.Size([4, 17, 68])
  finish_logp: torch.Size([4])


In [61]:
def compute_loss(
    model: InsertionTransformer,
    batch: InsertionBatch,
    device: str = DEVICE,
) -> tuple[torch.Tensor, dict]:
    """
    Compute the insertion transformer loss.

    Key insight: We give credit for ANY valid insertion, not just the sampled one.
    This is done via logsumexp over all valid (position, token) pairs.

    Loss = -log P(any valid action)
         = -logsumexp_{(p,t) in valid} log P(insert t at p)

    For EOS steps (when hypo == ref), the valid action is finishing.
    """
    hypo = batch.hypo.to(device)
    hypo_len = batch.hypo_len.to(device)
    valid_mask = batch.valid_mask.to(device)  # [B, T+1, vocab_size]
    target_token = batch.target_token.to(device)  # [B]

    B, T = hypo.shape

    # Forward pass
    out = model(hypo, hypo_len)
    insert_logp = out["insert_logp"]  # [B, T+1, vocab_size]

    # Check if this is an EOS step (target is EOS)
    is_eos_step = target_token == EOS  # [B]

    # For non-EOS steps: loss = -logsumexp over valid insertions
    # Mask invalid insertions with -inf before logsumexp
    # valid_mask might be smaller than insert_logp if hypo is padded differently
    # We need to align them

    # Ensure valid_mask matches insert_logp shape
    if valid_mask.shape[1] < insert_logp.shape[1]:
        # Pad valid_mask
        pad_size = insert_logp.shape[1] - valid_mask.shape[1]
        valid_mask = F.pad(valid_mask, (0, 0, 0, pad_size), value=False)
    elif valid_mask.shape[1] > insert_logp.shape[1]:
        # Truncate valid_mask
        valid_mask = valid_mask[:, : insert_logp.shape[1], :]

    # Mask out invalid insertions
    masked_logp = insert_logp.masked_fill(~valid_mask, float("-inf"))

    # logsumexp over all valid (position, token) pairs
    # Flatten positions and tokens, then logsumexp
    logp_any_valid = torch.logsumexp(masked_logp.view(B, -1), dim=-1)  # [B]

    # For EOS steps, use finish_logp instead
    # Actually, for EOS steps, valid_mask marks EOS as valid at all positions
    # So logp_any_valid should already be correct

    # Loss is negative log probability
    loss = -logp_any_valid.mean()

    # Compute some metrics
    with torch.no_grad():
        # Accuracy: is the argmax a valid action?
        pred_flat = insert_logp.view(B, -1).argmax(dim=-1)  # [B]
        valid_flat = valid_mask.view(B, -1)  # [B, T*vocab]
        acc = valid_flat.gather(1, pred_flat.unsqueeze(1)).float().mean()

        # Perplexity-like metric
        ppl = torch.exp(loss)

    metrics = {
        "loss": loss.item(),
        "acc": acc.item(),
        "ppl": ppl.item(),
    }

    return loss, metrics


# Test loss computation
batch = get_batch("train", batch_size=8, block_size=16)
loss, metrics = compute_loss(model, batch)
print(f"Test loss computation:")
print(f"  Loss: {metrics['loss']:.4f}")
print(f"  Accuracy: {metrics['acc']:.4f}")
print(f"  Perplexity: {metrics['ppl']:.2f}")

Test loss computation:
  Loss: 4.2346
  Accuracy: 0.0000
  Perplexity: 69.03


In [62]:
@torch.no_grad()
def estimate_loss(model: InsertionTransformer) -> dict[str, float]:
    """Estimate loss over multiple batches for more stable metrics."""
    out = {}
    model.eval()
    for split in ["train", "val"]:
        losses = torch.zeros(EVAL_ITERS)
        accs = torch.zeros(EVAL_ITERS)
        for k in range(EVAL_ITERS):
            batch = get_batch(split, batch_size=BATCH_SIZE, block_size=BLOCK_SIZE)
            loss, metrics = compute_loss(model, batch)
            losses[k] = metrics["loss"]
            accs[k] = metrics["acc"]
        out[f"{split}_loss"] = losses.mean().item()
        out[f"{split}_acc"] = accs.mean().item()
    model.train()
    return out

In [63]:
def train(model: InsertionTransformer):
    """Train the insertion transformer."""
    optimizer = torch.optim.AdamW(model.parameters(), lr=LEARNING_RATE)

    for step in range(TRAINING_STEPS):
        # Periodic evaluation
        if step % EVAL_ITER_PERIOD == 0 or step == TRAINING_STEPS - 1:
            metrics = estimate_loss(model)
            print(
                f"Step {step:05d} | "
                f"Train loss: {metrics['train_loss']:.4f} | "
                f"Val loss: {metrics['val_loss']:.4f} | "
                f"Train acc: {metrics['train_acc']:.2%} | "
                f"Val acc: {metrics['val_acc']:.2%}"
            )
            # Generate a sample
            sample = generate(model, max_len=50, temperature=0.8)
            print(f"  Sample: '{decode(sample)[:60]}...'")
            print()

        # Training step
        batch = get_batch("train", batch_size=BATCH_SIZE, block_size=BLOCK_SIZE)
        loss, _ = compute_loss(model, batch)

        optimizer.zero_grad(set_to_none=True)
        loss.backward()
        optimizer.step()

    print("Training complete!")

In [64]:
@torch.no_grad()
def generate(
    model: InsertionTransformer,
    max_len: int = 50,
    temperature: float = 1.0,
    device: str = DEVICE,
) -> list[int]:
    """
    Generate a sequence using the insertion transformer.

    Starts from empty sequence and iteratively inserts tokens
    until EOS is predicted or max_len is reached.
    """
    model.eval()

    # Start with empty sequence
    hypo = []

    for _ in range(max_len):
        if len(hypo) == 0:
            # Empty sequence: create dummy tensor
            hypo_tensor = torch.zeros(1, 1, dtype=torch.long, device=device)
            hypo_len = torch.tensor([0], device=device)
        else:
            hypo_tensor = torch.tensor([hypo], dtype=torch.long, device=device)
            hypo_len = torch.tensor([len(hypo)], device=device)

        # Forward pass
        out = model(hypo_tensor, hypo_len)
        insert_logp = out["insert_logp"][0]  # [T+1, vocab_size]

        # Apply temperature
        if temperature != 1.0:
            insert_logp = insert_logp / temperature

        # Sample from the distribution
        # Flatten, sample, then unflatten
        num_positions = len(hypo) + 1
        logp_flat = insert_logp[:num_positions].reshape(
            -1
        )  # [num_positions * vocab_size]
        probs = F.softmax(logp_flat, dim=0)
        idx = torch.multinomial(probs, 1).item()

        # Decode position and token
        pos = int(idx // model.vocab_size)
        tok = int(idx % model.vocab_size)

        # Check for EOS
        if tok == EOS:
            break

        # Insert token
        hypo.insert(pos, tok)

    model.train()
    return hypo


# Test generation (with untrained model - will be random)
print("Generation test (untrained model - random output):")
generated = generate(model, max_len=30, temperature=1.0)
print(f"  Generated {len(generated)} tokens: '{decode(generated)}'")

Generation test (untrained model - random output):
  Generated 25 tokens: 'yCm,dbB3I-sw;RtNgFXcqthI'


In [65]:
# Create a fresh model and train it
model = InsertionTransformer(vocab_size=vocab_size)
model = model.to(DEVICE)
print(f"InsertionTransformer: {model.get_num_params() / 1e6:.2f}M parameters")
print(f"Training on {len(train_data):,} tokens\n")

train(model)

InsertionTransformer: 10.76M parameters
Training on 1,003,854 tokens

Step 00000 | Train loss: 4.2597 | Val loss: 4.2283 | Train acc: 2.25% | Val acc: 2.83%
  Sample: 'qbjMQjbPx...'



KeyboardInterrupt: 

In [None]:
# Generate some samples from the trained model
print("=== Generated Samples ===\n")
for i in range(5):
    sample = generate(model, max_len=100, temperature=0.8)
    print(f"Sample {i + 1}:")
    print(decode(sample))
    print("-" * 40)

=== Generated Samples ===

Sample 1:
 t the  the thee te he be
e thetheee te whe t t the th t e wathe
----------------------------------------
Sample 2:
 t hhe he th athe  thee the e the hat te t zte
hee t see thhheee
----------------------------------------
Sample 3:
 he thhe the te the the e se thhoe t the hee dhe hee  the tt the
----------------------------------------
Sample 4:
  t hee the the the be te t Ge wthhen tthe at the t the he  thhe
----------------------------------------
Sample 5:
e tos see be t the she be te the be t thathhe tht e t thee atoe 
----------------------------------------
