`triton_call` cannot (currently) be vmapped - so we must write a multi-headed causal self
attention kernel instead.

We can probably write a faster version of `FSingleHeadCausalSelfAttention.get_qKV` using
a Triton kernel:
  - Linear layer Triton impl from https://github.com/ELS-RD/kernl/blob/main/src/kernl/implementations/linear_layer.py
  - Use Triton to write back to the cache, rather than dynamic_update_slice
  - Fuse the layer norm into the linear layer. Compute layer norm in a single pass using
  https://github.com/ELS-RD/kernl/blob/main/src/kernl/implementations/layer_norm.py - then
  pass the mean and variance into the linear layer impl.

# Post-attention Kernel

Can we speed up the post-attention operations using a Triton kernel? Specifically the lines:
```
x = x + y
y = nn.LayerNorm()(x)
y = nn.Dense(features=4 * C)(y)
y = GELU(y)
y = nn.Dense(features=C)(y)
x = x + y
```

Compare possible implementations:
- Baseline using jitted jax/flax
- Fuse the layer norm, dense and GELU
- Fuse the entire block into a single kernel - using `atomic_add` to parallelize the second
dense layer

## Setup

In [2]:
import jax
import jax.numpy as jnp
import flax.linen as nn
import jax_triton as jt
import triton
import triton.language as tl

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
from nimblegpt import get_config_for, param_shapes

In [4]:
config = get_config_for('gpt2')
config.n_embd = 1024 # use a power-of-2 embedding to simplify the kernels
rng = jax.random.PRNGKey(0)
k1, k2 = jax.random.split(rng)
x = jax.random.normal(k1, (config.n_embd,))
y = jax.random.normal(k2, (config.n_embd,))

## Overhead test - flax

In [5]:
class FlaxOverhead(nn.Module):
    
    @nn.compact
    def __call__(self, x, y):
        C, = x.shape

        x = x + y
        x = nn.Dense(features=4 * C)(x)
        # x = nn.Dense(features=C)(x)
        
        return x

In [6]:
params = FlaxOverhead().init(k1, x, y)

ohd_apply = jax.jit(FlaxOverhead().apply)
ohd_apply(params, x, y);

In [7]:
%%timeit

ohd_apply(params, x, y).block_until_ready();

59.4 µs ± 4.35 µs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)


## Reference Impl

In [8]:
def GELU(x):
    return 0.5 * x * (1.0 + jnp.tanh(jnp.sqrt(2.0 / jnp.pi) * (x + 0.044715 * x**3)))

class Reference(nn.Module):
    
    @nn.compact
    def __call__(self, x, y):
        C, = x.shape

        x = x + y
        y = nn.LayerNorm()(x)
        y = nn.Dense(features=4 * C)(y)
        y = GELU(y)
        y = nn.Dense(features=C)(y)
        x = x + y

        return x

In [9]:
params = Reference().init(rng, x, y)

ref_apply = jax.jit(Reference().apply)
ref_apply(params, x, y);

In [10]:
%%timeit

ref_apply(params, x, y).block_until_ready();

221 µs ± 16 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


## Fused Layer Norm, Dense and GELU

## Layer Norm Trick

The fused dense layer computes a single element $y = \bm{y}_i$ of the output by taking a dot product
with the $i$th column $\bm{w} = \bm{W}_{:, i}$ of the weights matrix and adding a bias term $b = \bm{b}_i$. The full expression for $y$, including layer norm and activation, is:

$$
\begin{align}
    y = \text{GELU} \left( \bm{w} \cdot \left( \frac{\bm{x} - \text{mean}(\bm{x})\bm{I}_{n}}{\sqrt{\text{var}(\bm{x}) + \epsilon}} \right) + b \right)
\end{align}
$$

Where $n$ is the number of elements in $\bm{x}$. We use $\bm{I}_{n}$ as the mathematical equivalent of broadcasting.

We would like to compute this incrementally, which means we won't have access to the mean
and variance of $\bm{x}$ until the end. We note that:

$$
\begin{align}
    \bm{w} \cdot \left( \frac{\bm{x} - \text{mean}(\bm{x})\bm{I}_n}{\sqrt{\text{var}(\bm{x}) + \epsilon}} \right) &= \frac{1}{\sqrt{\text{var}(\bm{x}) + \epsilon}} \left( \bm{w} \cdot \bm{x} - \bm{w} \cdot \text{mean}(\bm{x}) \bm{I}_n\right) \\
\end{align}
$$

Observe that $\bm{w} \cdot \text{mean}(\bm{x}) \bm{I}_n = n \times \text{mean}(\bm{x}) \times \text{sum}(\bm{w})$. Note also that $\text{var}(\bm{x}) = \text{mean}(\bm{x}^2) - \text{mean}(\bm{x})^2$. We can therefore compute the fused layer from the following pieces, all of which can be computed incrementally by streaming $w$ and $x$:

$$
\begin{align}
    \bm{w} \cdot \bm{x} && \text{sum}(\bm{w}) && \text{mean}(\bm{x}) && \text{mean}(\bm{x}^2) \\
\end{align}
$$

The equation for the output in terms of these pieces is:

$$
\begin{align}
    y = \text{GELU} \left( \frac{\bm{w} \cdot \bm{x} - n \times \text{mean}(\bm{x}) \times \text{sum}(\bm{w})}{\sqrt{\text{mean}(\bm{x}^2) - \text{mean}(\bm{x})^2 + \epsilon}} + b \right)
\end{align}
$$



In [11]:
import math

sqrt2pi = math.sqrt(2.0 / math.pi)

@triton.jit
def tanh(x):
    """Tanh activation function"""
    return tl.libdevice.tanh(x)

@triton.jit
def fast_gelu(x):
    """Fast approximation of the gelu function. May slightly decrease accuracy."""
    return 0.5 * x * (1 + tanh(sqrt2pi * (x + 0.044715 * x * x * x)))

In [12]:
config.n_embd

1024

In [33]:
@triton.jit
def fused_dense_kernel(x_ptr, W_ptr, b_ptr, out_ptr,
                       SUBEMBD_SIZE: tl.constexpr, BLOCK_ROWS: tl.constexpr,
                       N_EMBD: tl.constexpr):
    """
    Trition kernel implementing fused layer norm, dense layer and GELU activation.

    Kernel cell i computes out[i * SUBEMBD_SIZE: (i+1) * SUBEMBD_SIZE]. The multiplication
    is done iteratively on blocks of `BLOCK_ROWS` rows.

    Inputs
    ------
    x_ptr: [N_EMBD,] - current token embedding
    W_ptr: [N_EMBD, 4 * N_EMBD] - dense layer weights
    b_ptr: [4 * N_EMBD,] - dense layer bias

    Outputs
    -------
    out_ptr: [4 * N_EMBD,] - output embedding
    """
    subembd_start = tl.program_id(0) * SUBEMBD_SIZE

    x_idxs = tl.arange(0, N_EMBD)
    subembd_idxs = tl.arange(0, SUBEMBD_SIZE) + subembd_start

    x = tl.load(x_ptr + x_idxs)

    # Layer norm x.
    x_mean = tl.sum(x, axis=0) / N_EMBD
    x_mean_sq = tl.sum(x * x, axis=0) / N_EMBD
    x = (x - x_mean) / tl.sqrt(x_mean_sq + 1e-5)

    # Initialise the accumulator with the dense layer bias.
    acc = tl.load(b_ptr + subembd_idxs)

    n_blocks = tl.cdiv(4 * N_EMBD, BLOCK_ROWS)
    for block_i in range(0, n_blocks):
        
        row_idxs = tl.arange(0, BLOCK_ROWS) + block_i * BLOCK_ROWS
        Wi_idxs = row_idxs[:, None] * 4 * N_EMBD + subembd_idxs[None, :]

        Wi = tl.load(W_ptr + Wi_idxs)
        acc += tl.sum(x[:, None] * Wi, axis=0)

    acc = fast_gelu(acc)

    tl.store(out_ptr, subembd_idxs, acc)

In [34]:
def fused_dense(x, W, b, SUBEMBD_SIZE, BLOCK_ROWS):

    N_EMBD = x.shape[0]

    out_shape = jax.ShapeDtypeStruct((4 * N_EMBD, ), x.dtype)
    grid = (N_EMBD // SUBEMBD_SIZE, )

    return jt.triton_call(x,
                          W,
                          b,
                          kernel=fused_dense_kernel,
                          out_shape=out_shape,
                          grid=grid,
                          SUBEMBD_SIZE=SUBEMBD_SIZE,
                          BLOCK_ROWS=BLOCK_ROWS,
                          N_EMBD=N_EMBD)

In [35]:
class FusedDense(nn.Module):

    SUBEMBD_SIZE: int
    BLOCK_ROWS: int

    @nn.compact
    def __call__(self, x, y, W, b):

        C, = x.shape

        x = x + y
        y = fused_dense(x, W, b,
                        self.SUBEMBD_SIZE, self.BLOCK_ROWS)
        y = nn.Dense(features=C)(y)
        x = x + y

        return x

In [38]:
FD = FusedDense(SUBEMBD_SIZE = 32, BLOCK_ROWS = 32)
W = params["params"]["Dense_0"]["kernel"]
b = params["params"]["Dense_0"]["bias"]

out = FD.apply(params, x, y, W, b)

CompilationError: at 42:35:
def fused_dense_kernel(x_ptr, W_ptr, b_ptr, out_ptr,
                       SUBEMBD_SIZE: tl.constexpr, BLOCK_ROWS: tl.constexpr,
                       N_EMBD: tl.constexpr):
    """
    Trition kernel implementing fused layer norm, dense layer and GELU activation.

    Kernel cell i computes out[i * SUBEMBD_SIZE: (i+1) * SUBEMBD_SIZE]. The multiplication
    is done iteratively on blocks of `BLOCK_ROWS` rows.

    Inputs
    ------
    x_ptr: [N_EMBD,] - current token embedding
    W_ptr: [N_EMBD, 4 * N_EMBD] - dense layer weights
    b_ptr: [4 * N_EMBD,] - dense layer bias

    Outputs
    -------
    out_ptr: [4 * N_EMBD,] - output embedding
    """
    subembd_start = tl.program_id(0) * SUBEMBD_SIZE

    x_idxs = tl.arange(0, N_EMBD)
    subembd_idxs = tl.arange(0, SUBEMBD_SIZE) + subembd_start

    x = tl.load(x_ptr + x_idxs)

    # Layer norm x.
    x_mean = tl.sum(x, axis=0) / N_EMBD
    x_mean_sq = tl.sum(x * x, axis=0) / N_EMBD
    x = (x - x_mean) / tl.sqrt(x_mean_sq + 1e-5)

    # Initialise the accumulator with the dense layer bias.
    acc = tl.load(b_ptr + subembd_idxs)

    n_blocks = tl.cdiv(4 * N_EMBD, BLOCK_ROWS)
    for block_i in range(0, n_blocks):

        row_idxs = tl.arange(0, BLOCK_ROWS) + block_i * BLOCK_ROWS
        Wi_idxs = row_idxs[:, None] * 4 * N_EMBD + subembd_idxs[None, :]

        Wi = tl.load(W_ptr + Wi_idxs)
        acc += tl.sum(x[:, None] * Wi, axis=0)
                                   ^

: 

# Layer norm and linear layer

We implement a fused linear layer + embedding to q, k, v linear layer using Triton. Implement:

- baseline