# Lola (Layer nOrm + Linear + Activation) Fusion

As part of [my on-going work on a faster version of GPT-2 written in JAX](https://github.com/tristanheywood/nimbleGPT), I have been exploring the use of [OpenAI's Triton](https://github.com/openai/triton), along with [jax-triton](https://github.com/jax-ml/jax-triton) - a library for using Triton with JAX. Triton allows CUDA kernels to be written in Python, allowing users to combine the performance of custom CUDA kernels with the ease-of-development of Python.



## Kernel Fusion

What's so great about 'custom CUDA kernels'? No matter which Python framework you're using for ML, your model implementation will consist of a series of 'operations', such as matrix multiplications, convolutions, activation functions, normalizations, etc. For an operation like `torch.matmul`, PyTorch includes a CUDA kernel which is invoked to perform this operation. When exeucting a model, the model's weights and intermediate activations will be stored in GPU DRAM memory. Executing an operation involves copying the required data from DRAM to GPU registers, performing the operation, and copying the results back to DRAM. 

Unfortunately, recent advances in GPU architectures have seen computation speeds increase significantly, outpacing memory bandwidth. A significant fraction of the time spent executing a model is spent waiting for data to be copied from DRAM to GPU registers. The solution to this is to merge multiple operations into a single CUDA kernel, so that less data must be transferred for the same computation. This process is known as **Kernel Fusion** (also known as **Operator Fusion**), and is one of the primary uses cases for Triton.

<img src="https://pbs.twimg.com/media/FjzAsy-UYAAeptC?format=jpg&name=large" height="300px">

<img src="https://pbs.twimg.com/media/FjzAvZBVsAAXiQP?format=jpg&name=large" height="300px">

(from https://twitter.com/cHHillee)

## Lola Operation Sequence

To speed up GPT-2, I have been exploring the use of Kernel Fusion using Triton. In this post I present a custom Triton kernel I developed to fuse together the sequence: Layer norm => linear layer => GELU activation, a sequence which appears just after each self-attention layer in GPT-2. My custom kernel behaves the same as the Flax module:

In [1]:
import flax.linen as nn
import jax


class FlaxLola(nn.Module):
    @nn.compact
    def __call__(self, x):
        *_, C = x.shape

        x = nn.LayerNorm(use_scale=False, use_bias=False)(x)
        x = nn.Dense(features=4 * C)(x)
        x = nn.activation.gelu(x)

        return x


flax_lola_apply = jax.jit(FlaxLola().apply)

Or alternatively the torch function:

In [2]:
import torch.nn.functional as F


def torch_lola(x, weights, bias):
    *_, C = x.shape

    x = F.layer_norm(x, (C, ))
    x = F.linear(x, weights, bias)
    x = F.gelu(x)

    return x

  from .autonotebook import tqdm as notebook_tqdm


# GPU Memory Hierachy

The linear layer and activation function in Lola can be fused together pretty easily, we just need to write a matrix multiplication kernel and then just before we write the results back to DRAM, we apply the activation function to each element. The Layer Norm is more difficult to fuse. The issue is that computing the mean and variance of each embedding in the batch requires reading in that full embedding from DRAM. Unfortunately, the architecture of GPUs makes this impractical.

## GPU Architecture

Analogous to how CPUs have multiple cores, GPUs have multiple 'Streaming Multiprocessors' (SMs). The A100 GPU for example, has 108 SMs. Executing a kernel on a GPU typically involves multiple SMs executing the same kernel simultaneously, but on different data. For example, you could write a matrix multiplication kernel where each SM was responsible for a single element of the output matrix. Each SM would therefore be running the same code, but on a unique row and column combination.

Each SM has its own dedicated 'on-chip' Shared Memory (SMEM). When an SM executes a kernel, it first loads the required data from DRAM into its SRAM, then performs the computation. The amount of SRAM per SM is typically very small, compared to the GPU's DRAM. The A100 has 40 GB of DRAM but only 192KB of SRAM per SM. As a result, it is typically optimal for compute kernels to stream data from DRAM to SRAM as they need it, rather than loading the entire input into SRAM at the start of the kernel. For example if an SM was computing the dot-product of two vectors (e.g. a row and a column of a matrix - as is required during matrix multiplication), the SM might load blocks of 128 elements at a time from DRAM into SRAM, perform the dot-product on those 128 elements, and then load the next 128 elements.

## Streaming the Layer Norm

Computing the layer norm using a single streaming pass over the input is impossible. We don't know the final results for each element until we know the mean and variance, and we don't know these until we've seen the entire input.

I have discovered however, that it **is** possible to compute a combined layer norm and matrix multiplication, in a single streaming pass! We can derive this result by algebraic manipulation of the Lola operation sequence.

### Re-arranging the Lola Operation Sequence

Let's examine the process of computing a single element of the Lola output. In reality, Lola operates on a batch $\bm{X}$ of input embeddings. Consider a single embedding $\bm{x} = \bm{X}_{i, :}$, in the batch, and a single element $y = \bm{y}_j$ of the output vector $\bm{y}$ for this embedding. $y$ is calculated by taking the dot product of $x$ with the $j^{th}$ column $\bm{w} = \bm{W}_{:, j}$ of the weights matrix and adding a bias term $b = \bm{b}_j$. The full expression for $y$, including layer norm and activation, is:

$$
\begin{align}
    y = \text{GELU} \left( \bm{w} \cdot \left( \frac{\bm{x} - \text{mean}(\bm{x})\bm{I}_{n}}{\sqrt{\text{var}(\bm{x}) + \epsilon}} \right) + b \right)
\end{align}
$$

Where $n$ is the number of elements in $\bm{x}$. We use multiplication by $\bm{I}_{n}$ as the mathematical equivalent of broadcasting.

To compute this output by streaming, we need to find a way to express this computation that allows us to perform it by iterating over $\bm{x}$ and $\bm{w}$. This is not possible with the current expression, because we won't know $\text{mean}(\bm{x})$ until we have iterated over all of $\bm{x}$.

To this end, observe that:

$$
\begin{align}
    \bm{w} \cdot \left( \frac{\bm{x} - \text{mean}(\bm{x})\bm{I}_n}{\sqrt{\text{var}(\bm{x}) + \epsilon}} \right) &= \frac{1}{\sqrt{\text{var}(\bm{x}) + \epsilon}} \left( \bm{w} \cdot \bm{x} - \bm{w} \cdot \text{mean}(\bm{x}) \bm{I}_n\right) \\
\end{align}
$$

Now by the definition of the dot product: $\bm{w} \cdot \text{mean}(\bm{x}) \bm{I}_n = \sum_k w_k \cdot \text{mean}(\bm{x}) = \text{mean}(\bm{x}) \times \text{sum}(\bm{w})$. Note also that $\text{var}(\bm{x}) = \text{mean}(\bm{x}^2) - \text{mean}(\bm{x})^2$. We can therefore compute the required output from the following pieces, all of which can be computed incrementally by streaming $w$ and $x$!

$$
\begin{align}
    \bm{w} \cdot \bm{x} && \text{sum}(\bm{w}) && \text{mean}(\bm{x}) && \text{mean}(\bm{x}^2) \\
\end{align}
$$

The equation for the output in terms of these pieces is:

$$
\begin{align}
    y = \text{GELU} \left( \frac{\bm{w} \cdot \bm{x} - \text{mean}(\bm{x}) \times \text{sum}(\bm{w})}{\sqrt{\text{mean}(\bm{x}^2) - \text{mean}(\bm{x})^2 + \epsilon}} + b \right)
\end{align}
$$



To check our working, we can implement this computation in code and compare its output to the Lola functions above:

In [3]:
import torch


def torch_lola_rearranged(x: torch.Tensor, weights: torch.Tensor,
                          bias: torch.Tensor):
    *_, C = x.shape

    w_dot_x = torch.matmul(x, weights)
    sum_w = torch.sum(weights, dim=0, keepdim=True)
    mean_x = torch.mean(x, dim=1, keepdim=True)
    mean_x_sq = torch.mean(x * x, dim=1, keepdim=True)

    numer = w_dot_x - mean_x * sum_w
    denom = torch.sqrt(mean_x_sq - mean_x * mean_x + 1e-5)

    y = F.gelu(numer / denom + bias)

    return y

In [4]:
import jax.numpy as jnp


@jax.jit
def jax_lola_rearranged(x: jax.Array, weights: jax.Array, bias: jax.Array):
    *_, C = x.shape

    w_dot_x = jnp.dot(x, weights)
    sum_w = jnp.sum(weights, axis=0, keepdims=True)
    mean_x = jnp.mean(x, axis=1, keepdims=True)
    mean_x_sq = jnp.mean(x * x, axis=1, keepdims=True)

    numer = w_dot_x - mean_x * sum_w
    denom = jnp.sqrt(mean_x_sq - mean_x * mean_x + 1e-5)

    y = nn.activation.gelu(numer / denom + bias)

    return y

In [5]:
from typing import Union

import jax
import jax.numpy as jnp
import numpy as np
import torch

### Some utilities to help compare different Lola implementations in JAX, PyTorch and Triton.


def make_tensors(n_embd: int,
                 n_batch: int,
                 jax_dtype=jnp.float32,
                 torch_dtype=torch.float32):
    """Create embedding and parameter tensors required by JAX, PyTorch and Trition.
    
    We use the same values for all frameworks so outputs can be compared."""

    key = jax.random.PRNGKey(0)
    x = jax.random.normal(key, (n_batch, n_embd), dtype=jnp.float32)

    params = FlaxLola().init(key, x)

    weights = params["params"]["Dense_0"][
        "kernel"]  # Shape: [n_embd, 4 * n_embd]
    bias = params["params"]["Dense_0"]["bias"]

    tx = torch.tensor(np.array(x), device="cuda", dtype=torch_dtype)
    tweights = torch.tensor(np.array(weights),
                            device="cuda",
                            dtype=torch_dtype)
    tweights_sum = tweights.sum(dim=0)  # Shape: [4 * n_embd]
    # Shape: [4 * n_embd, n_embd]. Torch's linear layer expects the weights to be transposed.
    tweightsT = tweights.T.contiguous()
    tbias = torch.tensor(np.array(bias), device="cuda", dtype=torch_dtype)

    x = x.astype(jax_dtype)
    params = jax.tree_util.tree_map(lambda x: x.astype(jax_dtype), params)

    return (x, params), (tx, tweights, tweights_sum, tweightsT, tbias)


def max_abs_diff(t1: Union[torch.Tensor, jax.Array], t2: Union[torch.Tensor,
                                                               jax.Array]):
    """Compute the maximum absolute difference between two tensors (i.e. the l-infinity norm)"""
    def to_numpy(t):
        if isinstance(t, torch.Tensor):
            return t.cpu().numpy()
        elif isinstance(t, jnp.ndarray):
            return np.array(t)
        else:
            raise ValueError(f"Unknown tensor type: {type(t)}")

    return np.max(np.abs(to_numpy(t1) - to_numpy(t2)))

In [6]:
(x, params), (tx, tweights, tweights_sum, tweightsT,
              tbias) = make_tensors(n_embd=1024, n_batch=512)

flax_out = flax_lola_apply(params, x)
torch_out = torch_lola(tx, tweightsT, tbias)

jax_rearr_out = jax_lola_rearranged(x, params["params"]["Dense_0"]["kernel"],
                                    params["params"]["Dense_0"]["bias"])
torch_rearr_out = torch_lola_rearranged(tx, tweights, tbias)

flax_out.shape, torch_out.shape, jax_rearr_out.shape, torch_rearr_out.shape

((512, 4096), torch.Size([512, 4096]), (512, 4096), torch.Size([512, 4096]))

In [7]:
for name, (t1, t2) in [
    ("Torch <-> Flax", (torch_out, flax_out)),
    ("Torch <-> Torch (rearranged)", (torch_out, torch_rearr_out)),
    ("Torch <-> JAX (rearranged)", (torch_out, jax_rearr_out)),
    ("Flax <-> JAX (rearranged)", (flax_out, jax_rearr_out)),
]:
    print(f"{name:<30} {max_abs_diff(t1, t2):.6f}")

Torch <-> Flax                 0.001669
Torch <-> Torch (rearranged)   0.000008
Torch <-> JAX (rearranged)     0.001738
Flax <-> JAX (rearranged)      0.001493


# Iterating on the Triton Kernel

In this section, I develop a series of Triton kernels to compute the Lola expression derived above. By iteratively improving on the kernels instead of just jumping to the final kernel, I hope to demonstrate the motivations behind various aspects of the final kernel.

Fundamentally, we are writing kernels to implement the matrix multiplication shown below, between a batch of input embeddings $\bm{X}$ and a weights matrix $\bm{W}$. The diagram on the right shows the multiplication required to compute a single element of the output.





## Non-starters

The first Triton kernel we present is already quite sophisticated and does some non-obvious things. In this section, we explain why simpler solutions either wouldn't work or wouldn't perform optimally.



The most basic kernel design would just look like `torch_lola_rearranged` - i.e. just load all the entire batch of embeddings and the entire weights matrix and do the calculation. Such a kernel would crash immediately, because the entire inputs and weights matrices do not find in GPU shared memory. Even if they did fit, such a kernel would execute on a single SM only, leaving the other 107 SMs idle. This is clearly not optimal.

CUDA kernels (and therefore also Triton kernels) are written to be parallelizable. Here's some pseudocode for a kernel which computes the matrix multiplication $\bm{Y} = \bm{X} \bm{W}$:

```python
def mat_mul(X, W, Y):
    kernel_idx = get_kernel_idx()
    
    Y[kernel_idx, :] = X[kernel_idx, :] @ W
```

When executing this kernel, the GPU would launch as many instances of the kernel as there are rows of $\bm{X}$. When each individual instance calls `get_kernel_idx()`, it will get a different index, and therefore end up computing a different row of the output.

Another simple idea of a Lola kernel would look similar to this - have each instance of the kernel compute the output for a single row of the input. This would parallelize a lot better than the first kernel, but would still overflow GPU shared memory, due to each kernel needing to load the full weights matrix. The solution to this is to stream the input embeddings and weights:

```python
def streaming_mat_mul(X, W, Y):
    kernel_idx = get_kernel_idx()

    accumulator = zeros(Y.shape[1])
    for i in range(W.shape[0]):
        accumulator += X[kernel_idx, i] * W[:, i]

    Y[kernel_idx, :] = accumulator
```

This kernel would work, but would be quite slow. For one thing, each instance of the kernel must (over its lifetime) stream the entire weights matrix. A better way to parallelize this computation is to have each kernel instance handle multiple rows of the input $\bm{X}$, but only compute a subset of the corresponding columns of $\bm{Y}$. The computation performed by a single kernel instance is shown in the diagram below. The instances multiples a subset of the rows of $\bm{X}$ with a subset of the columns of $\bm{W}$ to compute a block (a subset of the rows and of the columns) of $\bm{Y}$.



The advantage of this is that each element of the weights matrix is used multiple times - once for each row of $\bm{X}$. We can therefore compute the output with less total memory accesses. An additional benefit of this is that each kernel instance is now doing matrix-matrix multiplications (as opposed to the matrix-vector multiplication in the `mat_mul` kernel). Recent GPUs have specialized hardware for matrix-matrix multiplications.

## Lola Kernel v1

We first implement the GELU activation function, so it can be used from our Triton kernel.

In [8]:
import math

import triton
import triton.language as tl

sqrt2pi = math.sqrt(2.0 / math.pi)
sqrt2 = math.sqrt(2.0)


@triton.jit
def tanh(x):
    """Tanh activation function"""
    return tl.libdevice.tanh(x)


@triton.jit
def fast_gelu(x):
    """Fast approximation of the gelu function. May slightly decrease accuracy."""
    return 0.5 * x * (1 + tanh(sqrt2pi * (x + 0.044715 * x * x * x)))


@triton.jit
def gelu(x):
    """Gaussian Error Linear Unit (GELU)"""
    return x * 0.5 * (1.0 + tl.libdevice.erf(x / sqrt2))

In [9]:
print(sqrt2pi, sqrt2)

0.7978845608028654 1.4142135623730951


Triton kernels accept pointers to the input and output tensors as arguments. We can load elements from the input tensors by using `tl.load`, and write elements into the output tensor using `tl.store`.

In [10]:
@triton.jit
def kernel_lola_v1(x_ptr, W_ptr, b_ptr, out_ptr, N_OCOLS: tl.constexpr,
                   N_BROWS: tl.constexpr, BLOCK_LEN: tl.constexpr,
                   N_FEAT_IN: tl.constexpr, N_FEAT_OUT: tl.constexpr):
    """
    Triton kernel implementing fused Layer nOrm, Linear and Activation.

    Kernel cell (i, j) computes:
    out[i * N_BROWS: (i+1) * N_BROWS, j * N_OCOLS: (j+1) * N_OCOLS].
    
    Iteration k of the inner loop computes:
    x[(i:i+1) * N_BROWS, (k:k+1) * BLOCK_LEN] @ W[(k:k+1) * BLOCK_LEN, (j:j+1) * N_OCOLS]

    Inputs
    ------
    x_ptr: [BATCH_SIZE, N_FEAT_IN] - input token embeddings.
    W_ptr: [N_FEAT_IN, N_FEAT_OUT] - linear layer weights.
    b_ptr: [N_FEAT_OUT,] - linear layer bias.

    Outputs
    -------
    out_ptr: [N_FEAT_OUT,] - Lola output.
    
    Parameters
    ----------
    N_OCOLS - number of output columns computed per kernel instance.
    N_BROWS - number of batch elements (i.e. rows of `x`) computed per kernel instance.
    BLOCK_LEN - size of the block of `x` and `W` processed each iteration of the inner loop.
    N_FEAT_IN - number of input features.
    N_FEAT_OUT - number of output features.
    """
    # Each instance will process x[x_brow_start: x_brow_start + N_BROWS, :].
    x_brow_start = tl.program_id(0) * N_BROWS
    x_brow_idxs = tl.arange(0, N_BROWS) + x_brow_start

    # Each instance will compute out[<b-rows>, ocol_start: ocol_start + N_OCOLS].
    ocol_start = tl.program_id(1) * N_OCOLS
    ocol_idxs = tl.arange(0, N_OCOLS) + ocol_start

    # Initialize accumulators. We build up partial results while iterating over `x` and `W`.
    w_dot_x_acc = tl.zeros((N_BROWS, N_OCOLS), dtype=tl.float32)
    w_sum_acc = tl.zeros((N_OCOLS, ), dtype=tl.float32)
    x_sum_acc = tl.zeros((N_BROWS, ), dtype=tl.float32)
    x_sq_sum_acc = tl.zeros((N_BROWS, ), dtype=tl.float32)

    # Iterate over N_FEAT_IN elements, in blocks of size BLOCK_LEN.
    n_blocks = tl.cdiv(N_FEAT_IN, BLOCK_LEN)
    for block_i in range(0, n_blocks):

        # Indices into the block dimension - columns of `x` and rows of `W`.
        block_idxs = tl.arange(0, BLOCK_LEN) + block_i * BLOCK_LEN

        # Load the current block of the input.
        x_block_idxs = x_brow_idxs[:, None] * N_FEAT_IN + block_idxs[None, :]
        x_block = tl.load(x_ptr + x_block_idxs)  # [N_BROWS, BLOCK_ROWS]

        W_block_idxs = block_idxs[:, None] * N_FEAT_OUT + ocol_idxs[None, :]
        W_block = tl.load(W_ptr + W_block_idxs)  # [BLOCK_ROWS, N_OCOLS]

        # Update the accumulators.
        # [N_BROWS, BLOCK_ROWS] @ [BLOCK_ROWS, N_OCOLS] -> [N_BROWS, N_OCOLS]
        w_dot_x_acc += tl.dot(x_block, W_block)
        w_sum_acc += tl.sum(W_block, axis=0)
        x_sum_acc += tl.sum(x_block, axis=1)
        x_sq_sum_acc += tl.sum(x_block * x_block, axis=1)

    bias = tl.load(b_ptr + ocol_idxs)
    x_mean = x_sum_acc / N_FEAT_IN
    x_sq_mean = x_sq_sum_acc / N_FEAT_IN

    numer = w_dot_x_acc - x_mean[:, None] * w_sum_acc[None, :]
    denom = tl.sqrt(tl.abs(x_sq_mean - x_mean * x_mean + 1e-5))
    out = gelu(numer / denom[:, None] + bias[None, :])

    out_idxs = x_brow_idxs[:, None] * N_FEAT_OUT + ocol_idxs[None, :]
    tl.store(out_ptr + out_idxs, out)

We typically write a wrapper function for each Triton kernel. The wrapper is responsible for calculating some of the constants such as `N_FEAT_IN` from the inputs tensors, and for deciding how many instances of the kernel to dispatch.

In [11]:
def dispatch_lola_v1(
    x,
    W,
    b,
    N_OCOLS: int,
    N_BROWS: int,
    BLOCK_LEN: int,
    num_warps=4,
    num_stages=1,
):
    assert N_BROWS >= 16 and BLOCK_LEN >= 16 and N_OCOLS >= 16, "Triton matrix multiplication requires matrix dimensions to be at least 16."

    N_BATCH = x.shape[0]
    N_FEAT_IN, N_FEAT_OUT = W.shape

    # Each instance processes `N_BROWS` of `x`, and `N_OCOLS` columns of `W`, so we need
    # `N_BATCH // N_BROWS * N_FEAT_OUT // N_OCOLS` instances.
    grid = (
        N_BATCH // N_BROWS,
        N_FEAT_OUT // N_OCOLS,
    )

    # Allocate output buffer.
    out = torch.zeros((N_BATCH, N_FEAT_OUT), dtype=x.dtype, device="cuda")

    # Launch the kernel.
    kernel_lola_v1[grid](x,
                         W,
                         b,
                         out,
                         N_OCOLS=N_OCOLS,
                         N_BROWS=N_BROWS,
                         BLOCK_LEN=BLOCK_LEN,
                         N_FEAT_IN=N_FEAT_IN,
                         N_FEAT_OUT=N_FEAT_OUT,
                         num_warps=num_warps,
                         num_stages=num_stages)

    return out

### Output comparison

In [12]:
v1_out = dispatch_lola_v1(tx, tweights, tbias, 128, 64, 32)
v1_out.shape

torch.Size([512, 4096])

In [13]:
max_abs_diff(torch_out, v1_out)

0.003700018

### Performance tuning

Notice that our kernel has free parameters `N_OCOLS`, `N_BROWS` and `BLOCK_LEN`, which control the distribution of work across kernel instances. In general there is no principled way to choose these parameters, other than to test a bunch of different values and choose the ones which don't crash and give the fastest performance. Tuning these parameters is critical for kernel performance - **the difference between good and bad choices can be multiple orders of magnitude in speed**.

One saving grace is that matrices in Triton must have power-of-two dimensions. This dramatically cuts down on the number of possible configurations.

In [14]:
%load_ext autoreload
%autoreload 2

In [15]:
from functools import partial

from conch import grid_search, PTXAnalyser

In [16]:
%script false --no-raise-error


v1_perf = grid_search(partial(dispatch_lola_v1, tx, tweights, tbias),
                      N_OCOLS=(16, 1024),
                      N_BROWS=(16, 1024),
                      BLOCK_LEN=(16, 1024),
                      min_val_prod=50_000,)
v1_perf[:5]

UsageError: Line magic function `%script` not found (But cell magic `%%script` exists, did you mean that instead?).


In [17]:
v1_pa = PTXAnalyser.FromKernel(kernel_lola_v1,
                               N_OCOLS=128,
                               N_BROWS=64,
                               BLOCK_LEN=32)

### Failed optimizations

During development of this kernel, I tried a bunch of things which ending up not working.

#### GELU vs fast GELU

The GELU activation function is defined as $\text{GELU}(x) = x \cdot \Phi(x) = x \cdot \frac{1}{2} \left(1 + \text{erf} \left(\frac{x}{\sqrt{2}}\right)\right)$ (where erf is [the error function](https://en.wikipedia.org/wiki/Error_function)). GELU is commonly approximated by $\text{GELU}(x) \approx 0.5 * x * (1 + \text{tanh}(\sqrt{\frac{2}{\pi}} * (x + 0.044715 * x^3)))$, however I found that using the approximation is actually slower than just using the exact GELU function. From examining the PTX code generated by the kernels, this happens because the CUDA implementations of `erf` and `tanh` are similar in speed, and the approximation is therefore slower because of the additional operations.

## Lola Kernel v2

Discussions on the Triton Github suggest that Triton kernels can perform poorly when multiple accmulators are used https://github.com/openai/triton/discussions/1186. We can remove one of our accumulators by computing the sum of weights ahead of time and passing this in as an argument to the kernel. 

Prior to testing this, I would have expected that re-computing the weight sum in the kernel would be faster than pre-computing and loading them in, since the re-computation has low arithmetic intensity, however the pre-computation does actually result in a significant performance increase.

Unfortunately we cannot apply the same trick to the embeddings `x`, since we are assuming that the kernel will be used in an inference setting, where the weights are fixed and the embeddings are changing.

In [18]:
@triton.jit
def kernel_lola_v2(x_ptr, W_ptr, Ws_ptr, b_ptr, out_ptr, N_OCOLS: tl.constexpr,
                   N_BROWS: tl.constexpr, BLOCK_LEN: tl.constexpr,
                   N_FEAT_IN: tl.constexpr, N_FEAT_OUT: tl.constexpr):
    """
    Triton kernel implementing fused Layer nOrm, Linear and Activation.

    Kernel cell (i, j) computes:
    out[i * N_BROWS: (i+1) * N_BROWS, j * N_OCOLS: (j+1) * N_OCOLS].
    
    Iteration k of the inner loop computes:
    x[(i:i+1) * N_BROWS, (k:k+1) * BLOCK_LEN] @ W[(k:k+1) * BLOCK_LEN, (j:j+1) * N_OCOLS]

    Inputs
    ------
    x_ptr: [BATCH_SIZE, N_FEAT_IN] - input token embeddings.
    W_ptr: [N_FEAT_IN, N_FEAT_OUT] - linear layer weights.
    Ws_ptr: [N_FEAT_OUT,] - sum of the weights (in axis 0).
    b_ptr: [N_FEAT_OUT,] - linear layer bias.

    Outputs
    -------
    out_ptr: [N_FEAT_OUT,] - Lola output.
    
    Parameters
    ----------
    N_OCOLS - number of output columns computed per kernel instance.
    N_BROWS - number of batch elements (i.e. rows of `x`) computed per kernel instance.
    BLOCK_LEN - size of the block of `x` and `W` processed each iteration of the inner loop.
    N_FEAT_IN - number of input features.
    N_FEAT_OUT - number of output features.
    """
    # Each instance will process x[x_brow_start: x_brow_start + N_BROWS, :].
    x_brow_start = tl.program_id(0) * N_BROWS
    x_brow_idxs = tl.arange(0, N_BROWS) + x_brow_start

    # Each instance will compute out[<b-rows>, ocol_start: ocol_start + N_OCOLS].
    ocol_start = tl.program_id(1) * N_OCOLS
    ocol_idxs = tl.arange(0, N_OCOLS) + ocol_start

    # Initialize accumulators. We build up partial results while iterating over `x` and `W`.
    w_dot_x_acc = tl.zeros((N_BROWS, N_OCOLS), dtype=tl.float32)
    x_sum_acc = tl.zeros((N_BROWS, ), dtype=tl.float32)
    x_sq_sum_acc = tl.zeros((N_BROWS, ), dtype=tl.float32)

    # Iterate over N_FEAT_IN elements, in blocks of size BLOCK_LEN.
    n_blocks = tl.cdiv(N_FEAT_IN, BLOCK_LEN)
    for block_i in range(0, n_blocks):

        # Indices into the block dimension - columns of `x` and rows of `W`.
        block_idxs = tl.arange(0, BLOCK_LEN) + block_i * BLOCK_LEN

        # Load the current block of the input.
        x_block_idxs = x_brow_idxs[:, None] * N_FEAT_IN + block_idxs[None, :]
        x_block = tl.load(x_ptr + x_block_idxs)  # [N_BROWS, BLOCK_ROWS]

        W_block_idxs = block_idxs[:, None] * N_FEAT_OUT + ocol_idxs[None, :]
        W_block = tl.load(W_ptr + W_block_idxs)  # [BLOCK_ROWS, N_OCOLS]

        # Update the accumulators.
        # [N_BROWS, BLOCK_ROWS] @ [BLOCK_ROWS, N_OCOLS] -> [N_BROWS, N_OCOLS]
        w_dot_x_acc += tl.dot(x_block, W_block)
        x_sum_acc += tl.sum(x_block, axis=1)
        x_sq_sum_acc += tl.sum(x_block * x_block, axis=1)

    bias = tl.load(b_ptr + ocol_idxs)
    Wsum = tl.load(Ws_ptr + ocol_idxs)
    x_mean = x_sum_acc / N_FEAT_IN
    x_sq_mean = x_sq_sum_acc / N_FEAT_IN

    numer = w_dot_x_acc - x_mean[:, None] * Wsum[None, :]
    denom = tl.sqrt(tl.abs(x_sq_mean - x_mean * x_mean + 1e-5))
    out = gelu(numer / denom[:, None] + bias[None, :])

    out_idxs = x_brow_idxs[:, None] * N_FEAT_OUT + ocol_idxs[None, :]
    tl.store(out_ptr + out_idxs, out)

In [19]:
def dispatch_lola_v2(
    x,
    W,
    Wsum,
    b,
    N_OCOLS: int,
    N_BROWS: int,
    BLOCK_LEN: int,
    num_warps=4,
    num_stages=1,
):
    assert N_BROWS >= 16 and BLOCK_LEN >= 16 and N_OCOLS >= 16, "Triton matrix multiplication requires matrix dimensions to be at least 16."

    N_BATCH = x.shape[0]
    N_FEAT_IN, N_FEAT_OUT = W.shape

    # Each instance processes `N_BROWS` of `x`, and `N_OCOLS` columns of `W`, so we need
    # `N_BATCH // N_BROWS * N_FEAT_OUT // N_OCOLS` instances.
    grid = (
        N_BATCH // N_BROWS,
        N_FEAT_OUT // N_OCOLS,
    )

    # Allocate output buffer.
    out = torch.zeros((N_BATCH, N_FEAT_OUT), dtype=x.dtype, device="cuda")

    # Launch the kernel.
    kernel_lola_v2[grid](x,
                         W,
                         Wsum,
                         b,
                         out,
                         N_OCOLS=N_OCOLS,
                         N_BROWS=N_BROWS,
                         BLOCK_LEN=BLOCK_LEN,
                         N_FEAT_IN=N_FEAT_IN,
                         N_FEAT_OUT=N_FEAT_OUT,
                         num_warps=num_warps,
                         num_stages=num_stages)

    return out

### Output comparison

In [20]:
v2_out = dispatch_lola_v2(tx, tweights, tweights_sum, tbias, 128, 64, 32)
v2_out.shape

torch.Size([512, 4096])

In [21]:
max_abs_diff(torch_out, v2_out)

0.003700018

### Performance tuning

In [22]:
%%script false --no-raise-error

v2_perf = grid_search(partial(dispatch_lola_v2, tx, tweights, tweights_sum,
                              tbias),
                      N_OCOLS=(16, 1024),
                      N_BROWS=(16, 1024),
                      BLOCK_LEN=(16, 1024),
                      min_val_prod=50_000,
)
v2_perf[:5]

- Fast GELU: ` ({'N_OCOLS': 128, 'N_BROWS': 64, 'BLOCK_LEN': 32}, 130.048006772995),`
- No activation: `({'N_OCOLS': 128, 'N_BROWS': 64, 'BLOCK_LEN': 32}, 121.85599654912949)`
- Full GELU: `({'N_OCOLS': 128, 'N_BROWS': 64, 'BLOCK_LEN': 32}, 128.00000607967377),`


## Kernel v3

### Transposed Weights

In [23]:
w = torch.arange(24).reshape(4, 6)
w

tensor([[ 0,  1,  2,  3,  4,  5],
        [ 6,  7,  8,  9, 10, 11],
        [12, 13, 14, 15, 16, 17],
        [18, 19, 20, 21, 22, 23]])

In [24]:
w.T

tensor([[ 0,  6, 12, 18],
        [ 1,  7, 13, 19],
        [ 2,  8, 14, 20],
        [ 3,  9, 15, 21],
        [ 4, 10, 16, 22],
        [ 5, 11, 17, 23]])

In [25]:
# attempt to select the block [[9, 10], [15, 16], [21, 22]] in both w and w.T

block_idxs = torch.tensor([1, 2, 3])
ocol_idxs = torch.tensor([3, 4])

W_block_idxs = block_idxs[:, None] * 6 + ocol_idxs[None, :]

W_block_idxs

tensor([[ 9, 10],
        [15, 16],
        [21, 22]])

In [26]:
w.flatten()[W_block_idxs]

tensor([[ 9, 10],
        [15, 16],
        [21, 22]])

In [27]:
Wt_block_idxs = ocol_idxs[None, :] * 4 + block_idxs[:, None]
Wt_block_idxs

tensor([[13, 17],
        [14, 18],
        [15, 19]])

In [28]:
w.T.flatten()[Wt_block_idxs]

tensor([[ 9, 10],
        [15, 16],
        [21, 22]])

In [29]:
tweights.shape, tweightsT.shape

(torch.Size([1024, 4096]), torch.Size([4096, 1024]))

In [30]:
ocol_idxs = torch.arange(0, 4) + 12
block_idxs = torch.arange(0, 8) + 16
ocol_idxs, block_idxs

(tensor([12, 13, 14, 15]), tensor([16, 17, 18, 19, 20, 21, 22, 23]))

In [31]:
N_FEAT_IN, N_FEAT_OUT = tweights.shape
N_FEAT_IN, N_FEAT_OUT

(1024, 4096)

In [32]:
W_block_idxs = block_idxs[:, None] * N_FEAT_OUT + ocol_idxs[None, :]
W_block_idxs

tensor([[65548, 65549, 65550, 65551],
        [69644, 69645, 69646, 69647],
        [73740, 73741, 73742, 73743],
        [77836, 77837, 77838, 77839],
        [81932, 81933, 81934, 81935],
        [86028, 86029, 86030, 86031],
        [90124, 90125, 90126, 90127],
        [94220, 94221, 94222, 94223]])

In [33]:
tweights.flatten()[W_block_idxs]

tensor([[ 0.0194,  0.0143, -0.0078, -0.0289],
        [ 0.0435, -0.0051,  0.0195,  0.0290],
        [ 0.0318,  0.0176, -0.0008,  0.0447],
        [-0.0306,  0.0089,  0.0175,  0.0062],
        [-0.0122,  0.0100,  0.0699,  0.0270],
        [-0.0278, -0.0342,  0.0610, -0.0172],
        [ 0.0528,  0.0565, -0.0195,  0.0314],
        [-0.0176, -0.0471,  0.0369,  0.0560]], device='cuda:0')

In [34]:
Wt_block_idxs = ocol_idxs[None, :] * N_FEAT_IN + block_idxs[:, None]
Wt_block_idxs

tensor([[12304, 13328, 14352, 15376],
        [12305, 13329, 14353, 15377],
        [12306, 13330, 14354, 15378],
        [12307, 13331, 14355, 15379],
        [12308, 13332, 14356, 15380],
        [12309, 13333, 14357, 15381],
        [12310, 13334, 14358, 15382],
        [12311, 13335, 14359, 15383]])

In [35]:
tweightsT.flatten()[Wt_block_idxs]

tensor([[ 0.0194,  0.0143, -0.0078, -0.0289],
        [ 0.0435, -0.0051,  0.0195,  0.0290],
        [ 0.0318,  0.0176, -0.0008,  0.0447],
        [-0.0306,  0.0089,  0.0175,  0.0062],
        [-0.0122,  0.0100,  0.0699,  0.0270],
        [-0.0278, -0.0342,  0.0610, -0.0172],
        [ 0.0528,  0.0565, -0.0195,  0.0314],
        [-0.0176, -0.0471,  0.0369,  0.0560]], device='cuda:0')

In [36]:
@triton.jit
def kernel_lola_v3(x_ptr, Wt_ptr, Ws_ptr, b_ptr, out_ptr, N_OCOLS: tl.constexpr,
                   N_BROWS: tl.constexpr, BLOCK_LEN: tl.constexpr,
                   N_FEAT_IN: tl.constexpr, N_FEAT_OUT: tl.constexpr):
    """
    Triton kernel implementing fused Layer nOrm, Linear and Activation.

    Kernel cell (i, j) computes:
    out[i * N_BROWS: (i+1) * N_BROWS, j * N_OCOLS: (j+1) * N_OCOLS].
    
    Iteration k of the inner loop computes:
    x[(i:i+1) * N_BROWS, (k:k+1) * BLOCK_LEN] @ W[(k:k+1) * BLOCK_LEN, (j:j+1) * N_OCOLS]

    Inputs
    ------
    x_ptr: [BATCH_SIZE, N_FEAT_IN] - input token embeddings.
    W_ptr: [N_FEAT_IN, N_FEAT_OUT] - linear layer weights.
    Ws_ptr: [N_FEAT_OUT,] - sum of the weights (in axis 0).
    b_ptr: [N_FEAT_OUT,] - linear layer bias.

    Outputs
    -------
    out_ptr: [N_FEAT_OUT,] - Lola output.
    
    Parameters
    ----------
    N_OCOLS - number of output columns computed per kernel instance.
    N_BROWS - number of batch elements (i.e. rows of `x`) computed per kernel instance.
    BLOCK_LEN - size of the block of `x` and `W` processed each iteration of the inner loop.
    N_FEAT_IN - number of input features.
    N_FEAT_OUT - number of output features.
    """
    # Each instance will process x[x_brow_start: x_brow_start + N_BROWS, :].
    x_brow_start = tl.program_id(0) * N_BROWS
    x_brow_idxs = tl.arange(0, N_BROWS) + x_brow_start

    # Each instance will compute out[<b-rows>, ocol_start: ocol_start + N_OCOLS].
    ocol_start = tl.program_id(1) * N_OCOLS
    ocol_idxs = tl.arange(0, N_OCOLS) + ocol_start

    # Initialize accumulators. We build up partial results while iterating over `x` and `W`.
    w_dot_x_acc = tl.zeros((N_BROWS, N_OCOLS), dtype=tl.float32)
    x_sum_acc = tl.zeros((N_BROWS, ), dtype=tl.float32)
    x_sq_sum_acc = tl.zeros((N_BROWS, ), dtype=tl.float32)

    # Iterate over N_FEAT_IN elements, in blocks of size BLOCK_LEN.
    n_blocks = tl.cdiv(N_FEAT_IN, BLOCK_LEN)
    for block_i in range(0, n_blocks):

        # Indices into the block dimension - columns of `x` and rows of `W`.
        block_idxs = tl.arange(0, BLOCK_LEN) + block_i * BLOCK_LEN

        # Load the current block of the input.
        x_block_idxs = x_brow_idxs[:, None] * N_FEAT_IN + block_idxs[None, :]
        x_block = tl.load(x_ptr + x_block_idxs)  # [N_BROWS, BLOCK_ROWS]

        # W_block_idxs = block_idxs[:, None] * N_FEAT_OUT + ocol_idxs[None, :]
        Wt_block_idxs = ocol_idxs[None, :] * N_FEAT_IN + block_idxs[:, None]
        W_block = tl.load(Wt_ptr + Wt_block_idxs)  # [BLOCK_ROWS, N_OCOLS]

        # Update the accumulators.
        # [N_BROWS, BLOCK_ROWS] @ [BLOCK_ROWS, N_OCOLS] -> [N_BROWS, N_OCOLS]
        w_dot_x_acc += tl.dot(x_block, W_block)
        x_sum_acc += tl.sum(x_block, axis=1)
        x_sq_sum_acc += tl.sum(x_block * x_block, axis=1)

    bias = tl.load(b_ptr + ocol_idxs)
    Wsum = tl.load(Ws_ptr + ocol_idxs)
    x_mean = x_sum_acc / N_FEAT_IN
    x_sq_mean = x_sq_sum_acc / N_FEAT_IN

    numer = w_dot_x_acc - x_mean[:, None] * Wsum[None, :]
    denom = tl.sqrt(tl.abs(x_sq_mean - x_mean * x_mean + 1e-5))
    # out = fast_gelu(numer / denom[:, None] + bias[None, :])
    out = numer / denom[:, None] + bias[None, :]
    out = gelu(out)

    out_idxs = x_brow_idxs[:, None] * N_FEAT_OUT + ocol_idxs[None, :]
    tl.store(out_ptr + out_idxs, out)

In [37]:
def dispatch_lola_v3(
    x,
    Wt,
    Wsum,
    b,
    N_OCOLS: int,
    N_BROWS: int,
    BLOCK_LEN: int,
    num_warps=4,
    num_stages=1,
):
    assert N_BROWS >= 16 and BLOCK_LEN >= 16 and N_OCOLS >= 16, "Triton matrix multiplication requires matrix dimensions to be at least 16."

    assert N_BROWS >= 32, "Triton gives 'CUDA error: an illegal memory access was encountered' for N_BROWS = 16 - no idea why."

    N_BATCH = x.shape[0]
    N_FEAT_OUT, N_FEAT_IN = Wt.shape

    # Each instance processes `N_BROWS` of `x`, and `N_OCOLS` columns of `W`, so we need
    # `N_BATCH // N_BROWS * N_FEAT_OUT // N_OCOLS` instances.
    grid = (
        N_BATCH // N_BROWS,
        N_FEAT_OUT // N_OCOLS,
    )

    # Allocate output buffer.
    out = torch.zeros((N_BATCH, N_FEAT_OUT), dtype=x.dtype, device="cuda")

    # Launch the kernel.
    kernel_lola_v3[grid](x,
                         Wt,
                         Wsum,
                         b,
                         out,
                         N_OCOLS=N_OCOLS,
                         N_BROWS=N_BROWS,
                         BLOCK_LEN=BLOCK_LEN,
                         N_FEAT_IN=N_FEAT_IN,
                         N_FEAT_OUT=N_FEAT_OUT,
                         num_warps=num_warps,
                         num_stages=num_stages)

    return out

### Output comparison

In [38]:
tweightsT.shape

torch.Size([4096, 1024])

In [39]:
v3_out = dispatch_lola_v3(tx, tweightsT, tweights_sum, tbias, 128, 64, 32)

In [41]:
max_abs_diff(torch_out, v3_out)

0.003700018

In [48]:
%%script false --no-raise-error

v3_perf = grid_search(partial(dispatch_lola_v3, tx, tweightsT, tweights_sum,
                              tbias),
                      N_OCOLS=(16, 1024),
                      N_BROWS=(32, 1024),
                      BLOCK_LEN=(16, 1024),
                      min_val_prod=50_000,
                      do_print=True)

n_ocols=16 n_brows=32 block_len=128 : 463.87 us
n_ocols=16 n_brows=32 block_len=256 : 595.97 us
n_ocols=16 n_brows=32 block_len=512 : 12061.18 us
out of resource: shared memory, Required: 197120, Hardware limit: 166912. Reducing block sizes or `num_stages` may help.
n_ocols=16 n_brows=64 block_len=64 : 412.67 us
n_ocols=16 n_brows=64 block_len=128 : 403.46 us
n_ocols=16 n_brows=64 block_len=256 : 773.12 us
n_ocols=16 n_brows=128 block_len=32 : 285.70 us
n_ocols=16 n_brows=128 block_len=64 : 340.99 us
n_ocols=16 n_brows=128 block_len=128 : 663.55 us
n_ocols=16 n_brows=256 block_len=16 : 368.64 us
n_ocols=16 n_brows=256 block_len=32 : 324.61 us
n_ocols=16 n_brows=256 block_len=64 : 679.94 us
n_ocols=16 n_brows=512 block_len=16 : 447.49 us
n_ocols=16 n_brows=512 block_len=32 : 872.45 us
n_ocols=16 n_brows=1024 block_len=16 : 11.26 us
n_ocols=32 n_brows=32 block_len=64 : 237.57 us
n_ocols=32 n_brows=32 block_len=128 : 264.19 us
n_ocols=32 n_brows=32 block_len=256 : 435.20 us
n_ocols=32 n_b

In [49]:
v3_perf[:5]

[({'N_OCOLS': 16, 'N_BROWS': 1024, 'BLOCK_LEN': 16}, 11.264000087976456),
 ({'N_OCOLS': 128, 'N_BROWS': 64, 'BLOCK_LEN': 32}, 122.8799968957901),
 ({'N_OCOLS': 64, 'N_BROWS': 128, 'BLOCK_LEN': 32}, 135.16800105571747),
 ({'N_OCOLS': 128, 'N_BROWS': 32, 'BLOCK_LEN': 64}, 138.2399946451187),
 ({'N_OCOLS': 128, 'N_BROWS': 64, 'BLOCK_LEN': 16}, 157.69599378108978)]

## Kernel v4 - Grouping instances

The [Triton matrix multiplication tutorial](https://triton-lang.org/master/getting-started/tutorials/03-matrix-multiplication.html) suggests that it is possible to improve performance by having successive kernel instances read from the same memory locations. This makes best use of the GPUs L2 cache, which is shared between all SMs.

To do instance grouping, we introduce a new meta-parameter `GROUP_SIZE`, which is the number of instances per group in the `N_FEAT_OUT` axis. This axis is divided into `N_FEAT_OUT / N_OCOLS` blocks, and therefore into `N_FEAT_OUT / N_OCOLS / GROUP_SIZE` groups.

In [None]:
@triton.jit
def kernel_lola_v4(x_ptr, Wt_ptr, Ws_ptr, b_ptr, out_ptr,
                   N_OCOLS: tl.constexpr, N_BROWS: tl.constexpr,
                   BLOCK_LEN: tl.constexpr, GROUP_SIZE: tl.constexpr,
                   N_FEAT_IN: tl.constexpr, N_FEAT_OUT: tl.constexpr):
    """
    Triton kernel implementing fused Layer nOrm, Linear and Activation.

    Kernel cell (i, j) computes:
    out[i * N_BROWS: (i+1) * N_BROWS, j * N_OCOLS: (j+1) * N_OCOLS].
    
    Iteration k of the inner loop computes:
    x[(i:i+1) * N_BROWS, (k:k+1) * BLOCK_LEN] @ W[(k:k+1) * BLOCK_LEN, (j:j+1) * N_OCOLS]

    Inputs
    ------
    x_ptr: [BATCH_SIZE, N_FEAT_IN] - input token embeddings.
    W_ptr: [N_FEAT_IN, N_FEAT_OUT] - linear layer weights.
    Ws_ptr: [N_FEAT_OUT,] - sum of the weights (in axis 0).
    b_ptr: [N_FEAT_OUT,] - linear layer bias.

    Outputs
    -------
    out_ptr: [N_FEAT_OUT,] - Lola output.
    
    Parameters
    ----------
    N_OCOLS - number of output columns computed per kernel instance.
    N_BROWS - number of batch elements (i.e. rows of `x`) computed per kernel instance.
    BLOCK_LEN - size of the block of `x` and `W` processed each iteration of the inner loop.
    GROUP_SIZE - number of blocks/instances per group in the `N_FEAT_OUT` axis.
    N_FEAT_IN - number of input features.
    N_FEAT_OUT - number of output features.
    """
    # Each instance will process x[x_brow_start: x_brow_start + N_BROWS, :].
    x_brow_start = tl.program_id(0) * N_BROWS
    x_brow_idxs = tl.arange(0, N_BROWS) + x_brow_start

    # Each instance will compute out[<b-rows>, ocol_start: ocol_start + N_OCOLS].
    ocol_start = tl.program_id(1) * N_OCOLS
    ocol_idxs = tl.arange(0, N_OCOLS) + ocol_start

    # Initialize accumulators. We build up partial results while iterating over `x` and `W`.
    w_dot_x_acc = tl.zeros((N_BROWS, N_OCOLS), dtype=tl.float32)
    x_sum_acc = tl.zeros((N_BROWS, ), dtype=tl.float32)
    x_sq_sum_acc = tl.zeros((N_BROWS, ), dtype=tl.float32)

    # Iterate over N_FEAT_IN elements, in blocks of size BLOCK_LEN.
    n_blocks = tl.cdiv(N_FEAT_IN, BLOCK_LEN)
    for block_i in range(0, n_blocks):

        # Indices into the block dimension - columns of `x` and rows of `W`.
        block_idxs = tl.arange(0, BLOCK_LEN) + block_i * BLOCK_LEN

        # Load the current block of the input.
        x_block_idxs = x_brow_idxs[:, None] * N_FEAT_IN + block_idxs[None, :]
        x_block = tl.load(x_ptr + x_block_idxs)  # [N_BROWS, BLOCK_ROWS]

        # W_block_idxs = block_idxs[:, None] * N_FEAT_OUT + ocol_idxs[None, :]
        Wt_block_idxs = ocol_idxs[None, :] * N_FEAT_IN + block_idxs[:, None]
        W_block = tl.load(Wt_ptr + Wt_block_idxs)  # [BLOCK_ROWS, N_OCOLS]

        # Update the accumulators.
        # [N_BROWS, BLOCK_ROWS] @ [BLOCK_ROWS, N_OCOLS] -> [N_BROWS, N_OCOLS]
        w_dot_x_acc += tl.dot(x_block, W_block)
        x_sum_acc += tl.sum(x_block, axis=1)
        x_sq_sum_acc += tl.sum(x_block * x_block, axis=1)

    bias = tl.load(b_ptr + ocol_idxs)
    Wsum = tl.load(Ws_ptr + ocol_idxs)
    x_mean = x_sum_acc / N_FEAT_IN
    x_sq_mean = x_sq_sum_acc / N_FEAT_IN

    numer = w_dot_x_acc - x_mean[:, None] * Wsum[None, :]
    denom = tl.sqrt(tl.abs(x_sq_mean - x_mean * x_mean + 1e-5))
    # out = fast_gelu(numer / denom[:, None] + bias[None, :])
    out = numer / denom[:, None] + bias[None, :]
    out = gelu(out)

    out_idxs = x_brow_idxs[:, None] * N_FEAT_OUT + ocol_idxs[None, :]
    tl.store(out_ptr + out_idxs, out)

# CUDA Graphs

In [58]:
cgraph_lola_v3 = torch.compile(dispatch_lola_v3, mode="reduce-overhead")

AttributeError: module 'torch' has no attribute 'compile'

# Benchmarking

In [50]:
from triton.testing import do_bench

# We use `do_bench` for both triton and pytorch, for fairness. `do_bench` does things like clearing the l2 cache, which %%timeit doesn't

In [54]:
do_bench(partial(torch_lola, tx, tweightsT, tbias), warmup=100, rep=250)[0] * 1000

301.05599761009216

In [57]:
do_bench(partial(dispatch_lola_v3, tx, tweightsT, tweights_sum, tbias, 128, 64, 32))[0] * 1000

155.64799308776855

In [None]:
dispatch_lola_v3(tx, tweightsT, tweights_sum, tbias, 128, 64, 32)

In [None]:
v2_pa = PTXAnalyser.FromKernel(kernel_lola_v2, N_OCOLS=128, N_BROWS=64, BLOCK_LEN=32)

In [None]:
v3_pa = PTXAnalyser.FromKernel(kernel_lola_v3, N_OCOLS=128, N_BROWS=64, BLOCK_LEN=32)

In [None]:
with open('cubin.bin', 'wb') as f:
    f.write(v3_pa.kernel_asm["cubin"])

In [None]:
v3_pa.op_counts

Counter({'mov.f32': 1059,
         'fma.rn.ftz.f32': 455,
         'mov.b32': 443,
         'mul.f32': 194,
         'add.f32': 174,
         'div.full.f32': 136,
         'bra': 131,
         'or.b32': 110,
         'fma.rn.f32': 83,
         'and.b32': 79,
         'ld.shared.u32': 73,
         'add.s32': 72,
         'neg.f32': 69,
         'setp.ge.f32': 69,
         'sub.f32': 69,
         'selp.f32': 69,
         'abs.ftz.f32': 65,
         'setp.ltu.f32': 65,
         'ex2.approx.ftz.f32': 65,
         'mma.sync.aligned.m16n8k8.row.col.f32.tf32.tf32.f32': 64,
         'ld.shared.f32': 54,
         'shl.b32': 49,
         'st.shared.v2.f32': 32,
         'add.s64': 30,
         'mul.wide.s32': 29,
         'bar.sync': 24,
         'st.shared.u32': 18,
         'ld.shared.v4.u32': 16,
         'st.global.v4.b32': 16,
         'xor.b32': 14,
         'ld.global.v4.b32': 12,
         'shfl.sync.bfly.b32': 12,
         'bfe.u32': 8,
         'st.shared.v4.u32': 8,
         'ldmatrix.

In [None]:
print(v3_pa.summarize_ptx())

// 
// Generated by LLVM NVPTX Back-End 
// 
 
.version 8.0 
.target sm_80 
.address_size 64 
 
	// .globl	kernel_lola_v3_0d1d2d3d4d 
.extern .shared .align 1 .b8 global_smem[]; 
.global .align 1 .b8 _$_str[11] = {95, 95, 67, 85, 68, 65, 95, 70, 84, 90, 0}; 
.global .align 1 .b8 _$_str_$_2[17] = {95, 95, 67, 85, 68, 65, 95, 80, 82, 69, 67, 95, 83, 81, 82, 84, 0}; 
 
.visible .entry kernel_lola_v3_0d1d2d3d4d( 
	.param .u64 kernel_lola_v3_0d1d2d3d4d_param_0, 
	.param .u64 kernel_lola_v3_0d1d2d3d4d_param_1, 
	.param .u64 kernel_lola_v3_0d1d2d3d4d_param_2, 
	.param .u64 kernel_lola_v3_0d1d2d3d4d_param_3, 
	.param .u64 kernel_lola_v3_0d1d2d3d4d_param_4 
) 
.maxntid 128, 1, 1 
{ 
	.reg .pred 	%p<234>; 
	.reg .b16 	%rs<3>; 
	.reg .b32 	%r<1552>; 
	.reg .f32 	%f<4587>; 
	.reg .b64 	%rd<65>; 
 
	ld.param.u64 	%rd5, [kernel_lola_v3_0d1d2d3d4d_param_4]; 
	ld.param.u64 	%rd4, [kernel_lola_v3_0d1d2d3d4d_param_3]; 
	ld.param.u64 	%rd3, [kernel_lola_v3_0d1d2d3d4d_param_2]; 
	ld.param.u64 	%rd2, [kern

In [None]:
import more_itertools
from collections import Counter

In [None]:
Counter(more_itertools.substrings(v1_pa.op_names))

Counter({('ld.param.u64',): 4,
         ('mov.u32',): 6,
         ('and.b32',): 79,
         ('shr.u32',): 4,
         ('bfe.u32',): 8,
         ('or.b32',): 111,
         ('shl.b32',): 56,
         ('xor.b32',): 14,
         ('add.s32',): 96,
         ('mad.lo.s32',): 4,
         ('setp.lt.u32',): 2,
         ('selp.b32',): 3,
         ('setp.eq.s32',): 3,
         ('cvt.u16.u32',): 1,
         ('and.b16',): 1,
         ('mul.wide.u16',): 1,
         ('mov.f32',): 99,
         ('setp.lt.s32',): 1,
         ('mov.pred',): 1,
         ('mul.wide.s32',): 29,
         ('add.s64',): 29,
         ('ld.global.v4.b32',): 12,
         ('mov.b32',): 475,
         ('bar.sync',): 46,
         ('st.shared.v4.f32',): 12,
         ('st.shared.u32',): 17,
         ('ld.shared.f32',): 134,
         ('ldmatrix.sync.aligned.m8n8.x4.shared.b16',): 8,
         ('ld.shared.u32',): 73,
         ('mma.sync.aligned.m16n8k8.row.col.f32.tf32.tf32.f32',): 64,
         ('add.f32',): 279,
         ('st.shared.f32'

In [None]:
v1_pa.op_counts - v2_pa.op_counts

Counter({'mov.b32': 32,
         'add.f32': 40,
         'ld.shared.f32': 32,
         'or.b32': 1,
         'mov.f32': 4,
         'add.s32': 12,
         'shl.b32': 1,
         'bar.sync': 18,
         'st.shared.f32': 16,
         'st.shared.v4.f32': 8,
         'selp.b32': 2,
         'setp.eq.s32': 1,
         'setp.lt.u32': 1})

In [None]:
v2_pa.op_counts - v1_pa.op_counts

Counter({'add.s64': 1,
         'ld.shared.v4.f32': 4,
         'st.shared.u32': 1,
         'st.shared.v4.u32': 8,
         'ld.param.u64': 1,
         'ld.global.b32': 1,
         'ld.shared.v2.f32': 2})

In [None]:
v2_pa.op_counts["bar.sync"]

28

In [None]:
v2_pa.op_counts

Counter({'mul.f32': 451,
         'mov.b32': 443,
         'fma.rn.ftz.f32': 390,
         'add.f32': 239,
         'fma.rn.f32': 147,
         'bra.uni': 130,
         'or.b32': 110,
         'ld.shared.f32': 102,
         'mov.f32': 95,
         'add.s32': 84,
         'and.b32': 79,
         'ld.shared.u32': 73,
         'div.full.f32': 72,
         'setp.ge.f32': 69,
         'selp.f32': 69,
         'bra': 66,
         'abs.ftz.f32': 65,
         'setp.ltu.f32': 65,
         'ex2.approx.ftz.f32': 65,
         'rcp.approx.ftz.f32': 65,
         'mma.sync.aligned.m16n8k8.row.col.f32.tf32.tf32.f32': 64,
         'st.shared.v2.f32': 64,
         'shl.b32': 55,
         'add.s64': 30,
         'mul.wide.s32': 29,
         'bar.sync': 28,
         'ld.shared.v4.f32': 19,
         'st.shared.u32': 18,
         'st.global.v4.b32': 16,
         'xor.b32': 14,
         'ld.global.v4.b32': 12,
         'shfl.sync.bfly.b32': 12,
         'bfe.u32': 8,
         'st.shared.v4.u32': 8,
         

In [None]:
v1_pa.op_counts

Counter({'mov.b32': 475,
         'mul.f32': 451,
         'fma.rn.ftz.f32': 390,
         'add.f32': 279,
         'fma.rn.f32': 147,
         'ld.shared.f32': 134,
         'bra.uni': 130,
         'or.b32': 111,
         'mov.f32': 99,
         'add.s32': 96,
         'and.b32': 79,
         'ld.shared.u32': 73,
         'div.full.f32': 72,
         'setp.ge.f32': 69,
         'selp.f32': 69,
         'bra': 66,
         'abs.ftz.f32': 65,
         'setp.ltu.f32': 65,
         'ex2.approx.ftz.f32': 65,
         'rcp.approx.ftz.f32': 65,
         'mma.sync.aligned.m16n8k8.row.col.f32.tf32.tf32.f32': 64,
         'st.shared.v2.f32': 64,
         'shl.b32': 56,
         'bar.sync': 46,
         'mul.wide.s32': 29,
         'add.s64': 29,
         'st.shared.f32': 23,
         'st.shared.u32': 17,
         'st.global.v4.b32': 16,
         'ld.shared.v4.f32': 15,
         'xor.b32': 14,
         'ld.global.v4.b32': 12,
         'st.shared.v4.f32': 12,
         'shfl.sync.bfly.b32': 12,
 

In [None]:
v2_pa.op_counts

Counter({'mul.f32': 451,
         'mov.b32': 443,
         'fma.rn.ftz.f32': 390,
         'add.f32': 239,
         'fma.rn.f32': 147,
         'bra.uni': 130,
         'or.b32': 110,
         'ld.shared.f32': 102,
         'mov.f32': 95,
         'add.s32': 84,
         'and.b32': 79,
         'ld.shared.u32': 73,
         'div.full.f32': 72,
         'setp.ge.f32': 69,
         'selp.f32': 69,
         'bra': 66,
         'abs.ftz.f32': 65,
         'setp.ltu.f32': 65,
         'ex2.approx.ftz.f32': 65,
         'rcp.approx.ftz.f32': 65,
         'mma.sync.aligned.m16n8k8.row.col.f32.tf32.tf32.f32': 64,
         'st.shared.v2.f32': 64,
         'shl.b32': 55,
         'add.s64': 30,
         'mul.wide.s32': 29,
         'bar.sync': 28,
         'ld.shared.v4.f32': 19,
         'st.shared.u32': 18,
         'st.global.v4.b32': 16,
         'xor.b32': 14,
         'ld.global.v4.b32': 12,
         'shfl.sync.bfly.b32': 12,
         'bfe.u32': 8,
         'st.shared.v4.u32': 8,
         

In [None]:
v2_pa.get_float_constant_counts()

Counter({1.0: 196,
         0.6000000238418579: 65,
         2.885390043258667: 65,
         9.010913848876953: 65,
         0.044714998453855515: 64,
         0.7978845834732056: 64,
         0.5: 64,
         0.0: 7,
         9.999999747378752e-06: 4,
         -2.0: 2,
         -0.05230396240949631: 2,
         0.01573968306183815: 2,
         0.13315297663211823: 2,
         -0.33332768082618713: 2})

In [None]:
v1_pa.get_op_motifs(10)

Counter({('mul.f32',
          'mul.f32',
          'mul.f32',
          'mul.f32',
          'mul.f32',
          'mul.f32',
          'mul.f32',
          'mul.f32',
          'mul.f32',
          'mul.f32'): 110,
         ('add.f32',
          'add.f32',
          'add.f32',
          'add.f32',
          'add.f32',
          'add.f32',
          'add.f32',
          'add.f32',
          'add.f32',
          'add.f32'): 84,
         ('mov.f32',
          'mov.f32',
          'mov.f32',
          'mov.f32',
          'mov.f32',
          'mov.f32',
          'mov.f32',
          'mov.f32',
          'mov.f32',
          'mov.f32'): 75,
         ('bra',
          'bra.uni',
          'mul.f32',
          'fma.rn.ftz.f32',
          'fma.rn.ftz.f32',
          'fma.rn.ftz.f32',
          'fma.rn.ftz.f32',
          'fma.rn.ftz.f32',
          'bra.uni',
          'mul.f32'): 64,
         ('bra.uni',
          'mul.f32',
          'fma.rn.ftz.f32',
          'fma.rn.ftz.f32',
          