## Associative memory

This notebook experiments with ideas from [Kohonen 1973](https://ieeexplore.ieee.org/stamp/stamp.jsp?arnumber=5009138).

### The setup

We have $n$ key-value pairs $\{ (k_i, v_i) \}$, where each key $k_i \in \mathbb{R}^{d_k}$ and each value $v_i \in \mathbb{R}^{d_v}$.

Define $K$ to be the matrix whose columns are the keys $k_i$, and $V$ to be the matrix whose columns are the values $v_i$.

We want to find a matrix $W$ such that $v_i \approx Wk_i$ for all $i \in \{1, 2, \ldots, n\}$, i.e. $V \approx WK$.

This problem can be framed as a least squares problem: we want to find $W$ that minimizes $\|V - WK\|^2$.

The solution is $\hat{W} = VK^T(KK^T)^{-1}$, or equivalently $\hat{W} = VK^+$, where $K^+$ is the pseudoinverse of $K$.

In [1]:
%%capture
%pip install torch plotly nbformat tqdm jaxtyping

In [2]:
import torch
import plotly.graph_objs as go
import tqdm

from torch import Tensor
from jaxtyping import Int, Float

torch.manual_seed(42)

<torch._C.Generator at 0x11c512850>

In [3]:
def get_cos_sim_mean(actual: Float[Tensor, 'n m'], estimated: Float[Tensor, 'n m']) -> float:
    cos_sims = torch.nn.functional.cosine_similarity(actual, estimated, dim=0)
    return cos_sims.mean()

def get_distance_mean(actual: Float[Tensor, 'n m'], estimated: Float[Tensor, 'n m']) -> float:
    resid = actual - estimated
    distances = torch.norm(resid, dim=0)
    return distances.mean()

In [4]:
log_d = 5
log_n_min = 2
log_n_max = 15

d = 2**log_d

n_values = []
cos_sim_values = []
euclidean_distance_values = []

# first find a baseline cosine similarity for random matrices
m1 = torch.randn((d, 1000))
m2 = torch.randn((d, 1000))
baseline_cos_sim = torch.nn.functional.cosine_similarity(m1, m2, dim=0).mean()
baseline_euclidean_distance = torch.norm(m1 - m2, dim=0).mean()

for log_n in tqdm.tqdm(range(log_n_min, log_n_max, 1)):
    n = 2**log_n
    n_values.append(n)

    K = torch.randn((d, n))
    V = torch.randn((d, n))

    recovered_V = V @ torch.linalg.pinv(K) @ K

    cos_sim_values.append(get_cos_sim_mean(V, recovered_V))
    euclidean_distance_values.append(get_distance_mean(V, recovered_V))

100%|██████████| 13/13 [00:00<00:00, 205.41it/s]


In [5]:
fig_cos = go.Figure()
fig_cos.add_trace(go.Scatter(x=n_values, y=cos_sim_values, mode='lines+markers', name='Recovered values'))
fig_cos.add_trace(go.Scatter(x=n_values, y=[0]*len(n_values), mode='lines', name='Random baseline', line=dict(dash='dash')))
fig_cos.update_xaxes(title_text="Number of key-value pairs (log scale)", type='log')
fig_cos.update_yaxes(title_text="Cosine similarity (avg)")
fig_cos.update_layout(title=f"Cosine similarity between recovered values and original values, d={d}", height = 400, width=1000)
fig_cos.show()

fig_dist = go.Figure()
fig_dist.add_trace(go.Scatter(x=n_values, y=euclidean_distance_values, mode='lines+markers', name='Recovered values'))
fig_dist.add_trace(go.Scatter(x=n_values, y=[baseline_euclidean_distance]*len(n_values), mode='lines', name='Random baseline', line=dict(dash='dash')))
fig_dist.update_xaxes(title_text="Number of key-value pairs (log scale)", type='log')
fig_dist.update_yaxes(title_text="Euclidean distance (avg)")
fig_dist.update_layout(title=f"Euclidean distance between recovered values and original values, d={d}", height = 400, width=1000)
fig_dist.show()

In [6]:
log_n_min = 2
log_n_max = 16

log_d_values = [4, 5, 6, 7, 8, 9, 10]

cos_sim_values_dict = {}

for log_d in log_d_values:
    d = 2 ** log_d
    
    current_cos_sim_values = []
    current_euclidean_distance_values = []

    for log_n in tqdm.tqdm(range(log_n_min, log_n_max, 1)):
        n = 2 ** log_n

        K = torch.randn((d, n))
        V = torch.randn((d, n))

        recovered_V = V @ torch.linalg.pinv(K) @ K

        current_cos_sim_values.append(get_cos_sim_mean(V, recovered_V))

    cos_sim_values_dict[d] = current_cos_sim_values

100%|██████████| 14/14 [00:00<00:00, 228.62it/s]
100%|██████████| 14/14 [00:00<00:00, 126.59it/s]
100%|██████████| 14/14 [00:00<00:00, 59.49it/s] 
100%|██████████| 14/14 [00:00<00:00, 24.22it/s]
100%|██████████| 14/14 [00:01<00:00, 11.16it/s]
100%|██████████| 14/14 [00:02<00:00,  4.74it/s]
100%|██████████| 14/14 [00:08<00:00,  1.72it/s]


In [7]:
n_values = [2 ** i for i in range(log_n_min, log_n_max, 1)]
fig_cos = go.Figure()
for d, cos_sims in cos_sim_values_dict.items():
    fig_cos.add_trace(go.Scatter(x=n_values, y=cos_sims, mode='lines+markers', name=f'd={d}'))
fig_cos.update_xaxes(title_text="Number of key-value pairs (log scale)", type='log')
fig_cos.update_yaxes(title_text="Cosine similarity (avg)")
fig_cos.update_layout(title=f"Cosine similarity between recovered values and original values for multiple d values", height=400, width=1000)
fig_cos.show()

### Comparing other methods
- Pseudoinverse
  - $v_i \approx V K^+ k_i$
- Softmax:
  - $v_i \approx V \cdot \text{softmax}(K^T k_i)$
- ReLU:
  - $v_i \approx V \cdot \text{ReLU}(K^T k_i)$
- Vanilla dot product:
  - $v_i \approx V K^T k_i$

In [8]:
log_d = 6
log_n_min = 2
log_n_max = 15
d = 2**log_d

n_values = []
cos_sims_pseudo = []
cos_sims_softmax = []
cos_sims_relu = []
cos_sims_dot_prod = []

for log_n in range(log_n_min, log_n_max, 1):
    n = 2**log_n
    n_values.append(n)

    K = torch.randn((d, n))
    V = torch.randn((d, n))

    # Pseudoinverse
    recovered_V_from_W = V @ torch.linalg.pinv(K) @ K
    cos_sims_pseudo.append(get_cos_sim_mean(V, recovered_V_from_W))

    # Softmax
    recovered_V_from_softmax = V @ torch.softmax(K.T @ K, dim=0)
    cos_sims_softmax.append(get_cos_sim_mean(V, recovered_V_from_softmax))

    # ReLU
    recovered_V_from_relu = V @ torch.nn.ReLU()(K.T @ K)
    cos_sims_relu.append(get_cos_sim_mean(V, recovered_V_from_relu))

    # Dot product
    recovered_V_from_dot_prod = V @ (K.T @ K)
    cos_sims_dot_prod.append(get_cos_sim_mean(V, recovered_V_from_dot_prod))


In [9]:
fig = go.Figure()
fig.add_trace(go.Scatter(x=n_values, y=cos_sims_pseudo, mode='lines+markers', name='Pseudoinverse'))
fig.add_trace(go.Scatter(x=n_values, y=cos_sims_softmax, mode='lines+markers', name='Softmax'))
fig.add_trace(go.Scatter(x=n_values, y=cos_sims_relu, mode='lines+markers', name='ReLU'))
fig.add_trace(go.Scatter(x=n_values, y=cos_sims_dot_prod, mode='lines+markers', name='Dot product'))

fig.update_xaxes(title_text="Number of key-value pairs (log scale)", type='log')
fig.update_yaxes(title_text="Cos sim (avg)")
fig.update_layout(title=f"Cos sim between recovered values and original values, d={d}", height=400, width=1000)
fig.show()