In [1]:
import torch
from torch.utils._triton import has_triton

## Vector ADD Triton Kernel

In [6]:
if not has_triton():
  print(f'Skipping because triton is not present')
else:
  import triton
  import triton.language as tl

  @triton.jit
  def add_kernel(
      x_ptr,
      y_ptr,
      output_ptr,
      n_elements,
      BLOCK_SIZE: tl.constexpr
  ):
    row_start = tl.program_id(0)
    offsets = row_start * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)

    mask = offsets < n_elements

    x = tl.load(x_ptr + offsets, mask=mask)
    y = tl.load(y_ptr + offsets, mask=mask)

    output = x + y

    tl.store(output_ptr + offsets, output, mask=mask)


  def add_fn(x: torch.Tensor, y: torch.Tensor):
    z = torch.empty_like(x)

    n_elements = x.numel()

    grid = lambda meta: ((n_elements + 1 // meta['BLOCK_SIZE']), )

    add_kernel[grid](x, y, z, n_elements, BLOCK_SIZE = 4)

    return z

  x = torch.randn(4, device='cuda')
  y = torch.randn(4, device ='cuda')

  out = add_fn(x, y)
  print(f"Vector addition of\nX:\t{x}\nY:\t{y}\nis equal to\n{out}")

Vector addition of
X:	tensor([-0.1338, -0.9214,  0.2518,  0.0580], device='cuda:0')
Y:	tensor([-1.6395,  0.7362,  0.2845, -1.1175], device='cuda:0')
is equal to
tensor([-1.7733, -0.1852,  0.5362, -1.0595], device='cuda:0')


## Trtiton Autotune + Compile

**NOTE**

torch.compile only supports configs and key arguments to triton.autotune.

As of June 2024

LIMITATIONS:

- Tensor Subclasses: Currently, there is no support for tensor subclasses and other advanced features.

- Triton Features: While `triton.heuristics` can be used either standalone or before `triton.autotune`, it cannot be used after `triton.autotune`. This implies that if triton.heuristics and triton.autotune are to be used together, `triton.heuristics` must be used first.

In [14]:
if not has_triton():
  print(f'Skipping because triton is not present')
else:
  import triton
  import triton.language as tl

  @triton.autotune(
      configs=[
          triton.Config({"BLOCK_SIZE": 4}, num_stages=3, num_warps=8),
          triton.Config({"BLOCK_SIZE": 4}, num_stages=4, num_warps=4),
          triton.Config({"BLOCK_SIZE": 2}, num_stages=3, num_warps=8),
          triton.Config({"BLOCK_SIZE": 2}, num_stages=4, num_warps=4),
      ],
      key = [],
  )
  @triton.jit
  def add_kernel_autotuned(
      x_ptr,
      y_ptr,
      output_ptr,
      n_elements,
      BLOCK_SIZE: tl.constexpr
  ):
    row_start = tl.program_id(0)
    offsets = row_start * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)

    mask = offsets < n_elements

    x = tl.load(x_ptr + offsets, mask=mask)
    y = tl.load(y_ptr + offsets, mask=mask)

    output = x + y

    tl.store(output_ptr + offsets, output, mask=mask)

  @torch.compile(fullgraph=True)
  def add_fn_autotuned(x: torch.Tensor, y: torch.Tensor):
    z = torch.empty_like(x)

    n_elements = x.numel()

    grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']), )

    add_kernel_autotuned[grid](x, y, z, n_elements)

    return z

  x = torch.randn(4, device='cuda')
  y = torch.randn(4, device ='cuda')

  out = add_fn_autotuned(x, y)
  print(f"Vector addition of\nX:\t{x}\nY:\t{y}\nis equal to\n{out}")

Vector addition of
X:	tensor([-0.4056, -2.5999, -0.6385,  0.2939], device='cuda:0')
Y:	tensor([-0.0749, -0.0383,  1.1321, -0.0747], device='cuda:0')
is equal to
tensor([-0.4805, -2.6382,  0.4936,  0.2192], device='cuda:0')


In [21]:
# naive add
def add(x, y):
  return x + y

# compiled
@torch.compile(fullgraph=True, mode='reduce-overhead')
def add_compiled(x, y):
  return x + y

In [12]:
%timeit - 10
add_fn_autotuned(x, y)

11.3 ns ± 0.0772 ns per loop (mean ± std. dev. of 7 runs, 100000000 loops each)


tensor([-2.9640,  2.1937, -0.8272,  0.3639], device='cuda:0')

In [13]:
%timeit - 10
add_fn(x, y)

13.4 ns ± 3.43 ns per loop (mean ± std. dev. of 7 runs, 100000000 loops each)


tensor([-2.9640,  2.1937, -0.8272,  0.3639], device='cuda:0')

In [22]:
%timeit - 10
add(x, y)

13.7 ns ± 3.18 ns per loop (mean ± std. dev. of 7 runs, 100000000 loops each)


tensor([-0.4805, -2.6382,  0.4936,  0.2192], device='cuda:0')

In [25]:
%timeit - 10
add_compiled(x, y)

12.6 ns ± 1.13 ns per loop (mean ± std. dev. of 7 runs, 100000000 loops each)


tensor([-0.4805, -2.6382,  0.4936,  0.2192], device='cuda:0')

In [26]:
def timed(fn, x, y):
  start = torch.cuda.Event(enable_timing=True)
  end = torch.cuda.Event(enable_timing=True)
  start.record()
  _ = fn(x, y)
  end.record()

  return start.elapsed_time(end)

In [31]:
print(timed(add, x, y))
print(timed(add_compiled, x, y))
print(timed(add_fn, x, y))
print(timed(add_fn_autotuned, x, y))

0.9092479944229126
0.5697600245475769
0.34828799962997437
0.42905598878860474
