# Goal: 
Implement the Chunked Cross-Attention proposed by `Improving language models by retrieving from trillions of tokens, Sebastian Borgeaud et. al`
# Idea:

Given an input to the decoder in the encoder-decoder setting, chunk the input into `l` chunks of length `m`. Then for each chunk retrieve `top_k` nearest neighbors from a vector database. Then perform cross attention between each chunk and its k nearest neighbors.


# Chunked Cross-attention

In practice you'd shift decoder's input by `m-1` tokens, hence now the input prior to attention would begin from last token of the first chunk and `m-1` tokens from the second chunk and so forth. See image below:

<img src='https://miro.medium.com/v2/resize:fit:834/format:webp/1*PW1kX80dwX6mjbZZq4_QGQ.png'/>

In [8]:
"""
Below is the port of the Jax implementation of chunked cross-attention 
present in the appendix of the original paper to pytorch.

"""
import torch
import torch.nn as nn

In [6]:
n = 128 # Sequence length
m = 16 # Chunk length
r = 32 # Retrieval length
k = 4 # Number of neighbours
d = 16 # Embedding size
l = n // m # Number of chunks

In [14]:
# Parameters
Q = nn.Parameter(torch.zeros(d, d))
K = nn.Parameter(torch.zeros(d, d))
V = nn.Parameter(torch.zeros(d, d))

In [15]:
def relative_positional_encodings(attending_length, attended_length):
# Classical relative positional encodings
    pass

In [16]:
def cross_attention(chunk, neighbour):
    m, d = chunk.shape
    r, d = neighbour.shape
    queries = chunk @ Q
    keys = neighbour @ K
    logits = queries @ keys.T
    values = neighbour @ V
    return logits, values

In [60]:
def multi_neighbour_cross_attention(chunk, neighbours):
    m, d = chunk.shape
    k, r, d = neighbours.shape
    
    attended_chunk = [cross_attention(chunk, neighbour) for neighbour in neighbours]
    """
    extract logits, and values from each tuple of logits, 
    values with resulting respective shape: (k, m, r)
    """
    logits = torch.stack([attended_item[0] for attended_item in attended_chunk]) 
    values = torch.stack([attended_item[1] for attended_item in attended_chunk])
    assert logits.shape == (k, m, r)
    assert values.shape == (k, r, d)
    logits = logits.reshape((m, r * k))
    values = values.reshape((r * k, d))
    return nn.functional.softmax(logits) @ values

In [75]:
def multi_chunk_cross_attention(observation, neighbours):
    # shift inputs so that you attend to last token of ith chunk and
    # m-1 tokens of (i+1)th chunk
    observation[m-1:] = 0
    attending_chunks = observation.reshape(l, m, d)
    chunked_output = torch.stack([attending_chunk for attending_chunk in attending_chunks]) 
    assert chunked_output.shape == (l, m, d)
    output = chunked_output.reshape(n, d)[:n]
    return output

In [74]:
observation = torch.zeros((n, d)) # Input
neighbours = torch.zeros((l, k, r, d))
h = multi_chunk_cross_attention(observation, neighbours)
assert h.shape == (n, d) # Output