In [1]:
import torch
import torch.nn.functional as F

import triton
import triton.language as tl

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
n_embd = 1024
n_btch = 512

dtype = torch.float16
x = torch.empty(n_btch, n_embd, dtype=dtype, device='cuda').normal_()

weights = torch.empty(n_embd, 4 * n_embd, dtype=dtype, device='cuda').normal_()
bias = torch.empty(4 * n_embd, dtype=dtype, device='cuda').normal_()
weights_sum = torch.sum(weights, dim=0)
weights_sum.shape

torch.Size([4096])

In [3]:
def torch_lola(x, weights, bias):
    x = F.layer_norm(x, (n_embd, ))
    x = F.linear(x, weights.T, bias)
    x = F.gelu(x, approximate="tanh")
    return x

In [4]:
ref_out = torch_lola(x, weights, bias)
ref_out.shape

torch.Size([512, 4096])

In [5]:
import math

sqrt2pi = math.sqrt(2.0 / math.pi)


@triton.jit
def tanh(x):
    """Tanh activation function"""
    return tl.libdevice.tanh(x)


@triton.jit
def fast_gelu(x):
    """Fast approximation of the gelu function. May slightly decrease accuracy."""
    return 0.5 * x * (1 + tanh(sqrt2pi * (x + 0.044715 * x * x * x)))

In [6]:
acc_dtype = tl.float16


@triton.jit
def lola_kernel(x_ptr, W_ptr, b_ptr, out_ptr, N_OCOLS: tl.constexpr,
                N_BROWS: tl.constexpr, BLOCK_ROWS: tl.constexpr,
                N_FEAT_IN: tl.constexpr, N_FEAT_OUT: tl.constexpr):
    """
    Triton kernel implementing fused Layer nOrm, Linear and Activation.

    Kernel cell (i, j) computes 
    out[i * N_BROWS: (i+1) * N_BROWS, j * N_OCOLS: (j+1) * N_OCOLS], by iterating over
    `BLOCK_ROWS`-sized blocks of the inputs.

    Inputs
    ------
    x_ptr: [BATCH_SIZE, N_FEAT_IN,] - current token embedding.
    W_ptr: [N_FEAT_IN, N_FEAT_OUT] - linear layer weights.
    b_ptr: [N_FEAT_OUT,] - linear layer bias.

    Outputs
    -------
    out_ptr: [N_FEAT_OUT,] - output of the fused layer.
    """
    x_brows_start = tl.program_id(0) * N_BROWS
    # This instance will process x[b_rows_start:b_rows_start + N_BROWS, :]
    x_brows_idxs = tl.arange(0, N_BROWS) + x_brows_start

    ocols_start = tl.program_id(1) * N_OCOLS
    col_idxs = tl.arange(0, N_OCOLS) + ocols_start

    w_dot_x_acc = tl.zeros((N_BROWS, N_OCOLS), dtype=acc_dtype)
    w_sum_acc = tl.zeros((N_OCOLS, ), dtype=acc_dtype)
    x_sum_acc = tl.zeros((N_BROWS, ), dtype=acc_dtype)
    x_sq_sum_acc = tl.zeros((N_BROWS, ), dtype=acc_dtype)

    n_blocks = tl.cdiv(N_FEAT_IN, BLOCK_ROWS)
    for block_i in range(0, n_blocks):

        block_row_idxs = tl.arange(0, BLOCK_ROWS) + block_i * BLOCK_ROWS

        # Load the current block of the input.
        x_block_idxs = x_brows_idxs[:,
                                    None] * N_FEAT_IN + block_row_idxs[None, :]
        x_block = tl.load(x_ptr + x_block_idxs).to(
            acc_dtype)  # [N_BROWS, BLOCK_ROWS]

        W_block_idxs = block_row_idxs[:, None] * N_FEAT_OUT + col_idxs[None, :]
        W_block = tl.load(W_ptr + W_block_idxs).to(
            acc_dtype)  # [BLOCK_ROWS, N_OCOLS]

        # Update the accumulators.
        # [N_BROWS, BLOCK_ROWS] @ [BLOCK_ROWS, N_OCOLS] -> [N_BROWS, N_OCOLS]
        w_dot_x_acc += tl.dot(x_block, W_block).to(acc_dtype)
        # w_dot_x_acc += tl.sum(W_block * x_block[:, None], axis=0)
        # w_sum_acc += tl.sum(W_block, axis=0)
        x_sum_acc += tl.sum(x_block, axis=1)
        x_sq_sum_acc += tl.sum(x_block * x_block, axis=1)


    bias = tl.load(b_ptr + col_idxs)
    x_mean = x_sum_acc / N_FEAT_IN
    x_sq_mean = x_sq_sum_acc / N_FEAT_IN

    numer = w_dot_x_acc - x_mean[:, None] * w_sum_acc[None, :] + bias[None, :]
    denom = tl.sqrt(tl.abs(x_sq_mean - x_mean * x_mean + 1e-5))
    out = fast_gelu(numer / denom[:, None])

    out_idxs = x_brows_idxs[:, None] * N_FEAT_OUT + col_idxs[None, :]
    tl.store(out_ptr + out_idxs, out.to(tl.float16))

In [7]:
def triton_lola(
    x,
    W,
    b,
    N_OCOLS: int,
    N_BROWS: int,
    BLOCK_ROWS: int,
    num_warps=4,
    num_stages=1,
):
    assert N_BROWS >= 16 and BLOCK_ROWS >= 16 and N_OCOLS >= 16, "Triton matrix multiplication requires matrix dimensions to be at least 16."

    N_BATCH = x.shape[0]
    N_FEAT_IN, N_FEAT_OUT = W.shape
    grid = (
        N_BATCH // N_BROWS,
        N_FEAT_OUT // N_OCOLS,
    )

    # Allocate output buffer.
    out = torch.zeros((N_BATCH, N_FEAT_OUT), dtype=x.dtype, device="cuda")

    # Launch the kernel.
    lola_kernel[grid](x,
                      W,
                      b,
                      out,
                      N_OCOLS=N_OCOLS,
                      N_BROWS=N_BROWS,
                      BLOCK_ROWS=BLOCK_ROWS,
                      N_FEAT_IN=N_FEAT_IN,
                      N_FEAT_OUT=N_FEAT_OUT,
                      num_warps=num_warps,
                      num_stages=num_stages)

    return out

In [8]:
tl_out = triton_lola(x,
                     weights,
                     bias,
                     N_OCOLS=32,
                     N_BROWS=32,
                     BLOCK_ROWS=32,
                     num_warps=4,
                     num_stages=1)
tl_out.shape

torch.Size([512, 4096])

In [9]:
(ref_out - tl_out).abs().max()

tensor(10.1875, device='cuda:0', dtype=torch.float16)

## Static w_sum


In [10]:
acc_dtype = tl.float16


@triton.jit
def lola_ws_kernel(x_ptr, W_ptr, Ws_ptr, b_ptr, out_ptr, N_OCOLS: tl.constexpr,
                N_BROWS: tl.constexpr, BLOCK_ROWS: tl.constexpr,
                N_FEAT_IN: tl.constexpr, N_FEAT_OUT: tl.constexpr):
    """
    Triton kernel implementing fused Layer nOrm, Linear and Activation.

    Kernel cell (i, j) computes 
    out[i * N_BROWS: (i+1) * N_BROWS, j * N_OCOLS: (j+1) * N_OCOLS], by iterating over
    `BLOCK_ROWS`-sized blocks of the inputs.

    Inputs
    ------
    x_ptr: [BATCH_SIZE, N_FEAT_IN,] - current token embedding.
    W_ptr: [N_FEAT_IN, N_FEAT_OUT] - linear layer weights.
    Ws_ptr: [N_FEAT_OUT,] - sum of the weights (in axis 0).
    b_ptr: [N_FEAT_OUT,] - linear layer bias.

    Outputs
    -------
    out_ptr: [N_FEAT_OUT,] - output of the fused layer.
    """
    x_brows_start = tl.program_id(0) * N_BROWS
    # This instance will process x[b_rows_start:b_rows_start + N_BROWS, :]
    x_brows_idxs = tl.arange(0, N_BROWS) + x_brows_start

    ocols_start = tl.program_id(1) * N_OCOLS
    col_idxs = tl.arange(0, N_OCOLS) + ocols_start

    w_dot_x_acc = tl.zeros((N_BROWS, N_OCOLS), dtype=acc_dtype)
    # w_sum_acc = tl.zeros((N_OCOLS, ), dtype=acc_dtype)
    x_sum_acc = tl.zeros((N_BROWS, ), dtype=acc_dtype)
    x_sq_sum_acc = tl.zeros((N_BROWS, ), dtype=acc_dtype)

    n_blocks = tl.cdiv(N_FEAT_IN, BLOCK_ROWS)
    for block_i in range(0, n_blocks):

        block_row_idxs = tl.arange(0, BLOCK_ROWS) + block_i * BLOCK_ROWS

        # Load the current block of the input.
        x_block_idxs = x_brows_idxs[:,
                                    None] * N_FEAT_IN + block_row_idxs[None, :]
        x_block = tl.load(x_ptr + x_block_idxs).to(
            acc_dtype)  # [N_BROWS, BLOCK_ROWS]

        W_block_idxs = block_row_idxs[:, None] * N_FEAT_OUT + col_idxs[None, :]
        W_block = tl.load(W_ptr + W_block_idxs).to(
            acc_dtype)  # [BLOCK_ROWS, N_OCOLS]

        # Update the accumulators.
        # [N_BROWS, BLOCK_ROWS] @ [BLOCK_ROWS, N_OCOLS] -> [N_BROWS, N_OCOLS]
        w_dot_x_acc += tl.dot(x_block, W_block).to(acc_dtype)
        # w_dot_x_acc += tl.sum(W_block * x_block[:, None], axis=0)
        # w_sum_acc += tl.sum(W_block, axis=0)
        x_sum_acc += tl.sum(x_block, axis=1)
        x_sq_sum_acc += tl.sum(x_block * x_block, axis=1)

    bias = tl.load(b_ptr + col_idxs)
    Wsum = tl.load(Ws_ptr + col_idxs)
    x_mean = x_sum_acc / N_FEAT_IN
    x_sq_mean = x_sq_sum_acc / N_FEAT_IN

    numer = w_dot_x_acc - x_mean[:, None] * Wsum[None, :] + bias[None, :]
    denom = tl.sqrt(tl.abs(x_sq_mean - x_mean * x_mean + 1e-5))
    out = fast_gelu(numer / denom[:, None])

    out_idxs = x_brows_idxs[:, None] * N_FEAT_OUT + col_idxs[None, :]
    tl.store(out_ptr + out_idxs, out.to(tl.float16))

In [13]:
def triton_lola_ws(
    x,
    W,
    Ws,
    b,
    N_OCOLS: int,
    N_BROWS: int,
    BLOCK_ROWS: int,
    num_warps=4,
    num_stages=1,
):
    assert N_BROWS >= 16 and BLOCK_ROWS >= 16 and N_OCOLS >= 16, "Triton matrix multiplication requires matrix dimensions to be at least 16."

    N_BATCH = x.shape[0]
    N_FEAT_IN, N_FEAT_OUT = W.shape
    grid = (
        N_BATCH // N_BROWS,
        N_FEAT_OUT // N_OCOLS,
    )

    # Allocate output buffer.
    out = torch.zeros((N_BATCH, N_FEAT_OUT), dtype=x.dtype, device="cuda")

    # Launch the kernel.
    lola_ws_kernel[grid](x,
                      W,
                      Ws,
                      b,
                      out,
                      N_OCOLS=N_OCOLS,
                      N_BROWS=N_BROWS,
                      BLOCK_ROWS=BLOCK_ROWS,
                      N_FEAT_IN=N_FEAT_IN,
                      N_FEAT_OUT=N_FEAT_OUT,
                      num_warps=num_warps,
                      num_stages=num_stages)

    return out

In [14]:
tlws_out = triton_lola_ws(x,
                     weights,
                     weights_sum,
                     bias,
                     N_OCOLS=32,
                     N_BROWS=32,
                     BLOCK_ROWS=32,
                     num_warps=4,
                     num_stages=1)
tlws_out.shape

torch.Size([512, 4096])

In [15]:
(ref_out - tlws_out).abs().max()

tensor(0.3750, device='cuda:0', dtype=torch.float16)

# Benchmark

- Default kernel (same as lola.ipynb): 941 us
- Comment out all updates to accumulators except w_dot_x_acc: 108 us
- Comment out all updates to accumulators except 2 of them:  889 us
- w_dot_x_acc and x_sum_acc updated only: 118 us
- w_dot_x_acc, x_sum_acc, x_sq_sum_acc updated only:  190 us

In [17]:
from triton.testing import do_bench
from functools import partial

In [None]:
do_bench(partial(triton_lola,
                 x,
                 weights,
                 bias,
                 N_OCOLS=32,
                 N_BROWS=32,
                 BLOCK_ROWS=32,
                 num_warps=4,
                 num_stages=1),
         warmup=100,
         rep=1000)[0] * 1000 # us


189.43999707698822

In [18]:
do_bench(partial(triton_lola_ws,
                 x,
                 weights,
                 weights_sum,
                 bias,
                 N_OCOLS=32,
                 N_BROWS=32,
                 BLOCK_ROWS=32,
                 num_warps=4,
                 num_stages=1),
         warmup=100,
         rep=1000)[0] * 1000 # us


190.46400487422943

In [21]:
torch.arange(12)[:, None] * 2 + torch.arange(3)[None, :]

tensor([[ 0,  1,  2],
        [ 2,  3,  4],
        [ 4,  5,  6],
        [ 6,  7,  8],
        [ 8,  9, 10],
        [10, 11, 12],
        [12, 13, 14],
        [14, 15, 16],
        [16, 17, 18],
        [18, 19, 20],
        [20, 21, 22],
        [22, 23, 24]])

In [22]:
torch.arange(12)[None, :] * 2 + torch.arange(3)[:, None]

tensor([[ 0,  2,  4,  6,  8, 10, 12, 14, 16, 18, 20, 22],
        [ 1,  3,  5,  7,  9, 11, 13, 15, 17, 19, 21, 23],
        [ 2,  4,  6,  8, 10, 12, 14, 16, 18, 20, 22, 24]])