## Replication of Grokking experiments 
(checking the relation between attention logits)

# Datasets to test on
- WMT en-fr
- WMT en-gr
- WikiText-103

In [1]:
pip install -r requirements.txt

Collecting scikit-learn (from -r requirements.txt (line 9))
  Using cached scikit_learn-1.8.0-cp312-cp312-macosx_12_0_arm64.whl.metadata (11 kB)
Collecting scipy>=1.10.0 (from scikit-learn->-r requirements.txt (line 9))
  Using cached scipy-1.17.0-cp312-cp312-macosx_14_0_arm64.whl.metadata (62 kB)
Collecting threadpoolctl>=3.2.0 (from scikit-learn->-r requirements.txt (line 9))
  Using cached threadpoolctl-3.6.0-py3-none-any.whl.metadata (13 kB)
Using cached scikit_learn-1.8.0-cp312-cp312-macosx_12_0_arm64.whl (8.1 MB)
Using cached scipy-1.17.0-cp312-cp312-macosx_14_0_arm64.whl (20.1 MB)
Using cached threadpoolctl-3.6.0-py3-none-any.whl (18 kB)
Installing collected packages: threadpoolctl, scipy, scikit-learn
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m3/3[0m [scikit-learn][0m [scikit-learn]
[1A[2KSuccessfully installed scikit-learn-1.8.0 scipy-1.17.0 threadpoolctl-3.6.0
Note: you may need to restart the kernel to use updated packages.


In [2]:
import torch
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import math
from torch import nn
from torch.utils.data import DataLoader, Subset
from torch.nn.utils.rnn import pad_sequence
from tokenizers import Tokenizer
from nltk.translate.bleu_score import corpus_bleu
DEVICE = torch.device("mps" if torch.backends.mps.is_available() else "cpu")

In [3]:
from datasets import load_dataset

dataset = load_dataset('wmt14', 'fr-en')

train_data = dataset['train'].select(range(1000000))
test_data = dataset['test']
print(train_data[0])

def extract_text(dataset, src="en", tgt="fr"):
    for example in dataset:
        yield example["translation"][src]
        yield example["translation"][tgt]


  from .autonotebook import tqdm as notebook_tqdm


{'translation': {'en': 'Resumption of the session', 'fr': 'Reprise de la session'}}


In [4]:
from tokenizers import Tokenizer
from tokenizers.models import BPE
from tokenizers.trainers import BpeTrainer
from tokenizers.pre_tokenizers import Whitespace

tokenizer = Tokenizer(BPE(unk_token="[UNK]"))
tokenizer.pre_tokenizer = Whitespace()

trainer = BpeTrainer(
    vocab_size=32000,
    special_tokens=["[PAD]", "[UNK]", "[BOS]", "[EOS]"]
)

tokenizer.train_from_iterator(
    extract_text(train_data),
    trainer=trainer
)







In [5]:
PAD_ID = tokenizer.token_to_id("[PAD]")
BOS_ID = tokenizer.token_to_id("[BOS]")
EOS_ID = tokenizer.token_to_id("[EOS]")

def tokenize_sentence(sentence, add_special_tokens=True):
    encoding = tokenizer.encode(sentence)
    token_ids = encoding.ids

    if add_special_tokens:
        token_ids = [BOS_ID] + token_ids + [EOS_ID]

    return token_ids

In [6]:
from sklearn.metrics import f1_score
from nltk.translate.bleu_score import corpus_bleu
import numpy as np

def decode_to_tokens(ids_list, tokenizer):
    """
    Converts a list of token IDs into a list of string tokens,
    skipping special tokens like [PAD], [BOS], [EOS].
    """
    # .decode() handles the skipping of special tokens automatically
    # .split() turns the sentence string into a list of tokens for BLEU
    decoded_text = tokenizer.decode(ids_list, skip_special_tokens=True)
    return decoded_text.split()

In [7]:
from torch.nn.utils.rnn import pad_sequence
def collate_fn(batch):
    """
    batch: list of dicts from wmt14 
    Example item: {'translation': {'en': 'Hello', 'fr': 'Bonjour'}}
    """
    src_list = []
    tgt_list = []
    
    for item in batch:
        en_text = item['translation']['en']
        fr_text = item['translation']['fr']
        
        src_ids = tokenize_sentence(en_text)[:128] 
        tgt_ids = tokenize_sentence(fr_text)[:128]
        
        src_list.append(torch.tensor(src_ids, dtype=torch.long))
        tgt_list.append(torch.tensor(tgt_ids, dtype=torch.long))

    src_padded = pad_sequence(
        src_list,
        batch_first=True,
        padding_value=PAD_ID
    )

    tgt_padded = pad_sequence(
        tgt_list,
        batch_first=True,
        padding_value=PAD_ID
    )

    return src_padded, tgt_padded

In [8]:
from torch.utils.data import DataLoader

train_loader = DataLoader(
    train_data,
    batch_size=16,
    shuffle=True,
    collate_fn=collate_fn
)

test_loader = DataLoader(
    test_data,
    batch_size=16,
    shuffle=True,
    collate_fn=collate_fn
)

In [8]:
#def generate_compact_dataset(num_samples=10000):
#    data_tokens = torch.arange(1, 11) 
#    
#    all_inputs = []
#    all_targets = []
#
#    for _ in range(num_samples):
#        perm = torch.randperm(10)
#        sample_data = data_tokens[perm[:6]]
#        
#        is_relational = torch.rand(1) > 0.5
#        
#        if is_relational:
#            cmd = torch.tensor([12])
#            # Pick a key from the first 5 (so there is a neighbor at +1)
#            key_idx = torch.randint(0, 5, (1,)).item()
#            query = sample_data[key_idx].view(1)
#            target = sample_data[key_idx + 1]
#        else:
#            # POSITIONAL: Input[7] is an Index (1-6); Target is data at that index
#            cmd = torch.tensor([11])
#            idx_to_pull = torch.randint(0, 6, (1,)).item()
#            query = torch.tensor([idx_to_pull + 1])
#            target = sample_data[idx_to_pull]
#
#        full_input = torch.cat([sample_data, cmd, query])
#        
#        all_inputs.append(full_input)
#        all_targets.append(target)
#
#    return torch.stack(all_inputs), torch.stack(all_targets)
#
## Generate the 10,000 samples
#inputs, targets = generate_compact_dataset(10000)
#
#print(f"Dataset Shape: {inputs.shape}") # [10000, 8]
#print(f"Sample 0 (Input): {inputs[0].tolist()} -> Target: {targets[0].item()}")

In [9]:
def get_pe(seq_len, d_model):
    pe = torch.zeros(seq_len, d_model)
    position = torch.arange(0, seq_len).unsqueeze(1)
    div_term = torch.exp(torch.arange(0, d_model, 2) * -(math.log(1000000.0) / d_model))
    pe[:, 0::2] = torch.sin(position * div_term)
    pe[:, 1::2] = torch.cos(position * div_term)
    return pe.unsqueeze(0)

In [10]:
class embeddings(nn.Module):
    def __init__(self, d, vocab_size=32000, max_len=512):
        super().__init__()
        self.token_emb = nn.Embedding(vocab_size, d)
        self.register_buffer('pe', get_pe(max_len, d))

    def forward(self, x):
        te = self.token_emb(x)
        return te + self.pe[:, :x.size(1), :]

In [11]:
#indices = torch.randperm(len(inputs))
#
#train_size = int(0.5*len(inputs))
#
#train_idx = indices[:train_size]
#test_idx = indices[train_size:]
#
#train_inputs, train_targets = inputs[train_idx], targets[train_idx]
#test_inputs,  test_targets  = inputs[test_idx],  targets[test_idx]

In [11]:
class AttentionModule(nn.Module):
    def __init__(self, d_k=512, dropout=0.1):
        super().__init__()
        self.query_v1 = nn.Linear(d_k, d_k, bias=False)
        self.key_v1 = nn.Linear(d_k, d_k, bias=False)
        self.value_v1 = nn.Linear(d_k, d_k, bias=False)
        self.dropout = nn.Dropout(dropout)
        self.d_k = d_k

    def forward(self, x, mask=None):
        Q1 = self.query_v1(x)
        K1 = self.key_v1(x)
        V1 = self.value_v1(x)
        att1 = torch.matmul(Q1, K1.transpose(-2, -1)) / math.sqrt(self.d_k)

        if mask is not None:
            att1 = att1.masked_fill(mask == 0, -1e9)

        att_soft1 = torch.softmax(att1, dim=-1)
        att_soft1 = self.dropout(att_soft1)

        out1 = torch.matmul(att_soft1, V1)
        return out1

In [12]:
class ModelArchitecture(nn.Module):
    def __init__(self, vocab_size=32000, d_k=512, n_ff=2048, dropout=0.1):
        super().__init__()
        self.embedding = embeddings(512)
        self.attention = AttentionModule(d_k, dropout)
        self.norm1 = nn.LayerNorm(d_k)
        self.norm2 = nn.LayerNorm(d_k)
        self.dropout = nn.Dropout(dropout)
        self.mlp = nn.Sequential(
            nn.Linear(d_k, n_ff),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(n_ff, d_k)
        )

        self.unembed = nn.Linear(d_k, vocab_size, bias=False)

    def forward(self, x, mask=None):
        x_emb = self.embedding(x) * math.sqrt(512)
        att_out = self.attention(x_emb, mask)
        x = self.norm1(x_emb + self.dropout(att_out))
        mlp_out = self.mlp(x)
        x = self.norm2(x + self.dropout(mlp_out))
        
        logits = self.unembed(x)
        return logits

In [13]:
epochs = 5
DEVICE = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
model = ModelArchitecture().to(DEVICE)
criterion = nn.CrossEntropyLoss(label_smoothing=0.1, ignore_index=PAD_ID)

# Test without scheduled LR

In [14]:
#def noam_lambda(step, d_model=512, warmup=4000):
#    # The scheduler passes the current_step (starting at 0)
#    # Adding 1 prevents division by zero
#    step += 1 
#    return (d_model ** -0.5) * min(step ** -0.5, step * (warmup ** -1.5))
#
## Use 1.0 as the base LR so the lambda controls the absolute value
optimizer = torch.optim.Adam(model.parameters(), lr=3e-4, betas=(0.9, 0.98), eps=1e-9, weight_decay = 0.01)
#scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=noam_lambda)

In [19]:
model.to(DEVICE)
global_step = 0
log_interval = 1000
accumulation_steps = 4 

for epoch in range(epochs):
    model.train()
    optimizer.zero_grad() 
    
    for i, (src, tgt) in enumerate(train_loader):
        src, tgt = src.to(DEVICE), tgt.to(DEVICE)
        
        decoder_input  = tgt[:, :-1]  
        decoder_target = tgt[:, 1:] 
        logits = model(decoder_input) 

        # 1. Standard Loss Calculation
        loss = criterion(
            logits.reshape(-1, logits.size(-1)),
            decoder_target.reshape(-1)
        )
        
        # Scale loss so gradients are averaged across the full effective batch
        loss_scaled = loss / accumulation_steps 
        loss_scaled.backward()

        # Train Accuracy (for the current mini-batch)
        with torch.no_grad():
            pred_tokens = logits.argmax(-1)
            mask = decoder_target != PAD_ID
            train_acc = ((pred_tokens == decoder_target) & mask).float().sum() / mask.sum()

        # 2. Only update weights and LOG every 'accumulation_steps'
        if (i + 1) % accumulation_steps == 0:
            optimizer.step()
            # scheduler.step() # Uncomment if using Noam
            optimizer.zero_grad()
            
            global_step += 1
            
            # --- EVALUATION BLOCK: Now strictly inside the update check ---
            if global_step % log_interval == 0:
                model.eval()
                total_test_loss = 0
                all_preds_ids, all_trues_ids = [], []
                bleu_hypotheses, bleu_references = [], []
                
                with torch.no_grad():
                    for test_src, test_tgt in test_loader:
                        test_src, test_tgt = test_src.to(DEVICE), test_tgt.to(DEVICE)
                        dec_in_test, dec_tgt_test = test_tgt[:, :-1], test_tgt[:, 1:]
                        
                        test_logits = model(dec_in_test)
                        t_loss = criterion(test_logits.reshape(-1, test_logits.size(-1)), dec_tgt_test.reshape(-1))
                        total_test_loss += t_loss.item()

                        test_pred_ids = test_logits.argmax(-1)
                        test_mask = (dec_tgt_test != PAD_ID)
                        
                        all_preds_ids.extend(test_pred_ids[test_mask].cpu().numpy())
                        all_trues_ids.extend(dec_tgt_test[test_mask].cpu().numpy())

                        for b in range(test_pred_ids.size(0)):
                            hyp_tokens = decode_to_tokens(test_pred_ids[b].tolist(), tokenizer)
                            ref_tokens = decode_to_tokens(dec_tgt_test[b].tolist(), tokenizer)
                            bleu_hypotheses.append(hyp_tokens)
                            bleu_references.append([ref_tokens])

                # Metric Calculations
                avg_test_loss = total_test_loss / len(test_loader)
                test_acc = (np.array(all_preds_ids) == np.array(all_trues_ids)).mean()
                test_f1 = f1_score(all_trues_ids, all_preds_ids, average='weighted', zero_division=0)
                test_bleu = corpus_bleu(bleu_references, bleu_hypotheses)
                current_lr = optimizer.param_groups[0]['lr']
                
                output_str = (
                    f"| Step: {global_step:7d} | Epoch: {epoch:3d} | LR: {current_lr:.8f} |\n"
                    f"| Train Acc: {train_acc.item():.4f} | Train Loss: {loss.item():.4f} |\n"
                    f"| Test Acc:  {test_acc:.4f} | Test Loss:  {avg_test_loss:.4f} |\n"
                    f"| Test F1:   {test_f1:.4f} | Test BLEU:  {test_bleu:.4f} |\n"
                    f"{'-'*75}"
                )
                print(output_str)
                with open("step_results_logging.log", "a") as f:
                    f.write(output_str + "\n")
                
                model.train() # Switch back to training mode

| Step:    1000 | Epoch:   0 | LR: 0.00030000 |
| Train Acc: 0.6793 | Train Loss: 3.6288 |
| Test Acc:  0.5919 | Test Loss:  4.6960 |
| Test F1:   0.5121 | Test BLEU:  0.2429 |
---------------------------------------------------------------------------
| Step:    2000 | Epoch:   0 | LR: 0.00030000 |
| Train Acc: 0.6858 | Train Loss: 3.6420 |
| Test Acc:  0.5904 | Test Loss:  4.7003 |
| Test F1:   0.5265 | Test BLEU:  0.2416 |
---------------------------------------------------------------------------
| Step:    3000 | Epoch:   0 | LR: 0.00030000 |
| Train Acc: 0.6955 | Train Loss: 3.5681 |
| Test Acc:  0.5952 | Test Loss:  4.6223 |
| Test F1:   0.5411 | Test BLEU:  0.1208 |
---------------------------------------------------------------------------
| Step:    4000 | Epoch:   0 | LR: 0.00030000 |
| Train Acc: 0.7446 | Train Loss: 3.3325 |
| Test Acc:  0.5976 | Test Loss:  4.6008 |
| Test F1:   0.5446 | Test BLEU:  0.1308 |
----------------------------------------------------------------

In [None]:
torch.save(model.state_dict(), 'model_weights.pth')