In [None]:
# toy_attention_pointer.py
import math
import random
import torch
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt

device = "mps" if torch.mps.is_available() else "cpu"
torch.manual_seed(0)
random.seed(0)

#
# -----------------------------
# Data: "copy at index" task
# -----------------------------
def make_batch(batch_size, seq_len, vocab_tokens, device):
    """
    Sequence: [x0 x1 ... x_{L-1} Qk]
    Target: x_k
    """
    # tokens are 0..(V-1), query tokens are V..(V+L-1) where query V+k means "ask for position k"
    V = vocab_tokens
    L = seq_len

    x = torch.randint(0, V, (batch_size, L), device=device)
    k = torch.randint(0, L, (batch_size,), device=device)
    q = (V + k).unsqueeze(1)  # [B, 1]
    inp = torch.cat([x, q], dim=1)  # [B, L+1]
    tgt = x[torch.arange(batch_size, device=device), k]  # [B]
    return inp, tgt, k

def encode_example(x, k, V, L, device):
    assert len(x) == L
    inp = torch.tensor(x + [V + k], device=device).unsqueeze(0)  # [1, L+1]
    tgt = torch.tensor([x[k]], device=device)                    # [1]
    return inp, tgt


In [None]:
# -----------------------------
# Model: 1-head self-attention
# -----------------------------
class OneHeadPointer(nn.Module):
    def __init__(self, vocab_tokens, seq_len, d_model=32):
        super().__init__()
        self.V = vocab_tokens
        self.L = seq_len
        self.T = seq_len + 1  # incl. query position

        # Total vocab includes query tokens Q0..Q_{L-1}
        self.emb = nn.Embedding(vocab_tokens + seq_len, d_model)

        # learned positional embeddings (keep it simple)
        self.pos = nn.Parameter(torch.randn(self.T, d_model) * 0.02)

        # single-head attention projections
        self.Wq = nn.Linear(d_model, d_model, bias=False)
        self.Wk = nn.Linear(d_model, d_model, bias=False)
        self.Wv = nn.Linear(d_model, d_model, bias=False)

        # classify only from the *query position* output to original token vocab
        self.out = nn.Linear(d_model, vocab_tokens)

    def forward(self, input):
        """
        input: [B, T]
        returns:
          logits: [B, V]
          attn:   [B, T, T]  (attention weights for every position -> every position)
        """
        B, T = input.shape
        h = self.emb(input) + self.pos.unsqueeze(0)  # [B, T, D]

        Q = self.Wq(h)  # [B, T, D]
        K = self.Wk(h)  # [B, T, D]
        V = self.Wv(h)  # [B, T, D]

        scores = (Q @ K.transpose(-2, -1)) / math.sqrt(h.size(-1))  # [B, T, T]
        attn = F.softmax(scores, dim=-1)
        z = attn @ V  # [B, T, D]

        # take last position (query) and predict token
        zq = z[:, -1, :]  # [B, D]
        logits = self.out(zq)  # [B, V]
        return logits, attn



In [None]:
vocab_tokens = 12   # tokens: 0..11
seq_len = 8         # positions 0..7, query tokens Q0..Q7
d_model = 48
batch_size = 256
steps = 200
lr = 3e-3


In [None]:
fixed_inp, fixed_tgt, fixed_k = make_batch(1, seq_len, vocab_tokens, device)

print(f"input tokens: {fixed_inp[0].tolist()} (query token {fixed_inp[0, -1].item()} asks for position {fixed_k.item()})")
print(f"target token: {fixed_tgt.item()} (token at position {fixed_k.item()})")

In [None]:
def visualise_attention(query_attn, fixed_inp_cpu, vocab_tokens, seq_len):
    T = seq_len + 1
    tokens = fixed_inp_cpu[0].tolist()
    # label tokens: regular tokens as t#, query token as Qk
    labels = []
    for i, tok in enumerate(tokens):
        if i < seq_len:
            labels.append(f"x{i}:{tok}")
        else:
            labels.append(f"Q{tok - vocab_tokens}")

    plt.figure(figsize=(8, 2.2))
    plt.imshow(query_attn.unsqueeze(0), aspect="auto")
    plt.yticks([0], ["query"])
    plt.xticks(range(T), labels, rotation=45, ha="right")
    plt.title("Attention weights from query position")
    plt.colorbar()
    plt.tight_layout()
    plt.show()


In [None]:
# -----------------------------
# Train + visualize attention
# -----------------------------
model = OneHeadPointer(vocab_tokens, seq_len, d_model=d_model).to(device)
opt = torch.optim.Adam(model.parameters(), lr=lr)

# a fixed example we will visualize during training
fixed_inp, fixed_tgt, fixed_k = make_batch(1, seq_len, vocab_tokens, device)
fixed_inp_cpu = fixed_inp.detach().cpu()

for step in range(1, steps + 1):
    inp, tgt, _ = make_batch(batch_size, seq_len, vocab_tokens, device)
    logits, _ = model(inp)
    loss = F.cross_entropy(logits, tgt)

    opt.zero_grad()
    loss.backward()
    opt.step()

    if step % 20 == 0 or step == 1:
        with torch.no_grad():
            pred = logits.argmax(dim=-1)
            acc = (pred == tgt).float().mean().item()

            flogits, fattn = model(fixed_inp)
            fpred = flogits.argmax(dim=-1).item()
            fk = fixed_k.item()
            ftrue = fixed_tgt.item()

            print(f"step {step:4d} | loss {loss.item():.4f} | acc {acc*100:5.1f}% "
                    f"| fixed query k={fk} true={ftrue} pred={fpred}")
        if step < 100:
            visualise_attention(fattn[0, -1].detach().cpu(), fixed_inp_cpu, vocab_tokens, seq_len)


In [None]:
for k in range(seq_len):
    x = [3, 7, 1, 11, 0, 0, 9, 2]  # length L=8
#k = 3
    fixed_inp, fixed_tgt = encode_example(x, k, vocab_tokens, seq_len, device)

    visualise_attention(model(fixed_inp)[1][0, -1].detach().cpu(), fixed_inp_cpu, vocab_tokens, seq_len)


In [None]:
mask = torch.triu(torch.ones(4, 4), diagonal=1).bool()
mask