# Adding XL Recurrence to Transformers

This part of the code demonstrates how Transformer-XL (an advanced version of Transformers) introduces recurrence to improve memory retention over long sequences.
The provided code defines two classes: XLAttention and KNN_XLAttention, which are enhancements to the standard Transformer architecture, adding mechanisms for improved long-range dependency handling.

### Libraries and Setup

- torch: This is the PyTorch library, used for building and training machine learning models, especially neural networks.
- nn: PyTorch's sub-library for defining neural network layers.
- F: Contains functional operations like activation functions (e.g., ReLU) and other tensor manipulations.
- einops: A helpful library for tensor manipulation like reshaping and performing complex operations (rearranging dimensions).

In [1]:
import torch
from torch import nn
import torch.nn.functional as F
import numpy as np
!pip install einops
from einops import rearrange, repeat, pack, unpack, einsum



### Recurrence with Transformer-XL

All technical details of Transformer-XL recurrence in [paper.](https://arxiv.org/pdf/1901.02860.pdf)

"We introduce the notion of recurrence into our deep self-attention network. In particular, instead of computing the hidden states from scratch for
each new segment, we reuse the hidden states obtained in previous segments. The reused hidden states serve as memory for the current segment,
which builds up a recurrent connection between the segments. As a result, modeling very longterm dependency becomes possible because information can be propagated through the recurrent connections. Meanwhile, passing information from the previous segment can also resolve
the problem of context fragmentation."

### The Core Concept of Transformer-XL
The key idea of Transformer-XL is adding recurrence to the Transformer model. In basic Transformers, each input sequence is processed independently. But in Transformer-XL:
- The hidden states (or memory) from previous segments of data are reused.
- This allows the model to remember long-term dependencies and make use of past information (for example, in long texts or time-series data).

This memory helps build a recurrent connection between segments of data. It is like a rolling window of context that the model uses for each new input.

In [None]:

# 1st segment: compute current kv projections [kv_1] and perform attention
# 2nd segment: concatenate old kv projections with current kv projections [kv1 + kv2] and perform attention
# 3rd segment: concatenate old kv projections with current kv projections [kv2 + kv3] and perform attention
# 4th segment: concatenate old kv projections with current kv projections [kv3 + kv4] and perform attention
# ...

# 1st segment:
seg_one_kv = [seg_1_layer_1_kv,
            seg_1_layer_2_kv,
            seg_1_layer_3_kv,
              ...]

# 2nd segment:
seg_two_kv = [concatenate(seg_1_layer_1_kv, seg_2_layer_1_kv),
            concatenate(seg_1_layer_2_kv, seg_2_layer_2_kv),
            concatenate(seg_1_layer_3_kv, seg_2_layer_3_kv),
                ...]

# 3rd segment:
seg_three_kv = [concatenate(seg_2_layer_1_kv, seg_3_layer_1_kv),
            concatenate(seg_2_layer_2_kv, seg_3_layer_2_kv),
            concatenate(seg_2_layer_3_kv, seg_3_layer_3_kv),
                ...]

### Preparing the Inputs

- batch_size: Number of sequences processed at once (16 sequences).
- seq_len: Length of each sequence (512 tokens).
- head_dimension: The size of each attention head (10).
- number_heads: The number of attention heads (8).
- embedding_dimension: Size of the input feature vectors (13).

In [2]:
batch_size = 16
seq_len = 512
head_dimension = 10
number_heads = 8
embedding_dimension = 13
scaling_factor = 1

This generates fake data (random values) to simulate input. It mimics a batch of sequences where each token has a feature vector of length 13.

In [3]:
# Create fake training batch
input_data = torch.randn((batch_size, seq_len, embedding_dimension))
input_data.shape

torch.Size([16, 512, 13])

### Projection Matrices

These linear layers transform the input embeddings into different spaces to form the queries, keys, and values used in the attention mechanism. In multi-head attention, the projections are repeated for each head (hence the multiplication by number_heads * head_dimension).

In [4]:
# Initialize projection matrices
query_matrix = nn.Linear(embedding_dimension, number_heads * head_dimension)
key_matrix = nn.Linear(embedding_dimension, number_heads * head_dimension)
value_matrix = nn.Linear(embedding_dimension, number_heads * head_dimension)
output_matrix = nn.Linear(number_heads * head_dimension, embedding_dimension)

### Creating Keys, Values, and Queries

These are components of the self-attention mechanism. Each input is transformed into these three vectors. They help the model decide how much focus each token should have on every other token in the sequence.

In [5]:
# Create KQV matrices with input data
queries = query_matrix(input_data)
keys = key_matrix(input_data)
values = value_matrix(input_data)
values.shape

torch.Size([16, 512, 80])

### Cached Memory (Recurrent Connection)

we create fake cached memory (xl_memory). This represents the past memory (previous segments' keys and values). The model will combine this with the current sequence's keys and values to maintain context.


In [6]:
# Create a fake cached XL recurrence
xl_memory = torch.randn(batch_size, seq_len,2,number_heads*head_dimension)
xl_memory.shape

torch.Size([16, 512, 2, 80])

### Merging Old and New Keys and Values

- xl_keys and xl_values: These come from the old segment (i.e., the memory from previous steps).
- torch.cat(): This operation concatenates the keys and values from the new and old sequences, so the model can use both old and new information.

In [7]:
xl_keys, xl_values = xl_memory.unbind(dim=-2) # the function unbind() is used to separate a tensor xl_memory along a specified dimension, in this case, dim=-2, which refers to the second-to-last dimension.
xl_keys.shape

torch.Size([16, 512, 80])

In [8]:
keys = torch.cat((xl_keys, keys), dim=-2)
values = torch.cat((xl_values, values), dim=-2)
values.shape

torch.Size([16, 1024, 80])

In [9]:
queries.shape

torch.Size([16, 512, 80])

### Attention Mechanism (QK)

- rearrange: This is used to reshape the tensors. It’s important for ensuring the queries, keys, and values are in the correct format for the attention mechanism.
- einsum: This is shorthand for a more complex matrix multiplication. It calculates the dot product between the queries and keys, which is used to determine the attention scores (how much focus each token gets on another).

In [10]:
queries = rearrange(queries, 'b t (h d) -> b h t d', h = number_heads)
keys    = rearrange(keys, 'b t (h d) -> b h t d', h = number_heads)
qk      = einsum(queries, keys, 'b h i d, b h j d -> b h i j')

print ("queries:", queries.shape)
print ("keys:", keys.shape)
print ("qk:", qk.shape)

queries: torch.Size([16, 8, 512, 10])
keys: torch.Size([16, 8, 1024, 10])
qk: torch.Size([16, 8, 512, 1024])


In [11]:
# Regular Self Attention QK (4,4)
#
# [    1., -1000., -1000., -1000.]
# [    1.,     1., -1000., -1000.]
# [    1.,     1.,     1., -1000.]
# [    1.,     1.,     1.,     1.]



# Transformer XL Self Attention QK (4,8)
#
# [    1.,     1.,     1.,     1.,     1., -1000., -1000., -1000.]
# [    1.,     1.,     1.,     1.,     1.,     1., -1000., -1000.]
# [    1.,     1.,     1.,     1.,     1.,     1.,     1., -1000.]
# [    1.,     1.,     1.,     1.,     1.,     1.,     1.,     1.]

In [12]:
i, j = qk.shape[-2:]
j

1024

### Masking (To Prevent Attention to Future Tokens)

In tasks like language modeling, we don't want the model to attend to future tokens. This mask ensures that attention is only given to the current and previous tokens (not future ones).

In [13]:
# Create mask
mask = torch.ones((i,j), dtype = torch.bool).triu(j-i+1)
mask.shape

torch.Size([512, 1024])

In [None]:
qk = qk.masked_fill(mask, float('-inf'))

### Applying Softmax and Attention

- Softmax: This step turns the attention scores into probabilities. The model then uses these probabilities to decide how much influence each token should have on the others.
- Matrix Multiplication (@): After computing the attention weights, we multiply them by the values to get the attended values.

In [14]:
# Apply softmax
qk = F.softmax(qk, dim=-1)

In [15]:
qk[0][0][0].sum()

tensor(1.0000, grad_fn=<SumBackward0>)

In [16]:
# Separate values tensor into heads for multi-head attention and move dimensions for @ with qk
values = rearrange(values, 'b t (h d) -> b h t d', h=number_heads)
print ("qk:", qk.shape)
print ("values:", values.shape)

qk: torch.Size([16, 8, 512, 1024])
values: torch.Size([16, 8, 1024, 10])


In [17]:
qkv = qk@values
qkv.shape

torch.Size([16, 8, 512, 10])

- Rearranging the output: After applying attention, we use rearrange to bring the output back into the original shape.
- output_matrix: The result of the attention mechanism is passed through a linear layer to bring it back to the original embedding size.

In [18]:
# Reassemble all heads
qkv = rearrange(qkv, 'b h t d -> b t (h d)')
qkv.shape

torch.Size([16, 512, 80])

In [19]:
output_matrix

Linear(in_features=80, out_features=13, bias=True)

In [20]:
out = output_matrix(qkv)
out.shape

torch.Size([16, 512, 13])

### **XLAttention Class**

This class implements **Transformer-XL**-like behavior, with a focus on introducing **recurrence** via the `xl_memory` argument, which allows the model to use memory from previous sequences (past attention states) to improve long-range dependencies.

- **Query, Key, Value Matrices**: The class first computes the **queries**, **keys**, and **values** for self-attention. These are generated using linear layers that project the input embedding into a space suited for multi-head attention.
  
- **XL Memory (Recurrent Memory)**: If `xl_memory` is provided, it prepends the **old memory** (previous keys and values) to the new keys and values, effectively allowing the model to "remember" past sequences. This memory is passed as part of the attention mechanism to maintain long-term context.

- **Self-Attention**: The attention scores are computed between the queries and keys, followed by a **masking** step to ensure the model doesn’t attend to future tokens. The attention scores are then passed through a **softmax** function, and the weighted sum of values is computed.

- **Output**: The resulting attended values are passed through a final linear layer to output the predictions. Also, the new key-value pairs are returned for the next recurrence.

In [21]:
class XLAttention(nn.Module):
    def __init__(
        self,
        embedding_dimension,
        heads = 8,
        head_dimension = 32,
    ):
        super().__init__()
        self.heads = heads
        self.scale = head_dimension ** -0.5

        self.query_matrix = nn.Linear(embedding_dimension, heads * head_dimension)
        self.key_matrix = nn.Linear(embedding_dimension, heads * head_dimension)
        self.value_matrix = nn.Linear(embedding_dimension, heads * head_dimension)
        self.output_matrix = nn.Linear(heads * head_dimension, embedding_dimension)


    def forward(
        self,
        x, # batch_size, sequence_length, embedding_dimension
        xl_memory = None
    ):
        batch_size, sequence_length = x.shape[:2]
        queries = self.query_matrix(x)
        keys = self.key_matrix(x)
        values = self.value_matrix(x)

        if xl_memory is not None:
            k_xl, v_xl = xl_memory.unbind(dim=-2) # unstack
            keys = torch.cat((k_xl, keys), dim = -2) # prepend XL memory
            values = torch.cat((v_xl, values), dim = -2) # prepend XL memory
            xl_sequence_length = k_xl.shape[1]

        queries = rearrange(queries, 'b t (h d) -> b h t d', h = self.heads)
        keys    = rearrange(keys, 'b t (h d) -> b h t d', h = self.heads)
        qk      = einsum(queries, keys, 'b h i d, b h j d -> b h i j')

        qk = qk * self.scale

        ############
        # TODO
        # qk = relative_position_values + qk
        ############

        i, j = qk.shape[-2:]
        mask = torch.ones((i,j), dtype = torch.bool).triu(j-i+1)
        qk = qk.masked_fill(mask, float('-inf'))

        qk = F.softmax(qk, dim=-1)

        values = rearrange(values, 'b t (h d) -> b h t d', h=self.heads)
        qkv = qk@values
        qkv = rearrange(qkv, 'b h t d -> b t (h d)')

        #### Return XL Memories

        keys = rearrange(keys, 'b h t d -> b t (h d)', h = self.heads)
        values = rearrange(values, 'b h t d -> b t (h d)', h=self.heads)
        kv_memories = torch.stack((keys, values), dim=-2) # (batch, sequence_len, 2, dimension)

        if xl_memory is not None:
            xl_memories, current_input = kv_memories[:, :-xl_sequence_length], kv_memories[:, -xl_sequence_length:]
            kv_to_add_xl = current_input
        else:
            kv_to_add_xl = kv_memories


        out = self.output_matrix(qkv)



        return out, kv_to_add_xl


### **KNN_XLAttention Class**

This class adds a **KNN (k-nearest neighbor)** retrieval mechanism to the XLAttention model. It combines local attention (traditional self-attention) with attention over a set of "retrieved" memories using KNN.

- **KNN Memory Retrieval**: The model uses a KNN search to find the **top-k nearest memories** from a previously stored set of memories (`knn.search`). These are added to the attention computation. By retrieving similar memories, the model can leverage past experiences to enhance its attention process.

- **Memory and Gate Bias**: A **gate bias** is applied to control the contribution of the retrieved memories vs the current query-key-value attention. This bias helps blend the two attention mechanisms (local and KNN-retrieved) effectively.

- **Combined Attention**: The attention values from the local attention (`qkv`) and the KNN-retrieved attention (`mem_qkv`) are combined using the gate bias. This mixture of both types of attentions is then passed through the output matrix to produce the final result.

- **Memory Update**: After performing attention, the new key-value pairs are stored in the KNN memory, so they can be used in subsequent computations.

In [22]:
class KNN_XLAttention(nn.Module):
    def __init__(
        self,
        embedding_dimension,
        heads = 8,
        head_dimension = 32,
        topk_retrieved_memories = 3,
    ):
        super().__init__()
        self.heads = heads
        self.scale = head_dimension ** -0.5

        self.query_matrix = nn.Linear(embedding_dimension, heads * head_dimension)
        self.key_matrix = nn.Linear(embedding_dimension, heads * head_dimension)
        self.value_matrix = nn.Linear(embedding_dimension, heads * head_dimension)
        self.output_matrix = nn.Linear(heads * head_dimension, embedding_dimension)

        self.gate_bias = nn.Parameter(torch.randn(self.heads, 1, 1))
        self.topk_retrieved_memories = topk_retrieved_memories

    def forward(
        self,
        x, # batch_size, sequence_length, embedding_dimension
        knn,
        xl_memory = None
    ):
        batch_size, sequence_length = x.shape[:2]
        queries = self.query_matrix(x)
        keys = self.key_matrix(x)
        values = self.value_matrix(x)

        if xl_memory is not None:
            k_xl, v_xl = xl_memory.unbind(dim = -2) # unstack
            keys = torch.cat((k_xl, keys), dim = -2) # prepend XL memory
            values = torch.cat((v_xl, values), dim = -2) # prepend XL memory
            xl_sequence_length = k_xl.shape[1]

        ### LOCAL ATTENTION

        queries = rearrange(queries, 'b t (h d) -> b h t d', h = self.heads)
        keys    = rearrange(keys, 'b t (h d) -> b h t d', h = self.heads)
        qk      = einsum(queries, keys, 'b h i d, b h j d -> b h i j')

        qk = qk * self.scale

        ############
        # TODO
        # qk = relative_position_values + qk
        ############

        i, j = qk.shape[-2:]
        mask = torch.ones((i,j), dtype = torch.bool).triu(j-i+1)
        qk = qk.masked_fill(mask, float('-inf'))

        qk = F.softmax(qk, dim=-1)

        values = rearrange(values, 'b t (h d) -> b h t d', h=self.heads)
        qkv = qk@values
        qkv = rearrange(qkv, 'b h t d -> b t (h d)')

        ### KNN ATTENTION

        # Convert queries to search form
        queries = rearrange(queries, 'b h t d -> b t (h d)')
        mem_kv = knn.search(queries, topk = self.topk_retrieved_memories) # returns b t k 2 d
        mem_k, mem_v = mem_kv.unbind(dim = -2)
        mem_k = rearrange(mem_k, 'b t k (h d) -> b h t k d', h=self.heads)
        mem_v = rearrange(mem_v, 'b t k (h d) -> b h t k d', h=self.heads)

        # Convert queries to attention form
        queries = rearrange(queries, 'b t (h d) -> b h t d', h = self.heads)
        mem_qk = einsum('b h t d, b h t k d -> b h t k', queries, mem_k)
        mem_qk = mem_qk * self.scale

        mem_qk = F.softmax(mem_qk, dim=-1)
        mem_qk = self.dropout(mem_qk)
        mem_qkv = einsum('b h t k, b h t k d -> b h t d', mem_qk, mem_v)

        # Combined attentions

        combined_qkv = mem_qkv * self.gate_bias + qkv * (1 - self.gate_bias)
        combined_qkv = rearrange(combined_qkv, 'b h t d -> b t (h d)')
        out = self.output_matrix(combined_qkv)

        # New XL memories
        keys = rearrange(keys, 'b h t d -> b t (h d)', h = self.heads)
        values = rearrange(values, 'b h t d -> b t (h d)', h=self.heads)
        kv_memories = torch.stack((keys, values), dim=-2) # (batch, sequence_len, 2, dimension)

        if xl_memory is not None:
            # if we're on a middle/end segment of a document (there are previous XL memories)
            xl_memories, current_kv = kv_memories[:, :-xl_sequence_length], kv_memories[:, -xl_sequence_length:]
        else:
            # if we're at the first segment
            current_kv = kv_memories

        knn.add(current_kv)

        return out, current_kv

Key Terminology:
- b: This represents the batch size.
- t: This represents the time or sequence length (e.g., number of tokens).
- k: This typically refers to the number of memory vectors or keys.
- h: This refers to the number of attention heads (a hyperparameter in attention mechanisms like transformers).
- d: This refers to the dimensionality of each attention head's output.

### Key Concepts

- Transformer-XL: Introduces recurrence to the transformer architecture by reusing memory from previous segments.
- Self-attention: Each token in a sequence learns to focus on other tokens based on similarity (query-key matching).
- Multi-head attention: Splits the attention process into multiple "heads," allowing the model to focus on different parts of the sequence at once.
- Masked self-attention: Ensures that tokens only attend to previous ones, preventing the model from cheating in autoregressive tasks.

### Key Features of These Classes:
1. **Recurrent Memory**: By utilizing `xl_memory`, these classes maintain state across input segments, which is particularly useful for processing long sequences, such as documents or time-series data.
   
2. **KNN Integration**: In `KNN_XLAttention`, the KNN memory retrieval allows the model to effectively use past "experiences" (memories) that are similar to the current input, which can improve its performance on tasks like language modeling and question answering where context and past information are important.

3. **Scalability**: Both models are designed to scale effectively with large sequences, by using both recurrent memory (Transformer-XL) and retrieval-based mechanisms (KNN).

### Applications
These modifications to the attention mechanism are useful in tasks like language modeling, document processing, or time-series forecasting where maintaining long-term dependencies or leveraging past experiences is crucial. Transformer-XL-like recurrence improves memory efficiency, while KNN provides a way to retrieve relevant past experiences dynamically.
