## Replication of Grokking experiments 
(checking the relation between attention logits)

In [1]:
pip install -r requirements.txt

Note: you may need to restart the kernel to use updated packages.


In [2]:
import torch
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import math
from torch import nn
from torch.utils.data import DataLoader, Subset
from torch.nn.utils.rnn import pad_sequence
from tokenizers import Tokenizer
from nltk.translate.bleu_score import corpus_bleu
DEVICE = torch.device("mps" if torch.backends.mps.is_available() else "cpu")

In [3]:
from datasets import load_dataset

dataset = load_dataset('wmt14', 'fr-en')

train_data = dataset['train'].select(range(10000))
test_data = dataset['test']
print(train_data[0])

def extract_text(dataset, src="en", tgt="fr"):
    for example in dataset:
        yield example["translation"][src]
        yield example["translation"][tgt]


  from .autonotebook import tqdm as notebook_tqdm


{'translation': {'en': 'Resumption of the session', 'fr': 'Reprise de la session'}}


In [4]:
from tokenizers import Tokenizer
from tokenizers.models import BPE
from tokenizers.trainers import BpeTrainer
from tokenizers.pre_tokenizers import Whitespace

tokenizer = Tokenizer(BPE(unk_token="[UNK]"))
tokenizer.pre_tokenizer = Whitespace()

trainer = BpeTrainer(
    vocab_size=37000,
    special_tokens=["[PAD]", "[UNK]", "[BOS]", "[EOS]"]
)

tokenizer.train_from_iterator(
    extract_text(train_data),
    trainer=trainer
)







In [5]:
PAD_ID = tokenizer.token_to_id("[PAD]")
BOS_ID = tokenizer.token_to_id("[BOS]")
EOS_ID = tokenizer.token_to_id("[EOS]")

def tokenize_sentence(sentence, add_special_tokens=True):
    encoding = tokenizer.encode(sentence)
    token_ids = encoding.ids

    if add_special_tokens:
        token_ids = [BOS_ID] + token_ids + [EOS_ID]

    return token_ids

In [6]:
from torch.nn.utils.rnn import pad_sequence
def collate_fn(batch):
    """
    batch: list of dicts from wmt14 
    Example item: {'translation': {'en': 'Hello', 'fr': 'Bonjour'}}
    """
    src_list = []
    tgt_list = []
    
    for item in batch:
        en_text = item['translation']['en']
        fr_text = item['translation']['fr']
        
        src_ids = tokenize_sentence(en_text)[:128] 
        tgt_ids = tokenize_sentence(fr_text)[:128]
        
        src_list.append(torch.tensor(src_ids, dtype=torch.long))
        tgt_list.append(torch.tensor(tgt_ids, dtype=torch.long))

    src_padded = pad_sequence(
        src_list,
        batch_first=True,
        padding_value=PAD_ID
    )

    tgt_padded = pad_sequence(
        tgt_list,
        batch_first=True,
        padding_value=PAD_ID
    )

    return src_padded, tgt_padded

In [7]:
from torch.utils.data import DataLoader

train_loader = DataLoader(
    train_data,
    batch_size=16,
    shuffle=True,
    collate_fn=collate_fn
)

test_loader = DataLoader(
    test_data,
    batch_size=16,
    shuffle=True,
    collate_fn=collate_fn
)

In [8]:
#def generate_compact_dataset(num_samples=10000):
#    data_tokens = torch.arange(1, 11) 
#    
#    all_inputs = []
#    all_targets = []
#
#    for _ in range(num_samples):
#        perm = torch.randperm(10)
#        sample_data = data_tokens[perm[:6]]
#        
#        is_relational = torch.rand(1) > 0.5
#        
#        if is_relational:
#            cmd = torch.tensor([12])
#            # Pick a key from the first 5 (so there is a neighbor at +1)
#            key_idx = torch.randint(0, 5, (1,)).item()
#            query = sample_data[key_idx].view(1)
#            target = sample_data[key_idx + 1]
#        else:
#            # POSITIONAL: Input[7] is an Index (1-6); Target is data at that index
#            cmd = torch.tensor([11])
#            idx_to_pull = torch.randint(0, 6, (1,)).item()
#            query = torch.tensor([idx_to_pull + 1])
#            target = sample_data[idx_to_pull]
#
#        full_input = torch.cat([sample_data, cmd, query])
#        
#        all_inputs.append(full_input)
#        all_targets.append(target)
#
#    return torch.stack(all_inputs), torch.stack(all_targets)
#
## Generate the 10,000 samples
#inputs, targets = generate_compact_dataset(10000)
#
#print(f"Dataset Shape: {inputs.shape}") # [10000, 8]
#print(f"Sample 0 (Input): {inputs[0].tolist()} -> Target: {targets[0].item()}")

In [9]:
def get_pe(seq_len, d_model):
    pe = torch.zeros(seq_len, d_model)
    position = torch.arange(0, seq_len).unsqueeze(1)
    div_term = torch.exp(torch.arange(0, d_model, 2) * -(math.log(10000.0) / d_model))
    pe[:, 0::2] = torch.sin(position * div_term)
    pe[:, 1::2] = torch.cos(position * div_term)
    return pe.unsqueeze(0)

In [10]:
class embeddings(nn.Module):
    def __init__(self, d, vocab_size=37000, max_len=512):
        super().__init__()
        self.token_emb = nn.Embedding(vocab_size, d)
        self.register_buffer('pe', get_pe(max_len, d))

    def forward(self, x):
        te = self.token_emb(x)
        return te + self.pe[:, :x.size(1), :]

In [11]:
#indices = torch.randperm(len(inputs))
#
#train_size = int(0.5*len(inputs))
#
#train_idx = indices[:train_size]
#test_idx = indices[train_size:]
#
#train_inputs, train_targets = inputs[train_idx], targets[train_idx]
#test_inputs,  test_targets  = inputs[test_idx],  targets[test_idx]

In [12]:
class AttentionModule(nn.Module):
    def __init__(self, d_k = 512):
        super().__init__()
        self.query_v1 = nn.Parameter(torch.randn(d_k,d_k))
        self.key_v1 =  nn.Parameter(torch.randn(d_k,d_k))
        self.value_v1 = nn.Parameter(torch.randn(d_k,d_k))

    def forward(self, x):
        Q1 = x @ self.query_v1
        K1 = x @ self.key_v1
        V1 = x @ self.value_v1
        att1 = Q1@K1.transpose(-2, -1)/ math.sqrt(512)


        att_soft1 = torch.softmax(att1, dim = -1)

        out1 = att_soft1 @ V1

        return out1

In [13]:
class ModelArchitecture(nn.Module):
    def __init__(self, n :int, d_k: int, attention: AttentionModule, embedding: embeddings):
        super().__init__()
        self.attention = attention
        self.embedding = embedding
        self.mlp = nn.Sequential(
            nn.Linear(d_k, n),
            nn.ReLU(),
            nn.Linear(n, d_k)
        )
        self.unembed = nn.Linear(d_k, 37000, bias = False)

    def forward(self, x):
        out = self.attention(self.embedding(x))
        output = self.mlp(out)
        logits = self.unembed(output)
        return logits

        

In [15]:
epochs = 20

attention = AttentionModule()
embedding = embeddings(512)
model = ModelArchitecture(n = 1024, d_k = 512, attention = attention, embedding = embedding)

optimizer = torch.optim.AdamW(model.parameters(), lr = 0.001)#0.3, 0.5, 1, 3, 5, 8
criterion = nn.CrossEntropyLoss()

In [17]:

criterion = nn.CrossEntropyLoss(ignore_index=PAD_ID)
model.to(DEVICE)

for i in range(epochs):
    model.train()
    total_train_loss = 0
    total_train_acc = 0
    for src, tgt in train_loader:
        src, tgt = src.to(DEVICE), tgt.to(DEVICE)
        
        optimizer.zero_grad()

        decoder_input  = tgt[:, :-1]  
        decoder_target = tgt[:, 1:] 
        logits = model(decoder_input) 

        loss = criterion(
            logits.reshape(-1, logits.size(-1)),
            decoder_target.reshape(-1)
        )

        pred_tokens = logits.argmax(-1)
        mask = decoder_target != PAD_ID
        train_acc = ((pred_tokens == decoder_target) & mask).float().sum() / mask.sum()

        loss.backward()
        optimizer.step()
        
        total_train_loss += loss.item()
        total_train_acc += train_acc.item()

    calculate_bleu = (i % 5 == 0) or (i == epochs - 1)

    model.eval()
    total_test_loss = 0
    total_test_acc = 0

    all_preds = []
    all_targets = []
    bleu_computed = False


    with torch.no_grad():
        for batch_idx, (test_src, test_tgt) in enumerate(test_loader):
            test_src, test_tgt = test_src.to(DEVICE), test_tgt.to(DEVICE)

            dec_in_test  = test_tgt[:, :-1]
            dec_tgt_test = test_tgt[:, 1:]

            test_logits = model(dec_in_test)

            t_loss = criterion(
                test_logits.view(-1, test_logits.size(-1)),
                dec_tgt_test.reshape(-1)
            )
            total_test_loss += t_loss.item()

            test_pred = test_logits.argmax(-1)
            test_mask = dec_tgt_test != PAD_ID
            batch_acc = ((test_pred == dec_tgt_test) & test_mask).float().sum() / test_mask.sum()
            total_test_acc += batch_acc.item()


    avg_train_loss = total_train_loss / len(train_loader)
    avg_train_acc  = total_train_acc / len(train_loader)

    avg_test_loss = total_test_loss / len(test_loader)
    avg_test_acc  = total_test_acc / len(test_loader)

    output_str = (
        f"| i: {i} "
        f"| Train Acc: {avg_train_acc:.4f} "
        f"| Train Loss: {avg_train_loss:.4f} "
        f"| Test Acc: {avg_test_acc:.4f} "
        f"| Test Loss: {avg_test_loss:.4f} "
    )

    print(output_str)

    with open("epoch_results_logging.log", "a") as f:
        f.write(output_str + "\n")

| i: 0 | Train Acc: 0.1156 | Train Loss: 5.7436 | Test Acc: 0.0808 | Test Loss: 8.4606 


KeyboardInterrupt: 

In [None]:
torch.save(model.state_dict(), 'model_weights.pth')