## Replication of Grokking experiments 
(checking the relation between attention logits)

In [1]:
pip install -r requirements.txt

Note: you may need to restart the kernel to use updated packages.


In [1]:
import torch
import math
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

In [2]:
from torch import nn

In [None]:
#torch.tensor(math.sin(math.exp(2*i*math.log(10,000)/d_k)))

In [3]:
def generate_compact_dataset(num_samples=10000):
    data_tokens = torch.arange(1, 11) 
    
    all_inputs = []
    all_targets = []

    for _ in range(num_samples):
        perm = torch.randperm(10)
        sample_data = data_tokens[perm[:6]]
        
        is_relational = torch.rand(1) > 0.5
        
        if is_relational:
            cmd = torch.tensor([12])
            # Pick a key from the first 5 (so there is a neighbor at +1)
            key_idx = torch.randint(0, 5, (1,)).item()
            query = sample_data[key_idx].view(1)
            target = sample_data[key_idx + 1]
        else:
            # POSITIONAL: Input[7] is an Index (1-6); Target is data at that index
            cmd = torch.tensor([11])
            idx_to_pull = torch.randint(0, 6, (1,)).item()
            query = torch.tensor([idx_to_pull + 1])
            target = sample_data[idx_to_pull]

        full_input = torch.cat([sample_data, cmd, query])
        
        all_inputs.append(full_input)
        all_targets.append(target)

    return torch.stack(all_inputs), torch.stack(all_targets)

# Generate the 10,000 samples
inputs, targets = generate_compact_dataset(10000)

print(f"Dataset Shape: {inputs.shape}") # [10000, 8]
print(f"Sample 0 (Input): {inputs[0].tolist()} -> Target: {targets[0].item()}")

Dataset Shape: torch.Size([10000, 8])
Sample 0 (Input): [2, 6, 3, 10, 5, 8, 11, 3] -> Target: 3


In [4]:
pe = torch.zeros([8, 64])
for j in range(3):
    for i in range(0, 64, 2):
        div_term = math.exp(-(math.log(10000.0) * i) / 64)
        pe[j, i] = math.sin(j * div_term)
        pe[j, i + 1] = math.cos(j * div_term)

In [5]:
class embeddings(nn.Module):
    def __init__(self, vocab_size:int = 13, d :int = 64):
        super().__init__()
        self.token_emb = nn.Embedding(vocab_size, d)

    def forward(self, vocab: list, pe):
        te = self.token_emb(vocab)
        return te+pe.to(te.dtype)

In [6]:
indices = torch.randperm(len(inputs))

train_size = int(0.3*len(inputs))

train_idx = indices[:train_size]
test_idx = indices[train_size:]

train_inputs, train_targets = inputs[train_idx], targets[train_idx]
test_inputs,  test_targets  = inputs[test_idx],  targets[test_idx]

In [12]:
class AttentionModule(nn.Module):
    def __init__(self, d_k = 64):
        super().__init__()
        self.query_v1 = nn.Parameter(torch.randn(d_k,d_k))
        self.key_v1 =  nn.Parameter(torch.randn(d_k,d_k))
        self.value_v1 = nn.Parameter(torch.randn(d_k,d_k))

    def forward(self, x):
        Q1 = x @ self.query_v1
        K1 = x @ self.key_v1
        V1 = x @ self.value_v1
        att1 = Q1@K1.transpose(-2, -1)/ math.sqrt(32)


        att_soft1 = torch.softmax(att1, dim = -1)

        out1 = att_soft1 @ V1

        return att1, out1

In [13]:
class ModelArchitecture(nn.Module):
    def __init__(self, n :int, d_k: int, attention: AttentionModule, embedding: embeddings):
        super().__init__()
        self.attention = attention
        self.embedding = embedding
        self.mlp = nn.Sequential(
            nn.Linear(d_k, n),
            nn.ReLU(),
            nn.Linear(n, d_k)
        )
        self.unembed = nn.Linear(d_k, 13, bias = False)

    def forward(self, x):
        att, out = self.attention(self.embedding(x, pe))
        output = self.mlp(out)
        logits = self.unembed(output)
        return att, logits

        

In [14]:
epochs = 20000

attention = AttentionModule()
embedding = embeddings()
model = ModelArchitecture(n = 32, d_k = 64, attention = attention, embedding = embedding)

optimizer = torch.optim.AdamW(model.parameters(), lr = 0.001, weight_decay = 1)#0.3, 0.5, 1, 3, 5, 8
criterion = nn.CrossEntropyLoss()

In [15]:

values = []
for epoch in range(epochs):
    model.train()
    optimizer.zero_grad()

    att, logits = model(train_inputs)
    predicted = logits[:, -1, :]
    loss = criterion(predicted, train_targets)
    train_acc = (predicted.argmax(-1) == train_targets).float().mean()

    loss.backward()
    optimizer.step()

    if epoch%100==0:
        model.eval()
        with torch.no_grad():
            att_test, test_logits = model(test_inputs)
            test_predicted = test_logits[:, -1, :]

            test_loss = criterion(test_predicted, test_targets)
            test_acc = (test_predicted.argmax(-1) == test_targets).float().mean()

        temp = {
            "train_att": att,
            "test_att": att_test,
            "train_loss": loss,
            "train_acc": train_acc,
            "test_loss": test_loss,
            "test_acc": test_acc
        }
        o = f"| Epoch: {epoch} | Train Accuracy: {train_acc} | Train Loss: {loss} | Test Accuracy: {test_acc} | Test Loss: {test_loss} |"
        print(o)
        if epoch%500==0:
            values.append(temp)
        with open('epoch_results_logging.log', 'a') as f:
            f.write(f"{o}\n")
        

| Epoch: 0 | Train Accuracy: 0.10166666656732559 | Train Loss: 3.1971819400787354 | Test Accuracy: 0.10614285618066788 | Test Loss: 2.9676084518432617 |
| Epoch: 100 | Train Accuracy: 0.1913333386182785 | Train Loss: 2.2216250896453857 | Test Accuracy: 0.16785714030265808 | Test Loss: 2.2733142375946045 |
| Epoch: 200 | Train Accuracy: 0.19599999487400055 | Train Loss: 2.2014172077178955 | Test Accuracy: 0.17014285922050476 | Test Loss: 2.2738969326019287 |
| Epoch: 300 | Train Accuracy: 0.20200000703334808 | Train Loss: 2.1853559017181396 | Test Accuracy: 0.167142853140831 | Test Loss: 2.272737979888916 |
| Epoch: 400 | Train Accuracy: 0.20999999344348907 | Train Loss: 2.168225049972534 | Test Accuracy: 0.1735714226961136 | Test Loss: 2.2743842601776123 |
| Epoch: 500 | Train Accuracy: 0.21533332765102386 | Train Loss: 2.1487319469451904 | Test Accuracy: 0.1744285672903061 | Test Loss: 2.2710790634155273 |
| Epoch: 600 | Train Accuracy: 0.21899999678134918 | Train Loss: 2.137470483779

In [None]:
torch.save(model.state_dict(), 'model_weights.pth')

In [17]:
import json
def tensor_handler(obj):
    if isinstance(obj, torch.Tensor):
        return obj.detach().cpu().tolist()

with open("results_1.json", 'w') as f:
    json.dump(values, f, indent=4, default=tensor_handler)

In [18]:
attention = AttentionModule()
embedding = embeddings()
model = ModelArchitecture(n=512, d_k=128, attention=attention, embedding=embedding)

model.load_state_dict(torch.load('model_weights.pth'))
model.eval()

ModelArchitecture(
  (attention): AttentionModule()
  (embedding): embeddings(
    (token_emb): Embedding(114, 128)
  )
  (mlp): Sequential(
    (0): Linear(in_features=128, out_features=512, bias=True)
    (1): ReLU()
    (2): Linear(in_features=512, out_features=128, bias=True)
  )
  (unembed): Linear(in_features=128, out_features=114, bias=False)
)