# Adding KNN memory to transformers

This code demonstrates how to enhance transformer models using k-Nearest Neighbors (kNN) memory. The idea is to integrate an external memory to a transformer architecture, allowing it to look up relevant past information from previous steps, which can improve the model’s performance on tasks requiring long-term context. Let's go through the code in simpler terms.

One of the transformer layers near the top of the stack is a kNN-augmented attention layer, which
combines two forms of attention. Like all of the other layers, it uses standard dense self-attention on
the local context, which is the input subsequence for the current training step. Unlike the other layers,
however, it also does an approximate k-nearest-neighbor search into the external memory.

The same queries are used for both the local context, and for the external memory. The keys and
values also belong to the same distribution; after each training step, the (key, value) pairs in the local
context are appended to the end of the external memory. If the document is very long, old (key, value)
pairs will be dropped from the memory to make room for new ones. Thus, for each head, the external
memory keeps a cache of the prior M (key, value) pairs, where M is the memory size.

The kNN lookup will return a set of retrieved memories, which consist of the top-k (key, value) pairs
that kNN search returns for each query (i.e. each token) in the input subsequence. As with standard
dense attention, we first construct an attention matrix by computing the dot product of each query
against the retrieved keys, then apply softmax, and finally return a weighted sum of the retrieved
values. Unlike standard dense attention, the retrieved memories contain a different set of (key, value)
pairs for each query.

Attention over the local context is performed in the usual way. The results of kNN-attention and local
attention are then combined using a learned gate:

gate = sigmoid(bias_parameter)

combined_attention = (local_attention * gate) + (external_attention * (1-gate))

Where bias_parameter is differentiable scalar, one per head. This parameter learns how to balance local and external attention.

### Installing Libraries:

- faiss: A library used for fast nearest neighbor search. It allows you to perform efficient k-nearest neighbors (kNN) search on large datasets. Here, it's being installed and imported.
- torch: The popular PyTorch library used for building machine learning models.
- numpy: A library used for working with arrays and numerical computations.

In [1]:
!pip install faiss-gpu
import faiss
import torch
from torch import nn
import torch.nn.functional as F
import numpy as np

Collecting faiss-gpu
  Downloading faiss_gpu-1.7.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (1.4 kB)
Downloading faiss_gpu-1.7.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (85.5 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m85.5/85.5 MB[0m [31m8.9 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: faiss-gpu
Successfully installed faiss-gpu-1.7.2


In [2]:
!pip install einops
from einops import rearrange, repeat, pack, unpack, einsum



### Setting Up kNN Index

- dim: Defines the dimensionality of the vectors in the kNN index (64 in this case).
- faiss.IndexFlatL2(dim): Creates a simple index for kNN searches using the L2 distance metric (Euclidean distance). This index will be used to store and retrieve vector embeddings (e.g., token representations).

In [3]:
dim = 64
index = faiss.IndexFlatL2(dim)

### Creating and Adding Data to kNN Index

- vector_data: A random dataset with 10,000 vectors of size dim=64. These vectors will be used for the kNN search.
- index.add(): This method adds the vectors to the kNN index for future lookups.

In [4]:
vector_data = np.random.random((10000, dim)).astype('float32')

In [5]:
index.add(vector_data)

In [6]:
index.ntotal

10000

In [7]:
index.remove_ids(np.arange(10))

10

In [8]:
index.ntotal

9990

### Querying the kNN Index

- query_data: A set of 10 random vectors to query against the kNN index.
- top_k = 2: We want to retrieve the top 2 closest neighbors for each query.
- index.search(): This method performs the kNN search. It returns:
    - distance: The Euclidean distance between the query vectors and the retrieved neighbors.
    - ids: The indices of the closest neighbors found in the index.

In [9]:
query_data = np.random.random((10, dim)).astype('float32')
top_k = 2
distance, ids = index.search(query_data, top_k)

In [10]:
distance

array([[5.620409 , 5.8326654],
       [5.0630836, 5.379696 ],
       [5.636686 , 5.7546487],
       [6.5205274, 6.8329067],
       [5.2549725, 5.375698 ],
       [6.1668134, 6.6554174],
       [5.416364 , 6.0980406],
       [5.612576 , 5.872774 ],
       [5.0625024, 5.1286287],
       [5.090478 , 6.277242 ]], dtype=float32)

In [11]:
ids

array([[7123, 2909],
       [5498, 9134],
       [2782, 6815],
       [1810, 9503],
       [4520, 1717],
       [ 905,  715],
       [1120, 3618],
       [8247, 5126],
       [4457, 5163],
       [ 665, 2275]])

The search_and_reconstruct method in Faiss is a useful function in the FAISS library that not only searches for the nearest neighbors of a query vector but also retrieves the original vectors stored in the index for those neighbors. This is particularly valuable when you need both the indices of the nearest neighbors and their corresponding vectors.

- Search:
Similar to the search method, search_and_reconstruct finds the top k nearest neighbors for each query vector based on the chosen similarity metric (e.g., L2 distance or cosine similarity).

- Reconstruct:
Once the nearest neighbors are identified, it reconstructs the original vectors (from the dataset) that correspond to these neighbors. These reconstructed vectors are often useful for applications where you need the raw data associated with the nearest neighbors rather than just their indices.

In [12]:
index.search_and_reconstruct(query_data, top_k)

(array([[5.620409 , 5.8326654],
        [5.0630836, 5.379696 ],
        [5.636686 , 5.7546487],
        [6.5205274, 6.8329067],
        [5.2549725, 5.375698 ],
        [6.1668134, 6.6554174],
        [5.416364 , 6.0980406],
        [5.612576 , 5.872774 ],
        [5.0625024, 5.1286287],
        [5.090478 , 6.277242 ]], dtype=float32),
 array([[7123, 2909],
        [5498, 9134],
        [2782, 6815],
        [1810, 9503],
        [4520, 1717],
        [ 905,  715],
        [1120, 3618],
        [8247, 5126],
        [4457, 5163],
        [ 665, 2275]]),
 array([[[0.8076478 , 0.67531157, 0.22166291, ..., 0.6731062 ,
          0.6935085 , 0.03726348],
         [0.9184147 , 0.6368605 , 0.22319569, ..., 0.76023334,
          0.08740996, 0.8329671 ]],
 
        [[0.4025765 , 0.51826936, 0.22416422, ..., 0.25313255,
          0.5544181 , 0.50589734],
         [0.21780738, 0.34675562, 0.03360439, ..., 0.17692722,
          0.41294774, 0.7699829 ]],
 
        [[0.00225048, 0.52846795, 0.9816772

### Memory Management with np.memmap

- np.memmap(): Creates a memory-mapped array, allowing you to efficiently store and access large arrays without loading everything into RAM at once. This is important for managing the "external memory" used by the transformer.
- max_memories = 10000: Defines the maximum number of memory slots (10,000 key-value pairs) that can be stored in this memory.

In [13]:
db_filepath = "./memory.memmap"
max_memories = 10000
shape = (max_memories, 2, dim)
db = np.memmap(db_filepath, mode = 'w+', dtype = np.float32, shape = shape)

In [14]:
db

memmap([[[0., 0., 0., ..., 0., 0., 0.],
         [0., 0., 0., ..., 0., 0., 0.]],

        [[0., 0., 0., ..., 0., 0., 0.],
         [0., 0., 0., ..., 0., 0., 0.]],

        [[0., 0., 0., ..., 0., 0., 0.],
         [0., 0., 0., ..., 0., 0., 0.]],

        ...,

        [[0., 0., 0., ..., 0., 0., 0.],
         [0., 0., 0., ..., 0., 0., 0.]],

        [[0., 0., 0., ..., 0., 0., 0.],
         [0., 0., 0., ..., 0., 0., 0.]],

        [[0., 0., 0., ..., 0., 0., 0.],
         [0., 0., 0., ..., 0., 0., 0.]]], dtype=float32)

### Adding Key-Value Pairs to Memory

Adds random key-value pairs to the memory (db) in the form of 3D arrays with dimensions (batch_size, 2, dim)

In [15]:
db[1:2] = np.random.rand(1,2,dim)

In [16]:
db

memmap([[[0.        , 0.        , 0.        , ..., 0.        ,
          0.        , 0.        ],
         [0.        , 0.        , 0.        , ..., 0.        ,
          0.        , 0.        ]],

        [[0.48390037, 0.33972484, 0.4520095 , ..., 0.8959501 ,
          0.1619047 , 0.4994492 ],
         [0.9100345 , 0.23197599, 0.04336949, ..., 0.3611623 ,
          0.19715595, 0.91898817]],

        [[0.        , 0.        , 0.        , ..., 0.        ,
          0.        , 0.        ],
         [0.        , 0.        , 0.        , ..., 0.        ,
          0.        , 0.        ]],

        ...,

        [[0.        , 0.        , 0.        , ..., 0.        ,
          0.        , 0.        ],
         [0.        , 0.        , 0.        , ..., 0.        ,
          0.        , 0.        ]],

        [[0.        , 0.        , 0.        , ..., 0.        ,
          0.        , 0.        ],
         [0.        , 0.        , 0.        , ..., 0.        ,
          0.        , 0.        ]

In [17]:
db[0] = torch.randn(1,2,dim)

In [18]:
type(db[0])

numpy.memmap

In [19]:
db[1].shape

(2, 64)

Purpose and Use

- Faiss Index:

Handles nearest-neighbor searches efficiently.

- Memory-Mapped Database:

Stores vectors persistently while minimizing RAM usage.

- Workflow:

Add vectors to both the Faiss index and the memory-mapped database.
Perform nearest-neighbor searches using the Faiss index.
Retrieve associated metadata or additional information from the database.

In [20]:
dim = 10
max_memories = 10000
batch_size = 16
top_k = 3

db_filepath = "./memory.memmap"
shape = (max_memories, 2, dim) # each memory slot will store two vectors (e.g., key and value) of size dim.


# create index
index = faiss.IndexFlatL2(dim) # Flat indices store all vectors in memory and perform brute-force search, which is simple and fast for small datasets.

# create database
db = np.memmap(db_filepath, mode = 'w+', dtype = np.float32, shape = shape)

In [None]:
# KNN DATABASE CLASS

# add to index
# add to database
# query the index
# retrieve from the database
# remove/clear from index and database

### Adding


In [21]:
kv = np.random.rand(batch_size, 512, 2, dim).astype('float32') # b t 2 (hd)
kv = kv.reshape(-1, 2, dim)
kv.shape

(8192, 2, 10)

In [22]:
k = kv[:,0,:]
k.shape

(8192, 10)

Why Use np.ascontiguousarray?
- np.ascontiguousarray ensures that the array k is stored in a contiguous block of memory.
- Faiss requires input arrays to be contiguous in memory because it uses highly optimized, low-level C++ routines for speed.
- If k is not contiguous (e.g., due to slicing, transposing, or other operations in NumPy), this function creates a contiguous version of the array without changing its data.

In [23]:
index.add(np.ascontiguousarray(k))

In [24]:
db_offset = 0
kv_len = kv.shape[0]
ids = (np.arange(kv_len) + db_offset)
db_offset += kv_len
db[ids] = kv

### Query and Retrieve

In [25]:
query = np.random.rand(batch_size, 512, dim).astype('float32') # b t (hd)
query = query.reshape(-1, dim)
query.shape

(8192, 10)

In [26]:
distance, ids = index.search(query, top_k)
ids.shape

(8192, 3)

In [27]:
ids

array([[2842, 7248, 3524],
       [7323, 4428,  819],
       [7501, 3558, 1737],
       ...,
       [6382, 6064, 1774],
       [2072, 8127, 8118],
       [ 738, 6222, 4904]])

In [28]:
retrieved_kvs = db[ids]
retrieved_kvs.shape

(8192, 3, 2, 10)

In [29]:
retrieved_kvs = retrieved_kvs.reshape(16, 512, 3, 2, 10)
retrieved_kvs.shape

(16, 512, 3, 2, 10)

### Remove / Clear / Database management

In [None]:
# 5120
# 10 segments of 512 tokens

### KNN Class

- This class is designed to manage the kNN memory for the transformer.
- It initializes with dim (dimension of each vector) and max_memories (maximum number of key-value pairs).
- A memory-mapped file db is created to store the key-value pairs, and faiss.IndexFlatL2 is used for the kNN search index.    

Adding New Data and Updating Memory
- add() method flattens the input data, adds it to both the memory (db) and the kNN index (index).
- Only the "keys" (i.e., the first part of each key-value pair) are added to the kNN index, as they are used to perform the search.

Searching and Retrieving from Memory
- search() method flattens the query vectors, performs a kNN search, and retrieves the most relevant key-value pairs from memory.
- It reshapes the results back to their original dimensions for easy use in the transformer model.

Clearing Memory and Index
- The clear() method resets the kNN index and clears the memory, allowing for a fresh start.


In [30]:

class KNN():
    def __init__(
        self,
        dim,
        max_memories,
        ):
        self.dim = dim
        self.max_memories = max_memories
        self.shape = (max_memories, 2, dim)
        self.db_offset = 0
        self.db_filepath = "./memory.memmap"
        self.db = np.memmap(self.db_filepath, mode = 'w+', dtype = np.float32, shape = self.shape)
        self.index = faiss.IndexFlatL2(dim)


    def add_to_db(self, new_data):
        new_data_len = new_data.shape[0]
        ids = (np.arange(new_data_len) + self.db_offset)
        self.db[ids] = new_data
        self.db_offset += new_data_len
        # Write to file
        self.db.flush()


    def search_and_retrieve(self, query_vecs, topk):
        query_vecs = query_vecs
        distances, indices = self.index.search(query_vecs, topk)
        kvs = self.db[indices]
        return kvs

    def add(self, new_data):
        # Input is b n 2 d, flatten to (b n) 2 d
        new_data = new_data.flatten(0,1)
        # Add to db
        self.add_to_db(new_data)
        # Only keys are used in knn index
        keys, vals = new_data.unbind(dim=-2)
        # Add (b n) d tensors to index
        keys = np.ascontiguousarray(keys.numpy())
        # Add to index
        self.index.add(keys)

    def search(self, query_vecs, topk):
        # can override topk
        query_batch_size, query_seq_len = query_vecs.shape[0], query_vecs.shape[1]
        # Input is b n d, flatten to (b n) d
        query_vecs = query_vecs.flatten(0,1)
        kvs = self.search_and_retrieve(np.ascontiguousarray(query_vecs.numpy()), topk)
        # kvs are (b n) k 2 d, unflatten to b n k 2 d
        kvs = torch.tensor(kvs)
        kvs = torch.unflatten(kvs, 0, (query_batch_size, query_seq_len))
        return kvs

    def clear(self):
        self.index.reset()
        self.db[:] = 0
        self.db_offset = 0


In [31]:
batch_size = 16
dim = 10
segments = 10
seq_len = 512
max_memories = batch_size * seq_len * segments

knn = KNN(dim=dim, max_memories=max_memories)

In [32]:
kv = torch.randn(batch_size, seq_len, 2, dim) # b t 2 (hd)
query = torch.randn(batch_size, seq_len, dim) # b t (hd)

In [33]:
knn.add(kv)

In [34]:
knn.index.ntotal

8192

In [35]:
knn.db[8192]

memmap([[0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]], dtype=float32)

In [37]:
retrieved_kvs = knn.search(query, 3)

In [38]:
retrieved_kvs.shape

torch.Size([16, 512, 3, 2, 10])

The provided code defines a Multi-Head Attention (MHAttention) mechanism and attempts to incorporate k-Nearest Neighbor (kNN) memory into the attention process. This setup aims to extend the traditional attention mechanism by allowing the model to search and retrieve past memory entries during the forward pass, improving performance in tasks that benefit from long-term context retention.

### KNN Attention Class

### Breakdown of the Key Components:

1. **Multi-Head Attention (MHAttention):**
   - **Embedding Transformation:** The input embeddings (`x`) are transformed into queries, keys, and values through linear layers (`query_matrix`, `key_matrix`, `value_matrix`). Each of these is then reshaped into multiple heads for parallel computation.
   - **Attention Mechanism:** The queries and keys are multiplied (`qk = q @ k`) to compute attention scores. These scores are then masked and passed through a softmax function to normalize the attention weights.
   - **Weighted Sum:** The attention scores are multiplied with the values (`qkv = qk @ v`), which gives the contextually weighted values for the attention mechanism.

2. **Incorporating kNN Memory:**
   - **Memory Representation:** The idea is to maintain a set of key-value pairs (from past computations) that can be accessed based on the current query (`q`). In the forward pass, a kNN search would retrieve the top `k` nearest neighbors of the query, providing additional context.
   - **Reshaping for kNN Search:** The queries and memory keys/values are appropriately reshaped and permuted to allow for efficient computation of the attention mechanism with memory.
   - **Memory Query:** The retrieved memory (from kNN) is integrated into the attention calculation, essentially combining both attention weights (`qkv`) and memory-weighted results (`mem_qkv`). This enhances the model's ability to remember and leverage previous context dynamically.

3. **Further Details:**
   - **Masking and Softmax:** After computing the initial attention scores (`qk`), a triangular mask is applied to prevent attention to future positions (common in autoregressive tasks). The softmax function then ensures the weights are normalized.
   - **Attention Gates:** A parameterized gate (`gate`) is used to combine the standard attention and memory-based attention results, allowing the model to balance between the two.
  
### Memory Mechanism:
The memory mechanism allows the model to "remember" useful past information and use it during the current attention computation. This is typically useful in situations where long-term dependencies are important (e.g., language models with long contexts). The memory is queried by `mem_qk = einsum(queries, mem_k, 'b h t d, b h t k d -> b h t k')`, which computes the attention between the current query and the stored memory keys.

### Key Challenges and TODOs:
- **Relative Position Encoding:** The code mentions an intended feature to integrate relative position encoding into the attention scores (`qk = relative_position_values + qk`), which would help in capturing the position of tokens relative to each other in the sequence.
- **KNN Integration:** The actual kNN memory retrieval (`mem_qkv`) is marked as a TODO, implying that the memory query and retrieval from a kNN model should be implemented and connected to the existing attention mechanism.

In [39]:
class MHAttention(nn.Module):
    def __init__(
        self,
        embedding_dimension,
        heads = 8,
        head_dimension = 32,
    ):
        super().__init__()
        self.heads = heads
        self.scale = head_dimension ** -0.5

        self.query_matrix = nn.Linear(embedding_dimension, heads * head_dimension)
        self.key_matrix = nn.Linear(embedding_dimension, heads * head_dimension)
        self.value_matrix = nn.Linear(embedding_dimension, heads * head_dimension)
        self.output_matrix = nn.Linear(heads * head_dimension, embedding_dimension)


    def forward(
        self,
        x, # batch_size, sequence_length, embedding_dimension
    ):
        batch_size, sequence_length = x.shape[:2]
        queries = self.query_matrix(x)
        keys = self.key_matrix(x)
        values = self.value_matrix(x)

        # Separate  into heads for multi-head attention
        k = keys.reshape(batch_size, sequence_length, self.heads, head_dimension)
        q = queries.reshape(batch_size, sequence_length, self.heads, head_dimension)
        v = values.reshape(batch_size, sequence_length, self.heads, head_dimension)

        # Swap head and sequence length dimensions
        q = q.transpose(1,2)
        k = k.transpose(1,2)
        v = v.transpose(1,2)
        # Rearrange keys to prepare for matrix multiplication q@k
        k = k.transpose(2,3)

        # QK
        qk = q@k

        qk = qk * self.scale

        ############
        # TODO
        # qk = relative_position_values + qk
        ############

        i, j = qk.shape[-2:]
        mask = torch.ones((i,j), dtype = torch.bool).triu(j-i+1)
        qk = qk.masked_fill(mask, float('-inf'))

        qk = F.softmax(qk, dim=-1)

        qkv = qk@v
        qkv = qkv.transpose(1,2)
        qkv = qkv.reshape(batch_size, sequence_length, self.heads * head_dimension)

        ############
        # TODO
        # KNN Memory
        ############

        out = self.output_matrix(qkv)

        return out

In [None]:
# make sure q is (b t (hd)) for searching in knn (reshape and transpose)
# knn returns (b t k 2 (hd))
# split to key and value each size (b n k (hd)) (unbind)
# convert k and v to (b t k h d) (reshape)
# change q to (b h t d) (transpose)
# change k to (b h t d k) (multiple transpose)
# change v to (b h t k d) (multiple transpose)
# get qk of  (b h t d) @ (b h t d k) -> (b h t k)
# get qkv of (b h t k) @ (b h t k d) -> (b h t d)
# .....


In [40]:
number_heads = 8
head_dimension = 10
q = torch.randn(batch_size, seq_len, number_heads * head_dimension)
k = torch.randn(batch_size, seq_len, number_heads * head_dimension)
k.shape

torch.Size([16, 512, 80])

In [41]:
# Manually

# Separate queries matrix into heads for multi-head attention
q = q.reshape(batch_size, seq_len, number_heads, head_dimension)
# Rearrange indices to prepare for matrix multiplication q@k
q = q.transpose(1,2)
# Separate keys matrix into heads for multi-head attention
k = k.reshape(batch_size, seq_len, number_heads, head_dimension)
# Rearrange indices to prepare for matrix multiplication q@k
k = k.permute(0,2,3,1)

manual_qk = q@k

print ("queries:", q.shape)
print ("keys:", k.shape)
print ("qk:", manual_qk.shape)

queries: torch.Size([16, 8, 512, 10])
keys: torch.Size([16, 8, 10, 512])
qk: torch.Size([16, 8, 512, 512])


### Einops

In [42]:
!pip install einops
from einops import rearrange, repeat, pack, unpack, einsum



In [43]:
q = torch.randn(batch_size, seq_len, number_heads * head_dimension)
k = torch.randn(batch_size, seq_len, number_heads * head_dimension)
k.shape

torch.Size([16, 512, 80])

In [44]:
# With einsum
q =  rearrange(q, 'b t (h d) -> b h t d', h = number_heads)
k =  rearrange(k, 'b t (h d) -> b h t d', h = number_heads)
qk = einsum(q, k, 'b h i d, b h j d -> b h i j')

print ("queries:", q.shape)
print ("keys:", k.shape)
print ("qk:", qk.shape)

queries: torch.Size([16, 8, 512, 10])
keys: torch.Size([16, 8, 512, 10])
qk: torch.Size([16, 8, 512, 512])


### Key Techniques Used:
- **`einsum`:** Efficiently computes tensor contractions (like matrix multiplications) and is used here to compute the attention scores between queries and keys, as well as the weighted sums.
- **`rearrange` from Einops:** This function reshapes tensors to facilitate operations like multi-head attention where dimensions need to be adjusted for head-based parallelism.

In [None]:
# (3,4) (3,4) -> x@y.T -> (3,3)
# multiply along b (4) dimension
einsum(x, y, 'a b, c b -> a c')

In [None]:
# (3,4) (3,4) -> x.T@y -> (4,4)
# multiply along a (3) dimension
einsum(x, y, 'a b, a d -> d b')

In [None]:
# The repeated letter in different inputs tells einsum to multiply along that dimension
# The differing letters in different inputs tells einsum to give those dimensions as the shape of the output
# Leaving out letters means that axis will be summed
# The output dimensions can be in any order you want

In [None]:
x = torch.randn(24, 10, 15)
rearrange(x, '(a b) c (d e) -> (e c) a b d', a=6, e=5).shape

In [None]:

q =  rearrange(q, 'b t (h d) -> b h t d', h = number_heads)
k =  rearrange(k, 'b t (h d) -> b h t d', h = number_heads)
qk = einsum(q, k, 'b h i d, b h j d -> b h i j')


In [47]:
class KNNAttention(nn.Module):
    def __init__(
        self,
        embedding_dimension,
        heads = 8,
        head_dimension = 32,
    ):
        super().__init__()
        self.heads = heads
        self.scale = head_dimension ** -0.5

        self.query_matrix = nn.Linear(embedding_dimension, heads * head_dimension)
        self.key_matrix = nn.Linear(embedding_dimension, heads * head_dimension)
        self.value_matrix = nn.Linear(embedding_dimension, heads * head_dimension)
        self.output_matrix = nn.Linear(heads * head_dimension, embedding_dimension)


    def forward(
        self,
        x, # batch_size, sequence_length, embedding_dimension
    ):
        batch_size, sequence_length = x.shape[:2]
        queries = self.query_matrix(x)
        keys = self.key_matrix(x)
        values = self.value_matrix(x)

        queries = rearrange(queries, 'b t (h d) -> b h t d', h = self.heads)
        keys    = rearrange(keys, 'b t (h d) -> b h t d', h = self.heads)
        qk      = einsum(queries, keys, 'b h i d, b h j d -> b h i j')

        qk = qk * self.scale

        ############
        # TODO
        # qk = relative_position_values + qk
        ############

        i, j = qk.shape[-2:]
        mask = torch.ones((i,j), dtype = torch.bool).triu(j-i+1)
        qk = qk.masked_fill(mask, float('-inf'))

        qk = F.softmax(qk, dim=-1)

        values = rearrange(values, 'b t (h d) -> b h t d', h=self.heads)
        qkv = qk@values
        qkv = rearrange(qkv, 'b h t d -> b t (h d)')

        ############
        # TODO
        # KNN Memory
        ############

        out = self.output_matrix(qkv)

        return out


In [48]:
queries = torch.randn(batch_size, number_heads, seq_len, head_dimension)
mem_kv = torch.randn(batch_size, seq_len, 3, 2, number_heads*head_dimension)
scale = 1

In [49]:
queries = rearrange(queries, 'b h t d -> b t (h d)')
queries.shape

torch.Size([16, 512, 80])

In [50]:
# mem_kv = knn.search(queries, topk)
mem_k, mem_v = mem_kv.unbind(dim = -2)
mem_k = rearrange(mem_k, 'b t k (h d) -> b h t k d', h=number_heads)
mem_v = rearrange(mem_v, 'b t k (h d) -> b h t k d', h=number_heads)
mem_v.shape

torch.Size([16, 8, 512, 3, 10])

In [51]:
queries = rearrange(queries, 'b t (h d) -> b h t d', h=number_heads)
mem_qk = einsum(queries, mem_k, 'b h t d, b h t k d -> b h t k') # d dimension
mem_qk.shape

torch.Size([16, 8, 512, 3])

In [52]:
mem_qk = mem_qk * scale

In [53]:
mem_qk = F.softmax(mem_qk, dim=-1)
mem_qkv = einsum(mem_qk, mem_v, 'b h t k, b h t k d -> b h t d') # k dimension
mem_qkv.shape

torch.Size([16, 8, 512, 10])

In [None]:
# gate between 0 and 1
gate = nn.Parameter(torch.randn(number_heads, 1, 1))
combined_qkv = (mem_qkv * gate) + (qkv * (1 - gate))
out = output_matrix(combined_qkv)

The `KNNAttention` class in your code incorporates a **k-Nearest Neighbor (kNN) memory** component into the traditional transformer-based **multi-head attention** mechanism. This design adds a dynamic memory retrieval step to the attention mechanism, allowing the model to access stored past context when processing the current input.

Here’s an overview of the design and operations involved:

### Key Components:
1. **Local Attention:**
   - The input query (`q`), key (`k`), and value (`v`) matrices are transformed using linear layers (`query_matrix`, `key_matrix`, `value_matrix`), then reshaped into multi-head attention format using `rearrange`.
   - The attention scores (`qk`) are computed using a matrix multiplication between queries and keys, followed by softmax normalization and a weighted sum of values (`qkv`).

2. **KNN Attention:**
   - After computing the local attention (`qkv`), kNN is used to retrieve `topk_retrieved_memories` for each query from a memory bank (`mem_kv`).
   - The kNN search step involves converting the queries into a search-friendly format and retrieving the top-k nearest memory keys and values.
   - These memory keys and values are used to compute attention scores (`mem_qk`), similar to the local attention calculation but with memory-specific keys and values.

3. **Combining Local and Memory-based Attention:**
   - A **gate bias** parameter (`self.gate_bias`) is used to combine the local attention and memory-based attention in a learnable way.
   - The final output is a weighted combination of the local attention (`qkv`) and memory-based attention (`mem_qkv`), passed through the output layer (`output_matrix`).

### Operations in KNN Memory:
- **Add:** New memory entries can be added when performing a kNN search or memory update.
- **Store:** Memory is stored in a form that allows fast search (e.g., vectors stored in an indexed form).
- **Search:** A kNN search finds the most relevant stored memories (using a search mechanism like approximate nearest neighbors).
- **Remove:** In certain memory systems, removing outdated or irrelevant memories is necessary.
- **Reconstruct:** If using lossy memory (e.g., compression), the original state could be reconstructed from the memory.

### Considerations for Index/Chip Paradigm:
1. **Required Operations:**
   - **Add** and **Search** operations are crucial. The kNN mechanism allows for efficient retrieval of memories during the forward pass. Depending on the task, **Remove** and **Reconstruct** operations could be required to manage memory efficiently.

2. **Frequency of Operations:**
   - **Search** is likely the most frequent operation, as it's called for every forward pass. The other operations (add, remove, reconstruct) are less frequent and typically handled at model updates or during memory management.

3. **Accuracy vs. Speed vs. Memory Footprint:**
   - **Accuracy:** The accuracy of the attention mechanism is enhanced by incorporating relevant past memories, especially for tasks requiring long-term dependencies.
   - **Speed:** The use of kNN can slow down the model, depending on the search method used. Approximate nearest neighbors (ANN) methods can help speed this up, but there's a trade-off between speed and accuracy.
   - **Memory Footprint:** Storing a large memory bank increases memory requirements. The number of vectors stored (i.e., memory size) and the number of neighbors (`topk_retrieved_memories`) directly affect memory usage.

4. **Size of Index/Query:**
   - The **index size** (number of vectors stored) is dependent on the model’s memory capacity. This could range from hundreds of vectors to millions, depending on the scale.
   - The **query size** is determined by the batch size and sequence length, with each query corresponding to a query vector that interacts with the memory.

5. **GPU vs. CPU:**
   - **GPU** is preferred for handling large-scale kNN memory retrieval and attention computation, especially for parallelized tasks.

6. **Retraining the Index:**
   - If the memory is dynamic (e.g., new data is continually added), **retraining the index** may be necessary to keep it up to date. This is especially true if the memory involves fixed vectors (e.g., word embeddings or historical states) that may need to be adjusted as new data arrives.

In [None]:
class KNNAttention(nn.Module):
    def __init__(
        self,
        embedding_dimension,
        heads = 8,
        head_dimension = 32,
        ######
        topk_retrieved_memories = 3,
    ):
        super().__init__()
        self.heads = heads
        self.scale = head_dimension ** -0.5

        self.query_matrix = nn.Linear(embedding_dimension, heads * head_dimension)
        self.key_matrix = nn.Linear(embedding_dimension, heads * head_dimension)
        self.value_matrix = nn.Linear(embedding_dimension, heads * head_dimension)
        self.output_matrix = nn.Linear(heads * head_dimension, embedding_dimension)

        #######
        self.gate_bias = nn.Parameter(torch.randn(self.heads, 1, 1))
        self.topk_retrieved_memories = topk_retrieved_memories

    def forward(
        self,
        x, # batch_size, sequence_length, embedding_dimension
        #######
        knn,
    ):
        batch_size, sequence_length = x.shape[:2]
        queries = self.query_matrix(x)
        keys = self.key_matrix(x)
        values = self.value_matrix(x)

        ### LOCAL ATTENTION

        queries = rearrange(queries, 'b t (h d) -> b h t d', h = self.heads)
        keys    = rearrange(keys, 'b t (h d) -> b h t d', h = self.heads)
        qk      = einsum(queries, keys, 'b h i d, b h j d -> b h i j')

        qk = qk * self.scale

        ############
        # TODO
        # qk = relative_position_values + qk
        ############

        i, j = qk.shape[-2:]
        mask = torch.ones((i,j), dtype = torch.bool).triu(j-i+1)
        qk = qk.masked_fill(mask, float('-inf'))

        qk = F.softmax(qk, dim=-1)

        values = rearrange(values, 'b t (h d) -> b h t d', h=self.heads)
        qkv = qk@values
        qkv = rearrange(qkv, 'b h t d -> b t (h d)')

        ### KNN ATTENTION

        # Convert queries to search form
        queries = rearrange(queries, 'b h t d -> b t (h d)')
        mem_kv = knn.search(queries, topk = self.topk_retrieved_memories) # returns b t k 2 d
        mem_k, mem_v = mem_kv.unbind(dim = -2)
        mem_k = rearrange(mem_k, 'b t k (h d) -> b h t k d', h=self.heads)
        mem_v = rearrange(mem_v, 'b t k (h d) -> b h t k d', h=self.heads)

        # Convert queries to attention form
        queries = rearrange(queries, 'b t (h d) -> b h t d', h = self.heads)
        mem_qk = einsum('b h t d, b h t k d -> b h t k', queries, mem_k)
        mem_qk = mem_qk * self.scale

        mem_qk = F.softmax(mem_qk, dim=-1)
        mem_qk = self.dropout(mem_qk)
        mem_qkv = einsum('b h t k, b h t k d -> b h t d', mem_qk, mem_v)

        # Combined attentions

        combined_qkv = mem_qkv * self.gate_bias + qkv * (1 - self.gate_bias)
        combined_qkv = rearrange(combined_qkv, 'b h t d -> b t (h d)')
        out = self.output_matrix(combined_qkv)

        return out

### Key Concepts:
- kNN (k-Nearest Neighbors): A method used to find the closest data points (neighbors) to a given query point based on some distance metric (e.g., Euclidean distance). In this code, faiss is used to perform this search efficiently on high-dimensional data.
- Memory Management: Using np.memmap allows the model to handle large amounts of memory (e.g., past key-value pairs) without loading everything into RAM at once.
- External Memory: This code augments the transformer model with an external memory where previous key-value pairs are stored. The transformer can then query this memory during training to retrieve relevant information, improving its performance on long-context tasks.

### Why is this important?
Adding kNN memory to transformers allows the model to:
- Remember past information and use it for current tasks, which is particularly useful for tasks with long sequences or documents.
- Improve performance on tasks like language modeling, translation, and other sequential tasks where earlier context is important.

### Optimization Considerations:
- **Approximate Nearest Neighbor (ANN):** To optimize search time, you can use ANN techniques like **HNSW (Hierarchical Navigable Small World)** graphs or **FAISS** (Facebook AI Similarity Search) for faster kNN searches.
- **Memory Management:** Memory management techniques like **LRU (Least Recently Used)**, or more advanced methods such as **episodic memory**, can be employed to control which memories are kept or discarded based on relevance.


### Summary:
The code implements a **knn-based attention mechanism** within the context of transformers. This approach attempts to blend traditional multi-head attention with a kNN-based external memory, allowing the transformer to access long-term context more effectively. This hybrid architecture is particularly useful in tasks such as **memory-augmented networks**, **neural architectures for continual learning**, or **long-range dependency modeling**, where remembering past states or contexts is crucial.