In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [2]:
from typing import Optional

In [3]:
import triton
import triton.language as tl

In [4]:
from triton.language.extra.cuda.libdevice import tanh, sqrt, erf

In [5]:
import math

In [6]:
torch.set_printoptions(precision=16)

In [7]:
@triton.jit
def gelu_approx(x):
    """
    GeLU_ activation - Gaussian error linear unit, with tanh approximation

    .. _GeLU: https://arxiv.org/pdf/1606.08415.pdf
    """
    return 0.5 * x * (1.0 + tanh(0.79788456760809 * x * (1.0 + 0.044715 * x * x)))

In [8]:
@triton.jit
def gelu_approx(x):
    """
    GeLU_ activation - Gaussian error linear unit, with tanh approximation

    .. _GeLU: https://arxiv.org/pdf/1606.08415.pdf
    """
    return 0.5 * x * (1.0 + tanh(0.7978845608028654 * x * (1.0 + 0.044715 * x * x)))


@triton.jit
def gelu_approx_grad(x):
    # CREDITS: Fast implementation proposed in
    # https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/fused_bias_gelu.py#L30
    tanh_out = tanh(0.79788456 * x * (1 + 0.044715 * x * x))
    return 0.5 * x * (
        (1 - tanh_out * tanh_out) * (0.79788456 + 0.1070322243 * x * x)
    ) + 0.5 * (1 + tanh_out)


@triton.autotune(
    configs=[
        triton.Config(
            {"BLOCK_M": 128, "BLOCK_N": 256, "BLOCK_K": 32, "SPLIT_K": 1},
            num_stages=3,
            num_warps=8,
        ),
        triton.Config(
            {"BLOCK_M": 256, "BLOCK_N": 128, "BLOCK_K": 32, "SPLIT_K": 1},
            num_stages=3,
            num_warps=8,
        ),
        triton.Config(
            {"BLOCK_M": 256, "BLOCK_N": 64, "BLOCK_K": 32, "SPLIT_K": 1},
            num_stages=4,
            num_warps=4,
        ),
        triton.Config(
            {"BLOCK_M": 64, "BLOCK_N": 256, "BLOCK_K": 32, "SPLIT_K": 1},
            num_stages=4,
            num_warps=4,
        ),
        triton.Config(
            {"BLOCK_M": 128, "BLOCK_N": 128, "BLOCK_K": 32, "SPLIT_K": 1},
            num_stages=4,
            num_warps=4,
        ),
        triton.Config(
            {"BLOCK_M": 128, "BLOCK_N": 64, "BLOCK_K": 32, "SPLIT_K": 1},
            num_stages=4,
            num_warps=4,
        ),
        triton.Config(
            {"BLOCK_M": 64, "BLOCK_N": 128, "BLOCK_K": 32, "SPLIT_K": 1},
            num_stages=4,
            num_warps=4,
        ),
        triton.Config(
            {"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 32, "SPLIT_K": 1},
            num_stages=4,
            num_warps=4,
        ),
        triton.Config(
            {"BLOCK_M": 64, "BLOCK_N": 32, "BLOCK_K": 32, "SPLIT_K": 1},
            num_stages=5,
            num_warps=2,
        ),
        # good for int8
        triton.Config(
            {"BLOCK_M": 128, "BLOCK_N": 256, "BLOCK_K": 128, "SPLIT_K": 1},
            num_stages=3,
            num_warps=8,
        ),
        triton.Config(
            {"BLOCK_M": 256, "BLOCK_N": 128, "BLOCK_K": 128, "SPLIT_K": 1},
            num_stages=3,
            num_warps=8,
        ),
        triton.Config(
            {"BLOCK_M": 256, "BLOCK_N": 64, "BLOCK_K": 128, "SPLIT_K": 1},
            num_stages=4,
            num_warps=4,
        ),
        triton.Config(
            {"BLOCK_M": 64, "BLOCK_N": 256, "BLOCK_K": 128, "SPLIT_K": 1},
            num_stages=4,
            num_warps=4,
        ),
        triton.Config(
            {"BLOCK_M": 128, "BLOCK_N": 128, "BLOCK_K": 128, "SPLIT_K": 1},
            num_stages=4,
            num_warps=4,
        ),
        triton.Config(
            {"BLOCK_M": 128, "BLOCK_N": 64, "BLOCK_K": 64, "SPLIT_K": 1},
            num_stages=4,
            num_warps=4,
        ),
        triton.Config(
            {"BLOCK_M": 64, "BLOCK_N": 128, "BLOCK_K": 64, "SPLIT_K": 1},
            num_stages=4,
            num_warps=4,
        ),
        triton.Config(
            {"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 64, "SPLIT_K": 1},
            num_stages=4,
            num_warps=4,
        ),
        triton.Config(
            {"BLOCK_M": 64, "BLOCK_N": 32, "BLOCK_K": 64, "SPLIT_K": 1},
            num_stages=5,
            num_warps=2,
        ),
    ],
    key=["CACHE_KEY_M", "CACHE_KEY_N", "CACHE_KEY_K"],
)
@triton.heuristics(
    {
        "EVEN_K": lambda args: args["K"] % (args["BLOCK_K"] * args["SPLIT_K"]) == 0,
    }
)
@triton.jit
def kernel_linear_gelu_fwd(
    C,  # Pointers to matrices
    ACT_INPUT,
    A,
    B,
    bias,
    # Matrix dimensions
    M,
    N,
    K,
    CACHE_KEY_M,
    CACHE_KEY_N,
    CACHE_KEY_K,
    # The stride variables represent how much to increase the ptr by when moving by 1
    # element in a particular dimension. E.g. stride_am is how much to increase a_ptr
    # by to get the element one row down (A has M rows)
    stride_cm,
    # stride_cn,  # Assume that stride_cn == 1
    stride_am,
    stride_ak,
    stride_bn,
    stride_bk,
    # Meta-parameters
    BLOCK_M: tl.constexpr,
    GROUP_M: tl.constexpr,
    BLOCK_N: tl.constexpr,
    BLOCK_K: tl.constexpr,
    # split k not used, not performant with activation, kept because early_config_prune is expecting it
    SPLIT_K: tl.constexpr,
    EVEN_K: tl.constexpr,
    A_ROWMAJOR: tl.constexpr,
    B_COLMAJOR: tl.constexpr,
    SAVE_ACT_INPUT: tl.constexpr,
):
    pid = tl.program_id(axis=0)

    grid_m = (M + BLOCK_M - 1) // BLOCK_M
    grid_n = (N + BLOCK_N - 1) // BLOCK_N
    # re-order program ID for better L2 performance
    width = GROUP_M * grid_n
    group_id = pid // width
    group_size = min(grid_m - group_id * GROUP_M, GROUP_M)
    pid_m = group_id * GROUP_M + (pid % group_size)
    pid_n = (pid % width) // (group_size)

    # now compute the block that each program will go through
    # rm (resp. rn) denotes a range of indices
    # for rows (resp. col) of C
    rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
    rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
    # trick to avoid masking on M and N axis
    ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M)
    rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N)
    rk = tl.arange(0, BLOCK_K)

    if A_ROWMAJOR:
        A = A + (ram[:, None] * stride_am + rk[None, :])
    else:
        A = A + (ram[:, None] * stride_am + rk[None, :] * stride_ak)
    if B_COLMAJOR:
        B = B + (rk[:, None] + rbn[None, :] * stride_bn)
    else:
        B = B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn)

    acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)

    for k in range(K, 0, -BLOCK_K):
        if EVEN_K:
            a = tl.load(A)
            b = tl.load(B)
        else:
            a = tl.load(A, mask=rk[None, :] < k, other=0.0)
            b = tl.load(B, mask=rk[:, None] < k, other=0.0)
        acc += tl.dot(a, b)

        if A_ROWMAJOR:
            A += BLOCK_K
        else:
            A += BLOCK_K * stride_ak
        if B_COLMAJOR:
            B += BLOCK_K
        else:
            B += BLOCK_K * stride_bk

    bias = tl.load(bias + rn, mask=rn < N, other=0.0).to(tl.float32)
    acc += bias[None, :]

    if SAVE_ACT_INPUT:
        act_in_ptrs = ACT_INPUT + ram[:, None] * stride_cm + rbn[None, :]
        tl.store(act_in_ptrs, acc)

    acc = gelu_approx(acc)
    rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
    rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)

    C = C + rm[:, None] * stride_cm + rn[None, :]
    mask = (rm < M)[:, None] & (rn < N)[None, :]
    tl.store(C, acc)


def triton_linear_gelu(
    x: torch.Tensor,
    weight: torch.Tensor,
    bias: torch.Tensor,
    save_act_input: bool = False,
) -> torch.Tensor:
    batch_shape, n = x.shape[:-1], x.shape[-1]
    batch_dim = batch_shape.numel()
    x_reshaped = x.reshape(batch_dim, n)

    if x_reshaped.stride(0) > 1 and x_reshaped.stride(1) > 1:
        x_reshaped = x_reshaped.contiguous()
    if weight.stride(0) > 1 and weight.stride(1) > 1:
        weight = weight.contiguous()
    bias = bias.contiguous() if bias is not None else None

    assert x.dtype == weight.dtype, (
        f"Input and weight must have the same dtype, got {x.dtype} and {weight.dtype}"
    )
    if bias is not None:
        assert x.dtype == bias.dtype, (
            f"Input and bias must have the same dtype, got {x.dtype} and {bias.dtype}"
        )
    assert x_reshaped.shape[1] == weight.shape[1], (
        f"Incompatible dimensions: {x_reshaped.shape} - {weight.shape}"
    )

    assert bias is None or bias.shape[0] == weight.shape[0], (
        "Incompatible dimensions in between weight and bias"
    )

    M, K = x_reshaped.shape
    N, K = weight.shape

    output = torch.empty((M, N), device=x.device, dtype=x.dtype)
    act_input = torch.empty_like(output) if save_act_input else None

    grid = lambda META: (
        triton.cdiv(M, META["BLOCK_M"]) * triton.cdiv(N, META["BLOCK_N"]),
    )  # noqa

    kernel_linear_gelu_fwd[grid](
        output,
        act_input,
        x_reshaped,
        weight,  # data ptrs
        bias if bias is not None else x,  # auto skip bias if not present
        M,  # shapes
        N,
        K,
        M // 32,  # key for triton cache (limit number of compilations)
        N // 32,
        K // 32,
        stride_cm=output.stride(0),  # strides
        # stride_cn=output.stride(1),
        stride_am=x_reshaped.stride(0),
        stride_ak=x_reshaped.stride(1),
        stride_bk=weight.stride(1),
        stride_bn=weight.stride(0),
        SAVE_ACT_INPUT=save_act_input,
        A_ROWMAJOR=x_reshaped.stride(1) == 1,
        B_COLMAJOR=weight.stride(1) == 1,
        GROUP_M=8,  # speed optimization: group the programs
    )

    if not save_act_input:
        return output.reshape(*batch_shape, output.shape[-1])
    else:
        return (
            output.reshape(*batch_shape, output.shape[-1]),
            act_input.reshape(*batch_shape, act_input.shape[-1]),
        )

In [9]:
def gelu_tanh(z: torch.Tensor) -> torch.Tensor:
    """
    z: Tensor of any shape
    returns: 0.5 * z * [1 + tanh(√(2/π)*(z + 0.044715*z^3))]
    """
    coeff = math.sqrt(2.0 / math.pi)
    return 0.5 * z * (1.0 + torch.tanh(coeff * (z + 0.044715 * z.pow(3))))

def gelu_linear(X: torch.Tensor, W: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
    """
    X: (batch, in_features)
    W: (out_features, in_features)
    b: (out_features,)
    All on CPU. By default does float32 → cast if you want float64.
    """
    Z = F.linear(X, W, b)        # equivalent to X @ W.T + b
    return gelu_tanh(Z)

In [10]:
device = 'cuda'
dtype = torch.bfloat16

In [11]:
x = torch.randn((1024, 1024), device=device, dtype=dtype)

In [12]:
W = torch.randn((1024, 1024), device=device, dtype=dtype)

In [13]:
b = torch.randn((1024,), device=device, dtype=dtype)

In [14]:
x_cpu = x.to(device='cpu', dtype=torch.float32)
W_cpu = W.to(device='cpu', dtype=torch.float32)
b_cpu = b.to(device='cpu', dtype=torch.float32)

In [15]:
out_normal = F.gelu(F.linear(x, W, b), approximate="tanh")

In [16]:
out_triton = triton_linear_gelu(x, W, b)

In [17]:
out_cpu = gelu_linear(x_cpu, W_cpu, b_cpu)

In [18]:
torch.abs(out_cpu - out_normal.to(device='cpu', dtype=torch.float32)).mean()

tensor(0.0180812384933233)

In [20]:
torch.abs(out_cpu - out_triton.to(device='cpu', dtype=torch.float32)).mean()

tensor(0.0180397387593985)

In [22]:
out_gpu = F.gelu(F.linear(x.to(dtype=torch.float32), W.to(dtype=torch.float32), b.to(dtype=torch.float32)), approximate="tanh")

In [23]:
torch.abs(out_gpu - out_normal.to(dtype=torch.float32)).mean()

tensor(0.0180812347680330, device='cuda:0')

In [25]:
torch.abs(out_gpu - out_triton.to(dtype=torch.float32)).mean()

tensor(0.0180397368967533, device='cuda:0')