# Lola (Layer nOrm + Linear + Activation)

## Mathematical derivation

The fused dense layer computes a single element $y = \bm{y}_i$ of the output by taking a dot product
with the $i$ th column $\bm{w} = \bm{W}_{:, i}$ of the weights matrix and adding a bias term $b = \bm{b}_i$. The full expression for $y$, including layer norm and activation, is:

$$
\begin{align}
    y = \text{GELU} \left( \bm{w} \cdot \left( \frac{\bm{x} - \text{mean}(\bm{x})\bm{I}_{n}}{\sqrt{\text{var}(\bm{x}) + \epsilon}} \right) + b \right)
\end{align}
$$

Where $n$ is the number of elements in $\bm{x}$. We use multiplication by $\bm{I}_{n}$ as the mathematical equivalent of broadcasting.

We would like to compute this incrementally, which means we won't have access to the mean
and variance of $\bm{x}$ until the end. We note that:

$$
\begin{align}
    \bm{w} \cdot \left( \frac{\bm{x} - \text{mean}(\bm{x})\bm{I}_n}{\sqrt{\text{var}(\bm{x}) + \epsilon}} \right) &= \frac{1}{\sqrt{\text{var}(\bm{x}) + \epsilon}} \left( \bm{w} \cdot \bm{x} - \bm{w} \cdot \text{mean}(\bm{x}) \bm{I}_n\right) \\
\end{align}
$$

Observe that $\bm{w} \cdot \text{mean}(\bm{x}) \bm{I}_n = \text{mean}(\bm{x}) \times \text{sum}(\bm{w})$. Note also that $\text{var}(\bm{x}) = \text{mean}(\bm{x}^2) - \text{mean}(\bm{x})^2$. We can therefore compute the fused layer from the following pieces, all of which can be computed incrementally by streaming $w$ and $x$:

$$
\begin{align}
    \bm{w} \cdot \bm{x} && \text{sum}(\bm{w}) && \text{mean}(\bm{x}) && \text{mean}(\bm{x}^2) \\
\end{align}
$$

The equation for the output in terms of these pieces is:

$$
\begin{align}
    y = \text{GELU} \left( \frac{\bm{w} \cdot \bm{x} - \text{mean}(\bm{x}) \times \text{sum}(\bm{w})}{\sqrt{\text{mean}(\bm{x}^2) - \text{mean}(\bm{x})^2 + \epsilon}} + b \right)
\end{align}
$$



In [2]:
import os
os.environ['XLA_PYTHON_CLIENT_MEM_FRACTION'] = '0.5'

# Flax Reference

In [3]:
import flax.linen as nn
import jax.numpy as jnp
import jax

from nimblegpt import param_shapes

  from .autonotebook import tqdm as notebook_tqdm


In [4]:
n_embd = 1024
key = jax.random.PRNGKey(0)
x = jax.random.normal(key, (n_embd,))

In [5]:
def GELU(x):
    return 0.5 * x * (1.0 + jnp.tanh(jnp.sqrt(2.0 / jnp.pi) * (x + 0.044715 * x**3)))

class FlaxLola(nn.Module):
    features: int

    @nn.compact
    def __call__(self, x):
        
        x = nn.LayerNorm(use_scale = False, use_bias=False)(x)
        x = nn.Dense(self.features)(x)
        x = GELU(x)

        return x

In [6]:
fl_module = FlaxLola(features = 4 * n_embd)
params = fl_module.init(key, x)

fl_apply = jax.jit(fl_module.apply)
fy = fl_apply(params, x)

param_shapes(params)

{'params': {'Dense_0': {'kernel': '(1024, 4096)', 'bias': '(4096)'}}}

In [7]:
%%timeit -n 1000

fl_apply(params, x)

50.4 µs ± 9.73 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


## Batched

In [8]:
n_btch = 1024
bx = jax.random.normal(key, (n_btch, n_embd))
bx.shape

(1024, 1024)

In [9]:
bfy = fl_apply(params, bx)

In [10]:
%%timeit -n 100

fl_apply(params, bx).block_until_ready()

502 µs ± 237 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


## Batched half precision

In [11]:
from jax import tree_util

In [12]:
# convert params to half precision
params16 = tree_util.tree_map(lambda x: x.astype(jnp.bfloat16), params)
bx16 = bx.astype(jnp.bfloat16)

In [13]:
bfy16 = fl_apply(params16, bx16)

In [14]:
(bfy16 - bfy).max()

Array(0.0264864, dtype=float32)

In [15]:
%%timeit -n 1000

fl_apply(params16, bx16).block_until_ready()

298 µs ± 68.1 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


In [16]:
tree_util.tree_map(lambda x: x.dtype, params16)

FrozenDict({
    params: {
        Dense_0: {
            bias: dtype(bfloat16),
            kernel: dtype(bfloat16),
        },
    },
})

## Transposed kernel

In [17]:
kernel16 = params["params"]["Dense_0"]["kernel"].astype(jnp.bfloat16)
kernelT16 = kernel16.T
bias16 = params["params"]["Dense_0"]["bias"].astype(jnp.bfloat16)

In [18]:
class FlaxLolaTrans(nn.Module):
    
    @nn.compact
    def __call__(self, x, kernelT, bias):

        x = nn.LayerNorm(use_scale = False, use_bias=False)(x)
        # x = jnp.dot(x, kernelT) + bias
        x = jax.lax.dot_general(x, kernelT, (((1,), (1,)), ((), ())))
        x = GELU(x)

        return x

In [19]:
flt_module = FlaxLolaTrans()

flt_apply = jax.jit(flt_module.apply)
fty = flt_apply(params16, bx16, kernelT16, bias16)
fty.shape

(1024, 4096)

In [20]:
(bfy16 - fty).max()

Array(0, dtype=bfloat16)

In [21]:
%%timeit -n 1000

flt_apply(params16, bx16, kernelT16, bias16).block_until_ready()

267 µs ± 86.5 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


## Torch Reference

In [22]:
import torch.nn.functional as F
import torch
import numpy as np

In [23]:
def torch_lola(x, weight, bias):
    x = F.layer_norm(x, (n_embd,))
    x = F.linear(x, weight, bias)
    x = F.gelu(x, approximate="tanh")

    return x

In [24]:
tkernel = torch.tensor(np.array(params["params"]["Dense_0"]["kernel"]), device="cuda")
tkernel_sum = tkernel.sum(dim=0)
tweights = tkernel.T
tbias = torch.tensor(np.array(params["params"]["Dense_0"]["bias"]), device="cuda")
tx = torch.tensor(np.array(x), device="cuda")

In [25]:
ty = torch_lola(tx, tweights, tbias)

In [26]:
(np.array(fy) - np.array(ty.cpu())).max()

1.66893e-05

In [27]:
%%timeit -n 100

torch_lola(tx, tweights, tbias)

73.9 µs ± 6.44 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


## Batched

In [28]:
tbx = torch.tensor(np.array(bx), device="cuda")

In [29]:
tby = torch_lola(tbx, tweights, tbias)

In [30]:
(np.array(bfy) - np.array(tby.cpu())).max()

0.0016235113

In [31]:
%%timeit -n 100

torch_lola(tbx, tweights, tbias)
torch.cuda.synchronize()

668 µs ± 63.1 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


## Batched half precsion

In [32]:
tbx16 = tbx.half()
tweights16 = tweights.half()
tkernel16 = tkernel.half()
tkernel_sum16 = tkernel_sum.half()
tbias16 = tbias.half()
tbx16.dtype

torch.float16

In [33]:
tby16 = torch_lola(tbx16, tweights16, tbias16)
tby16.dtype

torch.float16

In [34]:
(np.array(bfy) - np.array(tby16.cpu())).max()

0.0030958652

In [35]:
%%timeit -n 100

torch_lola(tbx16, tweights16, tbias16)
torch.cuda.synchronize()

122 µs ± 8.24 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


# Jax Incremental Reference

We validate that we can correctly produce the output from the pieces described above.

In [36]:
from jax import lax

In [37]:
kernel = params["params"]["Dense_0"]["kernel"]
bias = params["params"]["Dense_0"]["bias"]

In [38]:
def jax_pieces_lola(x, kernel, bias):
    w_dot_x = jnp.dot(x, kernel)
    w_sum = jnp.sum(kernel, axis = 0)
    x_mean = jnp.mean(x)
    x_sq_mean = jnp.mean(x * x)

    numer = w_dot_x - x_mean * w_sum + bias
    denom = jnp.sqrt(jnp.abs(x_sq_mean - x_mean * x_mean + 1e-5))
    return GELU(numer / denom)

In [39]:
py = jax_pieces_lola(x, kernel, bias)
py

Array([-0.13923575, -0.12308519,  1.4389666 , ..., -0.14115831,
       -0.1697047 , -0.14350325], dtype=float32)

In [40]:
(fy - py).max()

Array(1.692772e-05, dtype=float32)

In [41]:
jit_jpl = jax.jit(jax_pieces_lola)
jit_jpl(x, kernel, bias)

Array([-0.13923575, -0.12308519,  1.4389666 , ..., -0.1411583 ,
       -0.1697047 , -0.14350319], dtype=float32)

In [42]:
%%timeit -n1000

jit_jpl(x, kernel, bias).block_until_ready()

53.4 µs ± 2.29 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


## Batched

In [43]:
batched_jpl = jax.vmap(jax_pieces_lola, in_axes=(0, None, None))

In [44]:
bpy = batched_jpl(bx, kernel, bias)

In [45]:
(bfy - bpy).max()

Array(0.00158358, dtype=float32)

In [46]:
jit_bjpl = jax.jit(batched_jpl)
jit_bjpl(bx, kernel, bias).shape

(1024, 4096)

In [47]:
%%timeit -n1000

jit_bjpl(bx, kernel, bias).block_until_ready();

318 µs ± 70.6 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


# Triton dummy baseline

Test Triton kernel launch latency

In [48]:
import triton
import triton.language as tl

In [49]:
@triton.jit
def dummy_kernel(x_ptr, W_ptr, b_ptr, out_ptr, N_OCOLS: tl.constexpr,
                BLOCK_ROWS: tl.constexpr, N_FEAT_IN: tl.constexpr,
                N_FEAT_OUT: tl.constexpr):
    """
    Dummy Triton kernel.

    Kernel cell i computes out[i * N_OCOLS: (i+1) * N_OCOLS], by iterating over
    `BLOCK_ROWS`-sized blocks of the inputs.

    Inputs
    ------
    x_ptr: [N_FEAT_IN,] - current token embedding.
    W_ptr: [N_FEAT_IN, N_FEAT_OUT] - linear layer weights.
    b_ptr: [N_FEAT_OUT,] - linear layer bias.

    Outputs
    -------
    out_ptr: [N_FEAT_OUT,] - output of the fused layer.
    """
    ocols_start = tl.program_id(0) * N_OCOLS
    col_idxs = tl.arange(0, N_OCOLS) + ocols_start

    x = tl.load(x_ptr + col_idxs)

    # w_dot_x_acc = tl.zeros((N_OCOLS, ), dtype=tl.float32)
    # w_sum_acc = tl.zeros((N_OCOLS, ), dtype=tl.float32)
    # x_sum_acc = tl.zeros((1, ), dtype=tl.float32)
    # x_sq_sum_acc = tl.zeros((1, ), dtype=tl.float32)

    # n_blocks = tl.cdiv(N_FEAT_IN, BLOCK_ROWS)
    # for block_i in range(0, n_blocks):

    #     block_row_idxs = tl.arange(0, BLOCK_ROWS) + block_i * BLOCK_ROWS

    #     # Load the current block of the input.
    #     x_block = tl.load(x_ptr + block_row_idxs)
        
    #     W_block_idxs = block_row_idxs[:, None] * N_FEAT_OUT + col_idxs[None, :]
    #     W_block = tl.load(W_ptr + W_block_idxs)

    #     # Update the accumulators.
    #     # w_dot_x_acc += tl.dot(w_block, x_block)
    #     w_dot_x_acc += tl.sum(W_block * x_block[:, None], axis=0)
    #     w_sum_acc += tl.sum(W_block, axis=0)
    #     x_sum_acc += tl.sum(x_block, axis=0)
    #     x_sq_sum_acc += tl.sum(x_block * x_block, axis=0)


    # bias = tl.load(b_ptr + col_idxs)
    # x_mean = x_sum_acc / N_FEAT_IN
    # x_sq_mean = x_sq_sum_acc / N_FEAT_IN

    # numer = w_dot_x_acc - x_mean * w_sum_acc + bias
    # denom = tl.sqrt(tl.abs(x_sq_mean - x_mean * x_mean + 1e-5))
    # out = fast_gelu(numer / denom)

    tl.store(out_ptr + col_idxs, x)

In [50]:
def triton_dummy(x, W, b, N_OCOLS: int, BLOCK_ROWS: int):

    N_FEAT_IN, N_FEAT_OUT = W.shape
    grid = (N_FEAT_OUT // N_OCOLS, )

    # Allocate output buffer.
    out = torch.zeros((N_FEAT_OUT, ), dtype=torch.float32, device="cuda")

    # Launch the kernel.
    dummy_kernel[grid](x,
                      W,
                      b,
                      out,
                      N_OCOLS=N_OCOLS,
                      BLOCK_ROWS=BLOCK_ROWS,
                      N_FEAT_IN=N_FEAT_IN,
                      N_FEAT_OUT=N_FEAT_OUT)

    return out

In [51]:
triton_dummy(tx, tweights, tbias, N_OCOLS=32, BLOCK_ROWS=32)

tensor([-1.3119, -0.5691, -0.3651,  ...,  1.3971,  0.0326,  0.0858],
       device='cuda:0')

In [52]:
%%timeit

triton_dummy(tx, tweights, tbias, N_OCOLS=32, BLOCK_ROWS=32)

29 µs ± 727 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)


# Triton Lola Kernel

In [53]:
import triton
import triton.language as tl

In [107]:
import math

sqrt2pi = math.sqrt(2.0 / math.pi)
sqrt2 = math.sqrt(2.0)


@triton.jit
def tanh(x):
    """Tanh activation function"""
    return tl.libdevice.tanh(x)


@triton.jit
def fast_gelu(x):
    """Fast approximation of the gelu function. May slightly decrease accuracy."""
    return 0.5 * x * (1 + tanh(sqrt2pi * (x + 0.044715 * x * x * x)))


@triton.jit
def gelu(x):
    """Gaussian Error Linear Unit (GELU)"""
    return x * torch.tensor(0.5, dtype=torch.float16) * (
        torch.tensor(1.0, dtype=torch.float16) + tl.libdevice.erf(x / sqrt2))


In [55]:
@triton.jit
def lola_kernel(x_ptr, W_ptr, b_ptr, out_ptr, N_OCOLS: tl.constexpr,
                BLOCK_ROWS: tl.constexpr, N_FEAT_IN: tl.constexpr,
                N_FEAT_OUT: tl.constexpr):
    """
    Triton kernel implementing fused Layer nOrm, Linear and Activation.

    Kernel cell i computes out[i * N_OCOLS: (i+1) * N_OCOLS], by iterating over
    `BLOCK_ROWS`-sized blocks of the inputs.

    Inputs
    ------
    x_ptr: [N_FEAT_IN,] - current token embedding.
    W_ptr: [N_FEAT_IN, N_FEAT_OUT] - linear layer weights.
    b_ptr: [N_FEAT_OUT,] - linear layer bias.

    Outputs
    -------
    out_ptr: [N_FEAT_OUT,] - output of the fused layer.
    """
    ocols_start = tl.program_id(0) * N_OCOLS
    col_idxs = tl.arange(0, N_OCOLS) + ocols_start

    w_dot_x_acc = tl.zeros((N_OCOLS, ), dtype=tl.float32)
    w_sum_acc = tl.zeros((N_OCOLS, ), dtype=tl.float32)
    x_sum_acc = tl.zeros((1, ), dtype=tl.float32)
    x_sq_sum_acc = tl.zeros((1, ), dtype=tl.float32)

    n_blocks = tl.cdiv(N_FEAT_IN, BLOCK_ROWS)
    for block_i in range(0, n_blocks):

        block_row_idxs = tl.arange(0, BLOCK_ROWS) + block_i * BLOCK_ROWS

        # Load the current block of the input.
        x_block = tl.load(x_ptr + block_row_idxs)
        
        W_block_idxs = block_row_idxs[:, None] * N_FEAT_OUT + col_idxs[None, :]
        W_block = tl.load(W_ptr + W_block_idxs)

        # Update the accumulators.
        # w_dot_x_acc += tl.dot(W_block, x_block[None, :])
        w_dot_x_acc += tl.sum(W_block * x_block[:, None], axis=0)
        w_sum_acc += tl.sum(W_block, axis=0)
        x_sum_acc += tl.sum(x_block, axis=0)
        x_sq_sum_acc += tl.sum(x_block * x_block, axis=0)


    bias = tl.load(b_ptr + col_idxs)
    x_mean = x_sum_acc / N_FEAT_IN
    x_sq_mean = x_sq_sum_acc / N_FEAT_IN

    numer = w_dot_x_acc - x_mean * w_sum_acc + bias
    denom = tl.sqrt(tl.abs(x_sq_mean - x_mean * x_mean + 1e-5))
    out = fast_gelu(numer / denom)

    tl.store(out_ptr + col_idxs, out)
    # tl.store(out_ptr + col_idxs, w_dot_x_acc)

## Batched Row

In [56]:
acc_dtype = tl.float16

@triton.jit
def batched_row_lola_kernel(x_ptr, W_ptr, b_ptr, out_ptr, N_OCOLS: tl.constexpr,
                BLOCK_ROWS: tl.constexpr, N_FEAT_IN: tl.constexpr,
                N_FEAT_OUT: tl.constexpr):
    """
    Triton kernel implementing fused Layer nOrm, Linear and Activation.

    Kernel cell (i, j) computes out[i, j * N_OCOLS: (j+1) * N_OCOLS], by iterating over
    `BLOCK_ROWS`-sized blocks of the inputs.

    Inputs
    ------
    x_ptr: [BATCH_SIZE, N_FEAT_IN,] - current token embedding.
    W_ptr: [N_FEAT_IN, N_FEAT_OUT] - linear layer weights.
    b_ptr: [N_FEAT_OUT,] - linear layer bias.

    Outputs
    -------
    out_ptr: [N_FEAT_OUT,] - output of the fused layer.
    """
    x_row_idx = tl.program_id(0) # Which element of the batch we're processing.
    x_row_ptr = x_ptr + x_row_idx * N_FEAT_IN
    out_row_ptr = out_ptr + x_row_idx * N_FEAT_OUT

    ocols_start = tl.program_id(1) * N_OCOLS
    col_idxs = tl.arange(0, N_OCOLS) + ocols_start

    w_dot_x_acc = tl.zeros((N_OCOLS, ), dtype=acc_dtype)
    w_sum_acc = tl.zeros((N_OCOLS, ), dtype=acc_dtype)
    x_sum_acc = tl.zeros((1, ), dtype=acc_dtype)
    x_sq_sum_acc = tl.zeros((1, ), dtype=acc_dtype)

    n_blocks = tl.cdiv(N_FEAT_IN, BLOCK_ROWS)
    for block_i in range(0, n_blocks):

        block_row_idxs = tl.arange(0, BLOCK_ROWS) + block_i * BLOCK_ROWS

        # Load the current block of the input.
        x_block = tl.load(x_row_ptr + block_row_idxs).to(acc_dtype)
        
        W_block_idxs = block_row_idxs[:, None] * N_FEAT_OUT + col_idxs[None, :]
        W_block = tl.load(W_ptr + W_block_idxs).to(acc_dtype)

        # Update the accumulators.
        # w_dot_x_acc += tl.dot(W_block, x_block[None, :])
        w_dot_x_acc += tl.sum(W_block * x_block[:, None], axis=0).to(acc_dtype)
        w_sum_acc += tl.sum(W_block, axis=0)
        x_sum_acc += tl.sum(x_block, axis=0)
        x_sq_sum_acc += tl.sum(x_block * x_block, axis=0)


    bias = tl.load(b_ptr + col_idxs)
    x_mean = x_sum_acc / N_FEAT_IN
    x_sq_mean = x_sq_sum_acc / N_FEAT_IN

    numer = w_dot_x_acc - x_mean * w_sum_acc + bias
    denom = tl.sqrt(tl.abs(x_sq_mean - x_mean * x_mean + 1e-5))
    out = fast_gelu(numer / denom)

    tl.store(out_row_ptr + col_idxs, out)
    # tl.store(out_ptr + col_idxs, w_dot_x_acc)

## Batched Multi-row

In [57]:
acc_dtype = tl.float16

@triton.jit
def batched_mrow_lola_kernel(x_ptr, W_ptr, b_ptr, out_ptr,
                             N_OCOLS: tl.constexpr, N_BROWS: tl.constexpr,
                             BLOCK_ROWS: tl.constexpr, N_FEAT_IN: tl.constexpr,
                             N_FEAT_OUT: tl.constexpr):
    """
    Triton kernel implementing fused Layer nOrm, Linear and Activation.

    Kernel cell (i, j) computes 
    out[i * N_BROWS: (i+1) * N_BROWS, j * N_OCOLS: (j+1) * N_OCOLS], by iterating over
    `BLOCK_ROWS`-sized blocks of the inputs.

    Inputs
    ------
    x_ptr: [BATCH_SIZE, N_FEAT_IN,] - current token embedding.
    W_ptr: [N_FEAT_IN, N_FEAT_OUT] - linear layer weights.
    b_ptr: [N_FEAT_OUT,] - linear layer bias.

    Outputs
    -------
    out_ptr: [N_FEAT_OUT,] - output of the fused layer.
    """
    x_brows_start = tl.program_id(0) * N_BROWS
    # This instance will process x[b_rows_start:b_rows_start + N_BROWS, :]
    x_brows_idxs = tl.arange(0, N_BROWS) + x_brows_start

    ocols_start = tl.program_id(1) * N_OCOLS
    col_idxs = tl.arange(0, N_OCOLS) + ocols_start

    w_dot_x_acc = tl.zeros((N_BROWS, N_OCOLS), dtype=acc_dtype)
    w_sum_acc = tl.zeros((N_OCOLS, ), dtype=acc_dtype)
    x_sum_acc = tl.zeros((N_BROWS, ), dtype=acc_dtype)
    x_sq_sum_acc = tl.zeros((N_BROWS, ), dtype=acc_dtype)

    n_blocks = tl.cdiv(N_FEAT_IN, BLOCK_ROWS)
    for block_i in range(0, n_blocks):

        block_row_idxs = tl.arange(0, BLOCK_ROWS) + block_i * BLOCK_ROWS

        # Load the current block of the input.
        x_block_idxs = x_brows_idxs[:, None] * N_FEAT_IN + block_row_idxs[None, :]
        x_block = tl.load(x_ptr + x_block_idxs).to(acc_dtype) # [N_BROWS, BLOCK_ROWS]

        W_block_idxs = block_row_idxs[:, None] * N_FEAT_OUT + col_idxs[None, :]
        W_block = tl.load(W_ptr + W_block_idxs).to(acc_dtype) # [BLOCK_ROWS, N_OCOLS]

        # Update the accumulators.
        # [N_BROWS, BLOCK_ROWS] @ [BLOCK_ROWS, N_OCOLS] -> [N_BROWS, N_OCOLS]
        w_dot_x_acc += tl.dot(x_block, W_block).to(acc_dtype)
        # w_dot_x_acc += tl.sum(W_block * x_block[:, None], axis=0)
        w_sum_acc += tl.sum(W_block, axis=0)
        x_sum_acc += tl.sum(x_block, axis=1)
        x_sq_sum_acc += tl.sum(x_block * x_block, axis=1)

    bias = tl.load(b_ptr + col_idxs)
    x_mean = x_sum_acc / N_FEAT_IN
    x_sq_mean = x_sq_sum_acc / N_FEAT_IN

    numer = w_dot_x_acc - x_mean[:, None] * w_sum_acc[None, :] + bias[None, :]
    denom = tl.sqrt(tl.abs(x_sq_mean - x_mean * x_mean + 1e-5))
    out = fast_gelu(numer / denom[:, None])

    out_idxs = x_brows_idxs[:, None] * N_FEAT_OUT + col_idxs[None, :]
    tl.store(out_ptr + out_idxs, out.to(tl.float16))

# Batched Multi-row with precomputed weights sum

In [111]:
acc_dtype = tl.float16

@triton.jit
def batched_mrow_ws_lola_kernel(x_ptr, W_ptr, Ws_ptr, b_ptr, out_ptr,
                             N_OCOLS: tl.constexpr, N_BROWS: tl.constexpr,
                             BLOCK_ROWS: tl.constexpr, N_FEAT_IN: tl.constexpr,
                             N_FEAT_OUT: tl.constexpr):
    """
    Triton kernel implementing fused Layer nOrm, Linear and Activation.

    Kernel cell (i, j) computes 
    out[i * N_BROWS: (i+1) * N_BROWS, j * N_OCOLS: (j+1) * N_OCOLS], by iterating over
    `BLOCK_ROWS`-sized blocks of the inputs.

    Inputs
    ------
    x_ptr: [BATCH_SIZE, N_FEAT_IN,] - current token embedding.
    W_ptr: [N_FEAT_IN, N_FEAT_OUT] - linear layer weights.
    Ws_ptr: [N_FEAT_OUT,] - sum of the weights (in axis 0).
    b_ptr: [N_FEAT_OUT,] - linear layer bias.

    Outputs
    -------
    out_ptr: [N_FEAT_OUT,] - output of the fused layer.
    """
    x_brows_start = tl.program_id(0) * N_BROWS
    # This instance will process x[b_rows_start:b_rows_start + N_BROWS, :]
    x_brows_idxs = tl.arange(0, N_BROWS) + x_brows_start

    ocols_start = tl.program_id(1) * N_OCOLS
    col_idxs = tl.arange(0, N_OCOLS) + ocols_start

    w_dot_x_acc = tl.zeros((N_BROWS, N_OCOLS), dtype=acc_dtype)
    # w_sum_acc = tl.zeros((N_OCOLS, ), dtype=acc_dtype)
    x_sum_acc = tl.zeros((N_BROWS, ), dtype=acc_dtype)
    x_sq_sum_acc = tl.zeros((N_BROWS, ), dtype=acc_dtype)

    n_blocks = tl.cdiv(N_FEAT_IN, BLOCK_ROWS)
    for block_i in range(0, n_blocks):

        block_row_idxs = tl.arange(0, BLOCK_ROWS) + block_i * BLOCK_ROWS

        # Load the current block of the input.
        x_block_idxs = x_brows_idxs[:, None] * N_FEAT_IN + block_row_idxs[None, :]
        x_block = tl.load(x_ptr + x_block_idxs).to(acc_dtype) # [N_BROWS, BLOCK_ROWS]

        W_block_idxs = block_row_idxs[:, None] * N_FEAT_OUT + col_idxs[None, :]
        W_block = tl.load(W_ptr + W_block_idxs).to(acc_dtype) # [BLOCK_ROWS, N_OCOLS]

        # Update the accumulators.
        # [N_BROWS, BLOCK_ROWS] @ [BLOCK_ROWS, N_OCOLS] -> [N_BROWS, N_OCOLS]
        w_dot_x_acc += tl.dot(x_block, W_block).to(acc_dtype)
        # w_dot_x_acc += tl.sum(W_block * x_block[:, None], axis=0)
        # w_sum_acc += tl.sum(W_block, axis=0)
        x_sum_acc += tl.sum(x_block, axis=1)
        x_sq_sum_acc += tl.sum(x_block * x_block, axis=1)

    bias = tl.load(b_ptr + col_idxs)
    Wsum = tl.load(Ws_ptr + col_idxs)
    x_mean = x_sum_acc / N_FEAT_IN
    x_sq_mean = x_sq_sum_acc / N_FEAT_IN

    numer = w_dot_x_acc - x_mean[:, None] * Wsum[None, :] + bias[None, :]
    denom = tl.sqrt(tl.abs(x_sq_mean - x_mean * x_mean + 1e-5))
    out = fast_gelu(numer / denom[:, None])
    # out = numer / denom[:, None]
    # out = gelu(numer / denom[:, None])

    out_idxs = x_brows_idxs[:, None] * N_FEAT_OUT + col_idxs[None, :]
    tl.store(out_ptr + out_idxs, out.to(tl.float16))

## Native Triton (torch)

In [59]:
def triton_lola(x,
                W,
                b,
                N_OCOLS: int,
                BLOCK_ROWS: int,
                num_warps=4,):

    N_FEAT_IN, N_FEAT_OUT = W.shape
    grid = (N_FEAT_OUT // N_OCOLS, )

    # Allocate output buffer.
    out = torch.zeros((N_FEAT_OUT, ), dtype=torch.float32, device="cuda")

    # Launch the kernel.
    lola_kernel[grid](x,
                      W,
                      b,
                      out,
                      N_OCOLS=N_OCOLS,
                      BLOCK_ROWS=BLOCK_ROWS,
                      N_FEAT_IN=N_FEAT_IN,
                      N_FEAT_OUT=N_FEAT_OUT,
                      num_warps=num_warps)

    return out

In [60]:
ky = triton_lola(tx, tkernel, tbias, N_OCOLS=32, BLOCK_ROWS=32)

In [61]:
(ty - ky).max()

tensor(1.1921e-06, device='cuda:0')

In [62]:
%%timeit

triton_lola(tx, tkernel, tbias, N_OCOLS=32, BLOCK_ROWS=512)

29.7 µs ± 949 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)


## Batched Row

In [63]:
def batched_row_triton_lola(x,
                W,
                b,
                N_OCOLS: int,
                BLOCK_ROWS: int,
                num_warps=4,):

    N_BATCH = x.shape[0]
    N_FEAT_IN, N_FEAT_OUT = W.shape
    grid = (N_BATCH, N_FEAT_OUT // N_OCOLS, )

    # Allocate output buffer.
    out = torch.zeros((N_BATCH, N_FEAT_OUT), dtype=torch.float32, device="cuda")

    # Launch the kernel.
    batched_row_lola_kernel[grid](x,
                      W,
                      b,
                      out,
                      N_OCOLS=N_OCOLS,
                      BLOCK_ROWS=BLOCK_ROWS,
                      N_FEAT_IN=N_FEAT_IN,
                      N_FEAT_OUT=N_FEAT_OUT,
                      num_warps=num_warps)

    return out

In [64]:
bky = batched_row_triton_lola(tbx, tkernel, tbias, N_OCOLS=32, BLOCK_ROWS=512)

In [65]:
(tby - bky).max()

tensor(0.0052, device='cuda:0')

In [66]:
%%timeit -n10

batched_row_triton_lola(tbx, tkernel, tbias, N_OCOLS=32, BLOCK_ROWS=512)

The slowest run took 4.98 times longer than the fastest. This could mean that an intermediate result is being cached.
46.5 µs ± 40.5 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)


## Batched Multi-Row

In [67]:
def batched_mrow_triton_lola(
    x,
    W,
    b,
    N_OCOLS: int,
    N_BROWS: int,
    BLOCK_ROWS: int,
    num_warps=4,
    num_stages = 1,
):
    assert N_BROWS >= 16 and BLOCK_ROWS >= 16 and N_OCOLS >= 16, "Triton matrix multiplication requires matrix dimensions to be at least 16."

    N_BATCH = x.shape[0]
    N_FEAT_IN, N_FEAT_OUT = W.shape
    grid = (
        N_BATCH // N_BROWS,
        N_FEAT_OUT // N_OCOLS,
    )

    # Allocate output buffer.
    out = torch.zeros((N_BATCH, N_FEAT_OUT),
                      dtype=x.dtype,
                      device="cuda")

    # Launch the kernel.
    batched_mrow_lola_kernel[grid](x,
                                   W,
                                   b,
                                   out,
                                   N_OCOLS=N_OCOLS,
                                   N_BROWS=N_BROWS,
                                   BLOCK_ROWS=BLOCK_ROWS,
                                   N_FEAT_IN=N_FEAT_IN,
                                   N_FEAT_OUT=N_FEAT_OUT,
                                   num_warps=num_warps,
                                   num_stages=num_stages)

    return out

In [68]:
bmky = batched_mrow_triton_lola(tbx16, tkernel16, tbias16, N_OCOLS=16, N_BROWS=32, BLOCK_ROWS=256)
bmky.dtype

torch.float16

In [69]:
(tby - bmky).max()

tensor(0.0050, device='cuda:0')

In [70]:
(np.array(bfy) - np.array(bmky.cpu())).max()

0.004814148

## Batched Multi-row precomputed weights sum

In [112]:
def batched_mrow_ws_triton_lola(
    x,
    W,
    Ws,
    b,
    N_OCOLS: int,
    N_BROWS: int,
    BLOCK_ROWS: int,
    num_warps=4,
    num_stages = 1,
):
    assert N_BROWS >= 16 and BLOCK_ROWS >= 16 and N_OCOLS >= 16, "Triton matrix multiplication requires matrix dimensions to be at least 16."

    N_BATCH = x.shape[0]
    N_FEAT_IN, N_FEAT_OUT = W.shape
    grid = (
        N_BATCH // N_BROWS,
        N_FEAT_OUT // N_OCOLS,
    )

    # Allocate output buffer.
    out = torch.zeros((N_BATCH, N_FEAT_OUT),
                      dtype=x.dtype,
                      device="cuda")

    # Launch the kernel.
    batched_mrow_ws_lola_kernel[grid](x,
                                   W,
                                   Ws,
                                   b,
                                   out,
                                   N_OCOLS=N_OCOLS,
                                   N_BROWS=N_BROWS,
                                   BLOCK_ROWS=BLOCK_ROWS,
                                   N_FEAT_IN=N_FEAT_IN,
                                   N_FEAT_OUT=N_FEAT_OUT,
                                   num_warps=num_warps,
                                   num_stages=num_stages)

    return out

In [72]:
k2y = batched_mrow_ws_triton_lola(tbx16, tkernel16, tkernel_sum16, tbias16, N_OCOLS=16, N_BROWS=32, BLOCK_ROWS=256)
bmky.dtype

torch.float16

In [73]:
(np.array(bfy) - np.array(k2y.cpu())).max()

0.0055150986

### Debugging

In [74]:
w_dot_x = jnp.dot(bx, kernel)
w_sum = jnp.sum(kernel, axis = 0)
x_mean = jnp.mean(bx, axis=1)
x_sq_mean = jnp.mean(bx * bx, axis=1)

numer = w_dot_x - jnp.expand_dims(x_mean, axis=1) * w_sum + bias
denom = jnp.sqrt(jnp.abs(x_sq_mean - x_mean * x_mean + 1e-5))

out = GELU(numer / jnp.expand_dims(denom, axis=1)) # [BATCH_SIZE, N_FEAT_OUT]

In [75]:
(out - bfy).max()

Array(0.00148189, dtype=float32)

### Benchmarking

In [76]:
%%timeit -n100

batched_mrow_triton_lola(tbx, tkernel, tbias, N_OCOLS=16, N_BROWS=32, BLOCK_ROWS=128)
torch.cuda.synchronize()

The slowest run took 6.91 times longer than the fastest. This could mean that an intermediate result is being cached.
2.11 ms ± 2.36 ms per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [77]:
%%timeit -n1000

batched_mrow_triton_lola(tbx16, tkernel16, tbias16, N_OCOLS=128, N_BROWS=64, BLOCK_ROWS=64)

The slowest run took 6.95 times longer than the fastest. This could mean that an intermediate result is being cached.
345 µs ± 388 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


In [79]:
from triton.testing import do_bench
from functools import partial

In [113]:
print(f"{do_bench(partial(batched_mrow_ws_triton_lola, tbx16, tkernel16, tkernel_sum16, tbias16, N_OCOLS=128, N_BROWS=64, BLOCK_ROWS=64), warmup = 100, rep = 100)[0] * 1000:.2f} us")

133.12 us


In [124]:
print(f"{do_bench(partial(batched_mrow_triton_lola, tbx16, tkernel16, tbias16, N_OCOLS=128, N_BROWS=64, BLOCK_ROWS=64), warmup = 100, rep = 100)[0] * 1000:.2f} us")

198.66 us


In [125]:
key = [k for k in batched_mrow_lola_kernel.cache[0].keys() if k[-2] == (128, 64, 64, 1024, 4096)][0]
art = batched_mrow_lola_kernel.cache[0][key].asm
art.keys()

dict_keys(['ast', 'ttir', 'ttgir', 'llir', 'ptx', 'cubin'])

In [1]:
import struct
struct.unpack('!f', bytes.fromhex('3F19999A'))[0]

0.6000000238418579

In [126]:
print(art["ttir"])

module {
  func.func public @batched_mrow_lola_kernel_0d1d2d3d(%arg0: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg2: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg3: !tt.ptr<f16> {tt.divisibility = 16 : i32}) {
    %c16 = arith.constant 16 : index
    %cst = arith.constant dense<5.000000e-01> : tensor<64x128xf32>
    %cst_0 = arith.constant dense<4.471500e-02> : tensor<64x128xf32>
    %cst_1 = arith.constant dense<0.797884583> : tensor<64x128xf32>
    %cst_2 = arith.constant dense<1.000000e+00> : tensor<64x128xf32>
    %cst_3 = arith.constant dense<0.000000e+00> : tensor<64xf32>
    %cst_4 = arith.constant dense<0.000000e+00> : tensor<64xf16>
    %cst_5 = arith.constant dense<0.000000e+00> : tensor<128xf16>
    %cst_6 = arith.constant dense<0.000000e+00> : tensor<64x128xf16>
    %cst_7 = arith.constant dense<1.024000e+03> : tensor<64xf32>
    %c1 = arith.constant 1 : index
    %c0 = arith.constant 0 : index
    %cst_8 = arith.consta

In [127]:
print(art["ttgir"])

#blocked = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [2, 16], warpsPerCTA = [4, 1], order = [1, 0]}>
#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [2, 2], order = [0, 1]}>
#blocked2 = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}>
#blocked3 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>
#mma = #triton_gpu.mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [2, 2]}>
#shared = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0]}>
module attributes {"triton_gpu.num-warps" = 4 : i32} {
  func.func public @batched_mrow_lola_kernel_0d1d2d3d(%arg0: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg2: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg3: !tt.ptr<f16> {tt.divisibility = 16 : i32}) {
    %cst = arith.constant dense<1.024000e+03> : 

In [178]:
print(len(art["ptx"].splitlines()))

4341


In [197]:
print(summarize_ptx(art["ptx"]))

('mul.f32', 738)
('mov.b32', 716)
('fma.rn.ftz.f32', 523)
('add.f32', 265)
('bra.uni', 260)
('mov.b16', 227)
('add.f16', 212)
('fma.rn.f32', 199)
('or.b32', 189)
('add.s32', 163)
('cvt.rn.f16.f32', 161)
('bar.sync', 156)
('and.b32', 155)
('ld.shared.b16', 144)
('div.full.f32', 144)
('setp.ge.f32', 135)
('selp.f32', 135)
('abs.ftz.f32', 130)
('setp.ltu.f32', 130)
('ex2.approx.ftz.f32', 130)
('rcp.approx.ftz.f32', 130)
('cvt.u16.u32', 122)
('cvt.f32.f16', 108)
('mov.f32', 101)
('{', 96)
('shl.b32', 92)
('mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32', 80)
('st.shared.b16', 78)
('st.shared.v2.f32', 68)
('ld.shared.v4.f32', 43)
('mul.wide.s32', 42)
('add.s64', 42)
('@%p156', 42)
('st.shared.u16', 35)
('fma.rn.f16', 32)
('add.f16x2', 29)
('cvt.u32.u16', 24)
('shfl.sync.bfly.b32', 24)
('xor.b32', 20)
('st.shared.v4.u16', 17)
('ldmatrix.sync.aligned.m8n8.x4.trans.shared.b16', 17)
('mov.u32', 12)
('shr.u32', 12)
('bfe.u32', 10)
//
// Generated by LLVM NVPTX Back-End
//

.version 8.0
.targe

In [119]:
key = [k for k in batched_mrow_ws_lola_kernel.cache[0].keys() if k[-2] == (128, 64, 64, 1024, 4096)][0]
art_ws = batched_mrow_ws_lola_kernel.cache[0][key].asm
art_ws.keys()

dict_keys(['ast', 'ttir', 'ttgir', 'llir', 'ptx', 'cubin'])

In [120]:
print(art_ws["ttir"])

module {
  func.func public @batched_mrow_ws_lola_kernel_0d1d2d3d4d(%arg0: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg2: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg3: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg4: !tt.ptr<f16> {tt.divisibility = 16 : i32}) {
    %c16 = arith.constant 16 : index
    %cst = arith.constant dense<5.000000e-01> : tensor<64x128xf32>
    %cst_0 = arith.constant dense<4.471500e-02> : tensor<64x128xf32>
    %cst_1 = arith.constant dense<0.797884583> : tensor<64x128xf32>
    %cst_2 = arith.constant dense<1.000000e+00> : tensor<64x128xf32>
    %cst_3 = arith.constant dense<0.000000e+00> : tensor<64xf32>
    %cst_4 = arith.constant dense<0.000000e+00> : tensor<64xf16>
    %cst_5 = arith.constant dense<0.000000e+00> : tensor<64x128xf16>
    %cst_6 = arith.constant dense<1.024000e+03> : tensor<64xf32>
    %c1 = arith.constant 1 : index
    %c0 = arith.constant 0 : index
    %cst_7 = arith.constant dense<9.

In [122]:
print(art_ws["ttgir"])

#blocked = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [2, 16], warpsPerCTA = [4, 1], order = [1, 0]}>
#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [2, 2], order = [0, 1]}>
#blocked2 = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}>
#blocked3 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>
#mma = #triton_gpu.mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [2, 2]}>
#shared = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0]}>
module attributes {"triton_gpu.num-warps" = 4 : i32} {
  func.func public @batched_mrow_ws_lola_kernel_0d1d2d3d4d(%arg0: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg2: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg3: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg4: !tt.ptr<f16> {tt.divisibility = 16 : i

In [131]:
print(art_ws["ptx"])

//
// Generated by LLVM NVPTX Back-End
//

.version 8.0
.target sm_80
.address_size 64

	// .globl	batched_mrow_ws_lola_kernel_0d1d2d3d4d
.extern .shared .align 1 .b8 global_smem[];
.global .align 1 .b8 _$_str[11] = {95, 95, 67, 85, 68, 65, 95, 70, 84, 90, 0};
.global .align 1 .b8 _$_str_$_2[17] = {95, 95, 67, 85, 68, 65, 95, 80, 82, 69, 67, 95, 83, 81, 82, 84, 0};

.visible .entry batched_mrow_ws_lola_kernel_0d1d2d3d4d(
	.param .u64 batched_mrow_ws_lola_kernel_0d1d2d3d4d_param_0,
	.param .u64 batched_mrow_ws_lola_kernel_0d1d2d3d4d_param_1,
	.param .u64 batched_mrow_ws_lola_kernel_0d1d2d3d4d_param_2,
	.param .u64 batched_mrow_ws_lola_kernel_0d1d2d3d4d_param_3,
	.param .u64 batched_mrow_ws_lola_kernel_0d1d2d3d4d_param_4
)
.maxntid 128, 1, 1
{
	.reg .pred 	%p<162>;
	.reg .b16 	%h<780>;
	.reg .b16 	%rs<67>;
	.reg .b32 	%r<1385>;
	.reg .b32 	%hh<81>;
	.reg .f32 	%f<3144>;
	.reg .b64 	%rd<49>;

	ld.param.u64 	%rd5, [batched_mrow_ws_lola_kernel_0d1d2d3d4d_param_4];
	ld.param.u64 	%rd4, [batc

In [132]:
ptx = art_ws["ptx"]

In [195]:
from collections import Counter

def summarize_ptx(ptx):

    bolden = str.maketrans("ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789",
    "𝗔𝗕𝗖𝗗𝗘𝗙𝗚𝗛𝗜𝗝𝗞𝗟𝗠𝗡𝗢𝗣𝗤𝗥𝗦𝗧𝗨𝗩𝗪𝗫𝗬𝗭𝗮𝗯𝗰𝗱𝗲𝗳𝗴𝗵𝗶𝗷𝗸𝗹𝗺𝗻𝗼𝗽𝗾𝗿𝘀𝘁𝘂𝘃𝘄𝘅𝘆𝘇𝟬𝟭𝟮𝟯𝟰𝟱𝟲𝟳𝟴𝟵",)

    op_counter = Counter()
    ptx = ptx.splitlines()
    out_lines = []
    i = 0

    while i < len(ptx):

        curr_line = ptx[i]
        
        if len(curr_line) < 2 or curr_line[0] != "\t" or curr_line[1] == ".":
            out_lines.append(ptx[i])
            i += 1
            continue

        op, *args = curr_line.split()
        op_counter.update([op])

        j = i
        while j < len(ptx):
            nxt_line = ptx[j]

            try:
                nxt_op, *nxt_args = nxt_line.split()
            except ValueError:
                break

            # print(op, nxt_op, op == nxt_op)

            if nxt_op != op:
                break

            op_counter.update([nxt_op])
            j += 1

        out_lines.append(ptx[i])

        if j - 1 > i:
            num = str(j-i).translate(bolden)
            out_lines.append(f"\t...<and {num} more>")

        i = j

    print(*[(op, count) for op, count in op_counter.most_common() if count >= 10], sep="\n")
    return "\n".join(out_lines)

In [187]:
print(len(art_ws["ptx"].splitlines()))

4038


In [200]:
print(summarize_ptx(art_ws["ptx"]))

('mul.f32', 738)
('mov.b32', 719)
('fma.rn.ftz.f32', 523)
('add.f32', 265)
('bra.uni', 260)
('mov.b16', 219)
('fma.rn.f32', 199)
('or.b32', 187)
('cvt.rn.f16.f32', 161)
('and.b32', 155)
('add.f16', 154)
('div.full.f32', 144)
('setp.ge.f32', 135)
('selp.f32', 135)
('abs.ftz.f32', 130)
('setp.ltu.f32', 130)
('ex2.approx.ftz.f32', 130)
('rcp.approx.ftz.f32', 130)
('add.s32', 125)
('cvt.f32.f16', 108)
('mov.f32', 101)
('shl.b32', 88)
('ld.shared.b16', 87)
('mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32', 80)
('st.shared.v2.f32', 68)
('cvt.u16.u32', 58)
('bar.sync', 56)
('add.s64', 44)
('@%p154', 44)
('ld.shared.v4.f32', 43)
('mul.wide.s32', 42)
('st.shared.u16', 37)
('{', 32)
('fma.rn.f16', 32)
('cvt.u32.u16', 24)
('shfl.sync.bfly.b32', 24)
('st.shared.v4.b32', 21)
('xor.b32', 20)
('ldmatrix.sync.aligned.m8n8.x4.trans.shared.b16', 17)
('mov.u32', 12)
('shr.u32', 12)
('st.shared.b16', 12)
('bfe.u32', 10)
//
// Generated by LLVM NVPTX Back-End
//

.version 8.0
.target sm_80
.address_size

In [148]:
op, args = "".split(maxsplit=2)

ValueError: not enough values to unpack (expected 2, got 0)

In [146]:
ptx.splitlines()

['//',
 '// Generated by LLVM NVPTX Back-End',
 '//',
 '',
 '.version 8.0',
 '.target sm_80',
 '.address_size 64',
 '',
 '\t// .globl\tbatched_mrow_ws_lola_kernel_0d1d2d3d4d',
 '.extern .shared .align 1 .b8 global_smem[];',
 '.global .align 1 .b8 _$_str[11] = {95, 95, 67, 85, 68, 65, 95, 70, 84, 90, 0};',
 '.global .align 1 .b8 _$_str_$_2[17] = {95, 95, 67, 85, 68, 65, 95, 80, 82, 69, 67, 95, 83, 81, 82, 84, 0};',
 '',
 '.visible .entry batched_mrow_ws_lola_kernel_0d1d2d3d4d(',
 '\t.param .u64 batched_mrow_ws_lola_kernel_0d1d2d3d4d_param_0,',
 '\t.param .u64 batched_mrow_ws_lola_kernel_0d1d2d3d4d_param_1,',
 '\t.param .u64 batched_mrow_ws_lola_kernel_0d1d2d3d4d_param_2,',
 '\t.param .u64 batched_mrow_ws_lola_kernel_0d1d2d3d4d_param_3,',
 '\t.param .u64 batched_mrow_ws_lola_kernel_0d1d2d3d4d_param_4',
 ')',
 '.maxntid 128, 1, 1',
 '{',
 '\t.reg .pred \t%p<162>;',
 '\t.reg .b16 \t%h<780>;',
 '\t.reg .b16 \t%rs<67>;',
 '\t.reg .b32 \t%r<1385>;',
 '\t.reg .b32 \t%hh<81>;',
 '\t.reg .f32 \t

# Single vs multi row benchmarking and tuning

In [78]:
from triton.testing import do_bench
from functools import partial

In [71]:
n_ocols = 128
n_brows = 64
block_rows = 32

do_bench(partial(batched_mrow_triton_lola, tbx, tkernel, tbias, N_OCOLS=n_ocols, N_BROWS=n_brows, BLOCK_ROWS=block_rows), warmup = 1000, rep = 1000)[0] * 1000

495.61598896980286

In [81]:
# n_ocols=64, n_brows=64, block_rows=64: 157.70 us
# n_ocols=256, n_brows=32, block_rows=64: 143.36 us
# n_ocols=64, n_brows=128, block_rows=64: 143.36 us
# n_ocols=128, n_brows=64, block_rows=64: 132.10 us


for n_ocols in [16, 32, 64, 128, 256]:
    for n_brows in [16, 32, 64, 128, 256]:
        for block_rows in [16, 32, 64, 128, 256]:
            if n_ocols * n_brows * block_rows < 100_000:
                continue
            print(f"{n_ocols=}, {n_brows=}, {block_rows=}", end=": ")
            # 1_120_000 breaks
            if n_ocols * n_brows * block_rows > 1_020_000:
                print("skipped")
                continue
            print(f"{do_bench(partial(batched_mrow_ws_triton_lola, tbx16, tkernel16, tkernel_sum16, tbias16, N_OCOLS=n_ocols, N_BROWS=n_brows, BLOCK_ROWS=block_rows), warmup = 100, rep = 100)[0] * 1000:.2f} us")

n_ocols=16, n_brows=32, block_rows=256: 861.18 us
n_ocols=16, n_brows=64, block_rows=128: 435.20 us
n_ocols=16, n_brows=64, block_rows=256: 582.66 us
n_ocols=16, n_brows=128, block_rows=64: 338.94 us
n_ocols=16, n_brows=128, block_rows=128: 434.18 us
n_ocols=16, n_brows=128, block_rows=256: 772.10 us
n_ocols=16, n_brows=256, block_rows=32: 411.65 us
n_ocols=16, n_brows=256, block_rows=64: 363.52 us
n_ocols=16, n_brows=256, block_rows=128: 701.44 us
n_ocols=16, n_brows=256, block_rows=256: skipped
n_ocols=32, n_brows=16, block_rows=256: 474.11 us
n_ocols=32, n_brows=32, block_rows=128: 292.86 us
n_ocols=32, n_brows=32, block_rows=256: 361.47 us
n_ocols=32, n_brows=64, block_rows=64: 223.23 us
n_ocols=32, n_brows=64, block_rows=128: 253.95 us
n_ocols=32, n_brows=64, block_rows=256: 336.90 us
n_ocols=32, n_brows=128, block_rows=32: 232.45 us
n_ocols=32, n_brows=128, block_rows=64: 209.92 us
n_ocols=32, n_brows=128, block_rows=128: 266.24 us
n_ocols=32, n_brows=128, block_rows=256: skipped

In [259]:
bmky16 = batched_mrow_triton_lola(tbx16, tkernel16, tbias16, N_OCOLS=128, N_BROWS=64, BLOCK_ROWS=32)

In [260]:
(bmky16 - bmky).max()

tensor(0.0107, device='cuda:0', dtype=torch.float16)

In [261]:
do_bench(partial(batched_mrow_triton_lola, tbx16, tkernel16, tbias16, N_OCOLS=n_ocols, N_BROWS=n_brows, BLOCK_ROWS=block_rows), warmup = 1000, rep = 1000)[0] * 1000

295.9359884262085

In [262]:
for n_ocols in [16, 32, 64, 128, 256]:
    for n_brows in [16, 32, 64, 128, 256]:
        for block_rows in [16, 32, 64, 128, 256]:
            print(f"{n_ocols=}, {n_brows=}, {block_rows=}", end=": ")
            if n_ocols * n_brows * block_rows / 2 > 210_000:
                print("skipped")
                continue
            print(f"{do_bench(partial(batched_mrow_triton_lola, tbx, tkernel, tbias, N_OCOLS=n_ocols, N_BROWS=n_brows, BLOCK_ROWS=block_rows), warmup = 100, rep = 100)[0] * 1000:.2f} us")

n_ocols=16, n_brows=16, block_rows=16: 2652.16 us
n_ocols=16, n_brows=16, block_rows=32: 1923.07 us
n_ocols=16, n_brows=16, block_rows=64: 1715.20 us
n_ocols=16, n_brows=16, block_rows=128: 1796.10 us
n_ocols=16, n_brows=16, block_rows=256: 

RuntimeError: CUDA error: an illegal memory access was encountered
CUDA kernel errors might be asynchronously reported at some other API call,so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1.

In [None]:
2 ** 19

524288

In [None]:
for n_ocols in [16, 32, 64, 128, 256, 512]:
    for block_rows in [16, 32, 64, 128, 256, 512]:
        if n_ocols * block_rows < 0:
            # print("skipped")
            continue
        if n_ocols * block_rows / 4 > 166_000:
            # print("skipped")
            continue
        print(f"{n_ocols=}, {block_rows=}", end=": ")
        print(f"{do_bench(partial(batched_row_triton_lola, tbx16, tkernel16, tbias16, N_OCOLS=n_ocols, BLOCK_ROWS=block_rows), warmup = 100, rep = 100)[0] * 1000:.2f} us")

n_ocols=16, block_rows=16: 42020.35 us
n_ocols=16, block_rows=32: 48535.55 us
n_ocols=16, block_rows=64: 72646.65 us
n_ocols=16, block_rows=128: 42031.62 us
n_ocols=16, block_rows=256: 23421.44 us
n_ocols=16, block_rows=512: 14461.95 us
n_ocols=32, block_rows=16: 39963.65 us
n_ocols=32, block_rows=32: 61773.82 us
n_ocols=32, block_rows=64: 35853.31 us
n_ocols=32, block_rows=128: 23150.59 us
n_ocols=32, block_rows=256: 13998.08 us
n_ocols=32, block_rows=512: 10099.71 us
n_ocols=64, block_rows=16: 50528.26 us
n_ocols=64, block_rows=32: 30286.85 us
n_ocols=64, block_rows=64: 19648.51 us
n_ocols=64, block_rows=128: 16513.02 us
n_ocols=64, block_rows=256: 9954.82 us
n_ocols=64, block_rows=512: 7514.11 us
n_ocols=128, block_rows=16: 25782.27 us
n_ocols=128, block_rows=32: 18176.00 us
n_ocols=128, block_rows=64: 15496.70 us
n_ocols=128, block_rows=128: 16365.06 us
n_ocols=128, block_rows=256: 10807.30 us
n_ocols=128, block_rows=512: 130162.69 us
n_ocols=256, block_rows=16: 16906.24 us
n_ocols

In [None]:
for n_ocols in [16, 32, 64, 128, 256, 512]:
    for n_brows in [16, 32, 64, 128, 256, 512]:
        for block_rows in [16, 32, 64, 128, 256, 512]:
            if n_ocols * n_brows * block_rows < 200_000:
                # print("skipped")
                continue
            if n_ocols * n_brows * block_rows / 4 > 166_000:
                # print("skipped")
                continue
            print(f"{n_ocols=}, {n_brows=}, {block_rows=}", end=": ")
            print(f"{do_bench(partial(batched_mrow_triton_lola, tbx16, tkernel16, tbias16, N_OCOLS=n_ocols, N_BROWS=n_brows, BLOCK_ROWS=block_rows), warmup = 100, rep = 100)[0] * 1000:.2f} us")

n_ocols=16, n_brows=32, block_rows=512: 18827.78 us
n_ocols=16, n_brows=64, block_rows=256: 13317.12 us
n_ocols=16, n_brows=64, block_rows=512: 317147.13 us
n_ocols=16, n_brows=128, block_rows=128: 10800.13 us
n_ocols=16, n_brows=128, block_rows=256: 26132.48 us
n_ocols=16, n_brows=256, block_rows=64: 9798.14 us
n_ocols=16, n_brows=256, block_rows=128: 15298.56 us
n_ocols=16, n_brows=512, block_rows=32: 7733.25 us
n_ocols=16, n_brows=512, block_rows=64: 8606.72 us
n_ocols=32, n_brows=16, block_rows=512: 12203.01 us
n_ocols=32, n_brows=32, block_rows=256: 9525.76 us
n_ocols=32, n_brows=32, block_rows=512: 10292.22 us
n_ocols=32, n_brows=64, block_rows=128: 6738.94 us
n_ocols=32, n_brows=64, block_rows=256: 7048.19 us
n_ocols=32, n_brows=128, block_rows=64: 5837.82 us
n_ocols=32, n_brows=128, block_rows=128: 5624.83 us
n_ocols=32, n_brows=256, block_rows=32: 6349.82 us
n_ocols=32, n_brows=256, block_rows=64: 5102.59 us
n_ocols=32, n_brows=512, block_rows=16: 5923.84 us
n_ocols=32, n_brow

In [316]:
for n_ocols in [16, 32, 64, 128, 256, 512]:
    for n_brows in [16, 32, 64, 128, 256, 512]:
        for block_rows in [16, 32, 64, 128, 256, 512]:
            if n_ocols * n_brows * block_rows < 200_000:
                # print("skipped")
                continue
            if n_ocols * n_brows * block_rows / 4 > 166_000:
                # print("skipped")
                continue
            print(f"{n_ocols=}, {n_brows=}, {block_rows=}", end=": ")
            print(f"{do_bench(partial(batched_mrow_triton_lola, tbx16, tkernel16, tbias16, N_OCOLS=n_ocols, N_BROWS=n_brows, BLOCK_ROWS=block_rows), warmup = 100, rep = 100)[0] * 1000:.2f} us")

n_ocols=16, n_brows=32, block_rows=512: 17337.34 us
n_ocols=16, n_brows=64, block_rows=256: 891.90 us
n_ocols=16, n_brows=64, block_rows=512: 20571.14 us
n_ocols=16, n_brows=128, block_rows=128: 744.45 us
n_ocols=16, n_brows=128, block_rows=256: 1718.27 us
n_ocols=16, n_brows=256, block_rows=64: 702.46 us
n_ocols=16, n_brows=256, block_rows=128: 1714.18 us
n_ocols=16, n_brows=512, block_rows=32: 638.98 us
n_ocols=16, n_brows=512, block_rows=64: 754.69 us
n_ocols=32, n_brows=16, block_rows=512: 891.90 us
n_ocols=32, n_brows=32, block_rows=256: 648.19 us
n_ocols=32, n_brows=32, block_rows=512: 765.95 us
n_ocols=32, n_brows=64, block_rows=128: 466.94 us
n_ocols=32, n_brows=64, block_rows=256: 488.45 us
n_ocols=32, n_brows=128, block_rows=64: 456.70 us
n_ocols=32, n_brows=128, block_rows=128: 390.14 us
n_ocols=32, n_brows=256, block_rows=32: 479.23 us
n_ocols=32, n_brows=256, block_rows=64: 402.43 us
n_ocols=32, n_brows=512, block_rows=16: 613.38 us
n_ocols=32, n_brows=512, block_rows=32: 

In [262]:
n_ocols = 128
n_brows = 64
block_rows = 32

for num_warps in [1, 2, 4, 8, 16, 32, 64]:
    for num_stages in [1, 2, 4, 8, 16, 32, 64]:
        print(
            f"{n_ocols=}, {n_brows=}, {block_rows=}, {num_warps=}, {num_stages=}",
            end=": ")
        if n_ocols * n_brows * block_rows / 2 > 210_000:
            print("skipped")
            continue
        print(
            f"{do_bench(partial(batched_mrow_triton_lola, tbx, tkernel, tbias, N_OCOLS=n_ocols, N_BROWS=n_brows, BLOCK_ROWS=block_rows, num_warps=num_warps, num_stages=num_stages), warmup = 100, rep = 100)[0] * 1000:.2f} us"
        )

n_ocols=128, n_brows=64, block_rows=32, num_warps=1, num_stages=1: 430.08 us
n_ocols=128, n_brows=64, block_rows=32, num_warps=1, num_stages=2: 343.04 us
n_ocols=128, n_brows=64, block_rows=32, num_warps=1, num_stages=4: 342.02 us
n_ocols=128, n_brows=64, block_rows=32, num_warps=1, num_stages=8: 342.02 us
n_ocols=128, n_brows=64, block_rows=32, num_warps=1, num_stages=16: 342.02 us
n_ocols=128, n_brows=64, block_rows=32, num_warps=1, num_stages=32: 342.02 us
n_ocols=128, n_brows=64, block_rows=32, num_warps=1, num_stages=64: 343.04 us
n_ocols=128, n_brows=64, block_rows=32, num_warps=2, num_stages=1: 343.04 us
n_ocols=128, n_brows=64, block_rows=32, num_warps=2, num_stages=2: 343.04 us
n_ocols=128, n_brows=64, block_rows=32, num_warps=2, num_stages=4: 342.02 us
n_ocols=128, n_brows=64, block_rows=32, num_warps=2, num_stages=8: 342.02 us
n_ocols=128, n_brows=64, block_rows=32, num_warps=2, num_stages=16: 342.02 us
n_ocols=128, n_brows=64, block_rows=32, num_warps=2, num_stages=32: 342.

In [260]:
for n_ocols in [64, 128, 256]:
    for n_brows in [32, 64, 128]:
        for block_rows in [16, 32, 64]:
            for num_warps in [1, 2, 4, 8, 16]:
                for num_stages in [1, 2, 4, 8, 16]:
                    print(
                        f"{n_ocols=}, {n_brows=}, {block_rows=}, {num_warps=}, {num_stages=}",
                        end=": ")
                    if n_ocols * n_brows * block_rows / 2 > 210_000:
                        print("skipped")
                        continue
                    print(
                        f"{do_bench(partial(batched_mrow_triton_lola, tbx, tkernel, tbias, N_OCOLS=n_ocols, N_BROWS=n_brows, BLOCK_ROWS=block_rows, num_warps=num_warps, num_stages=num_stages), warmup = 100, rep = 100)[0] * 1000:.2f} us"
                    )

n_ocols=64, n_brows=32, block_rows=16, num_warps=1, num_stages=1: 865.28 us
n_ocols=64, n_brows=32, block_rows=16, num_warps=1, num_stages=2: 839.68 us
n_ocols=64, n_brows=32, block_rows=16, num_warps=1, num_stages=4: 838.66 us
n_ocols=64, n_brows=32, block_rows=16, num_warps=1, num_stages=8: 839.68 us
n_ocols=64, n_brows=32, block_rows=16, num_warps=1, num_stages=16: 839.68 us
n_ocols=64, n_brows=32, block_rows=16, num_warps=2, num_stages=1: 839.17 us
n_ocols=64, n_brows=32, block_rows=16, num_warps=2, num_stages=2: 839.68 us
n_ocols=64, n_brows=32, block_rows=16, num_warps=2, num_stages=4: 839.68 us
n_ocols=64, n_brows=32, block_rows=16, num_warps=2, num_stages=8: 837.63 us
n_ocols=64, n_brows=32, block_rows=16, num_warps=2, num_stages=16: 839.68 us
n_ocols=64, n_brows=32, block_rows=16, num_warps=4, num_stages=1: 839.68 us
n_ocols=64, n_brows=32, block_rows=16, num_warps=4, num_stages=2: 839.68 us
n_ocols=64, n_brows=32, block_rows=16, num_warps=4, num_stages=4: 838.66 us
n_ocols=64

KeyboardInterrupt: 

# Jax Triton

In [33]:
import jax_triton as jt

In [34]:
def jt_lola(x, W, b, N_OCOLS: int, BLOCK_ROWS: int):
    N_FEAT_IN, N_FEAT_OUT = W.shape
    grid = (N_FEAT_OUT // N_OCOLS, )

    out_shape = jax.ShapeDtypeStruct((N_FEAT_OUT, ), jnp.float32)

    return jt.triton_call(x,
                          W,
                          b,
                          kernel=lola_kernel,
                          out_shape=out_shape,
                          grid=grid,
                          N_OCOLS=N_OCOLS,
                          BLOCK_ROWS=BLOCK_ROWS,
                          N_FEAT_IN=N_FEAT_IN,
                          N_FEAT_OUT=N_FEAT_OUT)

In [35]:
jty = jt_lola(x, kernel, bias, N_OCOLS=32, BLOCK_ROWS=512)

In [36]:
(fy - jty).max()

Array(1.692772e-05, dtype=float32)

In [37]:
%%timeit

jt_lola(x, kernel, bias, N_OCOLS=32, BLOCK_ROWS=512)

195 µs ± 9.24 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


In [38]:
from triton.testing import do_bench
from functools import partial

In [39]:
metaparam_sizes = list(map(int, (2 ** np.arange(np.log2(n_embd) + 1))))
metaparam_sizes

[1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024]

In [40]:
%%script false --no-raise-error

for mps in metaparam_sizes:
    print(mps, do_bench(partial(triton_lola, tx, tkernel, tbias, N_OCOLS=mps, BLOCK_ROWS=32), warmup=100, rep=100)[0] * 1000)

In [41]:
%%script false --no-raise-error

for mps in metaparam_sizes:
    print(mps, do_bench(partial(triton_lola, tx, tkernel, tbias, N_OCOLS=32, BLOCK_ROWS=mps), warmup=100, rep=100)[0] * 1000)

In [42]:
for n_ocols in [4, 8, 16, 32]:
    for block_rows in [32, 64, 128, 256, 512, 1024]:
        print(n_ocols, block_rows, do_bench(partial(triton_lola, tx, tkernel, tbias, N_OCOLS=n_ocols, BLOCK_ROWS=block_rows), warmup=100, rep=100)[0] * 1000)

ImportError: /home/trist/notebooks/lola_kernel.so: cannot open shared object file: No such file or directory

`num_warps` does not appear to affect the performance of the kernel.

In [None]:
%%script false --no-raise-error

for num_warps in [1, 2, 4, 8, 16, 32, 64, 128]:
    print(f"{num_warps=}", do_bench(partial(triton_lola, tx, tkernel, tbias, N_OCOLS=8, BLOCK_ROWS=512, num_warps=num_warps), warmup=100, rep=100)[0] * 1000)

In [None]:
n_embd = 8192
x = jax.random.normal(key, (n_embd,))
params = FlaxLola(features = 4 * n_embd).init(key, x)

kernel = params["params"]["Dense_0"]["kernel"]
bias = params["params"]["Dense_0"]["bias"]

In [None]:
fl_apply = jax.jit(FlaxLola(features = 4 * n_embd).apply)
fl_apply(params, x)

Array([ 0.39584607, -0.15986574,  0.6896521 , ...,  0.20229153,
       -0.16803958, -0.1277257 ], dtype=float32)

In [None]:
%%timeit -n 100

fl_apply(params, x)

1.45 ms ± 97.7 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [None]:
jit_jpl(x, kernel, bias).shape

(32768,)

In [None]:
%%timeit -n 100

jit_jpl(x, kernel, bias)

1.42 ms ± 166 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [None]:
tweights = torch.tensor(np.array(params["params"]["Dense_0"]["kernel"]), device="cuda").T
tbias = torch.tensor(np.array(params["params"]["Dense_0"]["bias"]), device="cuda")
tx = torch.tensor(np.array(x), device="cuda")

tkernel = tweights.T

In [None]:
triton_lola(tx, tkernel, tbias, N_OCOLS=32, BLOCK_ROWS=32).shape

torch.Size([32768])

In [None]:
for n_ocols in [4, 8, 16, 32, 64]:
    for block_rows in [128, 256, 512, 1024, 2048]:
        print(n_ocols, block_rows, do_bench(partial(triton_lola, tx, tkernel, tbias, N_OCOLS=n_ocols, BLOCK_ROWS=block_rows), warmup=100, rep=100)[0] * 1000)

4 128 2498.5599517822266
4 256 2411.520004272461
4 512 2513.9200687408447
4 1024 2570.240020751953
4 2048 2665.4720306396484
8 128 1842.1759605407715
8 256 1453.05597782135
8 512 1356.7999601364136
8 1024 1358.847975730896
8 2048 1291.2640571594238
16 128 1698.815941810608
16 256 1187.3279809951782
16 512 1086.4640474319458
16 1024 1047.551989555359
16 2048 7900.15983581543
32 128 1203.1999826431274
32 256 865.2799725532532
32 512 809.984028339386
32 1024 10091.520309448242
32 2048 9367.551803588867
64 128 1485.8239889144897
64 256 1266.6879892349243
64 512 11991.552352905273
64 1024 11247.103691101074
64 2048 11044.351577758789


In [None]:
%%timeit -n1000

triton_lola(tx, tkernel, tbias, N_OCOLS=32, BLOCK_ROWS=512)

733 µs ± 141 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


In [None]:
xla = jax.xla_computation(jit_jpl)(x, kernel, bias)

In [None]:
print(xla.as_hlo_text())

HloModule xla_computation_jax_pieces_lola, entry_computation_layout={(f32[8192]{0},f32[8192,32768]{1,0},f32[32768]{0})->(f32[32768]{0})}

region_0.4 {
  Arg_0.5 = f32[] parameter(0)
  Arg_1.6 = f32[] parameter(1)
  ROOT add.7 = f32[] add(Arg_0.5, Arg_1.6)
}

region_1.8 {
  Arg_0.9 = f32[] parameter(0)
  Arg_1.10 = f32[] parameter(1)
  ROOT add.11 = f32[] add(Arg_0.9, Arg_1.10)
}

region_2.12 {
  Arg_0.13 = f32[] parameter(0)
  Arg_1.14 = f32[] parameter(1)
  ROOT add.15 = f32[] add(Arg_0.13, Arg_1.14)
}

jax_pieces_lola.16 {
  Arg_0.17 = f32[8192]{0} parameter(0)
  Arg_1.18 = f32[8192,32768]{1,0} parameter(1)
  dot.31 = f32[32768]{0} dot(Arg_0.17, Arg_1.18), lhs_contracting_dims={0}, rhs_contracting_dims={0}
  constant.30 = f32[] constant(0)
  reduce.33 = f32[] reduce(Arg_0.17, constant.30), dimensions={0}, to_apply=region_1.8
  constant.29 = f32[] constant(8192)
  divide.34 = f32[] divide(reduce.33, constant.29)
  broadcast.38 = f32[32768]{0} broadcast(divide.34), dimensions={}
  redu

In [None]:
jax.make_jaxpr(jit_jpl)(x, kernel, bias)

{ lambda ; a:f32[8192] b:f32[8192,32768] c:f32[32768]. let
    d:f32[32768] = xla_call[
      call_jaxpr={ lambda ; e:f32[8192] f:f32[8192,32768] g:f32[32768]. let
          h:f32[32768] = dot_general[dimension_numbers=(([0], [0]), ([], []))] e
            f
          i:f32[32768] = reduce_sum[axes=(0,)] f
          j:f32[] = reduce_sum[axes=(0,)] e
          k:f32[] = div j 8192.0
          l:f32[8192] = mul e e
          m:f32[] = reduce_sum[axes=(0,)] l
          n:f32[] = div m 8192.0
          o:f32[32768] = mul k i
          p:f32[32768] = sub h o
          q:f32[32768] = add p g
          r:f32[] = mul k k
          s:f32[] = sub n r
          t:f32[] = add s 9.999999747378752e-06
          u:f32[] = abs t
          v:f32[] = sqrt u
          w:f32[32768] = div q v
          x:f32[32768] = mul 0.5 w
          y:f32[] = sqrt 0.6366197723675814
          z:f32[32768] = integer_pow[y=3] w
          ba:f32[32768] = mul 0.044714998453855515 z
          bb:f32[32768] = add w ba
      

In [None]:
with jax.profiler.trace("/tmp/jax-trace", create_perfetto_link=True):
  jit_jpl(x, kernel, bias).block_until_ready()

2023-02-19 02:08:13.529812: E external/org_tensorflow/tensorflow/compiler/xla/python/profiler/internal/python_hooks.cc:398] Can't import tensorflow.python.profiler.trace
2023-02-19 02:08:13.556552: E external/org_tensorflow/tensorflow/compiler/xla/python/profiler/internal/python_hooks.cc:398] Can't import tensorflow.python.profiler.trace


Open URL in browser: https://ui.perfetto.dev/#!/?url=http://127.0.0.1:9001/perfetto_trace.json.gz


127.0.0.1 - - [19/Feb/2023 02:08:19] code 404, message File not found
127.0.0.1 - - [19/Feb/2023 02:08:19] "POST /status HTTP/1.1" 404 -
127.0.0.1 - - [19/Feb/2023 02:08:19] "GET /perfetto_trace.json.gz HTTP/1.1" 200 -


In [None]:
with jax.profiler.trace("/tmp/jax-trace", create_perfetto_link=True):
  triton_lola(tx, tkernel, tbias, N_OCOLS=32, BLOCK_ROWS=512)

2023-02-19 02:09:01.033134: E external/org_tensorflow/tensorflow/compiler/xla/python/profiler/internal/python_hooks.cc:398] Can't import tensorflow.python.profiler.trace
2023-02-19 02:09:01.058523: E external/org_tensorflow/tensorflow/compiler/xla/python/profiler/internal/python_hooks.cc:398] Can't import tensorflow.python.profiler.trace


Open URL in browser: https://ui.perfetto.dev/#!/?url=http://127.0.0.1:9001/perfetto_trace.json.gz


127.0.0.1 - - [19/Feb/2023 02:09:04] code 404, message File not found
127.0.0.1 - - [19/Feb/2023 02:09:04] "POST /status HTTP/1.1" 404 -
127.0.0.1 - - [19/Feb/2023 02:09:05] "GET /perfetto_trace.json.gz HTTP/1.1" 200 -


In [None]:
fl_apply = jax.jit(FlaxLola(features = 4 * n_embd).apply)
fl_apply(params, x)

Array([-0.12473921,  0.9696205 ,  1.3021432 , ...,  0.02534572,
       -0.15674949,  0.5233645 ], dtype=float32)

In [None]:
%%timeit -n 1000

fl_apply(params, x)

70.2 µs ± 4.09 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


In [None]:
with jax.profiler.trace("/tmp/jax-trace", create_perfetto_link=True):
  fl_apply(params, x).block_until_ready()

RuntimeError: Profile has already been started. Only one profile may be run at a time.