# Lola L. (Layer nOrm + Linear + Activation + Linear)

Based on `lola.ipynb`, implement batched float16 lola, with an additional linear and stack these layers to simulate the amount of computation performed by GPT.

It looks like batching is required for Triton to beat jax/pytorch.

In [2]:
import os

os.environ['XLA_PYTHON_CLIENT_MEM_FRACTION'] = '0.5'

In [3]:
# Simulate gpt-medium
n_embd = 1024
n_layer = 24

n_btch = 1024

# Flax Reference

In [4]:
import flax.linen as nn
import jax.numpy as jnp
import jax

from nimblegpt import param_shapes

  from .autonotebook import tqdm as notebook_tqdm


In [5]:
key = jax.random.PRNGKey(0)
x = jax.random.normal(key, (n_btch, n_embd), dtype=jnp.float16)

In [6]:
def GELU(x):
    return 0.5 * x * (1.0 +
                      jnp.tanh(jnp.sqrt(2.0 / jnp.pi) * (x + 0.044715 * x**3)))


class FlaxLolal(nn.Module):
    features: int

    @nn.compact
    def __call__(self, x):
        *_, n_embd = x.shape

        for _ in range(n_layer):
            x = nn.LayerNorm(use_bias=False, use_scale=False)(x)
            x = nn.Dense(self.features)(x)
            x = GELU(x)
            x = nn.Dense(n_embd)(x)

        return x

In [7]:
fl_module = FlaxLolal(features=4 * n_embd)
params = fl_module.init(key, x)
params = jax.tree_util.tree_map(lambda x: x.astype(jnp.float16), params)

fl_apply = jax.jit(fl_module.apply)
fy = fl_apply(params, x)

In [8]:
%%timeit -n 100

fl_apply(params, x).block_until_ready()

3.77 ms ± 205 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [9]:
# param_shapes(params)

# Pytorch reference

In [10]:
import torch.nn.functional as F
import torch
import numpy as np

In [11]:
tkernels = [
    torch.tensor(np.array(params["params"][f"Dense_{i}"]["kernel"]),
                 device="cuda") for i in range(2 * n_layer)
]
tweights = [tk.T for tk in tkernels]
tbias = [
    torch.tensor(np.array(params["params"][f"Dense_{i}"]["bias"]),
                 device="cuda") for i in range(2 * n_layer)
]

tkernel1s = tkernels[0::2]
tkernel2s = tkernels[1::2]
tweight1s = tweights[0::2]
tweight2s = tweights[1::2]
tbias1s = tbias[0::2]
tbias2s = tbias[1::2]

tx = torch.tensor(np.array(x), device="cuda")

In [12]:
def torch_lolal(x, weight1s, bias1s, weight2s, bias2s):

    for w1, b1, w2, b2 in zip(weight1s, bias1s, weight2s, bias2s):
        x = F.layer_norm(x, (n_embd, ))
        x = F.linear(x, w1, b1)
        x = F.gelu(x)
        x = F.linear(x, w2, b2)

    return x

In [13]:
ty = torch_lolal(tx, tweight1s, tbias1s, tweight2s, tbias2s)
ty.shape

torch.Size([1024, 1024])

In [14]:
(fy - ty.cpu().numpy()).max()

Array(0.01904, dtype=float16)

In [15]:
%%timeit -n 100

torch_lolal(tx, tweight1s, tbias1s, tweight2s, tbias2s)
torch.cuda.synchronize()

3.04 ms ± 126 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


# Jax Incremental Reference

In [16]:
from jax import lax

# Triton Batched Multi-row

In [17]:
import triton
import triton.language as tl

In [18]:
import math

sqrt2pi = math.sqrt(2.0 / math.pi)

@triton.jit
def tanh(x):
    """Tanh activation function"""
    return tl.libdevice.tanh(x)

@triton.jit
def fast_gelu(x):
    """Fast approximation of the gelu function. May slightly decrease accuracy."""
    return 0.5 * x * (1 + tanh(sqrt2pi * (x + 0.044715 * x * x * x)))

In [62]:
acc_dtype = tl.float16


@triton.jit
def lola_kernel(x_ptr, W_ptr, b_ptr, out_ptr, N_OCOLS: tl.constexpr,
                N_BROWS: tl.constexpr, BLOCK_ROWS: tl.constexpr,
                N_FEAT_IN: tl.constexpr, N_FEAT_OUT: tl.constexpr):
    """
    Triton kernel implementing fused Layer nOrm, Linear and Activation.

    Kernel cell (i, j) computes 
    out[i * N_BROWS: (i+1) * N_BROWS, j * N_OCOLS: (j+1) * N_OCOLS], by iterating over
    `BLOCK_ROWS`-sized blocks of the inputs.

    Inputs
    ------
    x_ptr: [BATCH_SIZE, N_FEAT_IN,] - current token embedding.
    W_ptr: [N_FEAT_IN, N_FEAT_OUT] - linear layer weights.
    b_ptr: [N_FEAT_OUT,] - linear layer bias.

    Outputs
    -------
    out_ptr: [N_FEAT_OUT,] - output of the fused layer.
    """
    x_brows_start = tl.program_id(0) * N_BROWS
    # This instance will process x[b_rows_start:b_rows_start + N_BROWS, :]
    x_brows_idxs = tl.arange(0, N_BROWS) + x_brows_start

    ocols_start = tl.program_id(1) * N_OCOLS
    col_idxs = tl.arange(0, N_OCOLS) + ocols_start

    w_dot_x_acc = tl.zeros((N_BROWS, N_OCOLS), dtype=acc_dtype)
    w_sum_acc = tl.zeros((N_OCOLS, ), dtype=acc_dtype)
    x_sum_acc = tl.zeros((N_BROWS, ), dtype=acc_dtype)
    x_sq_sum_acc = tl.zeros((N_BROWS, ), dtype=acc_dtype)

    n_blocks = tl.cdiv(N_FEAT_IN, BLOCK_ROWS)
    for block_i in range(0, n_blocks):

        block_row_idxs = tl.arange(0, BLOCK_ROWS) + block_i * BLOCK_ROWS

        # Load the current block of the input.
        x_block_idxs = x_brows_idxs[:,
                                    None] * N_FEAT_IN + block_row_idxs[None, :]
        x_block = tl.load(x_ptr + x_block_idxs).to(
            acc_dtype)  # [N_BROWS, BLOCK_ROWS]

        W_block_idxs = block_row_idxs[:, None] * N_FEAT_OUT + col_idxs[None, :]
        W_block = tl.load(W_ptr + W_block_idxs).to(
            acc_dtype)  # [BLOCK_ROWS, N_OCOLS]

        # Update the accumulators.
        # [N_BROWS, BLOCK_ROWS] @ [BLOCK_ROWS, N_OCOLS] -> [N_BROWS, N_OCOLS]
        w_dot_x_acc += tl.dot(x_block, W_block).to(acc_dtype)
        # w_dot_x_acc += tl.sum(W_block * x_block[:, None], axis=0)
        w_sum_acc += tl.sum(W_block, axis=0)
        x_sum_acc += tl.sum(x_block, axis=1)
        x_sq_sum_acc += tl.sum(x_block * x_block, axis=1)

    bias = tl.load(b_ptr + col_idxs)
    x_mean = x_sum_acc / N_FEAT_IN
    x_sq_mean = x_sq_sum_acc / N_FEAT_IN

    numer = w_dot_x_acc - x_mean[:, None] * w_sum_acc[None, :] + bias[None, :]
    denom = tl.sqrt(tl.abs(x_sq_mean - x_mean * x_mean + 1e-5))
    out = fast_gelu(numer / denom[:, None])

    out_idxs = x_brows_idxs[:, None] * N_FEAT_OUT + col_idxs[None, :]
    tl.store(out_ptr + out_idxs, out.to(tl.float16))

In [63]:
def triton_lola(
    x,
    W,
    b,
    N_OCOLS: int,
    N_BROWS: int,
    BLOCK_ROWS: int,
    num_warps=4,
    num_stages=1,
):
    assert N_BROWS >= 16 and BLOCK_ROWS >= 16 and N_OCOLS >= 16, "Triton matrix multiplication requires matrix dimensions to be at least 16."

    N_BATCH = x.shape[0]
    N_FEAT_IN, N_FEAT_OUT = W.shape
    grid = (
        N_BATCH // N_BROWS,
        N_FEAT_OUT // N_OCOLS,
    )

    # Allocate output buffer.
    out = torch.zeros((N_BATCH, N_FEAT_OUT), dtype=x.dtype, device="cuda")

    # Launch the kernel.
    lola_kernel[grid](x,
                      W,
                      b,
                      out,
                      N_OCOLS=N_OCOLS,
                      N_BROWS=N_BROWS,
                      BLOCK_ROWS=BLOCK_ROWS,
                      N_FEAT_IN=N_FEAT_IN,
                      N_FEAT_OUT=N_FEAT_OUT,
                      num_warps=num_warps,
                      num_stages=num_stages)

    return out

In [64]:
def triton_lolal(x, kernel1s, bias1s, weight2s, bias2s, **kernel_kwargs):

    for k1, b1, w2, b2 in zip(kernel1s, bias1s, weight2s, bias2s):
        x = triton_lola(x, k1, b1, **kernel_kwargs)
        x = F.linear(x, w2, b2)
    return x

In [65]:
ky = triton_lolal(tx, tkernel1s, tbias1s, tweight2s, tbias2s, N_OCOLS=16, N_BROWS=16, BLOCK_ROWS=16, num_warps=4, num_stages=1)
ky.shape

torch.Size([1024, 1024])

In [66]:
(ty - ky).abs().max()

tensor(0.0426, device='cuda:0', dtype=torch.float16)

: 

# Triton Batched Multi-row Transposed

We test transposing the weights matrix to see if it improves performance.

In [48]:
acc_dtype = tl.float16


@triton.jit
def lola_trans_kernel(x_ptr, Wt_ptr, b_ptr, out_ptr, N_OCOLS: tl.constexpr,
                N_BROWS: tl.constexpr, BLOCK_ROWS: tl.constexpr,
                N_FEAT_IN: tl.constexpr, N_FEAT_OUT: tl.constexpr):
    """
    Triton kernel implementing fused Layer nOrm, Linear and Activation.

    Kernel cell (i, j) computes 
    out[i * N_BROWS: (i+1) * N_BROWS, j * N_OCOLS: (j+1) * N_OCOLS], by iterating over
    `BLOCK_ROWS`-sized blocks of the inputs.

    Inputs
    ------
    x_ptr: [BATCH_SIZE, N_FEAT_IN,] - current token embedding.
    Wt_ptr: [N_FEAT_OUT, N_FEAT_IN] - linear layer weights (transposed).
    b_ptr: [N_FEAT_OUT,] - linear layer bias.

    Outputs
    -------
    out_ptr: [N_FEAT_OUT,] - output of the fused layer.
    """
    x_brows_start = tl.program_id(0) * N_BROWS
    # This instance will process x[b_rows_start:b_rows_start + N_BROWS, :]
    x_brows_idxs = tl.arange(0, N_BROWS) + x_brows_start

    ocols_start = tl.program_id(1) * N_OCOLS
    col_idxs = tl.arange(0, N_OCOLS) + ocols_start

    w_dot_x_acc = tl.zeros((N_BROWS, N_OCOLS), dtype=acc_dtype)
    w_sum_acc = tl.zeros((N_OCOLS, ), dtype=acc_dtype)
    x_sum_acc = tl.zeros((N_BROWS, ), dtype=acc_dtype)
    x_sq_sum_acc = tl.zeros((N_BROWS, ), dtype=acc_dtype)

    n_blocks = tl.cdiv(N_FEAT_IN, BLOCK_ROWS)
    for block_i in range(0, n_blocks):

        block_row_idxs = tl.arange(0, BLOCK_ROWS) + block_i * BLOCK_ROWS

        # Load the current block of the input.
        x_block_idxs = x_brows_idxs[:,
                                    None] * N_FEAT_IN + block_row_idxs[None, :]
        x_block = tl.load(x_ptr + x_block_idxs).to(
            acc_dtype)  # [N_BROWS, BLOCK_ROWS]

        Wt_block_idxs = col_idxs[:, None] * N_FEAT_IN + block_row_idxs[None, :]
        Wt_block = tl.load(Wt_ptr + Wt_block_idxs).to(acc_dtype)
        # W_block = tl.trans(Wt_block)

        print(x_block, Wt_block)
        # Update the accumulators.
        # [N_BROWS, BLOCK_ROWS] @ [BLOCK_ROWS, N_OCOLS] -> [N_BROWS, N_OCOLS]
        w_dot_x_acc += tl.dot(x_block, tl.trans(Wt_block)).to(acc_dtype)
        # w_dot_x_acc += tl.dot(Wt_block, tl.trans(x_block)).to(acc_dtype)
        # w_dot_x_acc += tl.dot(x_block, W_block).to(acc_dtype)
        # w_dot_x_acc += tl.dot(x_block, tl.trans(Wt_block)).to(acc_dtype)
        # w_dot_x_acc += tl.sum(W_block * x_block[:, None], axis=0)
        # w_sum_acc += tl.sum(W_block, axis=0)
        w_sum_acc += tl.sum(Wt_block, axis=1)
        x_sum_acc += tl.sum(x_block, axis=1)
        x_sq_sum_acc += tl.sum(x_block * x_block, axis=1)

    bias = tl.load(b_ptr + col_idxs)
    x_mean = x_sum_acc / N_FEAT_IN
    x_sq_mean = x_sq_sum_acc / N_FEAT_IN

    numer = w_dot_x_acc - x_mean[:, None] * w_sum_acc[None, :] + bias[None, :]
    denom = tl.sqrt(tl.abs(x_sq_mean - x_mean * x_mean + 1e-5))
    out = fast_gelu(numer / denom[:, None])

    out_idxs = x_brows_idxs[:, None] * N_FEAT_OUT + col_idxs[None, :]
    tl.store(out_ptr + out_idxs, out.to(tl.float16))

In [49]:
def triton_lola_trans(
    x,
    W_trans,
    b,
    N_OCOLS: int,
    N_BROWS: int,
    BLOCK_ROWS: int,
    num_warps=4,
    num_stages=1,
):
    assert N_BROWS >= 16 and BLOCK_ROWS >= 16 and N_OCOLS >= 16, "Triton matrix multiplication requires matrix dimensions to be at least 16."

    N_BATCH = x.shape[0]
    N_FEAT_OUT, N_FEAT_IN = W_trans.shape
    grid = (
        N_BATCH // N_BROWS,
        N_FEAT_OUT // N_OCOLS,
    )

    # Allocate output buffer.
    out = torch.zeros((N_BATCH, N_FEAT_OUT), dtype=x.dtype, device="cuda")

    # Launch the kernel.
    lola_trans_kernel[grid](x,
                      W_trans,
                      b,
                      out,
                      N_OCOLS=N_OCOLS,
                      N_BROWS=N_BROWS,
                      BLOCK_ROWS=BLOCK_ROWS,
                      N_FEAT_IN=N_FEAT_IN,
                      N_FEAT_OUT=N_FEAT_OUT,
                      num_warps=num_warps,
                      num_stages=num_stages)

    return out

In [26]:
def triton_lolal_trans(x, weight1s, bias1s, weight2s, bias2s, **kernel_kwargs):

    for w1, b1, w2, b2 in zip(weight1s, bias1s, weight2s, bias2s):
        x = triton_lola_trans(x, w1, b1, **kernel_kwargs)
        x = F.linear(x, w2, b2)
    return x

In [27]:
tky = triton_lolal_trans(tx, tweight1s, tbias1s, tweight2s, tbias2s, N_OCOLS=16, N_BROWS=16, BLOCK_ROWS=16, num_warps=4, num_stages=1)
tky.shape

fp16[constexpr[16],constexpr[16]] fp16[constexpr[16],constexpr[16]]
fp16[constexpr[16],constexpr[16]] fp16[constexpr[16],constexpr[16]]


torch.Size([1024, 1024])

In [28]:
(ty - tky).abs().max()

tensor(3.7734, device='cuda:0', dtype=torch.float16)

## Debugging

In [29]:
w1 = tweight1s[0]
w1.shape

torch.Size([4096, 1024])

In [52]:
w1 = torch.zeros_like(w1)
w1[:, 1] = 1
w1

tensor([[0., 1., 0.,  ..., 0., 0., 0.],
        [0., 1., 0.,  ..., 0., 0., 0.],
        [0., 1., 0.,  ..., 0., 0., 0.],
        ...,
        [0., 1., 0.,  ..., 0., 0., 0.],
        [0., 1., 0.,  ..., 0., 0., 0.],
        [0., 1., 0.,  ..., 0., 0., 0.]], device='cuda:0', dtype=torch.float16)

In [50]:
kout = triton_lola_trans(tx, tweight1s[0], tbias1s[0], N_OCOLS=16, N_BROWS=32, BLOCK_ROWS=64,num_warps=4, num_stages=1)
kout

tensor([[-0.0402,  0.2347, -0.0448,  ...,  0.0795,  0.4070, -0.1065],
        [ 0.1570, -0.1503,  1.0068,  ...,  0.2034, -0.0872, -0.1332],
        [ 0.5347, -0.0126,  0.0631,  ...,  0.7222,  0.5054,  1.6523],
        ...,
        [ 0.6846, -0.1473,  2.4023,  ...,  0.0865,  0.6001,  0.1892],
        [ 0.0549, -0.1543,  0.9678,  ...,  1.3838,  1.2451,  0.4380],
        [-0.1488,  1.0879,  0.0668,  ..., -0.1276,  0.5479, -0.1592]],
       device='cuda:0', dtype=torch.float16)

In [53]:
kout = triton_lola_trans(tx, w1, tbias1s[0], N_OCOLS=16, N_BROWS=32, BLOCK_ROWS=64,num_warps=4, num_stages=1)
kout

tensor([[0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        ...,
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.]], device='cuda:0', dtype=torch.float16)

In [61]:
kout[:30, 2:10]

tensor([[ 0.0000,  0.0000, -0.0153, -0.0153, -0.0153, -0.0153,  0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0165,  0.0165,  0.0165,  0.0165,  0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0331,  0.0331,  0.0331,  0.0331,  0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0412,  0.0412,  0.0412,  0.0412,  0.0000,  0.0000],
        [ 0.0000,  0.0000, -0.0158, -0.0158, -0.0158, -0.0158,  0.0000,  0.0000],
        [ 0.0000,  0.0000, -0.0078, -0.0078, -0.0078, -0.0078,  0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0077,  0.0077,  0.0077,  0.0077,  0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0079,  0.0079,  0.0079,  0.0079,  0.0000,  0.0000],
        [ 0.0000,  0.0000, -0.0079, -0.0079, -0.0079, -0.0079,  0.0000,  0.0000],
        [ 0.0000

In [61]:
(kout[:, 0].cpu() - np.array(out[:, 0])).abs().max()

tensor(0.0039, dtype=torch.float16)

In [None]:
kout.sum()

tensor(1205., device='cuda:0', dtype=torch.float16)

In [34]:
k1 = params["params"]["Dense_0"]["kernel"]
b1 = params["params"]["Dense_0"]["bias"]

In [35]:
w_dot_x = jnp.dot(x, k1)
w_sum = jnp.sum(k1, axis = 0)
x_mean = jnp.mean(x, axis=1)
x_sq_mean = jnp.mean(x * x, axis=1)

numer = w_dot_x - jnp.expand_dims(x_mean, axis=1) * w_sum + b1
denom = jnp.sqrt(jnp.abs(x_sq_mean - x_mean * x_mean + 1e-5))

out = GELU(numer / jnp.expand_dims(denom, axis=1)) # [BATCH_SIZE, N_FEAT_OUT]

In [36]:
out

Array([[ 1.194  , -0.1282 , -0.1147 , ..., -0.1663 , -0.05746, -0.1678 ],
       [-0.1307 ,  0.00922,  0.1714 , ..., -0.1691 , -0.1509 ,  0.1203 ],
       [ 0.5483 ,  0.06   ,  0.02933, ...,  0.11194, -0.08325, -0.1069 ],
       ...,
       [ 1.316  , -0.0973 ,  0.3074 , ..., -0.0041 ,  1.623  ,  0.4402 ],
       [ 1.132  ,  0.443  , -0.1521 , ..., -0.10846,  0.2368 , -0.0942 ],
       [-0.1692 ,  0.181  ,  0.531  , ..., -0.1691 ,  0.595  , -0.1252 ]],      dtype=float16)

In [None]:
w_dot_x

Array([[ 1.305  , -1.251  , -0.2788 , ..., -0.8896 , -1.838  , -0.904  ],
       [-0.35   , -0.0188 ,  0.2534 , ..., -0.819  , -1.114  ,  0.2705 ],
       [ 0.7075 ,  0.07965,  0.03482, ...,  0.1906 , -0.2156 , -1.359  ],
       ...,
       [ 1.442  , -0.2241 ,  0.477  , ..., -0.01009,  1.745  ,  0.58   ],
       [ 1.325  ,  0.6284 , -1.131  , ..., -1.504  ,  0.3782 , -1.606  ],
       [-0.815  ,  0.281  ,  0.6904 , ..., -0.6914 ,  0.7534 , -1.282  ]],      dtype=float16)

In [None]:
w_dot_x.sum()

Array(-1078., dtype=float16)

In [None]:
(np.array(out) - triton_lola(tx, tkernel1s[0], tbias1s[0], N_OCOLS=16, N_BROWS=16, BLOCK_ROWS=16, num_warps=4, num_stages=1).cpu().numpy()).max()

0.01758

In [None]:
(np.array(out) - kout.cpu().numpy()).max()

4.938

In [50]:
kout = triton_lola_trans(tx, tkernel1s[0], tbias1s[0], N_OCOLS=16, N_BROWS=32, BLOCK_ROWS=64,num_warps=4, num_stages=1)
kout.shape

torch.Size([1024, 1024])

In [52]:
kout

tensor([[-0.2438,  1.3438,  2.8867,  ..., -2.0781, -0.3979, -0.1449],
        [-0.4087,  2.7031,  1.5830,  ..., -0.8726,  0.0990,  1.9658],
        [ 0.2698,  1.9844, -1.7441,  ...,  1.7529,  2.0352, -1.2568],
        ...,
        [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000]],
       device='cuda:0', dtype=torch.float16)

In [51]:
w_dot_x

Array([[ 1.305  , -1.251  , -0.2788 , ..., -0.8896 , -1.838  , -0.904  ],
       [-0.35   , -0.0188 ,  0.2534 , ..., -0.819  , -1.114  ,  0.2705 ],
       [ 0.7075 ,  0.07965,  0.03482, ...,  0.1906 , -0.2156 , -1.359  ],
       ...,
       [ 1.442  , -0.2241 ,  0.477  , ..., -0.01009,  1.745  ,  0.58   ],
       [ 1.325  ,  0.6284 , -1.131  , ..., -1.504  ,  0.3782 , -1.606  ],
       [-0.815  ,  0.281  ,  0.6904 , ..., -0.6914 ,  0.7534 , -1.282  ]],      dtype=float16)

## Triton Benchmarking

In [33]:
from triton.testing import do_bench
from functools import partial

In [34]:
func = partial(triton_lolal, tx, tkernel1s, tbias1s, tweight2s, tbias2s)

for n_ocols in [16, 32, 64, 128, 256, 512]:
    for n_brows in [16, 32, 64, 128, 256, 512]:
        for block_rows in [16, 32, 64, 128, 256, 512]:
            if n_ocols * n_brows * block_rows < 200_000:
                continue
            if n_ocols * n_brows * block_rows / 4 > 166_000:
                continue
            print(f"{n_ocols=}, {n_brows=}, {block_rows=}", end=": ")
            print(f"{do_bench(partial(func, N_OCOLS=n_ocols, N_BROWS=n_brows, BLOCK_ROWS=block_rows), warmup = 100, rep = 100)[0]:.2f} ms")

n_ocols=16, n_brows=32, block_rows=512: 29713.41 us
n_ocols=16, n_brows=64, block_rows=256: 22185.47 us
n_ocols=16, n_brows=64, block_rows=512: 470851.59 us
n_ocols=16, n_brows=128, block_rows=128: 18547.71 us
n_ocols=16, n_brows=128, block_rows=256: 41585.66 us
n_ocols=16, n_brows=256, block_rows=64: 17542.14 us
n_ocols=16, n_brows=256, block_rows=128: 25800.70 us
n_ocols=16, n_brows=512, block_rows=32: 14923.26 us
n_ocols=16, n_brows=512, block_rows=64: 18748.93 us
n_ocols=32, n_brows=16, block_rows=512: 21115.90 us
n_ocols=32, n_brows=32, block_rows=256: 16519.68 us
n_ocols=32, n_brows=32, block_rows=512: 17985.54 us
n_ocols=32, n_brows=64, block_rows=128: 12158.46 us
n_ocols=32, n_brows=64, block_rows=256: 12904.45 us
n_ocols=32, n_brows=128, block_rows=64: 11728.38 us
n_ocols=32, n_brows=128, block_rows=128: 10908.67 us
n_ocols=32, n_brows=256, block_rows=32: 12298.75 us
n_ocols=32, n_brows=256, block_rows=64: 10683.39 us
n_ocols=32, n_brows=512, block_rows=16: 15394.82 us
n_ocols

In [35]:
func = partial(triton_lolal, tx, tkernel1s, tbias1s, tweight2s, tbias2s)

for n_ocols in [64, 128, 256, 512]:
    for n_brows in [16, 32, 64]:
        for block_rows in [16, 32, 64, 128, 256]:
            if n_ocols * n_brows * block_rows / 4 > 166_000:
                continue
            print(f"{n_ocols=}, {n_brows=}, {block_rows=}", end=": ")
            print(f"{do_bench(partial(func, N_OCOLS=n_ocols, N_BROWS=n_brows, BLOCK_ROWS=block_rows), warmup = 100, rep = 100)[0]:.2f} ms")

n_ocols=64, n_brows=16, block_rows=16: 71.27 ms
n_ocols=64, n_brows=16, block_rows=32: 38.81 ms
n_ocols=64, n_brows=16, block_rows=64: 21.98 ms
n_ocols=64, n_brows=16, block_rows=128: 14.59 ms
n_ocols=64, n_brows=16, block_rows=256: 11.91 ms
n_ocols=64, n_brows=32, block_rows=16: 38.95 ms
n_ocols=64, n_brows=32, block_rows=32: 20.96 ms
n_ocols=64, n_brows=32, block_rows=64: 12.78 ms
n_ocols=64, n_brows=32, block_rows=128: 9.32 ms
n_ocols=64, n_brows=32, block_rows=256: 9.79 ms
n_ocols=64, n_brows=64, block_rows=16: 21.61 ms
n_ocols=64, n_brows=64, block_rows=32: 14.17 ms
n_ocols=64, n_brows=64, block_rows=64: 8.78 ms
n_ocols=64, n_brows=64, block_rows=128: 8.10 ms
n_ocols=128, n_brows=16, block_rows=16: 31.89 ms
n_ocols=128, n_brows=16, block_rows=32: 19.09 ms
n_ocols=128, n_brows=16, block_rows=64: 11.65 ms
n_ocols=128, n_brows=16, block_rows=128: 9.48 ms
n_ocols=128, n_brows=16, block_rows=256: 9.35 ms
n_ocols=128, n_brows=32, block_rows=16: 18.41 ms
n_ocols=128, n_brows=32, block_ro