# Lola Caching

Explore how triton's cache modifiers affect Lola

In [1]:
import os
os.environ['XLA_PYTHON_CLIENT_MEM_FRACTION'] = '0.5'

# Flax Reference

In [2]:
import flax.linen as nn
import jax.numpy as jnp
import jax

from nimblegpt import param_shapes

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
n_embd = 1024
n_btch = 1024

key = jax.random.PRNGKey(0)
x = jax.random.normal(key, (n_btch, n_embd), dtype=jnp.float16)

In [4]:
def GELU(x):
    return 0.5 * x * (1.0 + jnp.tanh(jnp.sqrt(2.0 / jnp.pi) * (x + 0.044715 * x**3)))

class FlaxLola(nn.Module):
    features: int

    @nn.compact
    def __call__(self, x):
        
        x = nn.LayerNorm()(x)
        x = nn.Dense(self.features)(x)
        x = GELU(x)

        return x

In [5]:
fl_module = FlaxLola(features=4 * n_embd)
params = fl_module.init(key, x)
params = jax.tree_util.tree_map(lambda x: x.astype(jnp.float16), params)

fl_apply = jax.jit(fl_module.apply)
fy = fl_apply(params, x)

In [6]:
%%timeit -n 1000

fl_apply(params, x)

96.6 µs ± 14.1 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


# Torch

In [7]:
import torch
import numpy as np

In [33]:
tkernel = torch.tensor(np.array(params["params"]["Dense_0"]["kernel"]), device="cuda").contiguous()
tweights = tkernel.T.contiguous()
tbias = torch.tensor(np.array(params["params"]["Dense_0"]["bias"]), device="cuda").contiguous()
tx = torch.tensor(np.array(x), device="cuda").contiguous()

# Triton

In [9]:
import torch

In [10]:
import triton
import triton.language as tl

In [11]:
import math

sqrt2pi = math.sqrt(2.0 / math.pi)

@triton.jit
def tanh(x):
    """Tanh activation function"""
    return tl.libdevice.tanh(x)

@triton.jit
def fast_gelu(x):
    """Fast approximation of the gelu function. May slightly decrease accuracy."""
    return 0.5 * x * (1 + tanh(sqrt2pi * (x + 0.044715 * x * x * x)))

In [12]:
acc_dtype = tl.float16


@triton.jit
def lola_kernel(x_ptr, W_ptr, b_ptr, out_ptr, N_OCOLS: tl.constexpr,
                N_BROWS: tl.constexpr, BLOCK_ROWS: tl.constexpr,
                N_FEAT_IN: tl.constexpr, N_FEAT_OUT: tl.constexpr):
    """
    Triton kernel implementing fused Layer nOrm, Linear and Activation.

    Kernel cell (i, j) computes 
    out[i * N_BROWS: (i+1) * N_BROWS, j * N_OCOLS: (j+1) * N_OCOLS], by iterating over
    `BLOCK_ROWS`-sized blocks of the inputs.

    Inputs
    ------
    x_ptr: [BATCH_SIZE, N_FEAT_IN,] - current token embedding.
    W_ptr: [N_FEAT_IN, N_FEAT_OUT] - linear layer weights.
    b_ptr: [N_FEAT_OUT,] - linear layer bias.

    Outputs
    -------
    out_ptr: [N_FEAT_OUT,] - output of the fused layer.
    """
    x_brows_start = tl.program_id(0) * N_BROWS
    # This instance will process x[b_rows_start:b_rows_start + N_BROWS, :]
    x_brows_idxs = tl.arange(0, N_BROWS) + x_brows_start

    ocols_start = tl.program_id(1) * N_OCOLS
    col_idxs = tl.arange(0, N_OCOLS) + ocols_start

    w_dot_x_acc = tl.zeros((N_BROWS, N_OCOLS), dtype=acc_dtype)
    w_sum_acc = tl.zeros((N_OCOLS, ), dtype=acc_dtype)
    x_sum_acc = tl.zeros((N_BROWS, ), dtype=acc_dtype)
    x_sq_sum_acc = tl.zeros((N_BROWS, ), dtype=acc_dtype)

    n_blocks = tl.cdiv(N_FEAT_IN, BLOCK_ROWS)
    for block_i in range(0, n_blocks):

        block_row_idxs = tl.arange(0, BLOCK_ROWS) + block_i * BLOCK_ROWS

        # Load the current block of the input.
        x_block_idxs = x_brows_idxs[:,
                                    None] * N_FEAT_IN + block_row_idxs[None, :]
        x_block = tl.load(x_ptr + x_block_idxs).to(
            acc_dtype)  # [N_BROWS, BLOCK_ROWS]

        W_block_idxs = block_row_idxs[:, None] * N_FEAT_OUT + col_idxs[None, :]
        W_block = tl.load(W_ptr + W_block_idxs).to(
            acc_dtype)  # [BLOCK_ROWS, N_OCOLS]

        # Update the accumulators.
        # [N_BROWS, BLOCK_ROWS] @ [BLOCK_ROWS, N_OCOLS] -> [N_BROWS, N_OCOLS]
        w_dot_x_acc += tl.dot(x_block, W_block).to(acc_dtype)
        # w_dot_x_acc += tl.sum(W_block * x_block[:, None], axis=0)
        w_sum_acc += tl.sum(W_block, axis=0)
        x_sum_acc += tl.sum(x_block, axis=1)
        x_sq_sum_acc += tl.sum(x_block * x_block, axis=1)

    bias = tl.load(b_ptr + col_idxs)
    x_mean = x_sum_acc / N_FEAT_IN
    x_sq_mean = x_sq_sum_acc / N_FEAT_IN

    numer = w_dot_x_acc - x_mean[:, None] * w_sum_acc[None, :] + bias[None, :]
    denom = tl.sqrt(tl.abs(x_sq_mean - x_mean * x_mean + 1e-5))
    out = fast_gelu(numer / denom[:, None])

    out_idxs = x_brows_idxs[:, None] * N_FEAT_OUT + col_idxs[None, :]
    tl.store(out_ptr + out_idxs, out.to(tl.float16))

In [13]:
def triton_lola(
    x,
    W,
    b,
    N_OCOLS: int,
    N_BROWS: int,
    BLOCK_ROWS: int,
    num_warps=4,
    num_stages=1,
):
    assert N_BROWS >= 16 and BLOCK_ROWS >= 16 and N_OCOLS >= 16, "Triton matrix multiplication requires matrix dimensions to be at least 16."

    N_BATCH = x.shape[0]
    N_FEAT_IN, N_FEAT_OUT = W.shape
    grid = (
        N_BATCH // N_BROWS,
        N_FEAT_OUT // N_OCOLS,
    )

    # Allocate output buffer.
    out = torch.zeros((N_BATCH, N_FEAT_OUT), dtype=x.dtype, device="cuda")

    # Launch the kernel.
    lola_kernel[grid](x,
                      W,
                      b,
                      out,
                      N_OCOLS=N_OCOLS,
                      N_BROWS=N_BROWS,
                      BLOCK_ROWS=BLOCK_ROWS,
                      N_FEAT_IN=N_FEAT_IN,
                      N_FEAT_OUT=N_FEAT_OUT,
                      num_warps=num_warps,
                      num_stages=num_stages)

    return out

In [14]:
ky = triton_lola(tx, tkernel, tbias, N_BROWS=16, N_OCOLS=32, BLOCK_ROWS=64)

In [15]:
(np.array(fy) - ky.cpu().numpy()).max()

0.007812

In [16]:
from triton.testing import do_bench
from functools import partial

Range find:

`n_ocols=256, n_brows=32, block_rows=64: 192.51 us`


In [18]:
%%script false --no-raise-error

for n_ocols in [32, 64, 128, 256, 512]:
    for n_brows in [16, 32, 64, 128, 256, 512]:
        for block_rows in [16, 32, 64, 128, 256, 512]:
            if n_ocols * n_brows * block_rows < 200_000:
                continue
            if n_ocols * n_brows * block_rows / 4 > 166_000:
                continue
            print(f"{n_ocols=}, {n_brows=}, {block_rows=}", end=": ")
            print(f"{do_bench(partial(triton_lola, tx, tkernel, tbias, N_OCOLS=n_ocols, N_BROWS=n_brows, BLOCK_ROWS=block_rows), warmup = 100, rep = 100)[0] * 1000:.2f} us")

n_ocols=32, n_brows=16, block_rows=512: 1064.96 us
n_ocols=32, n_brows=32, block_rows=256: 638.98 us
n_ocols=32, n_brows=32, block_rows=512: 706.56 us
n_ocols=32, n_brows=64, block_rows=128: 457.73 us
n_ocols=32, n_brows=64, block_rows=256: 490.50 us
n_ocols=32, n_brows=128, block_rows=64: 441.34 us
n_ocols=32, n_brows=128, block_rows=128: 407.55 us
n_ocols=32, n_brows=256, block_rows=32: 464.90 us
n_ocols=32, n_brows=256, block_rows=64: 401.41 us
n_ocols=32, n_brows=512, block_rows=16: 569.34 us
n_ocols=32, n_brows=512, block_rows=32: 515.07 us
n_ocols=64, n_brows=16, block_rows=256: 447.49 us
n_ocols=64, n_brows=16, block_rows=512: 4956.16 us
n_ocols=64, n_brows=32, block_rows=128: 339.97 us
n_ocols=64, n_brows=32, block_rows=256: 359.42 us
n_ocols=64, n_brows=64, block_rows=64: 319.49 us
n_ocols=64, n_brows=64, block_rows=128: 290.82 us
n_ocols=64, n_brows=128, block_rows=32: 350.21 us
n_ocols=64, n_brows=128, block_rows=64: 290.82 us
n_ocols=64, n_brows=256, block_rows=16: 622.59 u

# `evict-first`

In [63]:
acc_dtype = tl.float16


@triton.jit
def lola_ef_kernel(x_ptr, W_ptr, b_ptr, out_ptr, N_OCOLS: tl.constexpr,
                N_BROWS: tl.constexpr, BLOCK_ROWS: tl.constexpr,
                N_FEAT_IN: tl.constexpr, N_FEAT_OUT: tl.constexpr):
    """
    Triton kernel implementing fused Layer nOrm, Linear and Activation.

    Kernel cell (i, j) computes 
    out[i * N_BROWS: (i+1) * N_BROWS, j * N_OCOLS: (j+1) * N_OCOLS], by iterating over
    `BLOCK_ROWS`-sized blocks of the inputs.

    Inputs
    ------
    x_ptr: [BATCH_SIZE, N_FEAT_IN,] - current token embedding.
    W_ptr: [N_FEAT_IN, N_FEAT_OUT] - linear layer weights.
    b_ptr: [N_FEAT_OUT,] - linear layer bias.

    Outputs
    -------
    out_ptr: [N_FEAT_OUT,] - output of the fused layer.
    """
    x_brows_start = tl.program_id(0) * N_BROWS
    # This instance will process x[b_rows_start:b_rows_start + N_BROWS, :]
    x_brows_idxs = tl.arange(0, N_BROWS) + x_brows_start

    ocols_start = tl.program_id(1) * N_OCOLS
    col_idxs = tl.arange(0, N_OCOLS) + ocols_start

    w_dot_x_acc = tl.zeros((N_BROWS, N_OCOLS), dtype=acc_dtype)
    w_sum_acc = tl.zeros((N_OCOLS, ), dtype=acc_dtype)
    x_sum_acc = tl.zeros((N_BROWS, ), dtype=acc_dtype)
    x_sq_sum_acc = tl.zeros((N_BROWS, ), dtype=acc_dtype)

    n_blocks = tl.cdiv(N_FEAT_IN, BLOCK_ROWS)
    for block_i in range(0, n_blocks):

        block_row_idxs = tl.arange(0, BLOCK_ROWS) + block_i * BLOCK_ROWS

        # Load the current block of the input.
        x_block_idxs = x_brows_idxs[:,
                                    None] * N_FEAT_IN + block_row_idxs[None, :]
        x_block = tl.load(x_ptr + x_block_idxs, cache_modifier=".cg").to(
            acc_dtype)  # [N_BROWS, BLOCK_ROWS]

        W_block_idxs = block_row_idxs[:, None] * N_FEAT_OUT + col_idxs[None, :]
        W_block = tl.load(W_ptr + W_block_idxs, cache_modifier=".ca").to(
            acc_dtype)  # [BLOCK_ROWS, N_OCOLS]

        # Update the accumulators.
        # [N_BROWS, BLOCK_ROWS] @ [BLOCK_ROWS, N_OCOLS] -> [N_BROWS, N_OCOLS]
        w_dot_x_acc += tl.dot(x_block, W_block).to(acc_dtype)
        # w_dot_x_acc += tl.sum(W_block * x_block[:, None], axis=0)
        w_sum_acc += tl.sum(W_block, axis=0)
        x_sum_acc += tl.sum(x_block, axis=1)
        x_sq_sum_acc += tl.sum(x_block * x_block, axis=1)

    bias = tl.load(b_ptr + col_idxs)
    x_mean = x_sum_acc / N_FEAT_IN
    x_sq_mean = x_sq_sum_acc / N_FEAT_IN

    numer = w_dot_x_acc - x_mean[:, None] * w_sum_acc[None, :] + bias[None, :]
    denom = tl.sqrt(tl.abs(x_sq_mean - x_mean * x_mean + 1e-5))
    out = fast_gelu(numer / denom[:, None])

    out_idxs = x_brows_idxs[:, None] * N_FEAT_OUT + col_idxs[None, :]
    tl.store(out_ptr + out_idxs, out.to(tl.float16))

In [64]:
def triton_lola_ef(
    x,
    W,
    b,
    N_OCOLS: int,
    N_BROWS: int,
    BLOCK_ROWS: int,
    num_warps=4,
    num_stages=1,
):
    assert N_BROWS >= 16 and BLOCK_ROWS >= 16 and N_OCOLS >= 16, "Triton matrix multiplication requires matrix dimensions to be at least 16."

    N_BATCH = x.shape[0]
    N_FEAT_IN, N_FEAT_OUT = W.shape
    grid = (
        N_BATCH // N_BROWS,
        N_FEAT_OUT // N_OCOLS,
    )

    # Allocate output buffer.
    out = torch.zeros((N_BATCH, N_FEAT_OUT), dtype=x.dtype, device="cuda")

    # Launch the kernel.
    lola_ef_kernel[grid](x,
                      W,
                      b,
                      out,
                      N_OCOLS=N_OCOLS,
                      N_BROWS=N_BROWS,
                      BLOCK_ROWS=BLOCK_ROWS,
                      N_FEAT_IN=N_FEAT_IN,
                      N_FEAT_OUT=N_FEAT_OUT,
                      num_warps=num_warps,
                      num_stages=num_stages,
                      )

    return out

In [65]:
for n_ocols in [128, 256, 512]:
    for n_brows in [16, 32, 64]:
        for block_rows in [32, 64, 128]:
            if n_ocols * n_brows * block_rows < 200_000:
                continue
            if n_ocols * n_brows * block_rows / 4 > 166_000:
                continue
            print(f"{n_ocols=}, {n_brows=}, {block_rows=}", end=": ")
            print(f"{do_bench(partial(triton_lola_ef, tx, tkernel, tbias, N_OCOLS=n_ocols, N_BROWS=n_brows, BLOCK_ROWS=block_rows), warmup = 100, rep = 100)[0] * 1000:.2f} us")

n_ocols=128, n_brows=16, block_rows=128: 351.23 us
n_ocols=128, n_brows=32, block_rows=64: 284.67 us
n_ocols=128, n_brows=32, block_rows=128: 260.10 us
n_ocols=128, n_brows=64, block_rows=32: 297.98 us
n_ocols=128, n_brows=64, block_rows=64: 201.73 us
n_ocols=256, n_brows=16, block_rows=64: 321.54 us
n_ocols=256, n_brows=16, block_rows=128: 289.79 us
n_ocols=256, n_brows=32, block_rows=32: 267.26 us
n_ocols=256, n_brows=32, block_rows=64: 193.54 us
n_ocols=256, n_brows=64, block_rows=32: 269.31 us
n_ocols=512, n_brows=16, block_rows=32: 275.46 us
n_ocols=512, n_brows=16, block_rows=64: 283.65 us
n_ocols=512, n_brows=32, block_rows=32: 256.00 us
