In [62]:
import math
from functools import partial

import torch
import torch.nn.functional as F

import triton
import triton.language as tl
from triton.testing import do_bench

In [2]:
%load_ext autoreload
%autoreload 2

In [33]:
import conch
from conch import float_from_hex, PTXAnalyser, grid_search
from conch import extra_ops as co

In [4]:
float_from_hex("0f3F19999A")

0.6000000238418579

In [85]:
n_embd = 1024
n_btch = 512

torch.manual_seed(0)
x = torch.randn(n_btch, n_embd, dtype=torch.float16, device="cuda")
weights = torch.randn(n_embd, 4 * n_embd, dtype=torch.float16, device="cuda")
weightsT = weights.T.contiguous()
bias = torch.randn(4 * n_embd, dtype=torch.float16, device="cuda")

In [6]:
def torch_lola(x, weight, bias):
    x = F.layer_norm(x, (n_embd,))
    x = F.linear(x, weight, bias)
    x = F.gelu(x, approximate="tanh")

    return x

In [89]:
%%timeit -n1000

torch_lola(x, weightsT, bias)
torch.cuda.synchronize()

119 µs ± 2.14 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


In [7]:
torch_y = torch_lola(x, weights.T, bias)
torch_y.shape

torch.Size([512, 4096])

In [8]:
acc_dtype = tl.float16

@triton.jit
def lola_kernel(x_ptr, W_ptr, Ws_ptr, b_ptr, out_ptr,
                             N_OCOLS: tl.constexpr, N_BROWS: tl.constexpr,
                             BLOCK_ROWS: tl.constexpr, N_FEAT_IN: tl.constexpr,
                             N_FEAT_OUT: tl.constexpr):
    """
    Triton kernel implementing fused Layer nOrm, Linear and Activation.

    Kernel cell (i, j) computes 
    out[i * N_BROWS: (i+1) * N_BROWS, j * N_OCOLS: (j+1) * N_OCOLS], by iterating over
    `BLOCK_ROWS`-sized blocks of the inputs.

    Inputs
    ------
    x_ptr: [BATCH_SIZE, N_FEAT_IN,] - current token embedding.
    W_ptr: [N_FEAT_IN, N_FEAT_OUT] - linear layer weights.
    Ws_ptr: [N_FEAT_OUT,] - sum of the weights (in axis 0).
    b_ptr: [N_FEAT_OUT,] - linear layer bias.

    Outputs
    -------
    out_ptr: [N_FEAT_OUT,] - output of the fused layer.
    """
    x_brows_start = tl.program_id(0) * N_BROWS
    # This instance will process x[b_rows_start:b_rows_start + N_BROWS, :]
    x_brows_idxs = tl.arange(0, N_BROWS) + x_brows_start

    ocols_start = tl.program_id(1) * N_OCOLS
    col_idxs = tl.arange(0, N_OCOLS) + ocols_start

    w_dot_x_acc = tl.zeros((N_BROWS, N_OCOLS), dtype=acc_dtype)
    # w_sum_acc = tl.zeros((N_OCOLS, ), dtype=acc_dtype)
    x_sum_acc = tl.zeros((N_BROWS, ), dtype=acc_dtype)
    x_sq_sum_acc = tl.zeros((N_BROWS, ), dtype=acc_dtype)

    n_blocks = tl.cdiv(N_FEAT_IN, BLOCK_ROWS)
    for block_i in range(0, n_blocks):

        block_row_idxs = tl.arange(0, BLOCK_ROWS) + block_i * BLOCK_ROWS

        # Load the current block of the input.
        x_block_idxs = x_brows_idxs[:, None] * N_FEAT_IN + block_row_idxs[None, :]
        x_block = tl.load(x_ptr + x_block_idxs).to(acc_dtype) # [N_BROWS, BLOCK_ROWS]

        W_block_idxs = block_row_idxs[:, None] * N_FEAT_OUT + col_idxs[None, :]
        W_block = tl.load(W_ptr + W_block_idxs).to(acc_dtype) # [BLOCK_ROWS, N_OCOLS]

        # Update the accumulators.
        # [N_BROWS, BLOCK_ROWS] @ [BLOCK_ROWS, N_OCOLS] -> [N_BROWS, N_OCOLS]
        w_dot_x_acc += tl.dot(x_block, W_block).to(acc_dtype)
        # w_dot_x_acc += tl.sum(W_block * x_block[:, None], axis=0)
        # w_sum_acc += tl.sum(W_block, axis=0)
        x_sum_acc += tl.sum(x_block, axis=1)
        x_sq_sum_acc += tl.sum(x_block * x_block, axis=1)

    bias = tl.load(b_ptr + col_idxs)
    Wsum = tl.load(Ws_ptr + col_idxs)
    x_mean = x_sum_acc / N_FEAT_IN
    x_sq_mean = x_sq_sum_acc / N_FEAT_IN

    numer = w_dot_x_acc - x_mean[:, None] * Wsum[None, :] + bias[None, :]
    denom = tl.sqrt(tl.abs(x_sq_mean - x_mean * x_mean + 1e-5))
    out = co.fast_gelu(numer / denom[:, None])

    out_idxs = x_brows_idxs[:, None] * N_FEAT_OUT + col_idxs[None, :]
    tl.store(out_ptr + out_idxs, out.to(tl.float16))

In [9]:
def triton_lola(
    x,
    W,
    Ws,
    b,
    N_OCOLS: int,
    N_BROWS: int,
    BLOCK_ROWS: int,
    num_warps=4,
    num_stages=1,
):
    assert N_BROWS >= 16 and BLOCK_ROWS >= 16 and N_OCOLS >= 16, "Triton matrix multiplication requires matrix dimensions to be at least 16."

    N_BATCH = x.shape[0]
    N_FEAT_IN, N_FEAT_OUT = W.shape
    grid = (
        N_BATCH // N_BROWS,
        N_FEAT_OUT // N_OCOLS,
    )

    # Allocate output buffer.
    out = torch.zeros((N_BATCH, N_FEAT_OUT), dtype=x.dtype, device="cuda")

    # Launch the kernel.
    lola_kernel[grid](x,
                      W,
                      Ws,
                      b,
                      out,
                      N_OCOLS=N_OCOLS,
                      N_BROWS=N_BROWS,
                      BLOCK_ROWS=BLOCK_ROWS,
                      N_FEAT_IN=N_FEAT_IN,
                      N_FEAT_OUT=N_FEAT_OUT,
                      num_warps=num_warps,
                      num_stages=num_stages)

    return out

In [10]:
triton_y = triton_lola(x, weights, weights.sum(axis=0), bias, 16, 16, 16)

In [11]:
(torch_y - triton_y).abs().max()

tensor(0.5000, device='cuda:0', dtype=torch.float16)

In [60]:
grid_search(partial(triton_lola, x, weights, weights.sum(axis=0), bias),
            # do_print = True,
            min_val_prod = 100_000,
            N_OCOLS=(16, 512),
            N_BROWS=(16, 512),
            BLOCK_ROWS=(16, 512))[:5]


[({'N_OCOLS': 128, 'N_BROWS': 64, 'BLOCK_ROWS': 64}, 78.84799689054489),
 ({'N_OCOLS': 256, 'N_BROWS': 32, 'BLOCK_ROWS': 64}, 84.99199897050858),
 ({'N_OCOLS': 64, 'N_BROWS': 128, 'BLOCK_ROWS': 64}, 89.08800035715103),
 ({'N_OCOLS': 128, 'N_BROWS': 32, 'BLOCK_ROWS': 32}, 90.11200070381165),
 ({'N_OCOLS': 128, 'N_BROWS': 64, 'BLOCK_ROWS': 32}, 92.16000139713287)]

In [65]:
do_bench(
    partial(triton_lola, x, weights, weights.sum(axis=0), bias,
    N_OCOLS=128,
    N_BROWS=64,
    BLOCK_ROWS=64), warmup=1000, rep=1000)[0] * 1000

78.84799689054489

In [73]:
meta_params = dict(N_OCOLS=128,
    N_BROWS=64,
    BLOCK_ROWS=64)

In [81]:
pa = PTXAnalyser.FromKernel(lola_kernel, **meta_params)

In [82]:
pa.op_counts.most_common()

[('mov.b32', 455),
 ('mul.f32', 450),
 ('fma.rn.ftz.f32', 390),
 ('add.f32', 197),
 ('mov.b16', 157),
 ('add.f16', 138),
 ('fma.rn.f32', 132),
 ('bra.uni', 130),
 ('cvt.rn.f16.f32', 128),
 ('cvt.f32.f16', 104),
 ('or.b32', 94),
 ('add.s32', 92),
 ('and.b32', 78),
 ('mov.f32', 78),
 ('ld.shared.b16', 78),
 ('div.full.f32', 72),
 ('setp.ge.f32', 69),
 ('selp.f32', 69),
 ('abs.ftz.f32', 65),
 ('setp.ltu.f32', 65),
 ('ex2.approx.ftz.f32', 65),
 ('rcp.approx.ftz.f32', 65),
 ('mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32', 64),
 ('st.shared.v2.f32', 64),
 ('shl.b32', 50),
 ('st.shared.u16', 34),
 ('ld.shared.v4.f32', 32),
 ('fma.rn.f16', 31),
 ('cvt.u16.u32', 29),
 ('bar.sync', 28),
 ('add.s64', 22),
 ('@%p154', 22),
 ('mul.wide.s32', 21),
 ('ldmatrix.sync.aligned.m8n8.x4.trans.shared.b16', 16),
 ('st.shared.v4.b32', 12),
 ('cvt.u32.u16', 12),
 ('shfl.sync.bfly.b32', 12),
 ('xor.b32', 10),
 ('ldmatrix.sync.aligned.m8n8.x4.shared.b16', 8),
 ('st.shared.b16', 7),
 ('mov.u32', 6),
 ('shr.u