# Implement scaled dot-product attention for attention-pooling

In [1]:
import math, random
from dataclasses import dataclass
from typing import List, Dict, Tuple

import torch
import torch.nn as nn
import torch.nn.functional as F

In [2]:
# Setup
SEED = 42
random.seed(SEED)
torch.manual_seed(SEED)

DEVICE = torch.device("mps") if torch.backends.mps.is_available() else torch.device("cpu")

EMBED_DIM = 32
PAD_TOKEN = "<pad>"
UNK_TOKEN = "<unk>"

# Tiny corpus from Day 8 + a few negations
TOY_SENTENCES = [
    "hello there",
    "good morning",
    "hi friend",
    "good evening",
    "hey buddy",
    "how are you",
    "i love pizza",
    "pasta is tasty",
    "eating an apple",
    "the sandwich is good",
    "fresh salad",
    "i like sushi",
    "representation learning is fun",
    "representation learning is not fun",
]

In [3]:
def build_vocab(sentences: List[str], min_freq: int = 1) -> Dict[str, int]:
    freq = {}
    for s in sentences:
        for w in s.split():
            freq[w] = freq.get(w, 0) + 1
    words = [w for w, c in sorted(freq.items(), key=lambda x: (-x[1], x[0])) if c >= min_freq]
    vocab = {PAD_TOKEN: 0, UNK_TOKEN: 1}
    for w in words:
        vocab[w] = len(vocab)
    return vocab

def tokenize(s: str) -> List[str]: return s.strip().split()

def numericalize(tokens: List[str], vocab: Dict[str, int]) -> List[int]:
    return [vocab.get(t, vocab[UNK_TOKEN]) for t in tokens]

def pad_batch(batch_ids: List[List[int]], pad_id: int = 0) -> Tuple[torch.Tensor, torch.Tensor]:
    T = max(len(x) for x in batch_ids)
    padded, mask = [], []
    for ids in batch_ids:
        pad_len = T - len(ids)
        padded.append(ids + [pad_id] * pad_len)
        mask.append([1] * len(ids) + [0] * pad_len)
    X = torch.tensor(padded, dtype=torch.long)
    M = torch.tensor(mask, dtype=torch.bool)
    return X, M

In [4]:
VOCAB = build_vocab(TOY_SENTENCES)
ID2TOK = {i: t for t, i in VOCAB.items()}

In [5]:
# Mean Pool Encoder
class MeanPooler(nn.Module):
    def forward(self, x: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
        maskf = mask.float().unsqueeze(-1)   # [B,T,1]
        summed = (x * maskf).sum(1)          # [B,D]
        count = maskf.sum(1).clamp(min=1.0)  # [B,1]
        return summed / count

@dataclass
class EmbeddingModelConfig:
    vocab_size: int
    embed_dim: int
    pad_id: int

class TinyEmbeddingModel(nn.Module):
    def __init__(self, cfg: EmbeddingModelConfig):
        super().__init__()
        self.embedding = nn.Embedding(cfg.vocab_size, cfg.embed_dim, padding_idx=cfg.pad_id)
        self.pool = MeanPooler()
    def forward(self, token_ids: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
        emb = self.embedding(token_ids)      # [B,T,D]
        return self.pool(emb, mask)          # [B,D]

## Scaled Dot-Product Attention Pooler

In [None]:
#    Uses one learnable global query vector q.
class ScaledDotAttentionPooler(nn.Module):
    def __init__(self, dim: int):
        super().__init__()
        self.q = nn.Parameter(torch.randn(dim))   # global query [D]
        self.Wk = nn.Linear(dim, dim, bias=False) # maps token embeddings into a key space attuned to what the query will score
        self.Wv = nn.Linear(dim, dim, bias=False) # maps token embeddings into a value space that will be mixed according to those scores.
        # These are standard linear projections in attention; 
        # they let the model learn “what makes a token match the query” (via keys) 
        # and “what information to collect” (via values).
        self.scale = math.sqrt(dim)

    def forward(self, x: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
        """
        x:    [B,T,D] token embeddings
        mask: [B,T]   True for real tokens
        returns: [B,D] attended sentence vectors
        """
        K = self.Wk(x)                             # [B,T,D]
        V = self.Wv(x)                             # [B,T,D]
        # Broadcast query to batch: Q = [B,1,D]
        Q = self.q.expand(x.size(0), 1, -1)        # [B,1,D]
        # Scores: [B,1,T]
        scores = torch.matmul(Q, K.transpose(1, 2)) / self.scale # [B,1,T]
        # Mask PADs by setting to -inf before softmax
        mask_ = mask.unsqueeze(1)                  # [B,T] -> [B,1,T]
        scores = scores.masked_fill(~mask_, float("-inf")) # ~mask_ is True exactly at PAD positions
        attn = torch.softmax(scores, dim=-1)       # [B,1,T], gives probability 0 to PAD positions
        # Weighted sum over tokens
        out = torch.matmul(attn, V).squeeze(1)     # [B,1,D] -> [B,D]
        return out, attn.squeeze(1)                # return weights for inspection

Contrast with self-attention in a Transformer encoder: there, Q, K, V are all projections of the tokens themselves. Here, Q is a single learned vector, while K and V come from tokens—so it’s attention pooling, not full self-attention.

In [7]:
# Cosine Head vs Linear Head
class LinearHead(nn.Module):
    def __init__(self, dim: int):
        super().__init__()
        self.fc = nn.Linear(dim, 1)
    def forward(self, x):                          # x: [B,D]
        return self.fc(x).squeeze(-1)

class CosineHead(nn.Module):
    def __init__(self, dim: int, scale: float = 20.0, bias: bool = True):
        super().__init__()
        self.w = nn.Parameter(torch.randn(1, dim))  # single-class weight (binary)
        self.b = nn.Parameter(torch.zeros(1)) if bias else None
        self.scale = nn.Parameter(torch.tensor(float(scale)))  # learnable temperature
    def forward(self, x):                           # x: [B,D]
        x = F.normalize(x, dim=-1)
        w = F.normalize(self.w, dim=-1)
        logit = self.scale * (x @ w.t()).squeeze(-1)
        if self.b is not None:
            logit = logit + self.b
        return logit

In [29]:
# Sentence encoder with selectable pooling + head
class TinySentenceClassifier(nn.Module):
    def __init__(self, vocab_size, embed_dim, pad_id,
                 pool: str = "attention", head: str = "cosine",
                 trainable_embed: bool = True):
        super().__init__()
        cfg = EmbeddingModelConfig(vocab_size=vocab_size, embed_dim=embed_dim, pad_id=pad_id)
        self.embedding = nn.Embedding(cfg.vocab_size, cfg.embed_dim, padding_idx=cfg.pad_id)
        self.embedding.weight.requires_grad=trainable_embed
        self.attn = True if pool =='attention' else False
        if pool=="mean":
            self.pool = MeanPooler()
        elif pool=="attention":
            self.pool = ScaledDotAttentionPooler(dim=embed_dim)
        
        self.norm = nn.LayerNorm(embed_dim)
        if head == "linear":
            self.head = LinearHead(embed_dim)
        elif head=="cosine":
            self.head = CosineHead(embed_dim)

    def forward(self, x: torch.Tensor, mask: torch.Tensor):
        embed = self.embedding(x)
        if self.attn:
            embed, attn_w = self.pool(embed, mask)
        else:
            embed = self.pool(embed, mask)
            attn_w = None
        embed = self.norm(embed)
        y = self.head(embed).squeeze(-1)
        return y, attn_w


In [11]:
def make_batch(sentences: List[str], vocab: Dict[str,int], pad_id: int):
    tok = [tokenize(s) for s in sentences]
    ids = [numericalize(t, vocab) for t in tok]
    return pad_batch(ids, pad_id)

In [16]:
len(VOCAB), EMBED_DIM, VOCAB[PAD_TOKEN]

(33, 32, 0)

In [23]:
# Inspect attention weights on a sample
pad_id = VOCAB[PAD_TOKEN]
X, M = make_batch(["representation learning is not fun"], VOCAB, pad_id)
probe = TinySentenceClassifier(vocab_size=len(VOCAB), embed_dim=EMBED_DIM,
                                pad_id=pad_id, pool="attention", head="cosine",
                                trainable_embed=True).to(DEVICE)
probe.eval()

TinySentenceClassifier(
  (embedding): Embedding(33, 32, padding_idx=0)
  (pool): ScaledDotAttentionPooler(
    (Wk): Linear(in_features=32, out_features=32, bias=False)
    (Wv): Linear(in_features=32, out_features=32, bias=False)
  )
  (norm): LayerNorm((32,), eps=1e-05, elementwise_affine=True)
  (head): CosineHead()
)

In [30]:
with torch.no_grad():
    logits, attn = probe(X.to(DEVICE), M.to(DEVICE))
    print("Attention weights over tokens:")
    toks = [ID2TOK[i] for i in X[0].tolist()]
    print(logits)
    for t, w in zip(toks, attn[0].cpu().tolist()):
        print(f"{t:>12s} : {w:.3f}")

Attention weights over tokens:
tensor([-1.1906], device='mps:0')
representation : 0.203
    learning : 0.413
          is : 0.194
         not : 0.077
         fun : 0.113


In [None]:
# Tiny supervised setup (greetings vs food)
greeting = [
    "hello there","good morning","hi friend",
    "good evening","hey buddy","how are you"
]
food = [
    "i love pizza","pasta is tasty","eating an apple",
    "the sandwich is good","fresh salad","i like sushi"
]
SUP = greeting + food
LAB = [0]*len(greeting) + [1]*len(food)

X_all, M_all = make_batch(SUP, VOCAB, VOCAB[PAD_TOKEN])
y_all = torch.tensor(LAB, dtype=torch.float)

In [53]:
# Train loop (full-batch for simplicity)
def evaluate(model, X, M, y, thr=0.5):
    # loss, acc, brier, margin
    model.eval()
    logits, _ = model(X, M)
    loss = nn.BCEWithLogitsLoss()(logits, y)
    probs = torch.sigmoid(logits)
    acc = ((probs>=thr).float()==y).float().mean().item()
    brier = (probs - y).pow(2).mean().item()
    margin = (probs - 0.5).abs().mean().item()
    return loss.item(), acc, brier, margin


def train_full_batch(model, X, M, y, epochs=60, lr_head=3e-3, lr_emb=3e-4, warm_start_epochs=8, freeze_embeddings=False):
    X, M, y = X.to(DEVICE), M.to(DEVICE), y.to(DEVICE)
    model.to(DEVICE)
    embed_parameters = model.embedding.parameters()

    # Stage 1: train head only
    for p in embed_parameters: p.requires_grad = False
    opt = torch.optim.Adam([p for p in model.parameters() if p.requires_grad], lr=lr_head, weight_decay=1e-4)
    for e in range(warm_start_epochs):
        model.train()
        opt.zero_grad()
        logits, _ = model(X, M)
        loss = nn.BCEWithLogitsLoss()(logits, y)
        loss.backward()
        opt.step()
        L, A, B, G = evaluate(model, X, M, y)
        print(f"[Warm start epoch {e}] loss={L:.4f} acc={A:.3f} brier={B:.4f} margin={G:.3f}")
    
    # Stage 2: train all params
    if freeze_embeddings:
        opt = torch.optim.Adam([p for p in model.parameters() if p not in embed_parameters], lr_head, weight_decay=1e-4)
    else:
        for p in embed_parameters: p.requires_grad = True
        opt = torch.optim.Adam([{"params": [p for p in embed_parameters], "lr": lr_emb},
                                {"params": [p for p in model.parameters() if p not in embed_parameters], "lr": lr_head}]
                            , weight_decay=1e-4)
    
    for e in range(warm_start_epochs, epochs):
        model.train()
        opt.zero_grad()
        logits, _ = model(X, M)
        loss = nn.BCEWithLogitsLoss()(logits, y)
        loss.backward()
        opt.step()
        if e %5==0 or e==epochs-1:
            L, A, B, G = evaluate(model, X, M, y)
            print(f"[Train epoch {e}] loss={L:.4f} acc={A:.3f} brier={B:.4f} margin={G:.3f}")

    return model

## Try a cosine head vs a linear head for tiny classification

In [41]:
X_all, y_all

(tensor([[16, 31,  0,  0],
         [ 3, 22,  0,  0],
         [18, 15,  0,  0],
         [ 3, 13,  0,  0],
         [17, 11,  0,  0],
         [19, 10, 32,  0],
         [ 5, 21, 25,  0],
         [24,  2, 29,  0],
         [12,  8,  9,  0],
         [30, 27,  2,  3],
         [14, 26,  0,  0],
         [ 5, 20, 28,  0]]),
 tensor([0., 0., 0., 0., 0., 0., 1., 1., 1., 1., 1., 1.]))

In [38]:
# Train on the tiny supervised set
model_cos = TinySentenceClassifier(vocab_size=len(VOCAB), embed_dim=EMBED_DIM,
                            pad_id=pad_id, pool="attention", head="cosine",
                            trainable_embed=True)
model_cos = train_full_batch(model_cos, X_all, M_all, y_all)

[Warm start epoch 0] loss=0.2989 acc=0.833 brier=0.1019 margin=0.412
[Warm start epoch 1] loss=0.0834 acc=1.000 brier=0.0199 margin=0.431
[Warm start epoch 2] loss=0.0272 acc=1.000 brier=0.0017 margin=0.474
[Warm start epoch 3] loss=0.0208 acc=1.000 brier=0.0012 margin=0.480
[Warm start epoch 4] loss=0.0247 acc=1.000 brier=0.0021 margin=0.476
[Warm start epoch 5] loss=0.0316 acc=1.000 brier=0.0040 margin=0.471
[Warm start epoch 6] loss=0.0369 acc=1.000 brier=0.0057 margin=0.466
[Warm start epoch 7] loss=0.0368 acc=1.000 brier=0.0057 margin=0.466
[Train epoch 10] loss=0.0031 acc=1.000 brier=0.0000 margin=0.497
[Train epoch 15] loss=0.0016 acc=1.000 brier=0.0000 margin=0.498
[Train epoch 20] loss=0.0004 acc=1.000 brier=0.0000 margin=0.500
[Train epoch 25] loss=0.0001 acc=1.000 brier=0.0000 margin=0.500
[Train epoch 30] loss=0.0001 acc=1.000 brier=0.0000 margin=0.500
[Train epoch 35] loss=0.0001 acc=1.000 brier=0.0000 margin=0.500
[Train epoch 40] loss=0.0001 acc=1.000 brier=0.0000 margin

In [42]:
model_lin = TinySentenceClassifier(vocab_size=len(VOCAB), embed_dim=EMBED_DIM,
                            pad_id=pad_id, pool="attention", head="linear",
                            trainable_embed=True)
model_lin = train_full_batch(model_lin, X_all, M_all, y_all)

[Warm start epoch 0] loss=0.5954 acc=0.667 brier=0.2042 margin=0.155
[Warm start epoch 1] loss=0.5091 acc=0.750 brier=0.1656 margin=0.149
[Warm start epoch 2] loss=0.4319 acc=0.917 brier=0.1306 margin=0.166
[Warm start epoch 3] loss=0.3618 acc=1.000 brier=0.0998 margin=0.206
[Warm start epoch 4] loss=0.2998 acc=1.000 brier=0.0746 margin=0.249
[Warm start epoch 5] loss=0.2477 acc=1.000 brier=0.0558 margin=0.288
[Warm start epoch 6] loss=0.2062 acc=1.000 brier=0.0425 margin=0.321
[Warm start epoch 7] loss=0.1733 acc=1.000 brier=0.0329 margin=0.347
[Train epoch 10] loss=0.0875 acc=1.000 brier=0.0089 margin=0.417
[Train epoch 15] loss=0.0322 acc=1.000 brier=0.0012 margin=0.468
[Train epoch 20] loss=0.0149 acc=1.000 brier=0.0002 margin=0.485
[Train epoch 25] loss=0.0084 acc=1.000 brier=0.0001 margin=0.492
[Train epoch 30] loss=0.0055 acc=1.000 brier=0.0000 margin=0.494
[Train epoch 35] loss=0.0040 acc=1.000 brier=0.0000 margin=0.496
[Train epoch 40] loss=0.0031 acc=1.000 brier=0.0000 margin

In [43]:
model_cos0 = TinySentenceClassifier(vocab_size=len(VOCAB), embed_dim=EMBED_DIM,
                            pad_id=pad_id, pool="mean", head="cosine",
                            trainable_embed=True)
model_cos0 = train_full_batch(model_cos0, X_all, M_all, y_all)

[Warm start epoch 0] loss=1.9397 acc=0.750 brier=0.2839 margin=0.380
[Warm start epoch 1] loss=1.9388 acc=0.750 brier=0.2835 margin=0.380
[Warm start epoch 2] loss=1.9387 acc=0.750 brier=0.2830 margin=0.381
[Warm start epoch 3] loss=1.9390 acc=0.750 brier=0.2825 margin=0.381
[Warm start epoch 4] loss=1.9394 acc=0.750 brier=0.2820 margin=0.382
[Warm start epoch 5] loss=1.9399 acc=0.750 brier=0.2815 margin=0.382
[Warm start epoch 6] loss=1.9406 acc=0.750 brier=0.2810 margin=0.383
[Warm start epoch 7] loss=1.9412 acc=0.750 brier=0.2805 margin=0.384
[Train epoch 10] loss=1.9279 acc=0.750 brier=0.2806 margin=0.384
[Train epoch 15] loss=1.9155 acc=0.750 brier=0.2788 margin=0.386
[Train epoch 20] loss=1.9051 acc=0.750 brier=0.2772 margin=0.388
[Train epoch 25] loss=1.8953 acc=0.750 brier=0.2759 margin=0.389
[Train epoch 30] loss=1.8860 acc=0.750 brier=0.2747 margin=0.391
[Train epoch 35] loss=1.8766 acc=0.750 brier=0.2738 margin=0.393
[Train epoch 40] loss=1.8666 acc=0.750 brier=0.2730 margin

In [44]:
model_lin0 = TinySentenceClassifier(vocab_size=len(VOCAB), embed_dim=EMBED_DIM,
                            pad_id=pad_id, pool="mean", head="linear",
                            trainable_embed=True)
model_lin0 = train_full_batch(model_lin0, X_all, M_all, y_all)

[Warm start epoch 0] loss=0.7442 acc=0.667 brier=0.2725 margin=0.148
[Warm start epoch 1] loss=0.7353 acc=0.667 brier=0.2685 margin=0.149
[Warm start epoch 2] loss=0.7267 acc=0.667 brier=0.2645 margin=0.150
[Warm start epoch 3] loss=0.7182 acc=0.667 brier=0.2606 margin=0.151
[Warm start epoch 4] loss=0.7099 acc=0.667 brier=0.2568 margin=0.152
[Warm start epoch 5] loss=0.7017 acc=0.667 brier=0.2530 margin=0.153
[Warm start epoch 6] loss=0.6937 acc=0.667 brier=0.2493 margin=0.154
[Warm start epoch 7] loss=0.6859 acc=0.667 brier=0.2456 margin=0.155
[Train epoch 10] loss=0.6621 acc=0.667 brier=0.2345 margin=0.157
[Train epoch 15] loss=0.6260 acc=0.667 brier=0.2177 margin=0.160
[Train epoch 20] loss=0.5926 acc=0.667 brier=0.2023 margin=0.162
[Train epoch 25] loss=0.5616 acc=0.750 brier=0.1882 margin=0.164
[Train epoch 30] loss=0.5328 acc=0.750 brier=0.1751 margin=0.167
[Train epoch 35] loss=0.5060 acc=0.750 brier=0.1629 margin=0.170
[Train epoch 40] loss=0.4811 acc=0.833 brier=0.1516 margin

In [None]:
model_cos_no_embed = TinySentenceClassifier(vocab_size=len(VOCAB), embed_dim=EMBED_DIM,
                            pad_id=pad_id, pool="attention", head="cosine",
                            trainable_embed=False)


('embedding.weight',
 Parameter containing:
 tensor([[ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
         [-0.3877,  1.5380,  1.3613,  ..., -0.3128, -0.9121, -0.5456],
         [ 1.4923, -0.6918,  2.1975,  ..., -1.5608, -0.9082, -0.7317],
         ...,
         [-1.3483, -0.9353, -0.5343,  ..., -0.4552,  0.6070, -0.0176],
         [-0.1786, -0.1068,  0.5567,  ..., -0.7250, -1.7054, -2.2150],
         [ 0.0138,  0.8519,  1.1709,  ...,  0.6909,  2.5133,  1.1747]]))

In [54]:
model_cos_no_embed = train_full_batch(model_cos_no_embed, X_all, M_all, y_all, freeze_embeddings=True)

[Warm start epoch 0] loss=0.7461 acc=0.750 brier=0.2464 margin=0.315
[Warm start epoch 1] loss=0.3952 acc=0.750 brier=0.1375 margin=0.326
[Warm start epoch 2] loss=0.2006 acc=1.000 brier=0.0508 margin=0.334
[Warm start epoch 3] loss=0.1113 acc=1.000 brier=0.0190 margin=0.400
[Warm start epoch 4] loss=0.0683 acc=1.000 brier=0.0085 margin=0.436
[Warm start epoch 5] loss=0.0442 acc=1.000 brier=0.0042 margin=0.458
[Warm start epoch 6] loss=0.0292 acc=1.000 brier=0.0021 margin=0.472
[Warm start epoch 7] loss=0.0197 acc=1.000 brier=0.0010 margin=0.481
[Train epoch 10] loss=0.0022 acc=1.000 brier=0.0000 margin=0.498
[Train epoch 15] loss=0.0003 acc=1.000 brier=0.0000 margin=0.500
[Train epoch 20] loss=0.0001 acc=1.000 brier=0.0000 margin=0.500
[Train epoch 25] loss=0.0000 acc=1.000 brier=0.0000 margin=0.500
[Train epoch 30] loss=0.0000 acc=1.000 brier=0.0000 margin=0.500
[Train epoch 35] loss=0.0000 acc=1.000 brier=0.0000 margin=0.500
[Train epoch 40] loss=0.0000 acc=1.000 brier=0.0000 margin