# **The Reformer - Pushing the limits of language modeling**

***How the Reformer uses less than 8GB of RAM to train on sequences of half a million tokens***

The Reformer model as introduced by [Kitaev, Kaiser et al. (2020)](https://arxiv.org/pdf/2001.04451.pdf) is one of the most memory-efficient transformer models for long sequence modeling as of today.

Recently, long sequence modeling has experienced a surge of interest as can be seen by the many submissions from this year alone - [Beltagy et al. (2020)](https://arxiv.org/abs/2004.05150), [Roy et al. (2020)](https://arxiv.org/abs/2003.05997), [Tay et al.](https://arxiv.org/abs/2002.11296), [Wang et al.](https://arxiv.org/abs/2006.04768) to name  a few. 
The motivation behind long sequence modeling is that many tasks in NLP, *e.g.* summarization, question answering, require the model to process longer input sequences than models, such as BERT, are able to handle. In tasks that require the model to process a large input sequence, long sequence models do not have to cut the input sequence to avoid memory overflow and thus have been shown to outperform standard "BERT"-like models *cf.* [Beltagy et al. (2020)](https://arxiv.org/abs/2004.05150). 

The Reformer pushes the limit of longe sequence modeling by its ability to process up to half a million tokens at once as shown in this [demo](https://github.com/patrickvonplaten/notebooks/blob/master/PyTorch_Reformer.ipynb). As a comparison, a conventional `bert-base-uncased` model limits the input length to only 512 tokens. In Reformer, each part of the standard transformer architecture is re-engineered to optimize for minimal memory requirement without a significant drop in performance.

The memory improvements can be attributed to **4** features which the Reformer authors introduced to the transformer world:

1.   **Reformer Self-Attention Layer** - *How to efficiently implement self-attention without being restricted to a local context?* => see [this colab](https://colab.research.google.com/drive/15oP52_7W5dRcAnbgX3tYADsu4R3cjMIf?usp=sharing)
2.  **Chunked Feed Forward Layers** - *How to get a better time-memory trade-off for large feed forward layers?* => see [this colab](https://colab.research.google.com/drive/1xKK32Yhda-iYgtoA3eCrnCVuy_lraQR9?usp=sharing)
3.   **Reversible Residual Layers**  - *How to drastically reduce memory consumption in training by a smart residual architecture?* => see [this colab](https://colab.research.google.com/drive/1BLffcRt9LXmM7nKU2UXhtm0PqAG0UE7J#scrollTo=mk1ETLlfEMGA)
4.   **Axial Positional Encodings** - *How to make positional encodings usable for extremely large input sequences?*

The goal of this blog post is to give the reader an **in-depth** understanding of each of the four Reformer features mentioned above. While the explanations are focussed on the Reformer, the reader should get a better intuition under which circumstances each of the four features can be effective for other transformer models as well. 
The four sections are only loosely connected, so they can very well be read individually.

Reformer is part of the 🤗Transformers library. For all users of the Reformer, it is advised to go through this very detailed blog post to better understand how the model works and how to correctly set its configuration. All equations are accompanied by their equivalent name for the Reformer config, *e.g.* `config.<param_name>`, so that the reader can quickly relate to the official docs and configuration file.

**Note**: *Axial Positional Encodings* are not explained in the official Reformer paper, but are extensively used in the official codebase. This blog post gives the first in-depth explanation of Axial Positional Encodings.

## **4. Axial Positional Encodings**

Reformer makes it possible to process huge input sequences. However, for such long input sequences standard positional encoding weight matrices alone would use more than 1GB to store its weights.
To prevent such large positional encoding matrices, the official Reformer code makes use of *Axial Position Encodings*.

**Important:** *Axial Position Encodings were not explained in the official paper, but can be well understood from looking into the code and talking to the authors*


### **Axial Positional Encodings in Reformer**

Transformers need positional encodings to account for the order of words in the input because self-attention layers have *no notion of order*. 
Positional encodings are usually defined by a simple look-up matrix $\mathbf{E} = \left[\mathbf{e}_1, \ldots, \mathbf{e}_{n_\text{max}}\right]$ The positional encoding vector $\mathbf{e}_i$ is then simply added to the *ith* input vector $\mathbf{x}_i + \mathbf{e}_i$ so that the model can distinguish if an input vector (*a.k.a* token) is at position $i$ or $j$. 
For every input position, the model needs to be able to look up the corresponding positional encoding vector so that the dimension of $\mathbf{E}$ is defined by the maximum length of input vectors the model can process `config.max_position_embeddings`, *i.e.* $n_\text{max}$, and the `config.hidden_size`, *i.e.* $d_h$ of the input vectors. 

Assuming $d_h=4$ and $n_\text{max}=49$, such a positional encoding matrix can be visualized as follows:

![alt text](https://raw.githubusercontent.com/patrickvonplaten/scientific_images/master/reformer_benchmark/positional_encodings_default.png)

Here, we showcase only the positional encodings $\mathbf{e}_1$, $\mathbf{e}_2$, and $\mathbf{e}_{49}$ each of dimension, *a.k.a* height 4.

Let's imagine, we want to train a Reformer model on sequences of a length of up to 0.5M tokens and an input vector `config.hidden_size` of 1024 (see notebook [here](https://github.com/patrickvonplaten/notebooks/blob/master/PyTorch_Reformer.ipynb)). The corresponding positional embeddings have a size of $0.5M \times 1024 \sim 512M$ parameters, which corresponds to a size of 2GB.

Such positional encodings would use an unnecessarily large amount of memory both when loading the model in memory and when saving the model on a hard drive.

The Reformer authors managed to drastically shrink the positional encodings in size by cutting the `config.hidden_size` dimension in two and smartly factorizing the $n_\text{max}$ dimension. In Transformer, the user can decide into which shape $n_\text{max}$ can be factorized into by setting `config.axial_pos_shape` to an appropriate list of two values $n_\text{max}^1$ and $n_\text{max}^2$ so that $n_\text{max}^1 \times n_\text{max}^2 = n_\text{max}$. By setting `config.axial_pos_embds_dim` to an appropriate list of two values $d_h^1$ and $d_h^2$ so that $d_h^1 + d_h^2 = d_h$, the user can decide how the hidden size dimension should be cut. 
Now, let's visualize and explain more intuitively.

One can think of factorizing $n_\text{max}$ as folding the dimension into a third axis, which is shown in the following for the factorization `config.axial_pos_shape = [7, 7]`:

![alt text](https://raw.githubusercontent.com/patrickvonplaten/scientific_images/master/reformer_benchmark/3d_positional_encoding.png)

Each of the three standing rectangular prisms corresponds to one of the encoding vectors $\mathbf{e}_1, \mathbf{e}_2, \mathbf{e}_{49}$, but we can see that the 49 encoding vectors are divided into 7 rows of 7 vectors each.
Now the idea is to use only one row of 7 encoding vectors and expand those vectors to the other 6 rows, essentially reusing their values. 
Because it is discouraged to have the same values for different encoding vectors, each vector of dimension (*a.k.a* height) `config.hidden_size=4` is cut into the lower encoding vector $\mathbf{e}_\text{down}$ of size $1$ and $\mathbf{e}_\text{up}$ of size $3$, so that the lower part can be expanded along the row dimension and the upper part can be expanded along the column dimension.
Let's visualize for more clarity.

![alt text](https://raw.githubusercontent.com/patrickvonplaten/scientific_images/master/reformer_benchmark/3d_positional_encoding_cut.png)

We can see that we have cut the embedding vectors into $\mathbf{e}_\text{down}$ (*in blue*) and $\mathbf{e}_\text{up}$ (*in yellow*).
Now for the "sub"-vectors $\mathbf{E}_\text{down} = \left[\mathbf{e}_{\text{down},1}, \ldots, \mathbf{e}_{\text{down},49}\right]$ only the first row, *a.k.a.* the width in the graphic, of $7$ is kept and expanded along the column dimension, *a.k.a.* the depth of the graphic. Inversely, for the "sub"-vectors $\mathbf{E}_\text{up} = \left[\mathbf{e}_{\text{up},1}, \ldots, \mathbf{e}_{\text{up},49}\right]$ only the first column of $7$ is kept and expanded along the row dimension.
The resulting embedding vectors $\mathbf{e'}_i$ then correspond to

  \begin{align}
    \mathbf{e'}_i &= \begin{bmatrix}
           \mathbf{e}_{\text{down, } i \% n_\text{max}^1} \\
           \mathbf{e}_{\text{up, } \left \lfloor{\frac{i}{n_\text{max}^2}}\right \rceil}
         \end{bmatrix}
  \end{align}
whereas $n_\text{max}^1 = 7$ and $n_\text{max}^2 = 7$ in our example.
These new encodings $\mathbf{E'} = \left[\mathbf{e'}_1, \ldots, \mathbf{e'}_{n_\text{max}}\right]$ are called **Axial Position Encodings**. 

In the following, these axial position encodings are illustrated in more detail for our example.

![alt text](https://raw.githubusercontent.com/patrickvonplaten/scientific_images/master/reformer_benchmark/axial_pos_encoding.png)

Now it should be more understandable how the final positional encoding vectors $\mathbf{E'}$ are calculated only from $\mathbf{E}_{\text{down}}$ of dimension $d_h^1 \times n_\text{max}^1$ and $\mathbf{E}_{\text{up}}$ of dimension $d_h^2 \times n_\text{max}^2$.

The crucial aspect to see here is that Axial Positional Encodings make sure that none of the vectors $\left[\mathbf{e'}_1, \ldots, \mathbf{e'}_{n_\text{max}}\right]$ are equal to each other by design and that the overall size of the encoding matrix is reduced from $n_\text{max} \times d_h$ to $n_\text{max}^1 \times d_h^1 + n_\text{max}^2 \times d_h^2$.
By allowing each axial positional encoding vector to be different by design the model is given much more flexibility to learn efficient positional representations if axial positional encodings are learned by the model.

To demonstrate the drastic reduction in size, 
let's assume we would have set `config.axial_pos_shape = [1024, 512]` and `config.axial_pos_embds_dim = [512, 512]` for a Reformer model that can process inputs up to a length of 0.5M tokens. The resulting axial positional encoding matrix would have had a size of only $1024 \times 512 + 512 \times 512 \sim 800K$ parameters which corresponds to roughly 3MB. This is a drastic reduction from the 2GB a standard positional encoding matrix would require in this case.

For a more condensed and math-heavy explanation please refer to the 🤗Transformers docs [here](https://huggingface.co/transformers/model_doc/reformer.html#axial-positional-encodings).

### **Benchmark**

Lastly, let's also compare the peak memory consumption of conventional positional embeddings to *axial positional embeddings*.

In [None]:
#@title Installs and Imports
# pip installs
!pip -qq install git+https://github.com/huggingface/transformers.git
!pip install -qq py3nvml

from transformers import ReformerConfig, PyTorchBenchmark, PyTorchBenchmarkArguments, ReformerModel

[K     |████████████████████████████████| 3.0MB 7.2MB/s 
[K     |████████████████████████████████| 1.1MB 32.1MB/s 
[K     |████████████████████████████████| 890kB 42.5MB/s 
[?25h  Building wheel for transformers (setup.py) ... [?25l[?25hdone
  Building wheel for sacremoses (setup.py) ... [?25l[?25hdone
[K     |████████████████████████████████| 61kB 3.7MB/s 
[?25h

Positional embeddings depend only on two configuration parameters: The maximum allowed length of input sequences `config.max_position_embeddings` and `config.hidden_size`. Let's use a model that pushes the maximum allowed length of input sequences to half a million tokens, called `google/reformer-crime-and-punishment`, to see the effect of using axial positional embeddings.

To begin with, we will compare the shape of axial position encodings with standard positional encodings and the number of parameters in the model.

In [None]:
config_no_pos_axial_embeds = ReformerConfig.from_pretrained("google/reformer-crime-and-punishment", axial_pos_embds=False)  # disable axial positional embeddings
config_pos_axial_embeds = ReformerConfig.from_pretrained("google/reformer-crime-and-punishment", axial_pos_embds=True, axial_pos_embds_dim=(64, 192), axial_pos_shape=(512, 1024))  # enable axial positional embeddings

print("Default Positional Encodings")
print(20 * '-')
model = ReformerModel(config_no_pos_axial_embeds)
print(f"Positional embeddings shape: {model.embeddings.position_embeddings}")
print(f"Num parameters of model: {model.num_parameters()}")
print(20 * '-' + '\n\n')

print("Axial Positional Encodings")
print(20 * '-')
model = ReformerModel(config_pos_axial_embeds)
print(f"Positional embeddings shape: {model.embeddings.position_embeddings}")
print(f"Num parameters of model: {model.num_parameters()}")
print(20 * '-' + '\n\n')

HBox(children=(FloatProgress(value=0.0, description='Downloading', max=1151.0, style=ProgressStyle(description…


Default Positional Encodings
--------------------
Positional embeddings shape: PositionEmbeddings(
  (embedding): Embedding(524288, 256)
)
Num parameters of model: 136572416
--------------------


Axial Positional Encodings
--------------------
Positional embeddings shape: AxialPositionEmbeddings(
  (weights): ParameterList(
      (0): Parameter containing: [torch.FloatTensor of size 512x1x64]
      (1): Parameter containing: [torch.FloatTensor of size 1x1024x192]
  )
)
Num parameters of model: 2584064
--------------------




Having read the theory, the shape of the axial positional encoding weights should not be a surprise to the reader.

Regarding the results, it can be seen that for models being capable of processing such long input sequences, it is not practical to use default positional encodings. 
In the case of `google/reformer-crime-and-punishment`, standard positional encodings alone contain more than 100M parameters. 
Axial positional encodings reduce this number to just over 200K.

Lastly, let's also compare the required memory at inference time.

In [None]:
benchmark_args = PyTorchBenchmarkArguments(sequence_lengths=[512], batch_sizes=[8], models=["Reformer-No-Axial-Pos-Embeddings", "Reformer-Axial-Pos-Embeddings"], no_speed=True, no_env_print=True)
benchmark = PyTorchBenchmark(configs=[config_no_pos_axial_embeds, config_pos_axial_embeds], args=benchmark_args)
result = benchmark.run()

1 / 2
2 / 2

--------------------------------------------------------------------------------
          Model Name             Batch Size     Seq Length    Memory in MB 
--------------------------------------------------------------------------------
Reformer-No-Axial-Pos-Embeddin       8              512             959      
Reformer-Axial-Pos-Embeddings        8              512             447      
--------------------------------------------------------------------------------


It can be seen that using axial positional embeddings reduces the memory requirement to approximately half in the case of `google/reformer-crime-and-punishment`.