In [171]:

import torch
import torch.nn as nn
import torch.nn.functional as F
from matplotlib import pyplot as plt
import time
import random
import math

PATH = "models/grokking_97.bin"
LOAD_MODEL = False

torch.manual_seed(1)

inf = torch.inf

context_length = 4
P = 97
model_dim = 128 # dimension of the model -> residual stream
n_layers = 1 # no of layers
vocab_size = P + 2
n_heads = 4
learning_rate = 3e-4
max_iters = 30000
eval_iters = 1
batch_size = 64
weight_decay = 1e-3
training_data_percentage = 0.5

special_token_1 = P-1 + 1  # (a%b=c)
special_token_2 = P-1 + 2

special_token_1 = P-1 + 1  # (a%b=c)
special_token_2 = P-1 + 2

total_data = []
for i in range(P):
    for j in range(P):
        total_data.append([i,j])
random.shuffle(total_data)

total_data = torch.tensor(total_data)
total_datapoints = total_data.shape[0]
train_data = total_data[:int(total_datapoints*training_data_percentage)]
val_data = total_data[int(total_datapoints*training_data_percentage):]


def sample_data(split: str = "train"): # With replacement
    if split == "train":
        data = train_data
    else:
        data = val_data
    
    batch_size_ = min(data.shape[0],batch_size)

    X = torch.zeros(batch_size_, context_length).long()
    Y = torch.zeros(batch_size_, context_length).long()
    ix = torch.randint(data.shape[0] - context_length, (batch_size_,))
    batch_index = 0
    for i in ix:
        i = i.item()
        a = data[i][0].item()
        b = data[i][1].item()
        c = (a + b)%P
        x = [a,special_token_1,b,special_token_2]
        y = x[1:] + [c]

        X[batch_index, :len(x)] = torch.tensor(x)
        Y[batch_index, :len(y)] = torch.tensor(y)
        batch_index+=1
    return X, Y

@torch.no_grad()
def estimate_loss():
    out = {}
    model.eval()
    for split in ['train', 'val']:
        losses = torch.zeros(eval_iters)
        for k in range(eval_iters):
            X, Y = sample_data(split)
            _, loss = model(X, Y)
            losses[k] = loss.item()
        out[split] = losses.mean()
    model.train()
    
    return out["train"], out['val']


class Layer(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        

class AttentionHead(nn.Module):
    def __init__(self):
        super().__init__()
        self.head_dim = model_dim//n_heads
        self.key = nn.Linear(model_dim, self.head_dim)
        self.query = nn.Linear(model_dim, self.head_dim)
        self.value = nn.Linear(model_dim, self.head_dim)
        self.proj = nn.Linear(self.head_dim, model_dim)
    
    def forward(self, idx):
        key = self.key(idx) # (batch, context_length, head_dim)
        query = self.query(idx)
        value = self.value(idx) # (batch, context_length, head_dim)

        attention = (query@torch.transpose(key,1,2))/(math.sqrt(self.head_dim)) # (batch, context_length, context_length)

        attention = torch.tril(attention)

        attention = attention.masked_fill(attention == 0, -inf)

        attention = F.softmax(attention,-1) # probs along context_length sum to 1

        attention_value = attention@value  # (batch, context_length, head_dim)

        return self.proj(attention_value)  # (batch, context_length, model_dim)
    

class MultiHeadAttention(nn.Module):
    def __init__(self):
        super().__init__()
        self.heads = nn.ModuleList([AttentionHead() for i in range(n_heads)])
    
    def forward(self, idx):
        res_stream = idx
        for head in self.heads:
            idx = idx + head(res_stream)
        return idx
    

class MLP(nn.Module):
    def __init__(self):
        super().__init__()
        self.layers = nn.Sequential(nn.Linear(model_dim, 4*model_dim), nn.ReLU(), nn.Linear(4*model_dim, model_dim))
    
    def forward(self, idx):
        idx = idx + self.layers(idx)
        return idx

class Transformer(nn.Module):
    def __init__(self):
        super().__init__()
        self.token_embedding = nn.Embedding(vocab_size, model_dim)
        self.pos_embedding = nn.Embedding(context_length, model_dim)
        self.attention_layers = nn.ModuleList([MultiHeadAttention() for i in range(n_layers)])
        self.mlp_layers = nn.ModuleList([MLP() for i in range(n_layers)])
        self.unembed_layer = nn.Linear(model_dim,vocab_size)

        self.total_parameters = sum([p.numel() for p in self.parameters()])
        print(f"Model has {self.total_parameters//1000}k params")


    def forward(self, idx, targets = None):
        # idx -> [1,2,0,3..] (batch, context_length)

        # for p in range(idx.shape[0]):
        #     print([decode(idx[p].tolist()), decode(targets[p].tolist())])

        input_sequence_length = idx.shape[-1]

        residual_stream = self.token_embedding(idx)  # (batch, context_length, model_dim)
        residual_stream = residual_stream + self.pos_embedding(torch.tensor([i for i in range(input_sequence_length)])) # Pos embedding will be # (context_length, model_dim)
        
        for i in range(n_layers):
            residual_stream = self.attention_layers[i](residual_stream)
            residual_stream = self.mlp_layers[i](residual_stream)

        residual_stream = self.unembed_layer(residual_stream) # (batch, context_length, vocab_size)
        if targets is None:
            return residual_stream
        
        
        # residual_stream = residual_stream[:,9:,:]
        # targets = targets[:,9:]


        (x,y,z) = residual_stream.shape

        # print(residual_stream.shape, targets.shape)

        residual_stream = residual_stream[:,-1,:]
        targets= targets[:,-1]

        (x,z) = residual_stream.shape

        loss = F.cross_entropy(residual_stream.reshape(x,z), targets.reshape(x))
        return residual_stream, loss
    
model = Transformer()

LOAD_MODEL = True
if LOAD_MODEL:
    model = Transformer()
    model.load_state_dict(torch.load(PATH))
    model.eval()

print(model.eval())




Model has 224k params
Model has 224k params
Transformer(
  (token_embedding): Embedding(99, 128)
  (pos_embedding): Embedding(4, 128)
  (attention_layers): ModuleList(
    (0): MultiHeadAttention(
      (heads): ModuleList(
        (0): AttentionHead(
          (key): Linear(in_features=128, out_features=32, bias=True)
          (query): Linear(in_features=128, out_features=32, bias=True)
          (value): Linear(in_features=128, out_features=32, bias=True)
          (proj): Linear(in_features=32, out_features=128, bias=True)
        )
        (1): AttentionHead(
          (key): Linear(in_features=128, out_features=32, bias=True)
          (query): Linear(in_features=128, out_features=32, bias=True)
          (value): Linear(in_features=128, out_features=32, bias=True)
          (proj): Linear(in_features=32, out_features=128, bias=True)
        )
        (2): AttentionHead(
          (key): Linear(in_features=128, out_features=32, bias=True)
          (query): Linear(in_features=128

In [192]:
import torch.fft as fft
embed = model.token_embedding.weight.data
f = fft.fft(embed, dim =0,norm = "ortho")

def plot_freqs(f):

    import plotly.graph_objects as go

    fig = go.Figure()
    magnitudes = []
    base_magnitude = torch.norm(f[0]).item()
    total_freqs = (f.shape[0])//2

    
    for freq in range(total_freqs):
        channel_values = f[freq]
        channel_values -= f[0] # DC component
        magnitude = torch.norm(channel_values).item()
        magnitudes.append(magnitude)


    fig.add_trace(go.Bar(x=list(range(total_freqs)), y=magnitudes, name='FFT Magnitudes'))

    fig.show()

plot_freqs(f)



In [193]:
mlp_out = model.mlp_layers[0].layers[-1].weight.data #( model_dim, 4*model_dim )
unembed = model.unembed_layer.weight.data #(vocab , model_dim)

combine = unembed@mlp_out # (P, hidden_layers)
f = fft.fft(embed, dim =0,norm = "ortho")

plot_freqs(f)


torch.Size([128, 512])

In [132]:
a -= a[0].clone()

In [133]:
a

tensor([ 0.0000, -0.4998, -0.8212, -0.3718, -0.6604])