# RetNet

Paper: https://arxiv.org/pdf/2307.08621.pdf

<img src="./img/impossible_triangle.png" alt="Impossible Triangle" width="50%">

## Perplexity

It is an intrinsic evaluation metric of performances for language model in Natural Language Processing.

It measure - on average - how many different equally most probable words can follow any given word.

Lower perplexities represent better language models.

## Performances comparison with Transformers

From the paper

> Retentive network (RetNet) achieves low-cost inference (i.e., GPU memory, throughput,
and latency), training parallelism, and favorable scaling curves compared with Transformer. Results
of inference cost are reported with 8k as input length. Figure 6 shows more results on different
sequence lengths.

<img src="./img/performances_comparison.png" alt="Impossible Triangle" width="80%">

## Parallelism

Transformers use the self-attention mechanism which is highly parallelisable. During training it's an advantage, but during inference time it's a drawback.

## Inference cost & memory complexity

- Inference cost per time step refers to GPU memory, throughput and latency
- Memory complexity refers to the scaling laws of the memory footprint with respect to sequence length.

### RNNs

RNNs use matrix multiplications.

- The inference time complexity is constant $O(1)$
- The inference memory complexity is linear $O(N)$

### Transformers

Transformers use the self-attention mechanism, possibly using a Multi-Head Attention approach, they need to store in memory the $N \times N$ to be used during inference.

- The inference time complexity is linear $O(N)$
- The inference memory complexity is quadratic $O(N^2)$

### Improvements

- Use of self-attention during training to achieve parallelism
- The self-attention mechanism is replaced with the so-called **retention mechanism** combined with the **recurrent inference** approach. This approach improves upon the self-attention during training.
- RetNet introduces a **Multi-Scale retention mechanism** which replaces the **Multi-Head Attention mechanism**. This helps reducing the complexity at inference time.

- **Parallel approach**: used during training
- **Recurrent approach**: used during inference to reduce the time/memory complexity to constant/linear
- **Chunkwise recurrent approach**: performs efficient modelling of long sequences. Each local block is encoded in parallel, then each local block is subsequently encoded in a recurrent fashion.

## Self-Attention

## SoftMax

RetNet replaces the SoftMax module with the [**Hadamard multiplication**](../math_recall/linear_algebra.ipynb) by a newly introduced **D-Matrix** followed by a [**Group Normalization**](https://arxiv.org/pdf/1803.08494.pdf) operation.

## Normalization

The step of normalization typically consist of **centering** the data around the origin and then **scaling** them so that their distribution falls within a well known multivariate normal distribution for better separability hence performances.

If this works for classification, then it should work in a neural network setting too. The idea is to do this at each layer so that the next layer will receive a normalised signal.

In order to normalize, you need to calculate the mean and standard deviation.

<img src="./img/normalization.png" alt="Normalization" width="80%">

> Normalization methods. Each subplot shows a feature map tensor, with N as the batch axis, C as the channel axis, and (H, W)
as the spatial axes. The pixels in blue are normalized by the same mean and variance, computed by aggregating the values of these pixels.

### Batch Normalization

**Batch Normalization** calculates the mean and standard deviation of a subset (batch) of data. The bigger is the batch, the better is the estimation.

The problem with this mechanism lately is that the size of the input, especially with LLM has become really big, so even a batch can be quite unmanageable. You could try to split the batch itself to different parallel units, but then you will calculate the statistics on very few datapoint, which will inevitably result in poor data quality.

The BatchNorm error grows with the reduction of number of samples.

**BatchNorm** normalizes across the $N$ datapoints.

### Group Normalization

It is similar to the batch normalization process: centering and scaling. The difference is that it does the normalization without relying on the batch statistics.

It works one datapoint at a time, and spans across few features. It takes the good intuitions of both Layer and Instance norms, by standardising across the features (multiple features together) but not all of them and not one at a time. The features selected may share a similar scale.

GroupNorm is scale-invariant, which means it improves the numerical precision of the retention layers without affecting the outputs and the backward gradients. Basically, it does not affect the final results while stabilizing the numerical flow of both forward and backward passes, because of the scale-invariant property.

## Softmax, D-matrix and GroupNorm

The **SoftMax** function in Transformers serves two purposes:
- Gives models the percentage of contribution in the encoding of each single token within the sequence.
- Introduces non-linearity

By removing SoftMax, we need to replace it with something that can provide these two properties.

The proposed **D-matrix** takes care of the "attending" part, by means of causal masking which prevents look-ahead and weights the past tokens with an exponential decay. The assumption is that more recent tokens in the sequence (closer to the current masked token) are exponentially more important than the farthest ones.

> **Causal masking** is used to prevent the model from attending to future tokens during training, so that it can only rely on information from past tokens to predict the next token.

In comparison, SoftMax is more flexible but has an expensive computation, while the D-matrix is less flexible - having a pre-defined weighting of tokens - but it gives a more efficient $O(1)$ inference time with a $O(N)$ memory complexity.

The **GroupNorm** operation re-introduces non-linearity back.

## Retention Mechanism

The **retention mechanism** combines the benefit of RNNs and Transformers. Not only that, the name itself is a combination of both: **RE**current + self-at**TENTION.

## Achievement

- Parallel retention mechanism during training
- Recurrent retention mechanism during inference

The recurrent version of the retention mechanism is achieved by deconstructing the parallel version. It has slightly unintuitive matrix operations, but both the parallel and recurrent output are the same.