## Associative memory

This notebook experiments with ideas from [Kohonen 1973](https://ieeexplore.ieee.org/stamp/stamp.jsp?arnumber=5009138).

### The setup

We have $n$ key-value pairs $\{ (k_i, v_i) \}$, where each key $k_i \in \mathbb{R}^{d_k}$ and each value $v_i \in \mathbb{R}^{d_v}$.

Define $K$ to be the matrix whose columns are the keys $k_i$, and $V$ to be the matrix whose columns are the values $v_i$.

We want to find a matrix $W$ such that $v_i \approx Wk_i$ for all $i \in \{1, 2, \ldots, n\}$, i.e. $V \approx WK$.

This problem can be framed as a least squares problem: we want to find $W$ that minimizes $\|V - WK\|^2$.

The solution is $\hat{W} = VK^T(KK^T)^{-1}$, or alternatively $\hat{W} = VK^+$, where $K^+$ is the pseudoinverse of $K$.

In [1]:
%%capture
%pip install torch plotly nbformat tqdm 

In [2]:
import torch
import plotly.graph_objs as go
import tqdm

In [3]:
log_d = 8
log_n_min = 2
log_n_max = 16

d = 2**log_d

n_values = []
cos_sim_values = []
euclidean_distance_values = []

# first find a baseline cosine similarity for random matrices
m1 = torch.randn((d, 100))
m2 = torch.randn((d, 100))
baseline_cos_sim = torch.nn.functional.cosine_similarity(m1, m2, dim=0).mean()
baseline_euclidean_distance = torch.norm(m1 - m2, dim=0).mean()

for log_n in tqdm.tqdm(range(log_n_min, log_n_max, 1)):
    n = 2**log_n

    K = torch.randn((d, n))
    V = torch.randn((d, n))

    W = V @ torch.linalg.pinv(K)

    recovered_V = W @ K

    resid = V - recovered_V
    resid_distance_mean = torch.norm(resid, dim=0).mean()

    cos_sims = torch.nn.functional.cosine_similarity(V, recovered_V, dim=0)
    cos_sim_mean = cos_sims.mean()

    n_values.append(n)
    cos_sim_values.append(cos_sim_mean)
    euclidean_distance_values.append(resid_distance_mean.item())

100%|██████████| 14/14 [00:01<00:00, 10.39it/s]


In [4]:
fig_cos = go.Figure()
fig_cos.add_trace(go.Scatter(x=n_values, y=cos_sim_values, mode='lines+markers', name='Recovered values'))
fig_cos.add_trace(go.Scatter(x=n_values, y=[baseline_cos_sim]*len(n_values), mode='lines', name='Random baseline', line=dict(dash='dash')))
fig_cos.update_xaxes(title_text="Number of key-value pairs (log scale)", type='log')
fig_cos.update_yaxes(title_text="Cosine similarity (avg)")
fig_cos.update_layout(title=f"Cosine similarity between recovered values and original values, d={d}", height = 400, width=1000)
fig_cos.show()

fig_dist = go.Figure()
fig_dist.add_trace(go.Scatter(x=n_values, y=euclidean_distance_values, mode='lines+markers', name='Recovered values'))
fig_dist.add_trace(go.Scatter(x=n_values, y=[baseline_euclidean_distance]*len(n_values), mode='lines', name='Random baseline', line=dict(dash='dash')))
fig_dist.update_xaxes(title_text="Number of key-value pairs (log scale)", type='log')
fig_dist.update_yaxes(title_text="Euclidean distance (avg)")
fig_dist.update_layout(title=f"Euclidean distance between recovered values and original values, d={d}", height = 400, width=1000)
fig_dist.show()

In [5]:
log_n_min = 2
log_n_max = 16

log_d_values = [6, 7, 8, 9, 10]

n_values = []
cos_sim_values_dict = {}
euclidean_distance_values_dict = {}

for log_d in log_d_values:
    d = 2 ** log_d
    
    current_cos_sim_values = []
    current_euclidean_distance_values = []

    for log_n in tqdm.tqdm(range(log_n_min, log_n_max, 1)):
        n = 2 ** log_n

        K = torch.randn((d, n))
        V = torch.randn((d, n))

        W = V @ torch.linalg.pinv(K)
        recovered_V = W @ K

        resid = V - recovered_V
        resid_distance_mean = torch.norm(resid, dim=0).mean().item()

        cos_sims = torch.nn.functional.cosine_similarity(V, recovered_V, dim=0)
        cos_sim_mean = cos_sims.mean().item()

        current_cos_sim_values.append(cos_sim_mean)
        current_euclidean_distance_values.append(resid_distance_mean)

    cos_sim_values_dict[d] = current_cos_sim_values
    euclidean_distance_values_dict[d] = current_euclidean_distance_values


100%|██████████| 14/14 [00:00<00:00, 56.14it/s]
100%|██████████| 14/14 [00:00<00:00, 21.34it/s]
100%|██████████| 14/14 [00:01<00:00, 11.03it/s]
100%|██████████| 14/14 [00:03<00:00,  4.51it/s]
100%|██████████| 14/14 [00:08<00:00,  1.67it/s]


In [6]:
n_values = [2 ** i for i in range(log_n_min, log_n_max, 1)]
fig_cos = go.Figure()
for d, cos_sims in cos_sim_values_dict.items():
    fig_cos.add_trace(go.Scatter(x=n_values, y=cos_sims, mode='lines+markers', name=f'd={d}'))
fig_cos.update_xaxes(title_text="Number of key-value pairs (log scale)", type='log')
fig_cos.update_yaxes(title_text="Cosine similarity (avg)")
fig_cos.update_layout(title=f"Cosine similarity between recovered values and original values for multiple d values", height=400, width=1000)
fig_cos.show()

fig_dist = go.Figure()
for d, distances in euclidean_distance_values_dict.items():
    fig_dist.add_trace(go.Scatter(x=n_values, y=distances, mode='lines+markers', name=f'd={d}'))
fig_dist.update_xaxes(title_text="Number of key-value pairs (log scale)", type='log')
fig_dist.update_yaxes(title_text="Euclidean distance (avg)")
fig_dist.update_layout(title=f"Euclidean distance between recovered values and original values for multiple d values", height=400, width=1000)
fig_dist.show()