# Transformers are RNNs Demo

This is a colab to accompany our [project page](https://linear-transformers.com/)
and to explore the transformer formulation developed in our paper [Transformers are RNNs: Fast Autoregressive Transformers with Linear Attention](https://arxiv.org/abs/2006.16236).

The colab is structured in three sections.
1. Firstly, we install the `fast_transformers` library and create an example transformer.
2. Secondly, we load the models we trained for autoregressive image prediction and reproduce parts of the experiment in section 4.2 in the paper.
3. Finally, we perform some benchmarking comparing the computational requirements for linear attention and softmax attention.

If you want to run the CUDA part of the demo, make sure that you enable CUDA acceleration via **Runtime -> Change runtime** or **Edit -> Notebook settings**.

# Installation and First Steps

The installation is directly from PyPI. This will take several minutes since it compiles several custom CUDA kernels, not only for linear autoregressive attention. Maybe grab a coffee (if you are into these things).

In [None]:
!pip install -v pytorch-fast-transformers

Created temporary directory: /tmp/pip-ephem-wheel-cache-x3cs0xno
Created temporary directory: /tmp/pip-req-tracker-g3xf8g83
Created requirements tracker '/tmp/pip-req-tracker-g3xf8g83'
Created temporary directory: /tmp/pip-install-uryulpbk
1 location(s) to search for versions of pytorch-fast-transformers:
* https://pypi.org/simple/pytorch-fast-transformers/
Getting page https://pypi.org/simple/pytorch-fast-transformers/
Found index url https://pypi.org/simple
Looking up "https://pypi.org/simple/pytorch-fast-transformers/" in the cache
Request header has "max_age" as 0, cache bypassed
Starting new HTTPS connection (1): pypi.org:443
https://pypi.org:443 "GET /simple/pytorch-fast-transformers/ HTTP/1.1" 200 980
Updating cache with response from "https://pypi.org/simple/pytorch-fast-transformers/"
Caching due to etag
Analyzing links from page https://pypi.org/simple/pytorch-fast-transformers/
  Found link https://files.pythonhosted.org/packages/25/95/f9b2bdee5bc35e2c3b44cb2847f7bc77ddecd89

Let's validate our freshly installed package by creating a small transformer encoder and running it on dummy data.

In [None]:
from fast_transformers.builders import TransformerEncoderBuilder
from fast_transformers.masking import LengthMask, TriangularCausalMask
import torch

model = TransformerEncoderBuilder.from_kwargs(
    n_layers=4,
    n_heads=4,
    feed_forward_dimensions=128,
    query_dimensions=32,
    value_dimensions=32,
    attention_type="full" # this means normal softmax attention
).get()

x = torch.rand(
    10,  # batch size 
    100, # sequence length
    128  # feature dimensions
)
y = model(x) # calling without masks which means attend to everything
y = model(
    x,
    attn_mask=TriangularCausalMask(100),   # causal masking
    length_mask=LengthMask(torch.tensor([
        100, 70, 60, 30, 80, 100,          # The sequence length for every
        50, 40, 10, 20                     # sample in the batch
    ]))
)
print("If you reached here, everything works", y.shape)

# Autoregressive Image Generation

Let us first define two pytorch modules for autoregressive image generation. One uses a recurrent formulation that accepts one input at a time and the other similar to the default PyTorch implementation accepts the whole sequence.

In both cases the model is simply wrapping a transformer with an input embedding layer and a prediction layer. This becomes apparent upon reading the simple `forward()` method.

In [None]:
import math
from fast_transformers.builders import RecurrentEncoderBuilder

class RecurrentGenerator(torch.nn.Module):
    class PositionalEncoding(torch.nn.Module):
        def __init__(self, d_model, dropout=0.0, max_len=5000):
            super(RecurrentGenerator.PositionalEncoding, self).__init__()
            self.dropout = torch.nn.Dropout(p=dropout)
            self.d_model = d_model
            pe = torch.zeros(max_len, d_model)
            position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
            div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
            pe[:, 0::2] = torch.sin(position * div_term)
            pe[:, 1::2] = torch.cos(position * div_term)
            pe = pe.unsqueeze(0)
            self.register_buffer('pe', pe)

        def forward(self, x, i):
            pos_embedding =  self.pe[0, i:i+1]
            x = torch.cat(
                [x, pos_embedding.expand_as(x)],
                dim=1
            )
            return self.dropout(x)

    def __init__(self, d_model, sequence_length, mixtures,
                 attention_type="full", n_layers=4, n_heads=4,
                 d_query=32, dropout=0.1, softmax_temp=None,
                 attention_dropout=0.1):
        super(RecurrentGenerator, self).__init__()

        self.pos_embedding = self.PositionalEncoding(
            d_model//2,
            max_len=sequence_length
        )
        self.value_embedding = torch.nn.Embedding(
            256,
            d_model//2
        )
        self.transformer = RecurrentEncoderBuilder.from_kwargs(
            attention_type=attention_type,
            n_layers=n_layers,
            n_heads=n_heads,
            feed_forward_dimensions=n_heads*d_query*4,
            query_dimensions=d_query,
            value_dimensions=d_query,
            dropout=dropout,
            softmax_temp=softmax_temp,
            attention_dropout=attention_dropout
        ).get()
        self.predictor = torch.nn.Linear(
            d_model,
            mixtures * 3
        )

    def forward(self, x, i=0, memory=None):
        x = x.view(x.shape[0])
        x = self.value_embedding(x)
        x = self.pos_embedding(x, i)
        y_hat, memory = self.transformer(x, memory)
        y_hat = self.predictor(y_hat)

        return y_hat, memory


and the non recurrent one

In [None]:
class Generator(torch.nn.Module):
    class PositionalEncoding(torch.nn.Module):
        def __init__(self, d_model, dropout=0.0, max_len=5000):
            super(Generator.PositionalEncoding, self).__init__()
            self.dropout = torch.nn.Dropout(p=dropout)
            self.d_model = d_model
            pe = torch.zeros(max_len, d_model)
            position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
            div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
            pe[:, 0::2] = torch.sin(position * div_term)
            pe[:, 1::2] = torch.cos(position * div_term)
            pe = pe.unsqueeze(0)
            self.register_buffer('pe', pe)

        def forward(self, x):
            pos_embedding =  self.pe[:, :x.size(1), :]
            pos_embedding = torch.repeat_interleave(pos_embedding, x.shape[0], dim=0)
            x =  torch.cat([x, pos_embedding], dim=2)
            return self.dropout(x)

    def __init__(self, d_model, sequence_length, mixtures,
                 attention_type="full", n_layers=4, n_heads=4,
                 d_query=32, dropout=0.1, softmax_temp=None,
                 attention_dropout=0.1):
        super(Generator, self).__init__()

        self.pos_embedding = self.PositionalEncoding(
            d_model//2,
            max_len=sequence_length
        )
        self.value_embedding = torch.nn.Embedding(
            256,
            d_model//2
        )

        self.transformer = TransformerEncoderBuilder.from_kwargs(
            attention_type=attention_type,
            n_layers=n_layers,
            n_heads=n_heads,
            feed_forward_dimensions=n_heads*d_query*4,
            query_dimensions=d_query,
            value_dimensions=d_query,
            dropout=dropout,
            softmax_temp=softmax_temp,
            attention_dropout=attention_dropout
        ).get()

        hidden_size = n_heads*d_query
        self.predictor = torch.nn.Linear(
            hidden_size,
            mixtures * 3
        )

    def forward(self, x):
        x = x.view(x.shape[0], -1)
        x = self.value_embedding(x)
        x = self.pos_embedding(x)
        triangular_mask = TriangularCausalMask(x.shape[1], device=x.device)
        y_hat = self.transformer(x, attn_mask=triangular_mask)
        y_hat = self.predictor(y_hat)

        return y_hat

## Helper functions

We also need some helper functions to perform the sampling from the mixture of logistics as well as the looping to generate the images.

In [None]:
def sample_mol(y_hat, num_classes=256):
    """Sample from mixture of logistics.

    y_hat: NxC where C is 3*number of logistics
    """
    assert len(y_hat.shape) == 2

    N = y_hat.size(0)
    nr_mix = y_hat.size(1) // 3

    probs = torch.softmax(y_hat[:, :nr_mix], dim=-1)
    means = y_hat[:, nr_mix:2 * nr_mix]
    scales = torch.nn.functional.elu(y_hat[:, 2*nr_mix:3*nr_mix]) + 1.0001

    indices = torch.multinomial(probs, 1).squeeze()
    batch_indices = torch.arange(N, device=probs.device)
    mu = means[batch_indices, indices]
    s = scales[batch_indices, indices]
    u = torch.rand(N, device=probs.device)
    preds = mu + s*(torch.log(u) - torch.log(1-u))

    return torch.min(
        torch.max(
            torch.round((preds+1)/2*(num_classes-1)),
            preds.new_zeros(1),
        ),
        preds.new_ones(1)*(num_classes-1)
    ).long().view(N, 1)


def predict_with_recurrent(model, images, n):
    memory = None
    y_hat = []
    x_hat = []

    with torch.no_grad():
        for i in range(n):
            x_hat.append(images[:, i:i+1])
            yi, memory = model(x_hat[-1], i=i, memory=memory)
            y_hat.append(yi)

        for i in range(n, images.shape[1]):
            x_hat.append(sample_mol(y_hat[-1], 256))
            yi, memory = model(x_hat[-1], i=i, memory=memory)
            y_hat.append(yi)

        x_hat.append(sample_mol(y_hat[-1], 256))
        x_hat = torch.stack(x_hat, dim=1)

    return x_hat


def predict(model, images, n):
    N, L = images.shape
    x_hat = images.new_zeros(N, L+1, dtype=torch.long)
    x_hat[:, :n] = images[:, :n]
    with torch.no_grad():
        for i in range(n, L):
            y_hat = model(x_hat[:, :i])
            x_hat[:, i:i+1] = sample_mol(y_hat[:,-1,:], 256)
        x_hat[:, -1:] = sample_mol(y_hat[:,-1,:], 256)
    return x_hat

## Loading pretrained models

After defining our modules we can now download our pretrained models from Google Drive.

In [None]:
import io
import requests

LINEAR_MODEL = "https://drive.google.com/uc?export=download&id=17fc94TzytTdAwNMVCE7qOg75-CWLGi_p"
SOFTMAX_MODEL = "https://drive.google.com/uc?export=download&id=1L47Ode6GxCMQbVMK33_ANjCu2iA4rf8l"

linear_weights = torch.load(io.BytesIO(requests.get(LINEAR_MODEL).content))
softmax_weights = torch.load(io.BytesIO(requests.get(SOFTMAX_MODEL).content))

Now, let's create the model and generate some images. Note that we are creating a recurrent model for softmax. This means that we save all keys and values to avoid computing them again which is not something easily done for every transformer implementation (for instance reformer).

On the other hand, for linear attention the state has fixed size and it is natural to implement it as a recurrent model (see section 3.4 in [our paper](https://arxiv.org/pdf/2006.16236.pdf)).

In [None]:
linear = RecurrentGenerator(256, 783, 10, "linear", 8, 8)
linear.load_state_dict(linear_weights)
linear.eval()
full = RecurrentGenerator(256, 783, 10, "full", 8, 8)
full.load_state_dict(softmax_weights)
full.eval()

images_linear = predict_with_recurrent(linear, torch.zeros(1, 783, dtype=torch.int64), 1)
images_full = predict_with_recurrent(full, torch.zeros(1, 783, dtype=torch.int64), 1)

import numpy as np
import matplotlib.pyplot as plt
fig, ax = plt.subplots(1, 2)
ax[0].set_title("Linear")
ax[0].imshow(images_linear[0].cpu().numpy().reshape(28, 28))
ax[1].set_title("Softmax")
ax[1].imshow(images_full[0].cpu().numpy().reshape(28, 28))

## Time measurements

After validating that our models work and generate proper images, let's do some time measurements. Let's measure how much time takes each method to generate 100 images (linear takes about 30 seconds and softmax about 200 so please be patient).


In [None]:
import time

start = time.time()
images_linear = predict_with_recurrent(linear, torch.zeros(100, 783, dtype=torch.int64), 1)
end = time.time()
print("Linear took", round(end-start, 2), "s")

start = time.time()
images_full = predict_with_recurrent(full, torch.zeros(100, 783, dtype=torch.int64), 1)
end = time.time()
print("Stateful-softmax took", round(end-start, 2), "s")

Note that all those computations have been using the colab CPU. Let's run the same experiment with the GPU instead.

Our linear model uses constant memory throughout the prediction. This means that as we increase the sequence length or the batch size the speedup will only increase.

In [None]:
# Transfer the models to the GPU
linear.cuda()
full.cuda()

# Do some warmup before taking the measurments
predict_with_recurrent(linear, torch.zeros(1, 783, dtype=torch.int64, device="cuda"), 1)
predict_with_recurrent(full, torch.zeros(1, 783, dtype=torch.int64, device="cuda"), 1)

start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)
start.record()
images_linear = predict_with_recurrent(linear, torch.zeros(500, 783, dtype=torch.int64, device="cuda"), 1)
end.record()
torch.cuda.synchronize()
del images_linear
print("Linear took", round(start.elapsed_time(end)/1000, 2), "s")

start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)
start.record()
images_full = predict_with_recurrent(full, torch.zeros(500, 783, dtype=torch.int64, device="cuda"), 1)
end.record()
torch.cuda.synchronize()
del images_full
print("Stateful-softmax took", round(start.elapsed_time(end)/1000, 2), "s")

# Micro-benchmark

Let's finish up this demonstration with a micro-benchmark for the training time of the transformers.

Firstly, we create some helper functions to measure the time for the benchmark. `bench()` measures the time required to perform a forward/backward pass for a variety of sequence lengths and batch sizes.

In [None]:
def bench_one(model, x):
  # warmup the caches
  y = model(x)
  y.sum().backward()

  start = torch.cuda.Event(enable_timing=True)
  end = torch.cuda.Event(enable_timing=True)
  start.record()
  y = model(x)
  y.sum().backward()
  end.record()
  torch.cuda.synchronize()

  return start.elapsed_time(end)/1000

def bench(model, batches, sequence_lengths):
  time = []
  for b, s in zip(batches, sequence_lengths):
    x = torch.rand(b, s, 768).cuda()
    time.append(bench_one(model, x)/b)
  return time

Now, we can create the transformers to be tested.

In [None]:
builder = TransformerEncoderBuilder.from_kwargs(
    n_layers=1,
    n_heads=12,
    query_dimensions=64,
    value_dimensions=64,
    feed_forward_dimensions=768
)

builder.attention_type = "full"
full = builder.get().cuda()
builder.attention_type = "linear"
linear = builder.get().cuda()

Finally, we just select batch sizes such that we do not get GPU out of memory errors and we measure the time for a forward/backward pass for increasing sequence lengths.

From the plot, we observe that linear scales indeed linearly with respect to the sequence length while the full attention scales quadratically.

In [None]:
sequence_lengths = [2**i for i in range(7, 13)]
linear_batches = [1000, 500, 250, 125, 62, 31]
softmax_batches = [1000, 300, 100, 25, 5, 1]

linear_time = bench(linear, linear_batches, sequence_lengths)
softmax_time = bench(full, softmax_batches, sequence_lengths)

plt.plot(sequence_lengths, softmax_time, label="softmax")
plt.plot(sequence_lengths, linear_time, label="linear")
plt.xlabel("Sequence Length")
plt.ylabel("Seconds")
plt.legend()