Tests the performance of the various Triton kernels, varying the configuration parameters.
Also examine the impact of padding masking, by speed testing with different amounts of padding.

In [None]:
import os

os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = "0.8"

In [1]:
import pickle

import jax
import jax.numpy as jnp
import flax.linen as nn

import jax_triton as jt
from functools import partial

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
from nimblegpt.triton_shcsa_kernels import (
    SHCSATritonSoftmax,
    SHCSATritonPaddedSoftmax,
    SHCSATritonPaddedSoftmaxV,
    SHCSATriton,
    padded_attention_kernel
)
from nimblegpt import get_config_for
from nimblegpt.model import SingleHeadCausalSelfAttention
from nimblegpt.jmodel import JSingleHeadCausalSelfAttention


In [3]:
config = get_config_for("gpt2")
n_feat = config.n_embd // config.n_head
rng = jax.random.PRNGKey(0)

In [4]:
x = jax.random.normal(rng, (config.block_size, config.n_embd))

In [5]:
params = SingleHeadCausalSelfAttention(n_feat).init(rng, x)

In [6]:
jit_jshcsa = jax.jit(JSingleHeadCausalSelfAttention(n_feat).apply)
jit_jshcsa(params, x)

%timeit -n100 jit_jshcsa(params, x).block_until_ready()

304 µs ± 18.6 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [7]:
jit_jshcsa = jax.jit(partial(JSingleHeadCausalSelfAttention(n_feat).apply, n_padd=1000))
jit_jshcsa(params, x)

%timeit -n100 jit_jshcsa(params, x).block_until_ready()

303 µs ± 19.8 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [8]:
jit_triton_softmax = jax.jit(SHCSATritonSoftmax(n_feat).apply)
jit_triton_softmax(params, x)

%timeit -n100 jit_triton_softmax(params, x).block_until_ready()

308 µs ± 10.1 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [9]:
jit_triton_padded_softmax = jax.jit(SHCSATritonPaddedSoftmax(n_feat).apply)
jit_triton_padded_softmax(params, x)

%timeit -n100 jit_triton_padded_softmax(params, x).block_until_ready()

279 µs ± 18.3 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [10]:
jit_triton_padded_softmax = jax.jit(partial(SHCSATritonPaddedSoftmax(n_feat).apply, n_padd=1000))
jit_triton_padded_softmax(params, x)

%timeit -n100 jit_triton_padded_softmax(params, x).block_until_ready()

276 µs ± 16.7 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


Performance improvment is presumably due to causal masking inside the kernel, rather than
in flax.

In [11]:
jit_triton_padded_softmax_v = jax.jit(SHCSATritonPaddedSoftmaxV(n_feat).apply)
jit_triton_padded_softmax_v(params, x)

%timeit -n100 jit_triton_padded_softmax_v(params, x).block_until_ready()

1.02 ms ± 53.1 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [12]:
jit_triton_padded_softmax_v = jax.jit(partial(SHCSATritonPaddedSoftmaxV(n_feat).apply, n_padd=1000))
jit_triton_padded_softmax_v(params, x)

%timeit -n100 jit_triton_padded_softmax_v(params, x).block_until_ready()

395 µs ± 21.7 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [13]:
jit_triton_attn = jax.jit(SHCSATriton(n_feat).apply)
jit_triton_attn(params, x)

%timeit -n10 jit_triton_attn(params, x).block_until_ready()

40 ms ± 1.54 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [14]:
def padded_attention2(q, k, v, n_padd, n_ocols: int = 4, **tc_kwargs):

    out_shape = jax.ShapeDtypeStruct(shape=v.shape, dtype=v.dtype)
    grid = (q.shape[0], q.shape[1] // n_ocols)
    assert grid[1] * n_ocols == q.shape[1]

    return jt.triton_call(
        q,
        k,
        v,
        jnp.array(n_padd),
        kernel=padded_attention_kernel,
        out_shape=out_shape,
        grid=grid,
        SM_SCALE=1.0 / k.shape[1] ** 0.5,
        SEQ_LEN=q.shape[0],
        N_FEAT=q.shape[1],
        N_OCOLS=n_ocols,
        **tc_kwargs
    )


class SHCSATriton2(nn.Module):

    n_feat: int

    @nn.compact
    def __call__(self, x, n_padd: int = 0, **kernel_kwargs):
        T, C = x.shape  # sequence length, embedding dimensionality (n_embd)

        # [T, C] @ [C, 3 * n_feat] -> [T, 3 * n_feat] -> 3 * [T, n_feat]
        q, k, v = jnp.split(nn.Dense(features=3 * self.n_feat)(x), 3, axis=1)

        y = padded_attention2(q, k, v, n_padd, **kernel_kwargs)

        return y

In [15]:
# n_ocols = 8 is optimal.
jit_triton_attn = jax.jit(partial(SHCSATriton2(n_feat).apply, n_ocols=8))
jit_triton_attn(params, x)

%timeit -n10 jit_triton_attn(params, x).block_until_ready()

21.1 ms ± 615 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)


`num_stages` has no effect

In [16]:
# num_warps = 8 is optimal.
jit_triton_attn = jax.jit(partial(SHCSATriton2(n_feat).apply, n_ocols=8, num_warps=8))
jit_triton_attn(params, x)

%timeit -n10 jit_triton_attn(params, x).block_until_ready()

4.51 ms ± 75.5 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [17]:
jit_triton_attn = jax.jit(partial(SHCSATriton2(n_feat).apply, n_ocols=8, num_warps=8, n_padd=1000))
jit_triton_attn(params, x)

%timeit -n10 jit_triton_attn(params, x).block_until_ready()

2.07 ms ± 56.4 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [18]:
jax.make_jaxpr(partial(SHCSATriton2(n_feat).apply, n_ocols=8))(params, x)

{ lambda ; a:f32[192] b:f32[768,192] c:f32[1024,768]. let
    d:f32[1024,192] = dot_general[
      dimension_numbers=(((1,), (0,)), ((), ()))
      precision=None
      preferred_element_type=None
    ] c b
    e:f32[1,192] = reshape[dimensions=None new_sizes=(1, 192)] a
    f:f32[1024,192] = add d e
    g:f32[1024,64] = slice[
      limit_indices=(1024, 64)
      start_indices=(0, 0)
      strides=None
    ] f
    h:f32[1024,64] = slice[
      limit_indices=(1024, 128)
      start_indices=(0, 64)
      strides=None
    ] f
    i:f32[1024,64] = slice[
      limit_indices=(1024, 192)
      start_indices=(0, 128)
      strides=None
    ] f
    j:f32[1024,64] = triton_kernel_call[
      N_FEAT=64
      N_OCOLS=8
      SEQ_LEN=1024
      SM_SCALE=0.125
      asm=<jax_triton.triton_call.Asm object at 0x7fbb241ca530>
      call_name=triton_kernel_call
      dump_binary_path=None
      grid=(1024, 8)
      input_output_aliases=()
      kernel_name=padded_attention_kernel_0d1d2d3d4d
      num_

In [19]:
triton_dump_binary_path = "./triton-binary.pickle"

In [20]:
jit_triton_attn = jax.jit(partial(SHCSATriton2(n_feat).apply, n_ocols=8, num_warps=8, n_padd=1000, dump_binary_path=triton_dump_binary_path))
jit_triton_attn(params, x)

Array([[ 2.3998618 ,  0.69639087,  1.5806489 , ..., -1.0705826 ,
         1.7936468 , -1.013537  ],
       [ 0.32660243,  0.77041066,  0.8374685 , ...,  0.5515491 ,
         0.48014343, -0.3229033 ],
       [ 0.8742785 , -0.75947136, -0.56516486, ..., -0.6283845 ,
        -0.08896749,  0.2526689 ],
       ...,
       [-0.26084435, -0.07213971, -0.10078567, ...,  0.48843464,
        -0.4253391 , -0.3293603 ],
       [-0.15446025, -0.09784772, -0.12534119, ...,  0.5775845 ,
        -0.28566545, -0.29566   ],
       [-0.09842795, -0.06713469,  0.38127655, ...,  0.00422825,
         0.14253983, -0.30762005]], dtype=float32)

In [21]:
triton_dump = pickle.load(open(triton_dump_binary_path, "rb"))
print(triton_dump["asm"]["ttir"])

module attributes {"triton_gpu.num-warps" = 8 : i32, triton_gpu.shared = 65792 : i32} {
  llvm.mlir.global external @global_smem() {addr_space = 3 : i32} : !llvm.array<0 x i8>
  llvm.func @padded_attention_kernel_0d1d2d3d4d(%arg0: !llvm.ptr<f32, 1> {tt.divisibility = 16 : i32}, %arg1: !llvm.ptr<f32, 1> {tt.divisibility = 16 : i32}, %arg2: !llvm.ptr<f32, 1> {tt.divisibility = 16 : i32}, %arg3: !llvm.ptr<i32, 1> {tt.divisibility = 16 : i32}, %arg4: !llvm.ptr<f32, 1> {tt.divisibility = 16 : i32}) attributes {nvvm.kernel = 1 : ui1, nvvm.maxntid = 256 : i32, sym_visibility = "public"} {
    %0 = nvvm.read.ptx.sreg.tid.x : i32
    %1 = llvm.mlir.constant(32 : i32) : i32
    %2 = llvm.urem %0, %1  : i32
    %3 = llvm.udiv %0, %1  : i32
    %4 = llvm.urem %3, %1  : i32
    %5 = llvm.mlir.constant(1024 : i32) : i32
    %6 = llvm.urem %2, %5  : i32
    %7 = llvm.mlir.constant(1 : i32) : i32
    %8 = llvm.mul %4, %1  : i32
    %9 = llvm.add %6, %8  : i32
    %10 = llvm.mul %9, %7  : i32
    %11 =

# Profiling

Profie `JSingleHeadSelfAttention` against `SHCSATriton` to see why Triton is slower.

In [22]:
%%script false --no-raise-error

# perfetto profiling is currently broken.

with jax.profiler.trace("/tmp/jax-trace", create_perfetto_link=True):
  # Run the operations to be profiled
  key = jax.random.PRNGKey(0)
  x = jax.random.normal(key, (5000, 5000))
  y = x @ x
  y.block_until_ready()

In [24]:
import jax

with jax.profiler.trace("/tmp/jax-trace"):
  key = jax.random.PRNGKey(0)
  x = jax.random.normal(key, (5000, 5000))
  y = x @ x
  y.block_until_ready()

2023-01-25 06:05:06.799142: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer.so.7'; dlerror: libnvinfer.so.7: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: /usr/local/nvidia/lib:/usr/local/nvidia/lib64
2023-01-25 06:05:06.799351: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer_plugin.so.7'; dlerror: libnvinfer_plugin.so.7: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: /usr/local/nvidia/lib:/usr/local/nvidia/lib64
2023-01-25 06:05:08.106274: E external/org_tensorflow/tensorflow/compiler/xla/pjrt/pjrt_stream_executor_client.cc:2163] Execution of replica 0 failed: INTERNAL: Failed to execute XLA Runtime executable: run time error: custom call 'xla.gpu.custom_call' failed: jaxlib/gpu/prng_kernels.cc:33: operation gpuGetLastError() failed: out of memory.


XlaRuntimeError: INTERNAL: Failed to execute XLA Runtime executable: run time error: custom call 'xla.gpu.custom_call' failed: jaxlib/gpu/prng_kernels.cc:33: operation gpuGetLastError() failed: out of memory.