<a href="https://colab.research.google.com/github/sadisticaudio/Modular_Attention/blob/main/Modular_Attention.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Transformer Modular Addition Through A Signal Processing Lens

this analysis is based on previous work by Neel Nanda et al, "Progress Measures for Grokking via Mechanistic Interpretability".\
in that work, a model learns to implement an algorithm for modular addition that generalizes to unseen data and exhibits grokking behaviour.\
the hypothesis here is that the model is not using multiplication of trig terms to perform a rotation to perform the task,\
but that symmetries required by the algorithm are formed in the attention layer that are responsible for the model's ability to generalize.

In [None]:
import os
import sys

cwd = os.getcwd()

!git clone https://github.com/sadisticaudio/checkpaint.git checkpaint_clone
if cwd + '/checkpaint_clone/template/src' not in sys.path: sys.path.append(cwd + '/checkpaint_clone/template/src')

if 'google.colab' in sys.modules:
  !pip install "ipywidgets==7.8.5"
  from google.colab import output
  output.enable_custom_widget_manager()
  !pip install transformer_lens pylinalg pythreejs

import torch
from torch.nn import functional as F
import numpy as np
import einops
from transformer_lens import HookedTransformer, HookedTransformerConfig

import checkpaint
from checkpaint.utils import *
from checkpaint.c_hooks import *
from checkpaint import line_plot as lp

p, d_model, d_mlp, n_heads, d_head = 113, 128, 512, 4, 32
### CPU IS BEST. IN COLAB FOR SURE. WE ARE NOT DOING ANYTHING HEAVY HERE AND TRANSFERRING TO/FROM GPU MAKES PLOTTING A BIT SLOWER ###
### CHANGE IF YOU'D LIKE ###
device = "cpu"
### THIS IS FROM NEEL NANDA'S ORIGINAL RUN, ONLY THE FINAL CHECKPOINT SAVED IN THIS FILE TO SAVE COLAB DISK USAGE
### THE ORIGINAL FILE WAS NAMED "full_run_data.pth"
end_run_data = torch.load(cwd + '/checkpaint_clone/end_run_data.pth', map_location=torch.device(device))

prange, pprange = torch.arange(p).to(device), torch.arange(p*p).to(device)
is_train, is_test = get_old_indices(device=device)
train_indices, test_indices = pprange[is_train], pprange[is_test]
dataset, labels = get_data(device=device)

In [None]:
def showVector(x,**kwargs):
    x = [inputs_last(t) for t in x] if isinstance(x, list) else inputs_last(x)
    lp.draw_vector(x, **kwargs)

In [None]:
### CREATE A CACHE STARTING WITH THE WEIGHTS OF THE FINAL MODEL IN THERE
cache = squeeze_cache(CacheDict(end_run_data["state_dicts"][-1]))

cfg = HookedTransformerConfig(
    n_layers = 1, n_heads = 4, d_model = 128, d_head = 32, d_mlp = 512, act_fn = "relu", normalization_type=None,
    d_vocab=p+1, d_vocab_out=p, n_ctx=3, init_weights=True, device=device, seed = 598,
)
model = HookedTransformer(cfg)
hooked_state_dict = model.state_dict()

for name, x in hooked_state_dict.items():
    if not name in cache and "b_" in name: cache[name] = torch.zeros_like(x)
    if name == "unembed.W_U": cache[name] = cache[name][:,:-1]
    if "mask" in name: cache[name] = x
    if name in cache and cache[name].shape != x.shape: cache[name] = cache[name].transpose(-2,-1)
    if "W_O" in name: cache[name] = cache[name].reshape(x.shape)
    if not name in cache and "IGNORE" in name: cache[name] = x

model.load_state_dict(cache)
### RUN THE MODEL, ADD THE LOGITS TO THE CACHE
cache["hook_logits"], hooked_cache = model.run_with_cache(dataset)
### ADD ALL OF THE ACTIVATIONS OF THE FINAL MODEL (FULL DATASET) TO THE CACHE
cache.update(hooked_cache)
key_freqs = get_top_k_freqs(cache["W_E"].transpose(-2,-1), 5, -1, sumlist=[0])[0].sort()[0]
key_harmonics, key_subharmonics = get_harmonics(key_freqs), get_subharmonics(key_freqs)
tprint("key_freqs", key_freqs, "key_harmonics", key_harmonics, "key_subharmonics", key_subharmonics)

In [None]:
print("train loss", cross_entropy_high_precision(cache["hook_logits"][train_indices], labels[train_indices]))
print("test loss", cross_entropy_high_precision(cache["hook_logits"][test_indices], labels[test_indices]))

In [None]:
cache = squeeze_cache(cache)
ex_list = dataset.clone()
ex_list[:,2] = labels
ex_list = torch.cat((ex_list, pprange.unsqueeze(1).to(ex_list.device)), 1)

def get_idx_by_pos(pos):
    exs = ex_list.tolist()
    pos = [pos] if type(pos) is not list else pos
    for i, ps in reversed(list(enumerate(pos))):
        exs = sorted(exs, key=lambda x: x[0 if ps == 'a' else 1 if ps == 'b' else 2 if ps == 'c' else 3])
    return torch.tensor(exs, dtype=torch.long, device=device)[:,3].squeeze()

a_idx, b_idx, c_idx = get_idx_by_pos('a'), get_idx_by_pos('b'), get_idx_by_pos('c')
ac_idx, bc_idx, cb_idx = get_idx_by_pos(['a','c']), get_idx_by_pos(['b','c']), get_idx_by_pos(['c','b'])

def a_hook(x):
    hook_x = cache[x] if isinstance(x, str) else x
    if hook_x.size(0) == p*p: return inputs_last(hook_x)
    else:
        print("doing a_hook on an inputs_last activation, take a look")
        return torch.empty(0)

def inv_idx(idx):
    s = sorted(range(len(idx)), key=idx.__getitem__)
    return torch.tensor(s, dtype=torch.long, device=device) if isinstance(idx, torch.Tensor) else s
c_idx_inv = inv_idx(c_idx)

def c_hook(x):
    hook_x = cache[x] if isinstance(x, str) else x
    if hook_x.size(0) == p*p: return inputs_last(hook_x[c_idx])
    else:
        print("doing c_hook on an inputs_last activation, take a look")
        return torch.empty(0)

def c_hook_inv(hook_x):
    hook_x = hook_x.flatten(-2,-1)
    hook_x = torch.moveaxis(hook_x, (-1), (0))
    return hook_x[c_idx_inv]

def make_wave(freq, phase=0): return torch.cos(prange*freq*2*torch.pi/p + phase)

def print_loss_splits(logits, name=""):
    print(name + " loss", cross_entropy_high_precision(logits, labels).item())
    print(name + " train loss", cross_entropy_high_precision(logits[is_train], labels[is_train]).item())
    print(name + " test loss", cross_entropy_high_precision(logits[is_test], labels[is_test]).item())

### Basics

the problem is defined as modular addition of the tokens, in the form of (a + b) % p (113) = c\
the tokens are integer indices (0-112) and when they are embedded, they index a seperate length p sinusoid for every residual stream dimension.\
through training, the spectra of these sinusoids becomes sparse, with most of the magnitude being attributed to a set of key frequencies.\
below are the embeddings looking at just one dimension of the residual stream, token position on the x-axis.  shape: [ d_model, pos ]\
use the slider to scroll through the residual stream dimensions or click the "spacetime" button to toggle between spatial and fourier modes

In [None]:
showVector(torch.stack([sd["embed.W_E"] for sd in end_run_data["state_dicts"]])[...,:-1].squeeze())

the key here is that for each dimension in the model, the pos axis indexes to a spot on a sinusoidal periodic function

## Attention

in the attention layer, the model combines sines and cosines for a and b into one single cosine wave,\
centered midway between a and b.
\
all transformations prior to the attention output are linear.\
prior to this, positions a & b are accessing the same waves but indexing at phase locations a & b.\
there are more details on linear maps below, but for now, just know that any linear transformation preserves this structure.\
the structure of the waves changes when the a & b value vectors are weighted and summed into the final "=" token position.\
at this point, each wave in [cosa, sina, cosb, sinb] are combined into a single cosine wave for each answer "c".\
\
in the previous work the entire input dataset [p * p] was typically reshaped to [p,p] with axes ordered as [a,b] for analysis\
here the dataset is also reshaped to [p,p] but with inputs ordered as [c,a], which will reveal symmetries in the activations\
\
here are dataset examples in the form [a, b, c] for c = 0

In [None]:
for i in range(5): tprint("c = 0: example", i, ex_list[c_idx][i][:-1])

notice that for modular addition, as a increases, b decreases.\
this has the implication that the sinusoidal waves of each example are being indexed in reverse order of eachother.\
the same applies to other values of c, but with an offset applied to the "b" token.

In [None]:
for i in range(10): tprint("c = 6: example", i, ex_list[c_idx][p * 6 + i][:-1])

notice that in example 4 above, a and b are equal, because 3 + 3 = 6.\
also notice that in examples 2 & 4, a & b are swapped, as are examples 1 & 5, 0 & 6, etc.\
this spot is found at c/2, for all values of "c".\
since the value vectors are just a linear combination of the embeddings and are different spots on the same wave,\
when summed, a new wave is formed for each "c", and it is a sinusoid made entirely of cosines centered at c/2.\
these cosines can have a negative polarity, but they are always an even, symmetric function about c/2.\
\
below is the "=" token position of the final residual stream prior to unembedding.\
if you scroll the the c axis, note the symmetry present at c/2. axes are [d_model,c,a]

In [None]:
showVector(inputs_last(cache["resid_mid"][c_idx])[-1], start_play_axis=1, full_mode=True)

In [None]:
def construct_resid_post_model():
    """CONSTRUCT A MODEL OF THE FINAL RESIDUAL STREAM"""
    names = ["cos_wave", "pos a", "pos b", "pos c"]
    r_model = { n: torch.zeros_like(inputs_last(cache["resid_post"])[-1,:len(key_freqs)+1])[...,None].repeat(1,1,1,p) for n in names }
    p_diag = torch.eye(p).bool().to(device)
    for c in range(p):
        symmetry_points = [c//2, (c//2 + p//2 + 1) % p]
        symmetry_points.append(c//2 + 1 if c % 2 == 1 else (c//2 + p//2) % p)
        index = (p + c - prange[...,None]) % p

        for f, freq in enumerate(list(key_freqs)):
            r_model["cos_wave"][f][c] = torch.cos((prange - c/2) * freq * 2 * torch.pi / p)
            r_model["pos c"][f][c,...,symmetry_points] = r_model["cos_wave"][f][c,...,symmetry_points]
            r_model["pos a"][f][c][p_diag] = r_model["cos_wave"][f][c][p_diag]
            b_wave = torch.cos(((p + c - prange) % p - c/2) * freq * 2 * torch.pi / p)[...,None]
            r_model["pos b"][f][c].scatter_(1, index, b_wave)
            for name in r_model: r_model[name][-1][c] += r_model[name][f][c]

    sum_max = r_model["cos_wave"][-1].abs().max()
    for name in r_model: r_model[name][-1] /= sum_max

    return list(r_model.values())

r_model = construct_resid_post_model()

def get_cos_points(indices):
    c, a, b = indices[-3], indices[-2], (indices[-3] - indices[-2] + p) % p
    return [a, b, c/2, c/2 + p/2], ["a", "b", "c/2", "c/2 + p/2"]
def get_cos_message(indices):
    c, a, b = indices[-3], indices[-2], (indices[-3] - indices[-2] + p) % p
    return str(a) + " + " + str(b) + " % p = " + str(c)

def get_custom_points(indices): return [indices[-2]/2, indices[-2]/2 + p/2], ["c/2", "c/2 + p/2"]
def get_custom_message(indices): return "c = " + str(indices[-2]) + ", c/2 = " + str(indices[-2]/2) + ", c/2 + p/2 = " + str(indices[-2]/2 + p/2)

below is a handcrafted function of length p with a strange shape with little sinusoidal content (top).\
the middle is a reversed version of the same function and the bottom is their sum.\
to seperate them on the graph, a bias has been applied afterward to the top and bottom functions.\
if you scroll through axis 0, the reversed version of the function is shifted prior to the summation.\
notice that there are actually two symmetry points. one is at shift/2 and the other at shift/2 + p/2.\
this is the case for all activations in the model post-attention.

In [None]:
showVector(get_reversed_shifted_waves(make_custom_wave(device)), get_points=get_custom_points, get_message=get_custom_message)

below is an illustration of the structure of the activations simplified to unit magnitude cosines of the key frequencies.\
the shape is [frequency,c,a,b]. frequency includes a normalized sum in the last row.\
notice that if you scroll the c axis, a and b remain symmetrically placed around c/2.\
scrolling fully through the a axis, a and b wrap around, maintaining symmetry with respect to both c/2 and c/2 + p/2 in a circular fashion.\
note: this graph is in "full_mode" the slider/play axis can be changed, might be slow if in Colab, works much smoother locally.

In [None]:
showVector(r_model, start_play_axis=1, full_mode=True)



to get this done, the model forms cosine waves over the inputs centered at c/2\
there are a number (~5) of key frequencies (key_freqs) for which these waves form\
each residual stream dimension forms it's own composite wave made up of len(key_freqs) cosines with different amplitudes,/
but they are all symmetric about c/2.\
given that cosine is an even function and a and b are both equidistant from c/2, the model forms two symmetry points\
one is at c/2 and the other is at c/2 + p/2, allowing the modular addition to wrap around the inputs in a periodic fashion

there is more later that shows pattern weights for a & b are always equal, with (1 - their sum) going to the "="./


## Transformations of Waves

most of the computations done by the model are linear transformations.\
in this work and the previous work, we are analyzing activations of the entire dataset.\
it is important to understand how sinusoids indexed by token position are affected by linear transformations.\
\
a linear map, in this case, takes a number of different (but typically similar) waves and weights them.\
each individual weight can only do two operations. it can scale the input wave and/or negate the input wave.\
after the weights are applied all these scaled/negated waves are summed into a single output dimension.\
this process is done for every dimension of the output.\
\
this gives the model the ability to manipulate the shape of these waves and adjust the magnitude and phase of the spectra in certain ways.\
for instance, the model can boost a certain frequency by applying large, positive weights to the waves that include that frequency\
and are in-phase with eachother while negating the weights for waves that are out-of-phase.\
frequencies can be attenuated with this mechanism by negating waves that are in-phase with eachother such that they sum to a low absolute value.\
in this sense, a linear map can be seen as a set of multi-channel linear-phase filters, each with one coefficient that sum at their output.\
\
this is on full display in the MLP. the W_in weights with respect to each neuron effectively act as a narrowband filter.\
since each weight can only scale and/or negate, there is obvious learned coordination to filter these neurons by frequency.\
for each neuron, there is one input weight for each residual stream dimension. to filter selectively for a desired frequency and phase,\
the neuron assigns a weight that is a cross correlation measure between the desired wave and the content at that residual stream dimension.\
\
a claim in the previous work, was that the model was multiplying waves in the MLP.\
reality is a bit weirder...  there is, in fact, a bump in harmonics of the dominant frequency of each neuron post ReLU,\
but this a consequence of how discontinuities in the time/spatial domain affect the frequency domain.\
if you multiply 2 waves together, the result is a composite wave made up of the sum/difference frequencies.\
so for 2 waves (14Hz & 35Hz), their product will be the sum of 2 waves (49Hz & 21Hz).\
if the 2 waves are the same frequency (14Hz), the result will be the sum of 2 waves (28 Hz - sum, 0Hz - diff). the 0Hz wave is just a bias.\
\
if you have a composite wave made up of 14Hz & 35Hz, multiplying this times itself will yeild these frequencies (0, 21, 28, 43, 49).\
the 43Hz component is a consequence of the Shannon-Nyquist theorem, where frequencies above the Nyquist limit,\
which is half the sample rate (113/2 = 56.5), wrap back below the limit in a mirror image.  Since 70 (35+35) is 13.5 above Nyquist,\
the sampling process craetes an alias of the 70Hz wave at 43Hz, which is 56.5 - 13.5.\
the ReLU is performing a different operation to acquire sum and difference frequencies.\
if you start with sinusoidal content and induce a discontinuity somehow, like by zeroing out part of it or distorting it with a ReLU,\
an infinite series of harmonics plus sums and differences between all frequency components will creep into your spectrum.

below is a ReLU applied to a single 14Hz wave, starting fully above zero, falling into negative territory, progressively affected by the ReLU.\
notice that while scrolling, a series of harmonics (integer multiples of 14) that mirror around the spectrum grow steadily.\
also we have 14Hz + 35Hz and 14Hz * 35Hz graphs below.  - shape: [a,b]

In [None]:
def relu_demo(x, **kwargs): showVector(F.relu(x - (x.max() - x.min()) * prange[...,None]/p - x.min()), **kwargs)

wave14, wave35 = make_wave(14), make_wave(35)
relu_demo(wave14, start_gui_type="fourier")
relu_demo(wave14 + wave35, start_gui_type="fourier")
relu_demo(wave14 * wave35, start_gui_type="fourier")

below the loss is computed from the neurons, starting with the unaltered neurons ("hook_post").\
next, the sum and difference frequencies are eliminated from the neurons post-ReLU and the model generalizes better.\
last, all frequencies aside from the key_freqs are eliminated and the loss is even better...

In [None]:
sums_and_diffs = get_sum_and_difference_frequencies(key_freqs)
tprint("sums_and_diffs", sums_and_diffs, "length =", len(sums_and_diffs))

def remove_freqs_from_hook(hook_x, freqs):
    HOOK_X = torch.fft.fft(inputs_last(hook_x))
    HOOK_X[...,freqs] = HOOK_X[..., p - freqs] = 0
    return inputs_first(torch.fft.ifft(HOOK_X).real)

def test_neurons(hook_post, name):
    mlp_out = einops.einsum(hook_post, cache["W_out"], "batch pos d_mlp, d_mlp d_model -> batch pos d_model") + cache["b_out"]
    resid_post = cache["resid_mid"] + mlp_out
    logits = einops.einsum(resid_post, cache["W_U"], "batch pos d_model, d_model d_vocab -> batch pos d_vocab")
    print_loss_splits(logits, name)

test_neurons(cache["post"], "regular")
test_neurons(remove_freqs_from_hook(cache["post"], sums_and_diffs), "without trig product terms")
test_neurons(remove_freqs_from_hook(cache["post"], torch.tensor([i for i in range(1,p//2) if i not in key_freqs], device=device)), "only key_freqs")

In [None]:
def get_top_freqs_and_indices(hook_x, dim=-1, *, sumlist=[-2], freqs_allowed=key_freqs):
    freqs, mags = get_top_k_freqs(hook_x, 1, dim, sumlist=sumlist, freqs_allowed=freqs_allowed, squeeze=True)
    idx_dict = {}
    dim_range = torch.arange(freqs.numel(), device=freqs.device)
    for freq in freqs.unique().tolist(): idx_dict[freq] = dim_range[freqs == freq]
    freq_idx_sorted = freqs.sort()[1]
    return freqs, mags, idx_dict, freq_idx_sorted

neuron_freqs, neuron_mags, neuron_freq_idx, neuron_freq_idx_sorted = get_top_freqs_and_indices(inputs_last(cache["pre"])[-1])

test_freq = key_freqs[0].item()
num_neur = len(neuron_freq_idx[test_freq])
freq_pre = inputs_last(cache["pre"])[-1,neuron_freq_idx[test_freq]]

neuron_phases = torch.angle(torch.fft.fft(freq_pre[...,66,:])[...,test_freq])
neuron_amps = torch.fft.fft(freq_pre[...,66,:])[...,test_freq].real
neuron_phases_sorted, ordered_phase_idx = torch.sort(neuron_phases)#, descending=True)
tprint("neuron_phases_sorted", neuron_phases_sorted.shape, "ordered_phase_idx", ordered_phase_idx.shape, "mags", neuron_mags.shape, "amps", neuron_amps.shape)
tprint("sorted phases, amps", torch.cat((neuron_phases_sorted[...,None], neuron_amps[ordered_phase_idx][...,None]), -1))

mock_freq_pre, mock_phases = torch.zeros([num_neur,p,p], device=device), torch.arange(num_neur, device=device) * 2 * torch.pi/num_neur - torch.pi
wk = test_freq * 2 * torch.pi / p
for d in range(num_neur):
    mock_freq_pre[d] = torch.cos(mock_phases[d] - prange[...,None] * wk) + torch.cos(mock_phases[d] + prange[None] * wk)

# mock_freq_pre = mock_freq_pre[inv_idx(ordered_phase_idx)]

def running_sum_with_and_without_relu(pre, weights=cache["W_out"][neuron_freq_idx[test_freq],0]):
    non_relu_list = running_sum(pre,0, weights=weights, create_normalized_list=True)
    post = F.relu(pre)
    relu_list = running_sum(post,0, weights=weights, create_normalized_list=True)
    max_mag = relu_list[0].abs().max()
    return [relu_list[0] + max_mag, relu_list[1] + max_mag, non_relu_list[0] - max_mag, non_relu_list[1] - max_mag]


here we have the pre-activation for every neuron whose highest magnitude frequency is the first of the key_freqs.\
notice the movement of the waves as you scroll through the a axis.\
each neuron is approximately: cos$\omega_{neur}$a+$\phi_{neur}$ + cos$\omega_{neur}$b+$\phi_{neur}$  - shape: [neuron, a, b]

In [None]:
showVector(freq_pre[ordered_phase_idx], start_play_axis=0, full_mode=True)

here we have a synthetic tensor made to reflect the neurons above.  it is an idealized set of unit 2D waves,\
equally spread across all phases ordered and spaced from -$\pi$ to $\pi$.  - shape: [neuron, a, b]

In [None]:
showVector(mock_freq_pre)

below we use the handcrafted model above to demonstrate the affect of the ReLU on the neurons for this frequency.\
on the top of the graph are the ReLU activated mock neurons, the bottom are the mock neurons without the ReLU.\
the first axis scrolls through each mock neuron (red = pre-ReLU, magenta = post-ReLU).\
the cyan and orange are a running sum (cyan = pre-ReLU, orange = post-ReLU).\
since the plot starts at the last neuron index (43), cyan and orange are the sum of all neurons.\
notice that the ReLU cuts off the bottom of the pre-activated mock neuron (magenta),\
allowing the final a+b wave (cyan) to travel all the way from left to right as you index from 0 to 112.\
this is what gives the model the ability to dot with the same unembed weights from any a or b position.  - shape: [neuron, a, b]

In [None]:
showVector(running_sum_with_and_without_relu(mock_freq_pre, weights=None), start_indices=[43], start_play_axis=1, full_mode=True)

below are the actual neuron activations for this frequency.\
the structure is similar, but there is noise that doen't perfectly cancel the orange wave (w/o ReLU) - shape: [neuron, a, b]

In [None]:
showVector(running_sum_with_and_without_relu(freq_pre), start_indices=[43], start_play_axis=1, full_mode=True)

here we have the sum of all neurons associated with the test_freq after being mapped back to the residual stream.  - shape: [d_model, a, b]

In [None]:
showVector(einops.einsum(F.relu(freq_pre),cache["W_out"][neuron_freq_idx[test_freq]], "neur posa posb, neur d_model -> d_model posa posb") + cache["b_out"][...,None,None], start_indices=[43], start_play_axis=1)

next is the last position of the final residual stream with all frequencies zeroed except this one.\
notice the movement from right to left. this movement allows the examples to align with the answer "c" in W_U.  - shape: [d_model, a, b]

In [None]:
showVector(pull_out_freqs(inputs_last(cache["resid_post"])[-1], test_freq, dims=[-2,-1])[1], start_play_axis=-2, full_mode=True)

for reference, this that frequency pulled out before the MLP.  - shape: [d_model, a, b]

In [None]:
showVector(pull_out_freqs(inputs_last(cache["resid_mid"])[-1], test_freq, dims=[-2,-1])[1], start_play_axis=-2, full_mode=True)