# Build a tiny pipeline: token IDs -> embeddings -> mean pooling -> sentence vectors

In [1]:
import math
import random
from dataclasses import dataclass
from typing import List, Dict, Tuple

import torch
import torch.nn as nn
import torch.nn.functional as F

In [2]:
# Config
SEED = 42
random.seed(SEED)
torch.manual_seed(SEED)

EMBED_DIM = 32      # small on purpose (try 16/64 later)
PAD_TOKEN = "<pad>"
UNK_TOKEN = "<unk>"

In [3]:
# Tiny toy corpus + vocab
TOY_SENTENCES = [
    "hello world",
    "hello there",
    "spam ham eggs",
    "ham and eggs",
    "representation learning is fun",
    "embeddings with mean pooling",
    "we love simple baselines",
]


In [4]:
def build_vocab(sentences: List[str],
                min_freq: int = 1) -> Dict[str, int]:
    """
    Build a tiny word-level vocab: {token -> id}.
    Includes PAD and UNK.
    """
    freq = {}
    for s in sentences:
        for w in s.strip().split():
            freq[w] = freq.get(w, 0) + 1
    # sort by frequency (desc), then token (asc) for determinism
    words = [w for w, c in sorted(freq.items(), key=lambda x: (-x[1], x[0])) if c >= min_freq]
    vocab = {PAD_TOKEN: 0, UNK_TOKEN: 1}
    for w in words:
        vocab[w] = len(vocab)
    return vocab

VOCAB = build_vocab(TOY_SENTENCES)
ID2TOK = {i: t for t, i in VOCAB.items()}
ID2TOK

{0: '<pad>',
 1: '<unk>',
 2: 'eggs',
 3: 'ham',
 4: 'hello',
 5: 'and',
 6: 'baselines',
 7: 'embeddings',
 8: 'fun',
 9: 'is',
 10: 'learning',
 11: 'love',
 12: 'mean',
 13: 'pooling',
 14: 'representation',
 15: 'simple',
 16: 'spam',
 17: 'there',
 18: 'we',
 19: 'with',
 20: 'world'}

In [5]:
# Tokenization helpers
def tokenize(sentence: str) -> List[str]:
    return sentence.strip().split()

def numericalize(tokens: List[str], vocab: Dict[str, int]) -> List[int]:
    return [vocab.get(t, vocab[UNK_TOKEN]) for t in tokens]

def pad_batch(batch_ids: List[List[int]], pad_id: int = 0) -> Tuple[torch.Tensor, torch.Tensor]:
    """
    Pad to max length in batch.
    Returns:
        ids: LongTensor [B, T]
        mask: BoolTensor [B, T] where True = valid token (not PAD)
    """
    max_len = max(len(x) for x in batch_ids) if batch_ids else 0
    padded = []
    mask = []
    for seq in batch_ids:
        pad_len = max_len - len(seq)
        padded.append(seq + [pad_id] * pad_len)
        mask.append([1] * len(seq) + [0] * pad_len)
    return torch.tensor(padded, dtype=torch.long), torch.tensor(mask, dtype=torch.bool)


In [6]:
# Embedding + Mean Pooling
class MeanPooler(nn.Module):
    """
    Mean-pool token embeddings over non-pad tokens.
    """
    def __init__(self):
        super().__init__()

    def forward(self, x: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
        """
        x:    [B, T, D] embeddings
        mask: [B, T]    True for real tokens
        returns:
            sent_emb: [B, D]
        """
        mask_f = mask.float().unsqueeze(-1)     # [B, T, 1]
        summed = (x * mask_f).sum(dim=1)        # [B, D]
        counts = mask_f.sum(dim=1).clamp(min=1) # [B, 1]
        return summed / counts

In [7]:
@dataclass
class EmbeddingModelConfig:
    vocab_size: int
    embed_dim: int
    pad_id: int

In [8]:
class TinyEmbeddingModel(nn.Module):
    """
    A minimal embedding -> mean pooling encoder.
    """
    def __init__(self, cfg: EmbeddingModelConfig):
        super().__init__()
        self.embedding = nn.Embedding(cfg.vocab_size, cfg.embed_dim, padding_idx=cfg.pad_id)
        self.pool = MeanPooler()

    def forward(self, token_ids: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
        # token_ids: [B, T]
        emb = self.embedding(token_ids)         # [B, T, D]
        sent_vec = self.pool(emb, mask)         # [B, D]
        return sent_vec

In [9]:
# Cosine similarity helper
def cosine_sim_matrix(x: torch.Tensor) -> torch.Tensor:
    """
    x: [B, D]
    returns:
        cos_sim: [B, B] cosine similarity
    """
    x = F.normalize(x, p=2, dim=-1)
    return x @ x.transpose(0, 1)

In [10]:
def demo_mean_pooling(sentences: List[str]) -> None:
    print(f"Vocab size: {len(VOCAB)} | Embed dim: {EMBED_DIM}")
    print(f"Vocab tokens: {list(VOCAB.keys())}\n")

    # Build inputs
    tokenized = [tokenize(s) for s in sentences]
    ids = [numericalize(tok, VOCAB) for tok in tokenized]
    ids_tensor, mask = pad_batch(ids, pad_id=VOCAB[PAD_TOKEN])

    # Model
    cfg = EmbeddingModelConfig(vocab_size=len(VOCAB), embed_dim=EMBED_DIM, pad_id=VOCAB[PAD_TOKEN])
    model = TinyEmbeddingModel(cfg)

    # Forward
    with torch.no_grad():
        sent_vecs = model(ids_tensor, mask)   # [B, D]

    # Show shapes
    print(f"Input IDs shape:     {ids_tensor.shape}  (B,T)")
    print(f"Embeddings shape:    {[len(sentences), max(len(x) for x in ids)]} -> pooled to (B,D) = {sent_vecs.shape}")

    # Inspect a couple vectors
    for i, s in enumerate(sentences[:3]):
        print(f"\nSentence[{i}]: \"{s}\"")
        print(f"Vector (first 8 dims): {sent_vecs[i, :8].tolist()}")

    # Optional: cosine similarity between all sentence vectors
    cos = cosine_sim_matrix(sent_vecs)
    print("\nCosine similarity matrix (rounded to 3 dp):")
    with torch.no_grad():
        print(torch.round(cos * 1000) / 1000)

    # Map IDs back to tokens for the first batch item (debugging/learning)
    print("\nExample token reconstruction for sample[0]:")
    first_ids = ids_tensor[0].tolist()
    print([ID2TOK[i] for i in first_ids])

In [26]:
demo_mean_pooling(TOY_SENTENCES)

Vocab size: 21 | Embed dim: 32
Vocab tokens: ['<pad>', '<unk>', 'eggs', 'ham', 'hello', 'and', 'baselines', 'embeddings', 'fun', 'is', 'learning', 'love', 'mean', 'pooling', 'representation', 'simple', 'spam', 'there', 'we', 'with', 'world']

Input IDs shape:     torch.Size([7, 4])  (B,T)
Embeddings shape:    [7, 4] -> pooled to (B,D) = torch.Size([7, 32])

Sentence[0]: "hello world"
Vector (first 8 dims): [0.3262319564819336, 0.48671817779541016, 0.23869723081588745, -0.39600759744644165, -0.005489230155944824, 0.4366290867328644, -0.2679099440574646, 0.6758059859275818]

Sentence[1]: "hello there"
Vector (first 8 dims): [0.09677910804748535, 0.44325578212738037, -1.4011105298995972, -0.009056806564331055, -0.3794155716896057, 0.4231144189834595, 0.24345803260803223, 1.1096596717834473]

Sentence[2]: "spam ham eggs"
Vector (first 8 dims): [-0.8093934059143066, 0.5124570727348328, 0.7974427342414856, 0.2383490651845932, 0.8325245976448059, 0.19940675795078278, -0.23762141168117523, -0.

In [40]:
def compare_one_hot_projection(sentences):
    """
    - Create a one-hot representation for each token (size = vocab_size).
    - Learn a linear projection W: [vocab_size -> EMBED_DIM].
    - Show that one-hot @ W is mathematically equivalent to nn.Embedding (without padding handling).
    Hints:
      - Build an identity matrix I[vocab_size, vocab_size].
      - Pick rows with token IDs -> gives one-hot rows. [B,T,V]
      - W=model.embedding.weight[V,D]
    """
    tokenized = [tokenize(s) for s in sentences]
    ids = [numericalize(tok, VOCAB) for tok in tokenized]
    ids_tensor, mask = pad_batch(ids, pad_id=VOCAB[PAD_TOKEN])
    V, D, pad_id = len(VOCAB), EMBED_DIM, VOCAB[PAD_TOKEN]

    # 1) Build an Embedding model and grab its weight matrix W_embed (V,D)
    cfg = EmbeddingModelConfig(vocab_size=V, embed_dim=D, pad_id=pad_id)
    emb_model = TinyEmbeddingModel(cfg)
    W_embed = emb_model.embedding.weight.detach().clone()  # [V, D]
    # detach(): stops autograd from tracking the matmul with one-hot. clone(): makes a snapshot copy.

    # 2) Make one-hot for the same ids: (B,T,V)
    identity_matrix = torch.eye(len(VOCAB))
    one_hot = torch.stack([identity_matrix[ids_tensor[s],:] for s in range(len(ids_tensor))])
    # one_hot = F.one_hot(ids_tensor, num_classes=len(VOCAB)).float()

    # Project with W as a linear map: (B,T,V) @ (V,D) -> (B,T,D)
    proj_vecs = torch.matmul(one_hot, W_embed)  # [B,T,D]
    proj_vecs = proj_vecs.masked_fill(ids_tensor.eq(pad_id).unsqueeze(-1), 0.0)

    # Compare with nn.Embedding lookup:
    with torch.no_grad():
        lookup_vecs = emb_model.embedding(ids_tensor)  # [B,T,D]



    same = torch.allclose(proj_vecs, lookup_vecs, atol=1e-6)
    print(f"One-hot @ W equals Embedding lookup? {same}")
    


In [None]:
compare_one_hot_projection(TOY_SENTENCES)

One-hot @ W equals Embedding lookup? True


In [36]:
NEW_TOY_SENTENCES = [
    "eggs and ham",
    "representation learning is not fun",
    "we like simple baselines",
]

In [37]:
def get_padded_tensor(sentences):
    tokenized = [tokenize(s) for s in sentences]
    ids = [numericalize(tok, VOCAB) for tok in tokenized]
    ids_tensor, mask = pad_batch(ids, pad_id=VOCAB[PAD_TOKEN])
    return ids_tensor, mask

In [43]:
def sentence_similarity_probe(sentences):
    """
    - Add 2–3 new sentences that are paraphrases/near-duplicates of existing ones.
    - Re-run embeddings + mean pooling and look at cosine similarities for those pairs.
    - Reflect: mean pooling is order-invariant; what does that imply?
    """
    ids_tensor, mask = get_padded_tensor(sentences)

    # Model
    cfg = EmbeddingModelConfig(vocab_size=len(VOCAB), embed_dim=EMBED_DIM, pad_id=VOCAB[PAD_TOKEN])
    model = TinyEmbeddingModel(cfg)

    # Forward
    with torch.no_grad():
        sent_vecs = model(ids_tensor, mask)   # [B, D]

    # Get cosine similarity between all sentence vectors
    cos = cosine_sim_matrix(sent_vecs)
    # Set diagnal to 0.0
    cos.fill_diagonal_(0.)
    farthest, closest = torch.argmin(cos, dim=1), torch.argmax(cos, dim=1)
    for i in range(len(sentences)):
        print(f"Sentence: {sentences[i]} | Closest match: {sentences[closest[i]]} | Farthest: {sentences[farthest[i]]}")
    
    return sent_vecs, cos
    

In [44]:
sent_vecs, cos = sentence_similarity_probe(TOY_SENTENCES+NEW_TOY_SENTENCES)

Sentence: hello world | Closest match: hello there | Farthest: we like simple baselines
Sentence: hello there | Closest match: hello world | Farthest: we like simple baselines
Sentence: spam ham eggs | Closest match: ham and eggs | Farthest: we like simple baselines
Sentence: ham and eggs | Closest match: eggs and ham | Farthest: embeddings with mean pooling
Sentence: representation learning is fun | Closest match: representation learning is not fun | Farthest: we like simple baselines
Sentence: embeddings with mean pooling | Closest match: we love simple baselines | Farthest: representation learning is not fun
Sentence: we love simple baselines | Closest match: we like simple baselines | Farthest: hello world
Sentence: eggs and ham | Closest match: ham and eggs | Farthest: embeddings with mean pooling
Sentence: representation learning is not fun | Closest match: representation learning is fun | Farthest: embeddings with mean pooling
Sentence: we like simple baselines | Closest match: 

In [42]:
from sentence_transformers import SentenceTransformer
st = SentenceTransformer('sentence-transformers/stsb-bert-large')

In [47]:
st_vecs = st.encode(sentences=TOY_SENTENCES+NEW_TOY_SENTENCES, batch_size=64, convert_to_tensor=True,
    show_progress_bar=True,
    normalize_embeddings=True)

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

In [48]:
type(sent_vecs), type(st_vecs)

(torch.Tensor, torch.Tensor)

In [91]:
def compare_cosine_similarity(sent_vecs):
    # Get cosine similarity between all sentence vectors
    cos = cosine_sim_matrix(sent_vecs)
    # Set diagnal to 0.0
    cos.fill_diagonal_(0.)
    farthest, closest = torch.argmin(cos, dim=1), torch.argmax(cos, dim=1)
    
    return farthest, closest
    

def compare_two_embeddings(sentences, sent_vecs1, sent_vecs2):
    farthest1, closest1 = compare_cosine_similarity(sent_vecs1)
    farthest2, closest2 = compare_cosine_similarity(sent_vecs2)

    for i in range(len(sentences)):
        print(f"""Sentence: {sentences[i]}
              Embedding 1 Closest : {sentences[closest1[i]]} | Farthest: {sentences[farthest1[i]]}, 
              Embedding 2 Closest : {sentences[closest2[i]]} | Farthest: {sentences[farthest2[i]]}""")
    

In [None]:
compare_two_embeddings(TOY_SENTENCES+NEW_TOY_SENTENCES, sent_vecs, st_vecs)

Sentence: hello world
              Embedding 1 Closest : hello there | Farthest: we like simple baselines, 
              Embedding 2 Closest : hello there | Farthest: representation learning is not fun
Sentence: hello there
              Embedding 1 Closest : hello world | Farthest: we like simple baselines, 
              Embedding 2 Closest : hello world | Farthest: representation learning is not fun
Sentence: spam ham eggs
              Embedding 1 Closest : ham and eggs | Farthest: we like simple baselines, 
              Embedding 2 Closest : eggs and ham | Farthest: representation learning is fun
Sentence: ham and eggs
              Embedding 1 Closest : eggs and ham | Farthest: embeddings with mean pooling, 
              Embedding 2 Closest : eggs and ham | Farthest: representation learning is not fun
Sentence: representation learning is fun
              Embedding 1 Closest : representation learning is not fun | Farthest: we like simple baselines, 
              Embedding 2 

Observation: <br>
Lexical overlap dominates. <br>
Mean pooling dilutes “not”. <br>
SBERT variants are optimized for semantic relatedness/paraphrase rather than logical polarity. <br>
Negation forms (“not”, “n’t”) are frequent but often get weak, context-agnostic vectors; models underrepresent negation scope.

In [88]:
class WeightedMeanPooler(nn.Module):
    """
    IDF-style Weighted-mean-pool token embeddings over non-pad tokens.
    """
    def __init__(self):
        super().__init__()

    def forward(self, x: torch.Tensor, mask: torch.Tensor, weight: torch.Tensor) -> torch.Tensor:
        """
        x:    [B, T, D] embeddings
        mask: [B, T]    True for real tokens
        weight: [B, T]  IDF weight for each token
        returns:
            sent_emb: [B, D]
        """
        mask_f = mask.float().unsqueeze(-1)     # [B, T, 1]
        weight = weight.float().unsqueeze(-1)
        summed = (x * mask_f * weight).sum(dim=1)        # [B, D]
        counts = (mask_f * weight).sum(dim=1).clamp(min=1) # [B, 1]
        return summed / counts

In [93]:
TOY_SENTENCES2 = [
    "the cat sat on the mat",
    "the dog sat on the mat",
    "the cat chased the mouse",
    "the dog chased the ball",
    "representation learning is fun",
    "representation learning is not fun",
    "deep learning with embeddings"
]

In [103]:
VOCAB = build_vocab(TOY_SENTENCES2)
ID2TOK = {i: t for t, i in VOCAB.items()}

In [104]:
ids_tensor, mask = get_padded_tensor(TOY_SENTENCES2)
ids_tensor

tensor([[ 2,  4, 12, 10,  2,  9],
        [ 2,  6, 12, 10,  2,  9],
        [ 2,  4,  5,  2, 16,  0],
        [ 2,  6,  5,  2, 13,  0],
        [11,  3,  8,  7,  0,  0],
        [11,  3,  8, 17,  7,  0],
        [14,  3, 18, 15,  0,  0]])

In [105]:
output, counts = torch.unique(ids_tensor, return_counts=True)
output, counts

(tensor([ 0,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17, 18]),
 tensor([7, 8, 3, 2, 2, 2, 2, 2, 2, 2, 2, 2, 1, 1, 1, 1, 1, 1]))

In [106]:
max_id = torch.max(ids_tensor).item()
idf_vec = torch.zeros(max_id + 1)
for k, v in zip(output, counts):
    idf_vec[k] = 1/v
idf_vec


tensor([0.1429, 0.0000, 0.1250, 0.3333, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000,
        0.5000, 0.5000, 0.5000, 0.5000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
        1.0000])

In [107]:
weights = idf_vec[ids_tensor]
weights

tensor([[0.1250, 0.5000, 0.5000, 0.5000, 0.1250, 0.5000],
        [0.1250, 0.5000, 0.5000, 0.5000, 0.1250, 0.5000],
        [0.1250, 0.5000, 0.5000, 0.1250, 1.0000, 0.1429],
        [0.1250, 0.5000, 0.5000, 0.1250, 1.0000, 0.1429],
        [0.5000, 0.3333, 0.5000, 0.5000, 0.1429, 0.1429],
        [0.5000, 0.3333, 0.5000, 1.0000, 0.5000, 0.1429],
        [1.0000, 0.3333, 1.0000, 1.0000, 0.1429, 0.1429]])

In [89]:
class TinyWEmbeddingModel(nn.Module):
    """
    A minimal embedding -> weighted mean pooling encoder.
    """
    def __init__(self, cfg: EmbeddingModelConfig):
        super().__init__()
        self.embedding = nn.Embedding(cfg.vocab_size, cfg.embed_dim, padding_idx=cfg.pad_id)
        self.pool = WeightedMeanPooler()

    def forward(self, token_ids: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
        # token_ids: [B, T]
        output, counts = torch.unique(token_ids, return_counts=True)
        max_id = torch.max(token_ids).item()
        idf_vec = torch.zeros(max_id + 1)
        for k, v in zip(output, counts):
            idf_vec[k] = 1/v
        weights = idf_vec[token_ids]
        emb = self.embedding(token_ids)         # [B, T, D]
        sent_vec = self.pool(emb, mask, weights)         # [B, D]
        return sent_vec

In [86]:
def compare_weighted_mean_pooling(sentences):
    """
    - Implement simple IDF-style weights (fake IDF from this tiny corpus is fine).
    - Use weights in pooling: sum(e_i * w_i) / sum(w_i).
    - Compare similarities vs. plain mean pooling.
    """
    ids_tensor, mask = get_padded_tensor(sentences)

    # Model
    cfg = EmbeddingModelConfig(vocab_size=len(VOCAB), embed_dim=EMBED_DIM, pad_id=VOCAB[PAD_TOKEN])
    model = TinyEmbeddingModel(cfg)
    model_w = TinyWEmbeddingModel(cfg)

    # Forward
    with torch.no_grad():
        sent_vecs = model(ids_tensor, mask)   # [B, D]
        sent_vecs_w = model_w(ids_tensor, mask)

    compare_two_embeddings(sentences, sent_vecs, sent_vecs_w)

In [108]:
compare_weighted_mean_pooling(TOY_SENTENCES2)

Sentence: the cat sat on the mat
              Embedding 1 Closest : the dog sat on the mat | Farthest: representation learning is not fun, 
              Embedding 2 Closest : the dog sat on the mat | Farthest: deep learning with embeddings
Sentence: the dog sat on the mat
              Embedding 1 Closest : the cat sat on the mat | Farthest: representation learning is not fun, 
              Embedding 2 Closest : the cat sat on the mat | Farthest: deep learning with embeddings
Sentence: the cat chased the mouse
              Embedding 1 Closest : the cat sat on the mat | Farthest: representation learning is not fun, 
              Embedding 2 Closest : the dog chased the ball | Farthest: the cat chased the mouse
Sentence: the dog chased the ball
              Embedding 1 Closest : the cat chased the mouse | Farthest: representation learning is not fun, 
              Embedding 2 Closest : the cat chased the mouse | Farthest: the dog chased the ball
Sentence: representation learning i

In [109]:
# Class 0 = greetings
greeting_sentences = [
    "hello there",
    "good morning",
    "hi friend",
    "good evening",
    "hey buddy",
    "how are you"
]

# Class 1 = food
food_sentences = [
    "i love pizza",
    "pasta is tasty",
    "eating an apple",
    "the sandwich is good",
    "fresh salad",
    "i like sushi"
]

TOY_SUPERVISED = greeting_sentences + food_sentences
TOY_LABELS = [0]*len(greeting_sentences) + [1]*len(food_sentences)

In [111]:
VOCAB = build_vocab(TOY_SUPERVISED)
ID2TOK = {i: t for t, i in VOCAB.items()}
ID2TOK

{0: '<pad>',
 1: '<unk>',
 2: 'good',
 3: 'i',
 4: 'is',
 5: 'an',
 6: 'apple',
 7: 'are',
 8: 'buddy',
 9: 'eating',
 10: 'evening',
 11: 'fresh',
 12: 'friend',
 13: 'hello',
 14: 'hey',
 15: 'hi',
 16: 'how',
 17: 'like',
 18: 'love',
 19: 'morning',
 20: 'pasta',
 21: 'pizza',
 22: 'salad',
 23: 'sandwich',
 24: 'sushi',
 25: 'tasty',
 26: 'the',
 27: 'there',
 28: 'you'}

In [113]:
torch.tensor(TOY_LABELS)

tensor([0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1])

In [121]:
from transformers import AutoTokenizer

name = "sentence-transformers/all-MiniLM-L6-v2"  # example
st1 = SentenceTransformer(name)
tokenizer = AutoTokenizer.from_pretrained(name)
pad_id = tokenizer.pad_token_id
pad_id

modules.json:   0%|          | 0.00/349 [00:00<?, ?B/s]

config_sentence_transformers.json:   0%|          | 0.00/116 [00:00<?, ?B/s]

README.md: 0.00B [00:00, ?B/s]

sentence_bert_config.json:   0%|          | 0.00/53.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/612 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/90.9M [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/350 [00:00<?, ?B/s]

vocab.txt: 0.00B [00:00, ?B/s]

tokenizer.json: 0.00B [00:00, ?B/s]

special_tokens_map.json:   0%|          | 0.00/112 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/190 [00:00<?, ?B/s]

0

In [130]:
word_emb = st1[0].auto_model.get_input_embeddings()  # nn.Embedding
pretrained_weights = word_emb.weight.detach().clone()  # [V, D]
pretrained_weights.shape[1]

384

In [150]:
class TinyEmbeddingMLP(nn.Module):
    def __init__(self, vocab_size, embed_dim, pad_id, trainable_embed=True, pretrained_weights=None):
        super().__init__()
        cfg = EmbeddingModelConfig(vocab_size=vocab_size, embed_dim=embed_dim, pad_id=pad_id)
        self.embed_model = TinyEmbeddingModel(cfg)
        if pretrained_weights is not None:
            self.embed_model.embedding = nn.Embedding.from_pretrained(
                embeddings=pretrained_weights,  # [V, D]
                freeze=True,                    # freeze the matrix
                padding_idx=pad_id              # PAD outputs zeros & gets no grads
            )
            embed_dim = pretrained_weights.shape[1]
        self.embed_model.embedding.weight.requires_grad = trainable_embed
        self.ln = nn.LayerNorm(embed_dim)  # learnable scale/shift
        self.final = nn.Linear(embed_dim, 1)

    def forward(self, x, mask=None):
        # x: [B,T] of token ids (padded with pad_id)
        if mask is None:
            # build mask on the fly if not provided
            pad_id = self.embed_model.embedding.padding_idx
            mask = x.ne(pad_id)  # True where token is real
        emb = self.embed_model(x, mask)
        emb = self.ln(emb)
        y = self.final(emb).squeeze(-1)       # [B]
        return y

In [157]:
def train_only(X, mask, y, epochs, model, lr=1e-3, weight_decay=0.0, patience=5, min_delta=1e-3):
    """
    Full-batch training for the tiny dataset.
    X:   [B,T] token ids (padded)
    mask:[B,T] bool mask (True for real tokens)
    y:   [B]   labels in {0,1}
    """
    criterion = nn.BCEWithLogitsLoss()
    optimizer = torch.optim.Adam((p for p in model.parameters() if p.requires_grad), lr=lr, weight_decay=weight_decay)
    history = {"train_loss": [], "train_acc": []}
    best_loss = float("inf")

    device = torch.device("mps") if torch.backends.mps.is_available() else torch.device("cpu")
    model.to(device)
    X = X.to(device)
    mask = mask.to(device)
    y = y.to(device).float()
    wait = 0

    for e in range(epochs):
        model.train()
        optimizer.zero_grad()
        logits = model(X, mask)
        loss = criterion(logits, y) # labels: [B] in {0,1}
        loss.backward()
        optimizer.step()

        with torch.no_grad():
            probs = torch.sigmoid(logits)
            preds = (probs >= 0.5).float()
            acc = (preds==y).float().mean().item()
            avg_loss = loss.item()
        
        # early stopping
        if best_loss == float("inf"):
            best_loss = avg_loss
        else:
            if abs(avg_loss-best_loss)<min_delta:
                wait += 1
            else:
                wait = 0
            best_loss = min(best_loss, avg_loss)

            if wait >= patience:
                print(f"Early stop at epoch {e+1} | best_loss={best_loss:.4f}")
                wait = 0
                break
        if wait>0 or e % 5 ==0 or e==epochs-1:
            print(f"Epoch {e}: train loss = {avg_loss:>6f}, train acc = {acc:>4f}")
        history["train_loss"].append(avg_loss)
        history["train_acc"].append(acc)

    return model, history


In [141]:
pad_id = VOCAB[PAD_TOKEN]
ids_tensor, mask = get_padded_tensor(TOY_SUPERVISED)
labels = torch.tensor(TOY_LABELS, dtype=torch.float32)

In [None]:
ids_tensor

tensor([[13, 27,  0,  0],
        [ 2, 19,  0,  0],
        [15, 12,  0,  0],
        [ 2, 10,  0,  0],
        [14,  8,  0,  0],
        [16,  7, 28,  0],
        [ 3, 18, 21,  0],
        [20,  4, 25,  0],
        [ 9,  5,  6,  0],
        [26, 23,  4,  2],
        [11, 22,  0,  0],
        [ 3, 17, 24,  0]])

In [158]:
model_f = TinyEmbeddingMLP(vocab_size=len(VOCAB), embed_dim=EMBED_DIM, pad_id=pad_id, trainable_embed=False, pretrained_weights=pretrained_weights)
model_ff, _ = train_only(ids_tensor, mask, labels, 50, model_f, lr=1e-4, weight_decay=0.01, patience=5, min_delta=1e-3)

Epoch 0: train loss = 0.693727, train acc = 0.500000
Epoch 5: train loss = 0.682599, train acc = 0.416667
Epoch 10: train loss = 0.674563, train acc = 0.500000
Epoch 15: train loss = 0.668308, train acc = 0.833333
Epoch 20: train loss = 0.662627, train acc = 0.833333
Epoch 25: train loss = 0.656915, train acc = 0.833333
Epoch 30: train loss = 0.651115, train acc = 0.833333
Epoch 35: train loss = 0.645354, train acc = 0.833333
Epoch 40: train loss = 0.639691, train acc = 0.833333
Epoch 45: train loss = 0.634095, train acc = 0.833333
Epoch 49: train loss = 0.629640, train acc = 0.833333


In [147]:
model_ff.to(torch.device("cpu"))
ids_tensorX = ids_tensor.to(torch.device("cpu"))
mask = mask.to(torch.device("cpu"))

In [149]:
with torch.no_grad():
    # 1) Are sentence vectors non-constant?
    sent = model_ff.embed_model(ids_tensor, mask)    # [B, D]
    print("sent mean abs:", sent.abs().mean().item())
    print("sent var mean:", sent.var(dim=0).mean().item())
    print("unique rows:", torch.unique(sent, dim=0).size(0), "/", sent.size(0))

    # 2) OOV rate (if you have <unk>)
    unk_id = VOCAB.get("<unk>")
    if unk_id is not None:
        oov_frac = (ids_tensor == unk_id).float().mean().item()
        print("OOV fraction:", round(oov_frac, 3))

    # 3) PAD sanity
    pad_id = VOCAB.get("<pad>")
    print("contains PAD?", bool((ids_tensor == pad_id).any()))
    # verify mask logic: True must mean "real token"
    print("mask true frac:", mask.float().mean().item())

    # 4) Head has grads?
    for n, p in model_ff.named_parameters():
        if "final" in n:
            print(n, "requires_grad=", p.requires_grad)

sent mean abs: 0.01762998104095459
sent var mean: 3.860611468553543e-05
unique rows: 12 / 12
OOV fraction: 0.0
contains PAD? True
mask true frac: 0.6458333134651184
final.weight requires_grad= True
final.bias requires_grad= True


In [170]:
def freeze_vs_train(sentences, labels, pretrained_weights, epochs=50, lr=1e-3, weight_decay=0.0, patience=5, min_delta=1e-3):
    """
    - Create a tiny supervised objective (e.g., classify {greeting vs. food}).
    - Compare:
        (A) Frozen embeddings + linear head
        (B) Trainable embeddings end-to-end
    - Observe which converges faster / achieves higher training accuracy.
    """
    pad_id = VOCAB[PAD_TOKEN]
    ids_tensor, mask = get_padded_tensor(sentences)
    labels = torch.tensor(labels, dtype=torch.float32)

    # freeze
    print("Train Frozen embeddings + linear head")
    model_f = TinyEmbeddingMLP(vocab_size=len(VOCAB), embed_dim=EMBED_DIM, pad_id=pad_id, trainable_embed=False, pretrained_weights=pretrained_weights)
    model_f, _ = train_only(ids_tensor, mask, labels, epochs, model_f, lr=lr, weight_decay=weight_decay, patience=patience, min_delta=min_delta)

    # trainable
    print("Trainable embeddings end-to-end")
    model_t = TinyEmbeddingMLP(vocab_size=len(VOCAB), embed_dim=EMBED_DIM, pad_id=pad_id, trainable_embed=True)
    model_t, _ = train_only(ids_tensor, mask, labels, epochs, model_t, lr=lr, weight_decay=weight_decay, patience=patience, min_delta=min_delta)

    return model_f, model_t
    

In [177]:
trained_f, trained_t = freeze_vs_train(TOY_SUPERVISED, TOY_LABELS, pretrained_weights, epochs=30)

Train Frozen embeddings + linear head
Epoch 0: train loss = 0.721469, train acc = 0.416667
Epoch 5: train loss = 0.660002, train acc = 0.750000
Epoch 10: train loss = 0.603310, train acc = 0.916667
Epoch 15: train loss = 0.550458, train acc = 0.916667
Epoch 20: train loss = 0.501928, train acc = 1.000000
Epoch 25: train loss = 0.457929, train acc = 1.000000
Epoch 29: train loss = 0.425927, train acc = 1.000000
Trainable embeddings end-to-end
Epoch 0: train loss = 0.620102, train acc = 0.500000
Epoch 5: train loss = 0.591766, train acc = 0.833333
Epoch 10: train loss = 0.564369, train acc = 0.833333
Epoch 15: train loss = 0.537933, train acc = 0.833333
Epoch 20: train loss = 0.512431, train acc = 0.916667
Epoch 25: train loss = 0.487797, train acc = 0.916667
Epoch 29: train loss = 0.468660, train acc = 0.916667


In [178]:
trained_t.to(torch.device("cpu"))
trained_f.to(torch.device("cpu"))

TinyEmbeddingMLP(
  (embed_model): TinyEmbeddingModel(
    (embedding): Embedding(30522, 384, padding_idx=0)
    (pool): MeanPooler()
  )
  (ln): LayerNorm((384,), eps=1e-05, elementwise_affine=True)
  (final): Linear(in_features=384, out_features=1, bias=True)
)

In [179]:
with torch.no_grad():
    logits = trained_t(ids_tensor, mask)
    probs = torch.sigmoid(logits)
    acc = ((probs >= 0.5)==labels).float().mean().item()
    brier = (probs - labels.float()).pow(2).mean().item()   # calibration/CE proxy
    margin = (probs - 0.5).abs().mean().item()         # confidence

    logits2 = trained_f(ids_tensor, mask)
    probs2 = torch.sigmoid(logits2)
    acc2 = ((probs2 >= 0.5)==labels).float().mean().item()
    brier2 = (probs2 - labels.float()).pow(2).mean().item()   # calibration/CE proxy
    margin2 = (probs2 - 0.5).abs().mean().item()         # confidence
print(f"Trainable: acc={acc:.3f} | brier={brier:.4f} | avg_margin={margin:.3f}")
print(f"Frozen Embed: acc={acc2:.3f} | brier={brier2:.4f} | avg_margin={margin2:.3f}")

Trainable: acc=0.917 | brier=0.1405 | avg_margin=0.135
Frozen Embed: acc=1.000 | brier=0.1181 | avg_margin=0.160


In [180]:
probs

tensor([0.2518, 0.4355, 0.2279, 0.3937, 0.4797, 0.3974, 0.6241, 0.7083, 0.6723,
        0.6700, 0.6255, 0.4947])

In [181]:
probs2

tensor([0.3298, 0.2663, 0.3687, 0.3251, 0.3341, 0.4291, 0.6884, 0.6979, 0.6976,
        0.5441, 0.6839, 0.6646])

### Try 2-stage warm start training

In [185]:
def train_warm_start(X, mask, y, model, epochs):
    criterion = nn.BCEWithLogitsLoss()
    history = {"train_loss": [], "train_acc": []}
    best_loss = float("inf")

    device = torch.device("mps") if torch.backends.mps.is_available() else torch.device("cpu")
    model.to(device)
    X = X.to(device)
    mask = mask.to(device)
    y = y.to(device).float()
    wait = 0


    # Stage 1: freeze embeddings, train head only
    for p in model.embed_model.embedding.parameters(): p.requires_grad = False
    optimizer = torch.optim.Adam(model.final.parameters(), lr=5e-3)

    for e in range(10):
        model.train()
        optimizer.zero_grad()
        logits = model(X, mask)
        loss = criterion(logits, y) # labels: [B] in {0,1}
        loss.backward()
        optimizer.step()

        with torch.no_grad():
            probs = torch.sigmoid(logits)
            preds = (probs >= 0.5).float()
            acc = (preds==y).float().mean().item()
            avg_loss = loss.item()
        
        print(f"Epoch {e}: train loss = {avg_loss:>6f}, train acc = {acc:>4f}")
        history["train_loss"].append(avg_loss)
        history["train_acc"].append(acc)

    # Stage 2: unfreeze with smaller LR for emb, larger for head
    for p in model.embed_model.embedding.parameters(): p.requires_grad = True
    optimizer = torch.optim.Adam([
        {"params": model.embed_model.embedding.parameters(), "lr": 5e-4},
        {"params": list(model.final.parameters()) + list(model.ln.parameters()), "lr": 5e-3},
    ], weight_decay=1e-4)

    for e in range(10, epochs):
        optimizer.zero_grad()
        logits = model(X, mask)
        loss = criterion(logits, y) # labels: [B] in {0,1}
        loss.backward()
        optimizer.step()

        with torch.no_grad():
            probs = torch.sigmoid(logits)
            preds = (probs >= 0.5).float()
            acc = (preds==y).float().mean().item()
            avg_loss = loss.item()
        
        # early stopping
        if best_loss == float("inf"):
            best_loss = avg_loss
        else:
            if abs(avg_loss-best_loss)<1e-3:
                wait += 1
            else:
                wait = 0
            best_loss = min(best_loss, avg_loss)

            if wait >= 5:
                print(f"Early stop at epoch {e+1} | best_loss={best_loss:.4f}")
                wait = 0
                break
        if wait>0 or e % 5 ==0 or e==epochs-1:
            print(f"Epoch {e}: train loss = {avg_loss:>6f}, train acc = {acc:>4f}")
        history["train_loss"].append(avg_loss)
        history["train_acc"].append(acc)
    
    return model

In [186]:
model_t_warm = TinyEmbeddingMLP(vocab_size=len(VOCAB), embed_dim=EMBED_DIM, pad_id=pad_id, trainable_embed=True)
model_t_warm = train_warm_start(ids_tensor, mask, labels, model_t_warm, 50)

Epoch 0: train loss = 0.736225, train acc = 0.333333
Epoch 1: train loss = 0.713766, train acc = 0.416667
Epoch 2: train loss = 0.691947, train acc = 0.416667
Epoch 3: train loss = 0.670756, train acc = 0.666667
Epoch 4: train loss = 0.650195, train acc = 0.666667
Epoch 5: train loss = 0.630271, train acc = 0.666667
Epoch 6: train loss = 0.610988, train acc = 0.666667
Epoch 7: train loss = 0.592346, train acc = 0.666667
Epoch 8: train loss = 0.574341, train acc = 0.666667
Epoch 9: train loss = 0.556966, train acc = 0.666667
Epoch 10: train loss = 0.540208, train acc = 0.750000
Epoch 15: train loss = 0.451410, train acc = 0.916667
Epoch 20: train loss = 0.376267, train acc = 0.916667
Epoch 25: train loss = 0.312388, train acc = 1.000000
Epoch 30: train loss = 0.258114, train acc = 1.000000
Epoch 35: train loss = 0.212386, train acc = 1.000000
Epoch 40: train loss = 0.174256, train acc = 1.000000
Epoch 45: train loss = 0.142873, train acc = 1.000000
Epoch 49: train loss = 0.121967, train

In [187]:
model_t_warm.to(torch.device("cpu"))
with torch.no_grad():
    logits = model_t_warm(ids_tensor, mask)
    probs = torch.sigmoid(logits)
    acc = ((probs >= 0.5)==labels).float().mean().item()
    brier = (probs - labels.float()).pow(2).mean().item()   # calibration/CE proxy
    margin = (probs - 0.5).abs().mean().item()         # confidence

print(f"Trainable: acc={acc:.3f} | brier={brier:.4f} | avg_margin={margin:.3f}")

Trainable: acc=1.000 | brier=0.0142 | avg_margin=0.391
