<a href="https://colab.research.google.com/github/sadisticaudio/Modular_Attention/blob/main/Modular_Attention.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Transformer Modular Addition Through A Signal Processing Lens

this analysis is based on previous work by Neel Nanda et al, "Progress Measures for Grokking via Mechanistic Interpretability".\
in that work, a model learns to implement an algorithm for modular addition that generalizes to unseen data and exhibits grokking behaviour.\
the hypothesis here is that the model is not using multiplication of trig terms to perform a rotation to perform the task,\
but that symmetries required by the algorithm are formed in the attention layer that are responsible for the model's ability to generalize.

In [None]:
import os
import sys
import shutil

cwd = os.getcwd()

!git clone https://github.com/sadisticaudio/checkpaint.git full_checkpaint
!ls full_checkpaint/template/src
if cwd + '/full_checkpaint/template/src' not in sys.path: sys.path.append(cwd + '/full_checkpaint/template/src')

if 'google.colab' in sys.modules:
  # !pip install --upgrade ipywidgets
  # !pip install "ipywidgets>=7,<8"
  !pip install "ipywidgets=7.8.5"
  from google.colab import output
  output.enable_custom_widget_manager()
  !pip install transformer_lens pylinalg pythreejs


import torch
from torch.nn import functional as F
import numpy as np
import einops
from transformer_lens import HookedTransformer, HookedTransformerConfig

import checkpaint
from checkpaint.utils import *
from checkpaint.c_hooks import *
from checkpaint import line_plot as lp

model_paths = (cwd + '/full_checkpaint/end_run_data.pth', )#project_dir + '/full_run_data.pth',
               # project_dir + '/grokking_demo_new.pth')

model_num, p, d_model, d_mlp, n_heads, d_head = 0, 113, 128, 512, 4, 32
print("model_paths", model_paths, "model_paths[model_num]", model_paths[model_num])
device = "cuda" if torch.cuda.is_available() else "cpu"
full_run_data = torch.load(model_paths[model_num], map_location=torch.device(device))
tprint("full_run_data keys", full_run_data.keys())

prange, pprange = torch.arange(p).to(device), torch.arange(p*p).to(device)
is_train, is_test = get_old_indices(device=device) if model_num < 1 else get_new_indices(device=device)
train_indices, test_indices = pprange[is_train], pprange[is_test]
dataset, labels = get_data(device=device)

printCudaMemUsage()

In [None]:

!pip show ipywidgets

In [None]:
def showVector(x,**kwargs):
    x = [inputs_last(t) for t in x] if isinstance(x, list) else inputs_last(x)
    lp.draw_vector(x, **kwargs)

showVector(torch.randn([12,23,p]), full_mode=False)


In [None]:
cache = squeeze_cache(CacheDict(full_run_data["state_dicts"][-1]))

cfg = HookedTransformerConfig(
    n_layers = 1, n_heads = 4, d_model = 128, d_head = 32, d_mlp = 512, act_fn = "relu", normalization_type=None,
    d_vocab=p+1, d_vocab_out=p, n_ctx=3, init_weights=True, device=device, seed = 598,
)
model = HookedTransformer(cfg)
hooked_state_dict = model.state_dict()

for name, x in hooked_state_dict.items():
    if not name in cache and "b_" in name: cache[name] = torch.zeros_like(x)
    if name == "unembed.W_U": cache[name] = cache[name][:,:-1]
    if "mask" in name: cache[name] = x
    if name in cache and cache[name].shape != x.shape: cache[name] = cache[name].transpose(-2,-1)
    if "W_O" in name: cache[name] = cache[name].reshape(x.shape)
    if not name in cache and "IGNORE" in name: cache[name] = x

model.load_state_dict(cache)
cache["hook_logits"], hooked_cache = model.run_with_cache(dataset)
cache.update(hooked_cache)
key_freqs = get_top_k_freqs(cache["W_E"].transpose(-2,-1), 5, -1, sumlist=[0])[0].sort()[0]
key_harmonics, key_subharmonics = get_second_harmonics(key_freqs), get_second_subharmonics(key_freqs)
tprint("key_freqs", key_freqs, "key_harmonics", key_harmonics, "key_subharmonics", key_subharmonics)
printCudaMemUsage()

In [None]:
print("train loss", cross_entropy_high_precision(cache["hook_logits"][train_indices], labels[train_indices]))
print("test loss", cross_entropy_high_precision(cache["hook_logits"][test_indices], labels[test_indices]))

In [None]:
cache = squeeze_cache(cache)
ex_list = dataset.clone()
ex_list[:,2] = labels
ex_list = torch.cat((ex_list, pprange.unsqueeze(1).to(ex_list.device)), 1)

def get_idx_by_pos(pos):
    exs = ex_list.tolist()
    pos = [pos] if type(pos) is not list else pos
    for i, ps in reversed(list(enumerate(pos))):
        exs = sorted(exs, key=lambda x: x[0 if ps == 'a' else 1 if ps == 'b' else 2 if ps == 'c' else 3])
    return torch.tensor(exs, dtype=torch.long, device=device)[:,3].squeeze()

a_idx, b_idx, c_idx = get_idx_by_pos('a'), get_idx_by_pos('b'), get_idx_by_pos('c')
ac_idx, bc_idx, cb_idx = get_idx_by_pos(['a','c']), get_idx_by_pos(['b','c']), get_idx_by_pos(['c','b'])

def a_hook(x):
    hook_x = cache[x] if isinstance(x, str) else x
    if hook_x.size(0) == p*p: return inputs_last(hook_x)
    else:
        print("doing a_hook on an inputs_last activation, take a look")
        return torch.empty(0)

def inv_idx(idx):
    s = sorted(range(len(idx)), key=idx.__getitem__)
    return torch.tensor(s, dtype=torch.long, device=device) if isinstance(idx, torch.Tensor) else s
c_idx_inv = inv_idx(c_idx)

def c_hook(x):
    hook_x = cache[x] if isinstance(x, str) else x
    if hook_x.size(0) == p*p: return inputs_last(hook_x[c_idx])
    else:
        print("doing c_hook on an inputs_last activation, take a look")
        return torch.empty(0)

def c_hook_inv(hook_x):
    hook_x = hook_x.flatten(-2,-1)
    hook_x = torch.moveaxis(hook_x, (-1), (0))
    return hook_x[c_idx_inv]

printCudaMemUsage()

### Basics

the problem is defined as modular addition of the tokens, in the form of (a + b) % p (113) = c\
the tokens are integer indices (0-112) and when they are embedded, they index a seperate length p sinusoid for every residual stream dimension.\
through training, the spectra of these sinusoids becomes sparse, with most of the magnitude being attributed to a set of key frequencies.\
the axes in this plot are [ d_model, pos ]\
use the slider to scroll through the residual stream dimensions or click the "spacetime" button to toggle between spatial and fourier modes

In [None]:
showVector(torch.stack([sd["embed.W_E"] for sd in full_run_data["state_dicts"]])[...,:-1].squeeze())
all_embed_W_E = torch.stack([sd["embed.W_E"] for sd in full_run_data["state_dicts"]])[...,:-1]
total_cache, total_model, total_hook = 0, 0, 0
for name in cache.keys():
    total_cache += cache[name].numel() * 4
    if "hook" in name: total_hook += cache[name].numel() * 4
    if "hook" not in name: total_model += cache[name].numel() * 4

print("embed size", to_human_readable(all_embed_W_E.numel() * 4))
print("total size of cache", to_human_readable(total_cache))
print("total size of model", to_human_readable(total_model))
print("total size of hooks", to_human_readable(total_hook))
printCudaMemUsage()

the key here is that for each dimension in the model, the pos axis indexes to a spot on a sinusoidal periodic function

## Attention

all transformations prior to the attention output are linear.\
prior to this, positions a & b are accessing the same waves but indexing phase locations linearly by a & b.\
there are more details on linear maps below, but for now, just know that any linear transformation preserves this structure.\
the structure of the waves changes when the a & b value vectors are weighted and summed into the final "=" token position.\
at this point, each wave in [cosa, sina, cosb, sinb] are combined into a single cosine wave for each answer "c".\
\
in the previous work the entire input dataset [p * p] was typically reshaped to [p,p] with axes ordered as [a,b] for analysis\
here the dataset is also reshaped to [p,p] but with inputs ordered as [c,a], which will reveal symmetries in the activations\
\
here are dataset examples in the form [a, b, c] for c = 0

In [None]:
for i in range(5): tprint("c = 0: example", i, ex_list[c_idx][i][:-1])

notice that for modular addition, as a increases, b decreases.\
this has the implication that the sinusoidal waves of each example are being indexed in reverse order of eachother.\
the same applies to other values of c, but with an offset applied to the "b" token.

In [None]:
for i in range(10): tprint("c = 8: example", i, ex_list[c_idx][p * 8 + i][:-1])

notice that in example 4 above, a and b are equal, because 4 + 4 = 8.\
also notice that in examples 3 & 5, a & b are swapped, as are examples 2 & 6, 1 & 7, etc.\
this spot is found at c/2, for all values of "c".\
since the value vectors are just a linear combination of the embeddings and are different spots on the same wave,\
when summed, a new wave is formed for each "c", and it is a sinusoid made entirely of cosines centered at c/2.\
these cosines can have a negative polarity, but they are always an even, symmetric function about c/2.\
\
below is the "=" token position of the final residual stream prior to unembedding.\
if you scroll the the c axis, note the symmetry present at c/2. axes are [d_model,c,a]

In [None]:
showVector(inputs_last(cache["resid_post"][a_idx])[-1], start_play_axis=1, full_mode=False)

In [None]:
def construct_resid_post_model():
    """CONSTRUCT A MODEL OF THE FINAL RESIDUAL STREAM"""
    names = ["cos_wave", "pos a", "pos b", "pos c"]
    r_model = { n: torch.zeros_like(inputs_last(cache["resid_post"])[-1,:len(key_freqs)+1])[...,None].repeat(1,1,1,p) for n in names }
    p_diag = torch.eye(p).bool().to(device)
    for c in range(p):
        symmetry_points = [c//2, (c//2 + p//2 + 1) % p]
        symmetry_points.append(c//2 + 1 if c % 2 == 1 else (c//2 + p//2) % p)
        index = (p + c - prange[...,None]) % p

        for f, freq in enumerate(list(key_freqs)):
            r_model["cos_wave"][f][c] = torch.cos((prange - c/2) * freq * 2 * torch.pi / p)
            r_model["pos c"][f][c,...,symmetry_points] = r_model["cos_wave"][f][c,...,symmetry_points]
            r_model["pos a"][f][c][p_diag] = r_model["cos_wave"][f][c][p_diag]
            b_wave = torch.cos(((p + c - prange) % p - c/2) * freq * 2 * torch.pi / p)[...,None]
            r_model["pos b"][f][c].scatter_(1, index, b_wave)
            for name in r_model: r_model[name][-1][c] += r_model[name][f][c]

    sum_max = r_model["cos_wave"][-1].abs().max()
    for name in r_model: r_model[name][-1] /= sum_max

    return list(r_model.values())

r_model = construct_resid_post_model()

def get_cos_points(indices):
    c, a, b = indices[-3], indices[-2], (indices[-3] - indices[-2] + p) % p
    return [a, b, c/2, c/2 + p/2], ["a", "b", "c/2", "c/2 + p/2"]
def get_cos_message(indices):
    c, a, b = indices[-3], indices[-2], (indices[-3] - indices[-2] + p) % p
    return str(a) + " + " + str(b) + " % p = " + str(c)

def get_custom_points(indices): return [indices[-2]/2, indices[-2]/2 + p/2], ["c/2", "c/2 + p/2"]
def get_custom_message(indices): return "c = " + str(indices[-2]) + ", c/2 = " + str(indices[-2]/2) + ", c/2 + p/2 = " + str(indices[-2]/2 + p/2)

below is a handcrafted function of length p with a strange shape with little sinusoidal content (top).\
the middle is a reversed version of the same function and the bottom is their sum.\
to seperate them on the graph, a bias has been applied afterward to the top and bottom functions.\
if you scroll through axis 0, the reversed version of the function is shifted prior to the summation.\
notice that there are actually two symmetry points. one is at shift/2 and the other at shift/2 + p/2.\
this is the case for all activations in the model post-attention.

In [None]:
showVector(get_reversed_shifted_waves(make_custom_wave(device)), get_points=get_custom_points, get_message=get_custom_message)

below is an illustration of the structure of the activations simplified to unit magnitude cosines of the key frequencies.\
the shape is [frequency,c,a,b]. frequency includes a normalized sum in the last row.\
notice that if you scroll the c axis, a and b remain symmetrically placed around c/2.\
scrolling fully through the a axis, a and b wrap around, maintaining symmetry with respect to both c/2 and c/2 + p/2 in a circular fashion

In [None]:
showVector(r_model, start_play_axis=1, get_points=get_cos_points, get_message=get_cos_message)



to get this done, the model forms cosine waves over the inputs centered at c/2\
there are a number (~5) of key frequencies (key_freqs) for which these waves form\
each residual stream dimension forms it's own composite wave made up of len(key_freqs) cosines with different amplitudes,/
but they are all symmetric about c/2.\
given that cosine is an even function and a and b are both equidistant from c/2, the model forms two symmetry points\
one is at c/2 and the other is at c/2 + p/2, allowing the modular addition to wrap around the inputs in a periodic fashion

there is more later that shows pattern weights for a & b are always equal with (1 - their sum) going to the "="./


## Transformations of Waves

most of the computations done by the model are linear transformation.\
in this work and the previous work, we are analyzing activations of the entire dataset./
it is important to understand how sinusoids indexed by token position are affected by linear transformations.\

a linear map, in this case, takes a number of different (but typically similar) waves and weights them.\
each individual weight can only do two operations. it can scale the input wave and/or negate the input wave.\
after the weights are applied all these scaled/negated waves are summed into a single output dimension.\
this process is done for every dimension of the output.\
\
this gives the model the ability to manipulate the shape of these waves and adjust the magnitude and phase of the spectra in certain ways.\
for instance, the model can increase a certain frequency by applying large, positive weights to the waves that include that frequency\
and are in-phase with eachother while negating the weights for waves that are out-of-phase.\
frequencies can be attenuated with this mechanism by only negating waves that are in-phase
in this sense, a linear map can be seen as a set of multi-channel linear-phase filters, each with one coefficient that sum at their output.\
\
this is on full display in the MLP. the W_in weights with respect to each neuron effectively act as a narrowband filter.\
since each weight can only scale and/or negate, there is obvious learned coordination to filter these neurons by frequency.\
for each neuron, there is one input weight for each residual stream dimension. to filter selectively for a desired frequency and phase,\
the neuron assigns a weight that is a cross correlation measure between the desired wave and the content at that residual stream dimension.\

that supply each neuron with a wave of a single frequency (one of the key_freqs) and phase.\
this gives the model a palette of waves for each of the 5 key frequencies, with a select span of phases to draw from for the MLP output.

here you can view a simplified synthetic model of the residual stream. the axes are [ frequency, c, a, b ]\
the d_model axis is reduced and the first axis splits the key_freqs up with the last element being a normalized sum\
if you scroll through the c or a axes, a and b remain located at circularly equidistant positions about both c/2 and c/2 + p/2\
also, the cosine wave remains centered at c/2 at any indexing

this same scheme can be plotted for any activation starting with the attention output and the symmetries will remain

the only difference between the above model and reality, is that for each "c", the final residual stream \
has a weight and bias applied to each of the cosine waves, but the symmetry and phase remain identical\
\
here is the actual final residual stream. the axes are [d_model,c,a]\
scrolling through the c axis, you'll see that as c increases, symmetries remain at c/2 and c/2 + p/2\
the x axis represents "a", when you scroll "c", element 0 is always the same, mirrored with element at index c, since 0 + c = c\
this means that c = b, in this case

In [None]:
def get_top_freqs_and_indices(hook_x, dim=-1, *, sumlist=[-2], freqs_allowed=key_freqs):
    freqs, mags = get_top_k_freqs(hook_x, 1, dim, sumlist=sumlist, freqs_allowed=freqs_allowed, squeeze=True)
    idx_dict = {}
    dim_range = torch.arange(freqs.numel(), device=freqs.device)
    for freq in freqs.unique().tolist(): idx_dict[freq] = dim_range[freqs == freq]
    freq_idx_sorted = freqs.sort()[1]
    return freqs, mags, idx_dict, freq_idx_sorted

neuron_freqs, _, neuron_freq_idx, neuron_freq_idx_sorted = get_top_freqs_and_indices(inputs_last(cache["pre"])[-1])

test_freq = key_freqs[0].item()
num_neur = len(neuron_freq_idx[test_freq])
freq_pre = inputs_last(cache["pre"])[-1,neuron_freq_idx[test_freq]]

neuron_phases = torch.angle(torch.fft.fft(freq_pre[...,0,:])[...,test_freq])
_, ordered_phase_idx = torch.sort(neuron_phases)#, descending=True)

# freq_pre = freq_pre[ordered_phase_idx]

mock_freq_pre, mock_phases = torch.zeros([num_neur,p,p], device=device), torch.arange(num_neur, device=device) * 2 * torch.pi/num_neur - torch.pi
wk = test_freq * 2 * torch.pi / p
for d in range(num_neur):
    mock_freq_pre[d] = torch.cos(neuron_phases[d] + prange[...,None] * wk) + torch.cos(neuron_phases[d] + prange[None] * wk)

# mock_freq_pre = mock_freq_pre[inv_idx(ordered_phase_idx)]

def running_sum_with_and_without_relu(pre, weights=cache["W_out"][neuron_freq_idx[test_freq],0]):
    non_relu_list = running_sum(pre,0, weights=weights, create_normalized_list=True)
    post = F.relu(pre)
    relu_list = running_sum(post,0, weights=weights, create_normalized_list=True)
    max_mag = relu_list[0].abs().max()
    return [relu_list[0] + max_mag, relu_list[1] + max_mag, non_relu_list[0] - max_mag, non_relu_list[1] - max_mag]


here we have the pre-activation for every neuron whose highest magnitude frequency is the first of the key_freqs.\
notice the movement of the waves as you scroll through the a axis.\
each neuron is approximately: cos$\omega_{neur}$a+$\phi_{neur}$ + cos$\omega_{neur}$b+$\phi_{neur}$  - shape: [neuron, a, b]

In [None]:
showVector(freq_pre, start_play_axis=1)

here we have a synthetic tensor made to reflect the neurons above.  it is an idealized set of unit 2D waves,\
equally spread across all phases ordered and spaced from -$\pi$ to $\pi$.  - shape: [neuron, a, b]

In [None]:
showVector(mock_freq_pre)

below we use the handcrafted model above to demonstrate the affect of the rELU on the neurons for this frequency.\
on the top of the graph are the rELU activated mock neurons, the bottom are the mock neurons without the rELU.\
the first axis scrolls through each mock neuron (orange = pre-relu, magenta = post-relu).\
the cyan and red are a running sum (cyan = pre-relu, red = post-relu).\
the last neuron element will illustrate the sum of all neurons.\
notice that the relu cuts off the bottom of the pre-activated mock neuron,\
allowing the rectified wave to travel all the way from left to right from a = 0 to a = 112.\
this is what gives the model the ability to dot with the unembed matrix from any a or b position.  - shape: [neuron, a, b]

In [None]:
showVector(running_sum_with_and_without_relu(mock_freq_pre, weights=None), start_indices=[43], start_play_axis=1)

below are the actual neuron activations for this frequency.  - shape: [neuron, a, b]

In [None]:
showVector(running_sum_with_and_without_relu(freq_pre), start_indices=[43], start_play_axis=1)

In [None]:
tprint(cache["W_out"][neuron_freq_idx[test_freq],0].sort()[0])

In [None]:
# showVector(running_sum_with_and_without_relu(inputs_last(cache["pre"])[-1,neuron_freq_idx[test_freq]] * cache["W_out"][neuron_freq_idx[test_freq],4][...,None,None] + cache["b_out"][4,None,None]), start_indices=[43], start_play_axis=1)
showVector(einops.einsum(inputs_last(F.relu(cache["pre"]))[-1,neuron_freq_idx[test_freq]],cache["W_out"][neuron_freq_idx[test_freq]], "neur posa posb, neur d_model -> d_model posa posb") + cache["b_out"][...,None,None], start_indices=[43], start_play_axis=1)

next is the last position of the final residual stream with all frequencies zeroed except this one.  - shape: [d_model, a, b]

In [None]:
showVector(pull_out_freqs(inputs_last(cache["mlp_out"])[-1], test_freq, dims=[-2,-1]), start_indices=[1], start_play_axis=2)

In [None]:
print(torch.__version__)
print(torch.version.cuda)

In [None]:



# check for magnitude of second order frequencies vs key frequencies
def get_second_order_freq_proportion(input):
    x = inputs_last(input)
    XA, XB = torch.fft.fft(x), torch.fft.fft(x, x.size(-2), -2)
    XA, XAH, XB, XBH = XA[...,key_freqs], XA[...,key_harmonics], XB[...,key_freqs,:], XB[...,key_harmonics,:]
    mags1a, mags2a, mags1b, mags2b = XA.abs() , XAH.abs(), XB.abs() , XBH.abs()
    other_dims_a = tuple([i for i in range(x.ndim) if i != x.ndim - 1])
    other_dims_b = tuple([i for i in range(x.ndim) if i != x.ndim - 2])
    proportion = (mags2a.sum(other_dims_a) / mags1a.sum(other_dims_a) + mags2b.sum(other_dims_b) / mags1b.sum(other_dims_b))/2
    # tprint(name, other_dims, proportion)
    return proportion

def get_dims_by_second_order_freq_proportion(input, freq_idx):
    freq, harmonic = key_freqs[freq_idx], key_harmonics[freq_idx]
    x = inputs_last(input)
    XA, XB = torch.fft.fft(x), torch.fft.fft(x, x.size(-2), -2)
    XA, XAH, XB, XBH = XA[...,freq], XA[...,harmonic], XB[...,freq,:], XB[...,harmonic,:]
    mags1a, mags2a, mags1b, mags2b = XA.abs() , XAH.abs(), XB.abs() , XBH.abs()
    other_dims_a = tuple([i for i in range(XA.ndim) if i > 0 and i != XA.ndim - 1])
    other_dims_b = tuple([i for i in range(x.ndim) if i > 0 and i != x.ndim - 2])
    proportion = (mags2a.sum(-1) / mags1a.sum(-1) + mags2a.sum(-1) / mags1a.sum(-1))/2
    tprint("freq", freq, "harmonic", harmonic, "proportion", proportion.shape)
    mags, idx = proportion.sort(0, True)
    tprint("mags[:10]", mags[:10], "idx[:10]", idx[:10])
    # tprint(name, other_dims, proportion)
    return idx

tprint("key_freqs", key_freqs)
tprint("key_harms", key_harmonics)

# for hook_name in ["resid_pre", "pattern", "resid_mid", "mlp_out", "resid_post"]:
#     tprint(hook_name, get_second_order_freq_proportion(cache[hook_name], hook_name))

tprint("resid_pre", get_second_order_freq_proportion(inputs_last(cache["resid_pre"])[:2]))
tprint("attn_out", get_second_order_freq_proportion(inputs_last(cache["attn_out"])[-1]))
tprint("resid_mid", get_second_order_freq_proportion(inputs_last(cache["resid_mid"])[-1]))
tprint("pre", get_second_order_freq_proportion(inputs_last(cache["pre"])[-1]))
# tprint("post", get_second_order_freq_proportion(inputs_last(cache["post"])[-1]))
tprint("mlp_out", get_second_order_freq_proportion(inputs_last(cache["mlp_out"])[-1]))
tprint("resid_post", get_second_order_freq_proportion(inputs_last(cache["resid_post"])[-1]))
second_idx = get_dims_by_second_order_freq_proportion(inputs_last(cache["post"])[-1,neuron_freqs == 35], 1)


In [None]:


mlp_out_freqs = torch.zeros([p*p,len(key_freqs) + 1,d_model], device=device)
mlp_out_freqs[:,-1] += cache["b_out"][None]
# for f, freq in enumerate(key_freqs):
#     mlp_out_freq = einops.einsum(cache["post"][:,-1,neuron_freqs == freq], cache["W_out"][neuron_freqs == freq], "b m, m d -> b d")
#     mlp_out_freqs[:,f] = mlp_out_freq
#     mlp_out_freqs[:,-1] += mlp_out_freq

rslist = running_sum(cache["post"][a_idx][:,-1,neuron_freq_idx_sorted] * cache["W_out"][neuron_freq_idx_sorted,0],-1, create_normalized_list=True)
rslistpre = running_sum(cache["pre"][a_idx][:,-1,neuron_freq_idx_sorted] * cache["W_out"][neuron_freq_idx_sorted,0],-1, create_normalized_list=True)
tprint("neuron_freqs", neuron_freqs[:10], "neuron_freq_idx_sorted", neuron_freq_idx_sorted[:10])
rslist[0] += 5
rslist[1] += 5
rslistpre[0] -= 5
rslistpre[1] -= 5
# showVector(c_hook("post"))
# showVector(rslist + rslistpre)
# showVector(inputs_last(cache["mlp_out"])[-1])
# showVector(inputs_last(mlp_out_freqs))
# analyzing freq 14 neuron phases

freq = 14
hook_post = inputs_last(cache["post"][c_idx])[-1,neuron_freqs == freq]
hook_post_mags, hook_post_phases = torch.abs(torch.fft.fft(hook_post)[...,freq]), torch.angle(torch.fft.fft(hook_post)[...,freq])
clast, cmin, cmax = torch.zeros_like(hook_post_phases[:10,0]), torch.zeros_like(hook_post_phases[:10,0]), torch.zeros_like(hook_post_phases[:10,0])
tprint("clast", clast.shape)
def get_phase_distance(p1, p2):
    pmin, pmax = torch.minimum(p1,p2), torch.maximum(p1,p2)
    pdiff = pmax - pmin
    return torch.where(pdiff > torch.pi, pmin + 2 * torch.pi - pmax, pdiff)


clast = hook_post_phases[:,0]
for c in range(1,100):
    # for d in range(hook_post.size(0)):
    W_out_polarity = cache["W_out"][neuron_freqs == freq,0] > 0
    round_phases, counts = torch.round(hook_post_phases[:,c], decimals=1).unique(return_counts=True)
    # tprint(c, "rounded phases", round_phases, counts)
    dist = get_phase_distance(clast, hook_post_phases[:,c])
    clast = hook_post_phases[:,c]
    # tprint("c", c, "dist", torch.nonzero(dist > .6), "phases", hook_post_phases[torch.nonzero(dist > 0.6),c])
    the_mode = torch.mode(torch.round(hook_post_phases[:,c], decimals=1))
    # tprint("c", c, "polarity/phases", torch.cat((W_out_polarity[:10,None],hook_post_phases[:10,c:c+1]), -1))

# hook_pre = einops.einsum(cache["resid_mid"], cache["W_in"], "batch pos d_model, d_model d_mlp -> batch pos d_mlp") + cache["b_in"]
# tprint("pre close", torch.allclose(hook_pre, cache["pre"]))
# hook_post = F.relu(hook_pre)
# non_relu_mlp_out = einops.einsum(cache["pre"], cache["W_out"], "batch pos d_mlp, d_mlp d_model -> batch pos d_model") + cache["b_out"]
# non_relu_mlp_out2[non_relu_mlp_out2 > 0] = 0
hook_post = F.relu(cache["pre"].clone())
c_post = inputs_last(hook_post)[-1]
# for d in range((neuron_freqs == freq).size(0)):
#     neuron_outputs = cache["W_out"][neuron_freqs == freq]
#     tprint(d, "phase", hook_post_phases[d][0], "outputs", neuron_outputs[d,0])
# mlp_out = einops.einsum(hook_post, cache["W_out"], "batch pos d_mlp, d_mlp d_model -> batch pos d_model") + cache["b_out"]
# non_relu_mlp_out2 = einops.einsum(cache["pre"].clone(), cache["W_out"], "batch pos d_mlp, d_mlp d_model -> batch pos d_model") + cache["b_out"]
# mlp_out = inputs_last(mlp_out[a_idx])[-1]
# non_relu_mlp_out2 = inputs_last(non_relu_mlp_out2[a_idx])[-1]
# # non_relu_mlp_out = pull_out_freqs(non_relu_mlp_out, key_freqs)
# # showVector([mlp_out, non_relu_mlp_out2])
# # showVector(inputs_last(cache["mlp_out"][a_idx])[-1])

# fake_mlp = torch.zeros_like(non_relu_mlp_out[0])
# fake_mlp = F.relu(torch.cos(10 + (prange[...,None] + prange[None]) * 14 * 2 * torch.pi / p) + torch.cos(10 +(prange[...,None] - prange[None]) * 14 * 2 * torch.pi / p))
# fake_mlp += F.relu(torch.cos(11 + (prange[...,None] + prange[None]) * 14 * 2 * torch.pi / p) + torch.cos(11 +(prange[...,None] - prange[None]) * 14 * 2 * torch.pi / p))
# fake_mlp += F.relu(torch.cos(12 + (prange[...,None] + prange[None]) * 14 * 2 * torch.pi / p) + torch.cos(12 +(prange[...,None] - prange[None]) * 14 * 2 * torch.pi / p))
# fake_mlp[fake_mlp < 0] = 0
# fake_mlp = F.relu(fake_mlp)
# showVector(fake_mlp)

# hook_post_mags, hook_post_mag_idx = hook_post_mags.sort(-1, True)

# tprint("hook 14 neuron phases", hook_post_phases.sort()[0])
# ccc = hook_post_phases
# cccmags = hook_post_mags
# showVector([ccc[hook_post_mag_idx][None], cccmags[None]/30])
# tprint("central phase", get_central_phase(hook_post[...,0,:], freq))
# showVector(cccmags[None])

# tprint("b_in min max", cache["b_in"].min(), cache["b_in"].max())
# tprint("b_out min max", cache["b_out"].min(), cache["b_out"].max())


In [None]:
showVector(inputs_last(cache["mlp_out"])[-1])
showVector(zero_out_freqs(inputs_last(cache["mlp_out"])[-1], [2,3,4,5,6,7,8,9,10,11,12,13,27,28,29,30,31,32,33,34,36,37,38,39,40,43,44,45,55,56]))

All linear maps can be viewed in the lens of FIR (Finite Impulse Response) filters.  Specifically, the weight at each output dimension of a linear map can be seen as a filter that takes in a multi-channel input signal and applies a seperate one-point FIR filter to each channel before summing these channels to produce its output.  Though the length of this filter is one, the width (number of channels) of the filter is the dimensionality of the vector space.\
\
Through this abstract lens, the input signals (one for each dimension) are not a function of time, but a function of the index into the input axis of the embedding matrix's vector space.  For GPT2, this is less relavent, but for this model, the input index ordering is very meaningful.\
\
For each model dimension of this model (d_model, d_mlp, d_head...) , the activations for each input represent the phase of a wave, each of length 2$\pi.$  During the forward pass, these waves are seen in the activations and originate in the embeddings and were learned during training through back-propagation.\
\
The

the rest of this analysis will focus on how the attention mechanism produces these symmetries...

to understand how the model performs modular addition, the best place to start is with the value vectors,\
which is the original residual stream with a "W_V" map applied.\
if you toggle the first (position) axis below and look carefully, you'll notice that positions "a" and "b" are reversed versions of eachother.\
if you scroll through the "c" axis, you'll see the same thing only circularly shifted by c\
shape is [pos,head,d_head,c,a]

In [None]:
showVector(inputs_last(cache["v"][c_idx]))
showVector(inputs_last(cache["v"][a_idx]))

In [None]:
for name in cache: tprint(name, cache[name].shape)

In [None]:
# showVector(cache["unembed.W_U"])
wu = cache["unembed.W_U"][...,None].expand(d_model,p,p)
wu = wu / wu.var().sqrt()
tprint("wu", wu.shape, "var", wu.var())

def show_wu_act(name):
    resid = inputs_last(cache[name][a_idx])
    # resid_mid = c_hook("resid_mid")
    resid = resid / resid.var().sqrt()
    tprint(name, resid.shape, "var", resid.var())
    showVector([resid[-1],wu[None].expand_as(resid)[-1]])
    # showVector(resid[-1])

def show_attn_act(name):
    resid = inputs_last(cache[name][a_idx])
    # resid_mid = c_hook("resid_mid")
    resid = resid / resid.var().sqrt()
    tprint(name, resid.shape, "var", resid.var())
    showVector(resid)


show_wu_act("resid_mid")
show_wu_act("attn_out")
show_attn_act("z")
show_wu_act("resid_post")

# resid_mid = inputs_last(cache["resid_mid"])
# # resid_mid = c_hook("resid_mid")
# resid_mid = resid_mid / resid_mid.var().sqrt()
# tprint("resid_mid", resid_mid.shape, "var", resid_mid.var())
# showVector([wu[None].expand_as(resid_mid),resid_mid])

# resid_post = inputs_last(cache["resid_post"])
# # resid_post = c_hook("resid_post")
# resid_post = resid_post / resid_post.var().sqrt()
# tprint("resid_post", resid_post.shape, "var", resid_post.var())
# showVector([wu[None].expand_as(resid_mid),resid_post])

In [None]:
x = cache["resid_pre"]
W_Q, W_K, W_V = cache["W_Q"], cache["W_K"], cache["W_V"]
q = einops.einsum(W_Q, x, "i d h, b p d -> b p i h")# [nh C hs] @ [B T C] = [B T nh hs]
k = einops.einsum(W_K, x, "i d h, b p d -> b p i h")# [nh C hs] @ [B T C] = [B T nh hs]
v = einops.einsum(W_V, x, "i d h, b p d -> b p i h")# [nh C hs] @ [B T C] = [B T nh hs]
q,k,v = cache["q"], cache["k"], cache["v"]
# showVector(torch.stack((c_hook(v), c_hook(cache["v"])/30.0)))
q = einops.rearrange(q, "b p i h -> b p i h")
k = einops.rearrange(k, "b p i h -> b p i h")
v = einops.rearrange(v, "b p i h -> b p i h")
q = q.transpose(1,2)# [B T nh hs] -> [B nh T hs]
k = k.transpose(1,2).transpose(-2, -1)# [B T nh hs] -> [B nh T hs] -> [B nh hs T]
v = v.transpose(1,2)

attn_scores = torch.matmul(q, k)/np.sqrt(d_head)# [B nh T hs] @ [B nh hs T] = [B nh T T]
attn_scores_masked = attn_scores + torch.full([1, 3, 3], -torch.inf).triu(1).to(k.device)# [B nh T T]
pattern = F.softmax(attn_scores_masked, -1)# [B nh T T]
z = einops.einsum(v, pattern, "batch mhead k_pos d_head, batch mhead q_pos k_pos -> batch mhead q_pos d_head")
hook_z = einops.rearrange(z, "batch head q_pos d_head -> batch q_pos head d_head", head=n_heads)
out = einops.einsum(hook_z, cache["W_O"], "batch pos head d_head, head d_head d_model -> batch pos d_model")

resid_mid = out + cache["resid_pre"]
W_in, W_out, b_in, b_out = cache["W_in"], cache["W_out"], cache["b_in"], cache["b_out"]
hook_pre = einops.einsum(resid_mid, W_in, "batch pos d_model, d_model d_mlp -> batch pos d_mlp") + b_in
hook_post = F.relu(hook_pre)
mlp_out = einops.einsum(hook_post, W_out, "batch pos d_mlp, d_mlp d_model -> batch pos d_model") + b_out
resid_post = resid_mid + mlp_out


In [None]:
c_post = c_hook("post")[-1,:,0]
C_POST = torch.fft.fft(c_post)/p
C_MAGS = C_POST[...,:p//2+1].abs()
C_MAG_SUM = C_MAGS.sum(0)
tprint("C_MAG_SUM top 20", torch.cat((C_MAG_SUM.topk(20)[1][...,None], C_MAG_SUM.topk(20)[0][...,None]), -1))
post_dc = (2 * C_MAGS[:,:1]).clone()
C_MAGS[...,0] = 0
post_mags, post_freqs = C_MAGS.topk(2)
tprint("post_freqs shape", post_freqs.shape)
row_indices = torch.arange(d_mlp, device=device).unsqueeze(1).expand(-1, 2)
post_phases = torch.angle(C_POST[row_indices, post_freqs])
post_amps = C_POST[row_indices, post_freqs].real

freq_dict, best_freq_dict, mag_dict, best_mag_dict = {}, {}, {}, {}
best_mags, all_mags = torch.zeros([p], device=device), torch.zeros([p], device=device)
for d in range(d_mlp):
    if post_freqs[d,0].item() not in best_freq_dict: freq_dict[post_freqs[d,0].item()] = best_freq_dict[post_freqs[d,0].item()] = 0
    if post_freqs[d,1].item() not in freq_dict: freq_dict[post_freqs[d,1].item()] = 0
    best_freq_dict[post_freqs[d,0].item()] = best_freq_dict[post_freqs[d,0].item()] + 1
    best_mags[post_freqs[d,0].item()] += post_mags[d,0].item()
    all_mags[post_freqs[d,0].item()] += post_mags[d,0].item()
    all_mags[post_freqs[d,1].item()] += post_mags[d,1].item()
    freq_dict[post_freqs[d,0].item()] = freq_dict[post_freqs[d,0].item()] + 1
    freq_dict[post_freqs[d,1].item()] = freq_dict[post_freqs[d,1].item()] + 1

# for f in freq_dict: tprint("freq_dict freq", f, "total", freq_dict[f])
# for f in best_freq_dict: tprint("best_freq_dict freq", f, "total", best_freq_dict[f], "best mag", best_mags[f], "total mag", all_mags[f], "mag accum sum", C_MAG_SUM[f])

cos_weights = torch.zeros([d_model,p//2+1], device=device)
row_idx = torch.arange(d_model, device=device).unsqueeze(1).expand(d_model, 2)

# for d in range(d_mlp):
#     if d < 5:
#         tprint("cos_weights", cos_weights.shape, "post_amps", post_amps.shape, "W_out", W_out.shape)
#         tprint("cos_weights indexed shape", cos_weights[row_idx, post_freqs[d]].shape)
#         tprint("post_freqs", post_freqs[d], "post_amps", post_amps[d], "first 5 dims", W_out[d,:5])
#         tprint("dim", d, "adding this to top 2 freqs", einops.einsum(post_amps[d], W_out[d], "n, d_model -> d_model n")[:5])
#     cos_weights[row_idx, post_freqs[d]] += einops.einsum(post_amps[d], W_out[d], "n, d_model -> d_model n")

# for d in range(5):
#     tprint("d", d, cos_weights[d,key_freqs])
#     tprint("d", d, torch.fft.fft(c_hook("mlp_out")[-1,:,0]).real[d,key_freqs])

# for d in range(20):
#     tprint(d, "freq", post_freqs[d,0], "mag", post_mags[d,0], "phase", post_phases[d,0], "amp", post_amps[d,0], "cos", post_cos[d,0], post_amps[d,0] == post_cos[d,0])
#     tprint(d, "freq", post_freqs[d,1], "mag", post_mags[d,1], "phase", post_phases[d,1], "amp", post_amps[d,1], "cos", post_cos[d,1], post_amps[d,1] == post_cos[d,1])

# tprint("mags", post_mags[:5])
# tprint("freqs", post_freqs[:5])

resid_mid = out + cache["resid_pre"]
W_in, W_out, b_in, b_out = cache["W_in"], cache["W_out"], cache["b_in"], cache["b_out"]
hook_pre = einops.einsum(resid_mid, W_in, "batch pos d_model, d_model d_mlp -> batch pos d_mlp") + b_in
hook_post = F.relu(hook_pre)
mlp_out = einops.einsum(hook_post, W_out, "batch pos d_mlp, d_mlp d_model -> batch pos d_model") + b_out

def do_spectral_matmul(x, W, freqs, dim):
    # example: x = hook_post [12769, 3, 512], W = W_out [512, 128]
    x = torch.moveaxis(c_hook(x)[-1], 0, -1)
    tprint("x", x.shape)
    M, N = x.size(-2), W.size(-1)
    if x.size(-1) != W.size(-2):
        tprint("you can't matmul these matrices", "x =", x.shape, "W =", W.shape, "...", x.size(-1), "!=", W.size(-2))

    # branch 1 = regular matmul
    ret1 = x @ W

    # branch 2 - spectral way
    X = torch.fft.fft(x)[freqs]
    ret2 = torch.fft.ifft(X @ torch.complex(W, torch.zeros_like(W))).real
    showVector(ret1)
    showVector(ret2)
    return ret1
    # cos_weights = torch.zeros([W.size(-1),len(freqs)], device=device)
    # row_idx = torch.arange(W.size(-1), device=device).unsqueeze(1).expand(W.size(-1), len(freqs))
    # cos_weights[row_idx, freqs.expand_as(cos_weights)] += einops.einsum(X, W, "n, d_model -> d_model n")
    # c_post = c_hook("post")[-1,:,0]
    # C_POST = torch.fft.fft(c_post)/p
    # C_MAGS = C_POST[...,:p//2+1].abs()
    # C_MAG_SUM = C_MAGS.sum(0)
    # tprint("C_MAG_SUM top 20", torch.cat((C_MAG_SUM.topk(20)[1][...,None], C_MAG_SUM.topk(20)[0][...,None]), -1))
    # post_dc = (2 * C_MAGS[:,:1]).clone()
    # C_MAGS[...,0] = 0
    # post_mags, post_freqs = C_MAGS.topk(2)
    # tprint("post_freqs shape", post_freqs.shape)
    # row_indices = torch.arange(d_mlp, device=device).unsqueeze(1).expand(-1, 2)
    # post_phases = torch.angle(C_POST[row_indices, post_freqs])

# do_spectral_matmul(hook_post, W_out, key_freqs)

In [None]:
showVector(c_hook("mlp_out")[-1])
printCudaMemUsage()

In [None]:
# hooks_post = metrix.get_activations_slice(dataset, ["blocks.0.mlp.hook_post"], [-1,0])
tprint('hooks_post["blocks.0.mlp.hook_post"].shape', hooks_post["blocks.0.mlp.hook_post"].shape)
showVector(inputs_last(hooks_post["blocks.0.mlp.hook_post"].transpose(0,1)[c_idx]))

In [None]:
# c_post = zero_out_freqs(c_post, torch.arange(1, p//2 + 1).to(device), -1)
tprint("hook_pre", hook_pre.shape)
showVector(a_hook(hook_pre)[-1])
tprint("hook_post", hook_post.shape)
showVector(a_hook(hook_post)[-1])
# showVector(c_post)
# showVector(c_hook("resid_post")[-1])
# showVector(cache["W_U"])

In [None]:
showVector(inputs_last(cache["resid_post"][c_idx])[-1])
printCudaMemUsage()

as a increases (i.e. 0, 1, 2, 3...), to calculate a + b % p, b must decrement circularly (i.e. 0, 112, 111, 110...)\
this holds for all examples where c = 0... as you increase c, b's start index must increase as well.\
and since a and b are just indexing the same wave, the b wave becomes a reversed, shifted version of the a wave.\
an interesting thing happens when you add two reversed, shifted functions together.\
the resulting function is doubly symmetric with respect to two lines, one at shift/2 and one at shift/2 + length/2\
and if the function is sinusoidal, this leaves strictly cosine waves of double amplitude, centered at shift/2.\
sine, being an odd function, cancels itself out when a reversed copy is added.\
\
hook_z\
\
everything in the attention layer prior to hook_z, shares this reversal/shift scheme in how inputs index sinusoidal waves in the activations\
hook_z is where we first see the result of the interaction between the a and b inputs.\
hook_z is where we first see the symmetries that remain throughout all later activations\
this interaction and symmetry allows the model to generalize rather than memorize.

In [None]:
nuscores = inputs_last(cache["scores"][c_idx])
nsa,nsb = nuscores[0,-1,0,20,13], nuscores[0,-1,1,20,7]
tprint("nsa,nsb", nsa, nsb)
showVector(nuscores)
# tprint(nuscores)
w5 = torch.cos(torch.arange(p) * 5 * 2 * torch.pi / p)
w7 = torch.cos(torch.arange(p) * 7 * 2 * torch.pi / p)
wp = w5 * w7
# showVector(torch.stack((w5,w7,wp)))
# showVector(c_hook("z"))

glossary:\
p: 113 (modulus)\
k: key frequency index\
n_heads: number of heads(4)\
d_model: model dimensionality(128)\
d_head: head dimensionality(32)\
head: index for head\
d: index for model dimension\
dh: index for head dimension\
$freq_k$: key frequency\
$\omega_k$: angular frequency ($freq_k$ * 2$\pi$ / p)\
$\phi$: phase\
$\alpha$,$\beta$: amplitude\
|$\alpha$|: magnitude

hook_q: waves over inputs are all the same wave with phase ($\phi_k$) with head & d_head axes weighted by W_Q:\
here are positions (pos) a and b.  the axes are [pos,head,d_head,c,a]\
all waves in position b are reversed and shifted versions of waves in position a, shifted by -c\
q, k, and v activations all share this reversal/shift scheme between positions a and b.\
this will be key to how the attention mechanism inserts symmetries into the activations

here is a demonstration of how the attention scores are calculated\
one element of the d_head space will be selected from just one head to analyze\
the attention scores are calculated by q @ k.T

In [None]:
showVector([qk[0], qk[1]])
q_a,q_b,k_a,k_b = c_hook("q")[0,0,0], c_hook("q")[1,0,0], c_hook("k")[0,0,0], c_hook("k")[1,0,0]
score = q_a
tprint("q_a", q_a.shape, "q_b", q_b.shape, "k_a", k_a.shape, "k_b", k_b.shape)

here is a demonstration of how the attention mechanism produces the required symmetries needed to predict "c"\
the same element of the d_head space will be used for analysis\
since for every c, the attention scores are equal, hook_v a and b are summed together\
the "=" position is also added but it is a constant and doesn't effect the frequency content of the result

In [None]:
printCudaMemUsage()

In [None]:
def get_hook_stats(name, act, freqs=key_freqs):
    if not isinstance(freqs, torch.Tensor): freqs = torch.tensor(freqs, dtype=torch.long).to(act.device)
    x = inputs_last(act[c_idx])
    X = torch.fft.fft(x)
    dc = X[...,0:1].real.repeat(1,1,1,1,len(freqs)).squeeze(0)/p
    X = X[..., freqs]
    mag, phase = torch.abs(X)/(p/2), torch.angle(X)
    return torch.stack((mag, phase, dc, dc/mag), -1)

def get_frequency_stats(freqs=key_freqs):
    stats = CacheDict()
    for name in cache.keys():
        if "hook" in name and (d_model in list(cache[name].shape) or d_head in list(cache[name].shape)):
            if name not in stats.keys():
                stats[name] = get_hook_stats(name, cache[name], freqs)
    return stats

stats, second_order_stats = get_frequency_stats(key_freqs), get_frequency_stats(key_harmonics)

In [None]:
tprint("hook_q base phases:", stats["q"][0,0,0,0,:,1])

the waves for the b position are the same but with phase negated (equivalent to either reversing the wave or negating the sine component)\
the waves can be calculated like this:\
a: $\alpha_{head,dh}$ * cos(a$\omega_k$) + $\beta_{head,dh}$ * sin(a$\omega_k$)\
b: $\alpha_{head,dh}$ * cos(b$\omega_k$) - $\beta_{head,dh}$ * sin(b$\omega_k$)\
$\alpha_{head,dh}$ and $\beta_{head,dh}$ are amplitudes (weights) applied to each head and head dimension by W_Q

when

In [None]:
for name in cache: tprint(name, cache[name].shape)

In [None]:
unembed, resid = cache["unembed.W_U"], inputs_last(cache["resid_post"])[-1,:,0]
showVector([unembed/unembed.norm(), resid/resid.norm()])

In [None]:
showVector()

In [None]:
wk = key_harmonics * 2 * torch.pi / p
magnitude, phase, dc = second_order_stats["q"][...,0], second_order_stats["q"][...,1], second_order_stats["q"][...,2]
wk = wk.expand_as(phase)
polarity = torch.where(torch.sign(second_order_stats["q"][:,...,0:1,0:1,1]) == torch.sign(second_order_stats["q"][:,0:1,0:1,0:1,0:1,1]), 1.0, -1.0)
tprint("polarity", polarity.shape, "magnitude", magnitude.shape)
polarity = polarity.expand_as(magnitude)
amplitude = polarity * magnitude
p_indices = torch.arange(p).to(device)[None,None,None,...,None]
p_indices = p_indices.expand_as(magnitude)
# printvars(dc, magnitude, phase, amplitude, polarity, p_indices, wk)

base_phase = phase[:,0:1,0:1,0:1,:].expand_as(phase)
r_wave = (amplitude * torch.cos(p_indices * wk + base_phase)).sum(-1)
r_wave = dc[...,0] + r_wave
tprint("key_harmonics", key_harmonics)
q_corrs = get_hook_correlations(c_hook("q")[...,0,:], key_harmonics, name="hook q")
print_correlation_summary(q_corrs)
# tprint("q_corr keys", q_corrs.keys())
# tprint("q_corr", q_corrs)
# print_object_info(q_corrs)
# print("end of q_corr info, individual dict items below")
# for q_key in q_corrs.keys():
#     print("info for", q_key)
#     print_object_info(q_corrs[q_key])
# showVector(c_hook("q"))# - c_hook("q").mean(-1, True))
# showVector([c_hook("q")[:-1,...,0,:], r_wave[:-1]])
# showVector(r_wave)
tprint("r_wave diff", get_mse_and_worst_vector(r_wave, c_hook("q")[...,0,:]))
# showVector(pull_out_freqs(c_hook("pattern"), key_harmonics))
showVector(inputs_last(cache["q"][c_idx])[:2])

W_Q takes the residual stream embeddings and applies a weight and a bias (they are either both positive or both negative)
when they are negative, this is equivalent to adding pi to the phase.
the base phases are the result of each heads weighting of the waves in the embeddings during it's linear transformation.
at c == 0, the phases at position a and b are always negated versions of eachother
when the answer, c, increases from 0 to p - 1, the wave at the a position remains the same but the wave at the b position circularly shifts to the right one point at a time but otherwise remains identical, which is equivalent to subtracting wk(freq * 2 pi / p) from every frequency

a similar story goes for the k vectors, but each head focuses on one key_freq and has different starting phases (the relationship between q/k phases is random?)

the W_V operation seems entirely stochastic (you can see how random it seems by scrolling through the fourier transformed W_V), which is unintuitive.  given the theory that the attn output is just a linear combination of the value vectors weighted by the attn scores, stochasticity here seems pretty crazy, and thus the trigonometric information flow (each wave) is carried through the keys and queries (i am not at all sure of this)

In [None]:
a_a, a_b = prange[...,None].expand(p,p), prange[None].expand(p,p)
a_c, b_a, b_b, b_c, c_a, c_b, c_c = (a_a + a_b) % p, a_b, a_a, (a_a + a_b) % p, (a_b + p) % p, (prange[...,None] - a_b + p) % p, a_a
tokens = { "a": { "a":a_a, "b":a_b, "c":a_c }, "b": { "a":b_a, "b":b_b, "c":b_c }, "c": { "a":c_a, "b":c_b, "c":c_c } }

In [None]:
showVector(c_hook("resid_post")[2])

In [None]:
### ANALYTICALLY RECONSTRUCT RESIDUAL STREAM ###

def synthesize_symmetric_hook(act, freqs, *, alpha="c", seperate_freqs=False):
    freqs = freqs.to(act.device) if isinstance(freqs, torch.Tensor) else torch.tensor(freqs, dtype=torch.long, device=act.device)
    mean, phasors = get_phasors(act, freqs)
    mags, phases = torch.abs(phasors), torch.angle(phasors)
    output = (torch.zeros_like(act) + mean).unsqueeze(0)

    a, b, c = tokens[alpha]["a"], tokens[alpha]["b"], tokens[alpha]["c"]

    for f in range(freqs.shape[-1]):

        freq = freqs[...,f:f+1]
        wk = freq * 2 * torch.pi / p
        mag, phase = mags[...,f:f+1], phases[...,f:f+1]

        cos_waves = torch.cos((a - c/2) * wk)
        phasor = mag * torch.cos(phase + (c/2) * wk)

        if seperate_freqs:
            output = torch.cat((output, torch.zeros_like(output[0:1])), 0)
            output[-1] = cos_waves * phasor + mean
        output[0] += cos_waves * phasor
        # output[0] += mag * 10 * cos_waves

    return output if seperate_freqs else output.squeeze(0)

# showVector(cache["unembed.W_U"] * 5)

# post_a = inputs_last(cache["resid_post"][a_idx])[2]
# new_post_a = torch.zeros_like(post_a[...,0,:])
# for i in range(p):
#     new_post_a += torch.roll(post_a[...,i,:], shifts=i, dims=-1)
# new_post_a /= p
# showVector(new_post_a / 5)

def split_dim(hook_x, freqs):
    hook_x = inputs_last(hook_x)
    # while hook_x.ndim > 2: hook_x = hook_x[0]
    hook_x = hook_x[:50]
    tprint("hook_x", hook_x.shape)
    X = torch.fft.fft(hook_x)/p
    # tprint("X[:5,35]    ", X[:5,35])
    # tprint("X[:5,p - 35]", X[:5,p - 35])
    X = X[...,freqs]
    tprint("X", X.shape)
    X = torch.view_as_real(X)
    tprint("X", X.shape)
    X = torch.permute(X, (2,3,0,1))
    tprint("X", X.shape)

    return X

def shift_exponential():
    X = torch.zeros([2,p,p], device=device, dtype=torch.float)
    for j in range(p):
        X[0][j] = torch.cos(prange * 14 * 2 * torch.pi / p)
        X[1][j] = torch.sin(prange * 14 * 2 * torch.pi / p) - 1 + 2 * j / p
    tprint("X just cos and sine waves", X.shape)
    showVector(X)
    X = X.transpose(0,-1).contiguous()
    tprint("X before ifft", X.shape)
    x = torch.fft.ifft(torch.view_as_complex(X))
    tprint("x after ifft", x.shape)
    showVector(x.real)

# shift_exponential()

# showVector(inputs_last(cache["resid_post"][a_idx])[2])
# split = split_dim(inputs_last(cache["resid_post"][a_idx])[2], key_freqs)
# showVector([split[:,0,:], split[:,1,:]])

showVector(inputs_last(cache["resid_post"][a_idx])[2])
# showVector(split_dim(inputs_last(cache["resid_post"][c_idx])[2]))
s_hook_x = synthesize_symmetric_hook(inputs_last(cache["resid_post"][a_idx]), key_harmonics, alpha="a")
showVector(s_hook_x[2])


In [None]:
def get_fourier_coefficient_for_specific_frequency(x, frequency, dim=-1, theSampleRate=p):
    if dim < 0: dim = x.ndim + dim
    length = x.size(dim)
    jw = torch.exp(-2j * torch.arange(length, device=x.device) * torch.pi * frequency / theSampleRate)
    while jw.ndim <= dim: jw = jw[None]
    while jw.ndim < x.ndim: jw = jw[...,None]
    coeff = 2 * (jw * x).sum(dim) / length

    # tprint("mag", torch.abs(coeff), "phase", torch.angle(coeff))
    return coeff

two_freqs = torch.cos(torch.arange(p) * 14.5 * 2 * torch.pi / p)[None,...,None].expand(2,p,3)# + torch.cos(torch.arange(p) * 16 * 2 * torch.pi / p)
tprint("two_freqs", two_freqs.shape)
tprint("15.5 Hz", get_fourier_coefficient_for_specific_frequency(two_freqs, 14.5, 1))
tprint("16 Hz", get_fourier_coefficient_for_specific_frequency(two_freqs, 16, 1))
tprint("16.5 Hz", get_fourier_coefficient_for_specific_frequency(two_freqs, 16.5, 1))
tprint("7 Hz", get_fourier_coefficient_for_specific_frequency(two_freqs, 7, 1))

vecplen = torch.cos(1.0 + torch.arange(p) * 17.5 * 2 * torch.pi / p)
vecp2len = torch.cos(1.0 + torch.arange(p * 2) * 17.5 * 2 * torch.pi / p)
tprint("negated close", torch.allclose(vecp2len[p:], -vecplen, rtol=1e-4, atol=1e-4))
tprint("vecp2len[p:][:10]", vecp2len[p:][:10])
tprint("-vecplen[:10]", -vecplen[:10])

In [None]:
def get_c_weights(hook_x, freqs):
    DIM_BIAS = torch.mean(hook_x, dim=(-2,-1), keepdim=True)
    hook_x2 = interpDFT(hook_x, p*2)
    centered_hook_x = torch.zeros_like(hook_x)
    for c in range(p):
        rolled = torch.roll(hook_x2[...,c,:], 2 * p - c, -1)
        centered_hook_x[...,c,:] = rolled[...,::2]
    CENTERED_HOOK_X = torch.fft.fft(centered_hook_x)/p

    centered_coeffs = CENTERED_HOOK_X[...,[0] + list(freqs)]
    waves = torch.view_as_real(centered_coeffs)[...,0].transpose(-2,-1)

    weight_waves, bias_wave = torch.cat((waves[...,1:,:], waves[...,1:,:]), -1), 0.5 * waves[...,0,:]
    BIAS_COEFFS = torch.fft.fft(bias_wave)[...,[0] + list(freqs)]/p

    for f, freq in enumerate(freqs):
        if freq % 2 == 1:
            weight_waves[...,f,p:] *= -1

    WEIGHT_SPECTRUM = torch.fft.fft(weight_waves)/(p*2)
    WEIGHT_COEFFS = torch.gather(WEIGHT_SPECTRUM, -1, freqs[None,...,None].expand(WEIGHT_SPECTRUM[...,:1].shape)).squeeze()

    return WEIGHT_COEFFS, BIAS_COEFFS, DIM_BIAS

def synthesize_hook(weight_coeffs, bias_coeffs, dim_bias, freqs, alpha):
    shape = list(weight_coeffs[...,0].shape) + [p,p]
    wk = freqs * 2 * torch.pi / p
    mags, phases = torch.abs(weight_coeffs).unsqueeze(-2).unsqueeze(-2), torch.angle(weight_coeffs).unsqueeze(-2).unsqueeze(-2)
    bias_mags, bias_phases = torch.abs(bias_coeffs).unsqueeze(-2).unsqueeze(-2), torch.angle(bias_coeffs).unsqueeze(-2).unsqueeze(-2)
    a, c = tokens[alpha]["a"][None,...,None], tokens[alpha]["c"][None,...,None]
    psquare = a - c/2
    centered_waves = torch.cos(psquare * wk)
    weight_wave = mags * 4 * torch.cos(phases + c * wk/2)
    bias_wave = (bias_mags * 4 * torch.cos(bias_phases + c * torch.cat((torch.zeros_like(wk[0:1]), wk), -1))).sum(-1)
    synth = (weight_wave * centered_waves).sum(-1) + bias_wave - dim_bias
    return synth

hook_a = inputs_last(cache["resid_post"][a_idx,2])
hook_c = inputs_last(cache["resid_post"][c_idx,2])
full_freqs = key_freqs#torch.cat((key_freqs, key_harmonics), 0)
weight_coeffs, bias_coeffs, dim_bias = get_c_weights(hook_c, full_freqs)
tprint("weight_coeffs", weight_coeffs[0])
synth_hook_a = synthesize_hook(weight_coeffs, bias_coeffs, dim_bias, full_freqs, "a")
synth_hook_c = synthesize_hook(weight_coeffs, bias_coeffs, dim_bias, full_freqs, "c")
non_key_freqs = [x for x in range(p//2+1) if x not in key_freqs.tolist()]
tprint("non_key_freqs", non_key_freqs)
hook_a = zero_out_freqs(hook_a, non_key_freqs, [-2,-1])
synth_hook_a = zero_out_freqs(synth_hook_a, non_key_freqs, [-2,-1])
showVector([hook_a, synth_hook_a])
# hook_c = zero_out_freqs(hook_c, non_key_freqs, [-2])
# synth_hook_c = zero_out_freqs(synth_hook_c, non_key_freqs, [-2])
showVector([hook_c, synth_hook_c])

In [None]:
# showVector(unembed)
# showVector(cache["resid_post"][:,2])
# showVector(c_hook("resid_post")[2])
unembed = cache["unembed.W_U"]
uspec = torch.fft.fft(unembed)
allumags = torch.abs(uspec)/p
alluphases = torch.angle(uspec)
umags = torch.abs(uspec[...,key_freqs])
uphases = torch.angle(uspec[...,key_freqs])
tprint("unembed base phases", uphases.shape, uphases[:5])

def synthesize_resid_post(alpha, freqs, mags, phases):
    global wk,bwk,p5,p6,p7,synth
    synth = torch.zeros_like(inputs_last(cache["resid_post"]), dtype=torch.float)
    a, b, c = expand_all_left(synth, tokens[alpha]["a"], tokens[alpha]["b"], tokens[alpha]["c"])
    mags, phases = mags.unsqueeze(-2).expand_as(synth[...,:mags.size(-1)]), phases.unsqueeze(-2).expand_as(synth[...,:phases.size(-1)])
    tprint("mags", mags.shape, "phases", phases.shape, "synth", synth.shape, "a", a.shape, "b", b.shape, "c", c.shape, "freqs", freqs.shape)
    freqs = expand_all_left(mags, freqs)[0].unsqueeze(-2)
    tprint("freqs", freqs.dtype, freqs.shape)

    wk = freqs * 2 * torch.pi / p
    a, b, c = a.unsqueeze(-1), b.unsqueeze(-1), c.unsqueeze(-1)
    tprint("wk", wk.shape, "\na", a.shape, a[-1,0,:5,:5,0], "\nb", b.shape, b[-1,0,:5,:5,0], "\nc", c.shape, c[-1,0,:5,:5,0])

    awk, bwk = a * wk, b * wk

    phases = phases.unsqueeze(-3)
    p5 = torch.exp(1j * (bwk + phases)) * torch.exp(1j * (awk))
    mags = mags.unsqueeze(-3)
    tprint("awk", awk.shape, "bwk", bwk.shape, "phases", phases.shape, "p5", p5.shape, "mags", mags.shape)
    p6 = mags * p5
    p7 = p6.real

    printvars(wk,bwk,p5,p6,p7,synth)
    # tprint("
    # synth[:] += (mags * torch.exp(1j * (b * freqs * 2 * torch.pi / p + phases))).real
    synth += p7.sum(-1)
    return synth

synth = synthesize_resid_post("a", key_freqs, umags[None], uphases[None])
showVector(synth[-1])
showVector(inputs_last(cache["resid_post"])[2])

showVector(inputs_last(inputs_first(synth)[c_idx]))
showVector(c_hook("resid_post"))

# synth_c = synthesize_resid_post("c", key_freqs, umags[None], uphases[None])
# showVector(inputs_last(inputs_first(synth_c)[c_idx])[-1])
# showVector(inputs_last(cache["resid_post"][c_idx])[2])

for pos in range(p):
    wk = torch.exp(1j * -pos * key_freqs * 2 * torch.pi / p)
    resid = inputs_last(cache["resid_post"])[2,:,pos]
    rspec = torch.fft.fft(resid)
    rphases = torch.angle(rspec[...,key_freqs] * wk.unsqueeze(0))

    total = 0
    diff, absdiff = 0.0, 0.0
    for d in range(d_model):
        for f, freq in enumerate(key_freqs):
            u, r = uphases[d][f].item(), rphases[d][f].item()
            r = r + 2 * torch.pi if u > torch.pi/2 and r < -torch.pi/2 else r
            u = u + 2 * torch.pi if r > torch.pi/2 and u < -torch.pi/2 else u
            absdiff += np.abs(u - r)
            diff += (u - r)
            total = total + 1
    tprint("pos", pos, "total", total, "absdiff", absdiff, "phase delta", absdiff/total, "diff", round(diff, 4), "rspec", rspec.shape, "wk", wk.shape)

In [None]:
pc = 0.9
w = 14 * 2 * torch.pi / p
# cef = torch.polar(torch.tensor([1.0]), torch.tensor([pc]) - torch.arange(p) * w) * torch.exp(1j * w)
# cer = torch.polar(torch.tensor([1.0]), torch.tensor([-pc])) * torch.exp(1j * w)
# cep = cef * cer
# cf, cr, cp, sf, sr, sp = cef.real, cer.real, cep.real, cef.imag, cer.imag, cep.imag
wave_a, wave_b, wave_p, wave_s = [], [], [], []
for i in range(p):
    wave_a.append((torch.polar(torch.tensor([1.0]), torch.tensor([pc]) - i * w) * torch.exp(1j * torch.arange(p) * w)).real)
    wave_b.append((torch.polar(torch.tensor([1.0]), torch.tensor([-pc])) * torch.exp(1j * torch.arange(p) * w)).real)
    wave_p.append(wave_a[i] * wave_b[i])
    wave_s.append(1.0 * wave_a[i] + 1.0 * wave_b[i])
wave_a, wave_b, wave_p, wave_s = torch.stack(wave_a), torch.stack(wave_b), torch.stack(wave_p), torch.stack(wave_s)
waves = torch.stack((wave_a, wave_b, wave_p, wave_s))
# showVector(waves)
x = cache["resid_pre"]
# x[...,50:] = 0.0
# x[...,0:49] = 0.0
# showVector(c_hook(x))
W_Q, W_K, W_V = cache["W_Q"], cache["W_K"], cache["W_V"]
q = einops.einsum(W_Q, x, "i d h, b p d -> b p i h")# [nh C hs] @ [B T C] = [B T nh hs]
k = einops.einsum(W_K, x, "i d h, b p d -> b p i h")# [nh C hs] @ [B T C] = [B T nh hs]
v = einops.einsum(W_V, x, "i d h, b p d -> b p i h")# [nh C hs] @ [B T C] = [B T nh hs]
q,k,v = cache["q"], cache["k"], cache["v"]
# showVector(torch.stack((c_hook(v), c_hook(cache["v"])/30.0)))
q = einops.rearrange(q, "b p i h -> b p i h")
k = einops.rearrange(k, "b p i h -> b p i h")
v = einops.rearrange(v, "b p i h -> b p i h")
q = q.transpose(1,2)# [B T nh hs] -> [B nh T hs]
k = k.transpose(1,2).transpose(-2, -1)# [B T nh hs] -> [B nh T hs] -> [B nh hs T]
v = v.transpose(1,2)

attn_scores = torch.matmul(q, k)/np.sqrt(d_head)# [B nh T hs] @ [B nh hs T] = [B nh T T]
attn_scores_masked = attn_scores + torch.full([1, 3, 3], -torch.inf).triu(1).to(k.device)# [B nh T T]
pattern = F.softmax(attn_scores_masked, -1)# [B nh T T]
z = einops.einsum(v, pattern, "batch mhead k_pos d_head, batch mhead q_pos k_pos -> batch mhead q_pos d_head")
hook_z = einops.rearrange(z, "batch head q_pos d_head -> batch q_pos head d_head", head=n_heads)
# showVector(c_hook(hook_z))
tprint(pattern[:10,0].shape, pattern[c_idx][:10,0])
# pattern[...,[0, 1]] = pattern[...,[1, 0]]
pattern[...,0] = 0.463
pattern[...,1] = 0.537
pattern[...,2] = 0.0
tprint(pattern[:10,0].shape, pattern[c_idx][:10,0])
# tprint(pattern[:10])
z = einops.einsum(v, pattern, "batch mhead k_pos d_head, batch mhead q_pos k_pos -> batch mhead q_pos d_head")
nu_hook_z = einops.rearrange(z, "batch head q_pos d_head -> batch q_pos head d_head", head=n_heads)
showVector(torch.stack((c_hook(hook_z), c_hook(nu_hook_z)))[:,-1])

In [None]:
print("weights\n")
for name, x in cache.items():
    if "hook" not in name: tprint(name, x.shape)
print("\nactivations\n")
for name, x in cache.items():
    if "hook" in name: tprint(name, x.shape)

In [None]:
hook_names_ordered = [
    "hook_embed",
    "hook_pos_embed",
    "blocks.0.hook_resid_pre",
    "blocks.0.attn.hook_q",
    "blocks.0.attn.hook_k",
    "blocks.0.attn.hook_v",
    "blocks.0.attn.hook_attn_scores",
    "blocks.0.attn.hook_pattern",
    "blocks.0.attn.hook_z",
    "blocks.0.attn.hook_result",
    "blocks.0.hook_attn_out",
    "blocks.0.hook_resid_mid",
    "blocks.0.mlp.hook_pre",
    "blocks.0.mlp.hook_post",
    "blocks.0.hook_mlp_out",
    "blocks.0.hook_resid_post",
    "hook_logits",
]

In [None]:
def get_full_cache_name(name):
    found = 0
    ret_name = ""
    for full_name in cache:
        if name in full_name:
            ret_name = full_name
            found = found + 1
    if found > 1: print("the key given matched more than one full_name in the cache", name)
    return ret_name

def get_pos_to_analyze(hook_name):
    name = get_full_cache_name(hook_name)
    retval = None
    if ("embed" in name or "attn." in name) and "hook_z" not in name:
        retval = slice(0,2)
    else:
        retval = 2
    print(hook_name, "pos return value", retval)
    return retval

def get_cpos_to_analyze(hook_name):
    name = get_full_cache_name(hook_name)
    retval = None
    if ("embed" in name or "attn." in name) and "hook_z" not in name:
        retval = 0
    else:
        retval = slice(0,p)
    print(hook_name, "cpos return value", retval)
    return retval

In [None]:
for hook_name in hook_names_ordered:
    if "hook" in hook_name:
        sym = get_hook_correlations(c_hook(hook_name), key_harmonics, pos=get_pos_to_analyze(hook_name), cpos=get_cpos_to_analyze(hook_name))
        tprint("\n" + hook_name + "\nsymmetries\n")
        tprint(sym)

# hook_name = "hook_k"
# sym = get_hook_correlations(c_hook(hook_name), key_harmonics, pos=slice(0,2))
# tprint("\nsymmetries\n")
# tprint(sym)

hook_attn_out:

not a theory: hook_attn_out = sum of (each head score * hook_v)

In [None]:
hook_attn_out = c_hook("hook_attn_out")[:,:2]
print("hook_attn_out", hook_attn_out.shape)
showVector(pull_out_freqs(hook_attn_out, key_freqs, dims=[-1]).transpose(0,1))

In [None]:
neurons_pre = cache["blocks.0.mlp.hook_pre"]
neuron_acts = cache["blocks.0.mlp.hook_post"]
# Center the neurons to remove the constant term
neuron_acts_centered = neuron_acts - einops.reduce(neuron_acts, 'batch neuron -> 1 neuron', 'mean')
# for name, t in cache.items(): print(name, t.shape)

In [None]:
# def extract_freq_2d(tensor, freq):
#     # Takes in a pxpx... or batch x ... tensor, returns a 3x3x... tensor of the
#     # Linear and quadratic terms of frequency freq
#     if tensor.shape[0]==p*p: tensor = einops.rearrange(tensor, '(x y) ... -> x y ...', x=p, y=p)
#     # Extracts the linear and quadratic terms corresponding to frequency freq
#     index_1d = [0, 2*freq-1, 2*freq]
#     # Some dumb manipulation to use fancy array indexing rules
#     # Gets the rows and columns in index_1d
#     print("index_1d", [[i]*3 for i in index_1d], [index_1d]*3)
#     rv = tensor[[[i]*3 for i in index_1d], [index_1d]*3]
#     return rv

# def get_best_freqs(x):
#     x = inputs_last(x)
#     print("x shhape", x.shape)
#     x = x - einops.reduce(x, '... a b -> ... 1 1', 'mean')
#     X = fft2d(inputs_last(x))
#     X = X.transpose(0,-1)
#     d_X = X.shape[-1]
#     freqs, second_freqs, frac_explained, second_frac_explained = [], [], [], []

#     for d in range(d_X):
#         best_frac_explained = -1e6
#         best_freq = -1
#         second_best_frac_explained = -1e6
#         second_best_freq = -1
#         for freq in range(1, p//2):
#             # We extract the linear and quadratic fourier terms of frequency freq,
#             # and look at how much of the variance of the full vector this explains
#             # If neurons specialise into specific frequencies, one frequency should
#             # have a large value
#             frac = (extract_freq_2d(X[:, :, d], freq).pow(2).sum()/
#                               X[:, :, d].pow(2).sum()).item()
#             if frac > best_frac_explained:
#                 second_best_freq = best_freq
#                 second_best_frac_explained = best_frac_explained
#                 best_freq = freq
#                 best_frac_explained = frac
#             elif frac > second_best_frac_explained:
#                 second_best_freq = freq
#                 second_best_frac_explained = frac
#         freqs.append(best_freq)
#         frac_explained.append(best_frac_explained)
#         second_freqs.append(second_best_freq)
#         second_frac_explained.append(second_best_frac_explained)
#     freqs = np.array(freqs)
#     frac_explained = np.array(frac_explained)
#     freqs[frac_explained < 0.25] = -1.

#     key_freqs_plus = np.concatenate([key_freqs, np.array([-1])])
#     for i in range(len(key_freqs_plus)):
#         print(f'Cluster {i}: freq {key_freqs_plus[i]}. {(freqs==key_freqs_plus[i]).sum()} neurons')
#     return freqs, frac_explained

# freqs_path = "best_freqs.pt"
# best_freqs = {}
# # neuron_acts = cache["blocks.0.mlp.hook_post"]
# # # Center the neurons to remove the constant term
# # neuron_acts_centered = neuron_acts - einops.reduce(neuron_acts, 'batch neuron -> 1 neuron', 'mean')

# if os.path.isfile(freqs_path):
#     best_freqs = torch.load(freqs_path)
# else:
#     for name in cache.keys():
#         if "hook" in name and cache[name].shape[-1] in [d_mlp, d_model]:
#             _, bf, _ = get_freq_indices(cache[name], 5)
#             best_freqs[name] = bf
#     print("best freqs", [name + str(tstr(b[:5])) for name, b in best_freqs.items()])
#     torch.save(best_freqs, freqs_path)

#############################################

# neuron_freqs, neuron_frac_explained = get_best_freqs(neuron_acts_centered)
# neuronsab = neuron_acts.unflatten(0, (p,p))

# freq_pairs = {}

# for i in range(d_mlp):
#     tta, ttav = get_top_k_freqs(neuronsab[..., i], 5, 0, [1])
#     freq_pair = (tta[1].item(), tta[2].item(), tta[3].item())
#     ratio = 666.0
#     if not tta[2].item() in [14,35,41,52]: ratio = (ttav[1]/ttav[2]).item()
#     freq_pairs[freq_pair] = freq_pairs.get(freq_pair, [])
#     freq_pairs[freq_pair].append(ratio)
#     # print("dim", i, "getting top k freqs for a", list(neuronsab[..., i].shape), tta[1:].tolist(), torch.round(ttav[1:]).tolist(), "ratio", ratio)
#     ttb, ttbv = get_top_k_freqs(neuronsab[..., i], 5, 1, [0])
#     # print("getting top k freqs for b", list(neuronsab[..., i].shape), ttb.tolist(), torch.round(ttbv, decimals=2).tolist())
#     if tta.tolist() != ttb.tolist():
#         print("not equal freqz")
#         print("getting top k freqs for a", list(neuronsab[..., i].shape), tta[1:].tolist(), torch.round(ttav[1:], decimals=2).tolist())
#         print("getting top k freqs for b", list(neuronsab[..., i].shape), ttb[1:].tolist(), torch.round(ttbv[1:], decimals=2).tolist())
#         print("neuron freq", neuron_freqs[i], "frac explained", neuron_frac_explained[i], "second freq", second_neuron_freqs[i], "second frac explained", second_neuron_frac_explained[i])

# for key in freq_pairs.keys():
#     ratios = freq_pairs[key]
#     rten = torch.tensor(ratios)
#     print("\nfreq pair", key, "count", len(ratios), "\nratios\n", ratios)
#     print("min", rten.min().item(), "max", rten.max().item(), "stddev", rten.std().item(), "mean", rten.mean().item())
# key_freqs, neuron_freq_counts = np.unique(neuron_freqs, return_counts=True)

In [None]:
def get_top_k_quadratic_terms(spec, k):
    top = []
    lintop = []
    for freq in range(1,p//2+1):
        box = spec[freq*2:freq*2+2,freq*2:freq*2+2]
        mag = box.flatten().abs().max().item()
        linbox1 = spec[0:1,freq*2:freq*2+2]
        linbox2 = spec[freq*2:freq*2+2,0:1]
        linmag = (linbox1.flatten().abs().max() + linbox2.flatten().abs().max()).item() / 2
        if len(top) < k:
            top.append((freq, round(mag,1)))#, [round(num,1) for num in box.flatten().tolist()]))
            if len(top) == k:
                top = sorted(top, key=lambda x: x[1])
        elif mag > top[0][1]:
            top[0] = (freq, round(mag,1))
            top = sorted(top, key=lambda x: x[1])
        if len(lintop) < k:
            lintop.append((freq, round(linmag,1)))#, [round(num,1) for num in box.flatten().tolist()]))
            if len(lintop) == k:
                lintop = sorted(lintop, key=lambda x: x[1])
        elif linmag > lintop[0][1]:
            lintop[0] = (freq, round(linmag,1))
            lintop = sorted(lintop, key=lambda x: x[1])
    top = sorted(top, key=lambda x: x[1], reverse=True)
    lintop = sorted(lintop, key=lambda x: x[1], reverse=True)
    return top, lintop

def get_top_k_linear_terms(x, k):
    mags = torch.fft.fft(x, dim=-1).abs().sum(-1)[...,:p//2+1] + torch.fft.fft(x, dim=-2).abs().sum(-2)[...,:p//2+1]
    mags[0] = 0
    mags, freqs = mags.topk(k)
    return sorted(dict(zip(freqs.tolist(), mags.tolist())), key=lambda item: item[1], reverse=True)

def get_top_k_sums(x, k):
    X = fft2d(x)
    topk = {}
    lintopk = {}
    for d in range(x.size(0)):
        top, lintop = get_top_k_quadratic_terms(X[d], k)
        # print("lintop linear 2d", len(lintop), lintop)
        # lintop = get_top_k_linear_terms(x[d], k)
        # print("fft linear", len(lintop.items()), lintop)
        if d == 0: print(top)
        for i in range(len(top)):
            freq = top[i][0]
            if freq in topk:
                topk[freq] += top[i][1]
            else:
                topk[freq] = top[i][1]
            freq = lintop[i][0]
            if freq in lintopk:
                lintopk[freq] += lintop[i][1]
            else:
                lintopk[freq] = lintop[i][1]
    topk = dict(sorted(topk.items(), key=lambda item: item[1], reverse=True))
    lintopk = dict(sorted(lintopk.items(), key=lambda item: item[1], reverse=True))
    return topk, lintopk

def inspect_quad(spec):
    if spec.size(0) == p*p: spec = spec.unflatten(0,(p,p))
    def describe(idx): return "Top 3 Freqs: " + str(get_top_k_quadratic_terms(spec[idx], 3))
    get_label = lambda x,start,end : str(x) if x==0 else "sin " + str(x//2+1) if x%2==1 else "cos " + str(x//2)
    lp.draw_matrix(spec.pow(2).sqrt(), xmap=get_label, ymap=get_label, descriptor=describe)

def inspect_quads(specs, title = "Tensors", names = []):
    N = len(specs)

    if names == []: names = list(map(chr, range(ord('A'), ord('A')+N)))

    print("specs[0].shape", specs[0].shape, specs[1].shape)

    spec = torch.empty(0).to(specs[0].device)

    for i in range(len(specs)):
        if specs[i].size(0) == p*p: specs[i] = specs[i].unflatten(0,(p,p))
        if specs[i].size(-1) == p*p: specs[i] = specs[i].unflatten(-1,(p,p))
        if specs[i].size(-1) % p != 0: specs[i] = specs[i].permute(-1,0,1)
        spec = torch.cat((spec, specs[i].unsqueeze(1)), 1)

    dim_len = specs[0].size(0)
    spec = spec.flatten(0,1)

    print("specs[0].shape", specs[0].shape, spec.shape)

    def describe(idx):
        n=idx%N
        return str(idx//N) + names[n] + " " + title + ": " + str(get_top_k_quadratic_terms(spec[idx], 5))
    get_label = lambda x,start,end : str(x) if x==0 else "sin " + str(x//2+1) if x%2==1 else "cos " + str(x//2)
    lp.draw_matrix(spec.pow(2).sqrt(), xmap=get_label, ymap=get_label, descriptor=describe)

# inspect_quad(fft2d(neuron_acts.unflatten(0,(p,p)).permute(-1,0,1)))
# inspect_quad(fft2d(neurons_pre.unflatten(0,(p,p)).permute(-1,0,1)))

def compare_spectrums(before, after, title = ""):
    pre = inputs_last(before)
    post = inputs_last(after)
    print("pre", pre.shape, "sum", round(pre.abs().sum().item(), 2), "post", post.shape, "sum", round(post.abs().sum().item(), 2))
    top_pre = get_top_k_sums(pre, 10)
    top_post = get_top_k_sums(post, 10)
    combined = [{}, {}]
    for i in range(2):
        for key, value in top_pre[i].items():
            combined[i][key] = (round(value,1), 0.0)
        for key, value in top_post[i].items():
            if key in combined[i]:
                combined[i][key] = (combined[i][key][0], round(value,1))
            else:
                combined[i][key] = (0.0, round(value,1))
    print(title)
    for freq, mags in combined[0].items():
        print("quadratic freq", freq, "before", mags[0], "after", mags[1])
    for freq, mags in combined[1].items():
        print("linear freq", freq, "before", mags[0], "after", mags[1])


pre_attn = cache["blocks.0.hook_resid_pre"]
post_attn = cache["blocks.0.hook_attn_out"]
pre_pos_embed = cache["hook_embed"][:,0,:]
post_pos_embed = cache["hook_pos_embed"][:,0,:]


# inspect_quads([fft2d(inputs_last(neurons_pre)), fft2d(inputs_last(neuron_acts))], "Neurons")
# inspect_quads([fft2d(inputs_last(pre_attn * 500000)), fft2d(inputs_last(post_attn))], "Attention")

# compare_spectrums(neurons_pre, neuron_acts)
# compare_spectrums(pre_attn * 500000, post_attn, "Pre/Post Attn")
compare_spectrums(pre_pos_embed * 500000, post_pos_embed * 500000, "Pre/Post Pos Embed")
linears = [46, 49, 32, 56, 39, 48, 41, 52, 25, 43]
sums = {}
diffs = {}
for l in linears:
    for m in linears:
        sum = l + m
        if sum > p//2:
            sum = round(p/2 - (sum - p/2))
        diff = round(abs(l - m))
        sums[sum] = sums[sum] + 1 if sum in sums else 1
        diffs[diff] = diffs[diff] + 1 if diff in diffs else 1

print("sums", sorted(sums.items(), key=lambda x : x[1], reverse=True))
print("diffs", sorted(diffs.items(), key=lambda x : x[1], reverse=True))