<a href="https://colab.research.google.com/github/simondanielsson/triton-kernels/blob/main/triton_testing.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install triton

Collecting triton
  Downloading triton-3.0.0-1-cp310-cp310-manylinux2014_x86_64.manylinux_2_17_x86_64.whl.metadata (1.3 kB)
Downloading triton-3.0.0-1-cp310-cp310-manylinux2014_x86_64.manylinux_2_17_x86_64.whl (209.4 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m209.4/209.4 MB[0m [31m5.0 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: triton
Successfully installed triton-3.0.0


In [None]:
import time
def timing(func):

  def wrapper(*args, **kwargs):
      t0 = time.perf_counter()
      result = func(*args, **kwargs)
      print(f"Ran in {time.perf_counter() - t0}s")
      return result

  return wrapper

In [None]:
torch.manual_seed(0)

<torch._C.Generator at 0x7d51b2f65e70>

In [None]:
# This is a GPU kernel in Triton.
# Different instances of this
# function may run in parallel.
@triton.jit
def add_kernel(x_ptr, y_ptr, output_ptr, n_elements: tl.constexpr, BLOCK_SIZE: tl.constexpr,):
  """
  x_ptr and y_ptr: input pointers
  output_ptr: output pointer
  n_elements: size of the inputs so we don't access prohibited memory
  BLOCK_SIZE: Number of elements each program should process. constexpr required to be used as shape value
  """
  # get grid index associated with this kernel instance, here from 0th dim
  # We use a 1D launch grid so axis is 0.
  pid = tl.program_id(0)

  # This program will process inputs that are offset from the initial data.
  # For instance, if you had a vector of length 256 and block_size of 64, the programs
  # would each access the elements [0:64, 64:128, 128:192, 192:256].
  # Note that offsets is a list of pointers:
  block_start = pid * BLOCK_SIZE
  offsets = tl.arange(0, BLOCK_SIZE) + block_start

  # Create a mask to guard memory operations against out-of-bounds accesses.
  mask = offsets < n_elements

  # fetch the slices of the inputs using pointer arithmetics rather than indexing operators
  x = tl.load(x_ptr + offsets, mask=mask)
  y = tl.load(y_ptr + offsets, mask=mask)

  # store the values for the relevant indices back to the output
  tl.store(output_ptr + offsets, x + y, mask=mask)

def add(x: torch.Tensor, y: torch.Tensor):
    # We need to preallocate the output.
    output = torch.empty_like(x)
    assert x.is_cuda and y.is_cuda and output.is_cuda

    n_elements = output.numel()

    # The SPMD launch grid denotes the number of kernel instances that run in parallel.
    # It is analogous to CUDA launch grids. It can be either Tuple[int], or Callable(metaparameters) -> Tuple[int].
    # In this case, we use a 1D grid where the size is the number of blocks:

    # Ceiling division ensures the grid always contains enough indices to stride through the full dimension of the input
    # E.g.: if n_elements = 1024 and BLOCK_SIZE=64, then grid is (8,). 8 program indexes are enough to cover the full 1024 length of the input using strides)
    # It also ensure that if BLOCK_SIZE > n_elements, e.g. block of 2048 with input of 1024 then grid is (1,) since it always rounds 0.00x to 1, i.e. we compute in only one block (one instance)
    # if input length an block size is not divisible (i.e. not returning an int), then we must make sure to mask the output to not access outside data in the last kernel instance (due to last block covering more than the input)
    grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']), )

    # NOTE:
    #  - Each torch.tensor object is implicitly converted into a pointer to its first element.
    #  - Don't forget to pass meta-parameters as keywords arguments.
    add_kernel[grid](x, y, output, n_elements, BLOCK_SIZE=16)

    # We return a handle to z but, since `torch.cuda.synchronize()` hasn't been called, the kernel is still
    # running asynchronously at this point.
    return output

def simple_add(x, y):
  return x + y

def test_add():
  vector_size = 512

  size = 98432
  x = torch.rand(vector_size, device='cuda')
  y = torch.rand(vector_size, device='cuda')
  output_torch = timing(simple_add)(x, y)
  output_triton = timing(add)(x, y)
  print(f'The maximum difference between torch and triton is '
      f'{torch.max(torch.abs(output_torch - output_triton))}')

test_add()

Ran in 4.1091000184678705e-05s
Ran in 0.08164957000008144s
The maximum difference between torch and triton is 0.0


In [None]:
import torch
import triton
import triton.language as tl

def naive_softmax(x):
  """x is of size MxN, M rows and N columns"""
  # find the max along each row
  # read  MN elements ; write M elements
  x_max = x.max(dim=1)[0] # get all row values by accessing 0 index, as a single row vector

  # transpose to get max for each row. No reading for transpose
  x_max_per_row = x_max[:, None]

  # read MN + M elements ; write MN elements
  x_scaled = x - x_max_per_row

  # read  MN elements ; write MN elements
  numerator = torch.exp(x_scaled)
  # sum along the rows (giving a single flat row vector), and transpose to get them as one value per row
  # read  MN elements ; write M  elements
  denominator = numerator.sum(dim=1)[:, None]

  # read MN elements (numerator) + M elements (denominator), write MN elements
  result = numerator / denominator

  # Total reads: 5MN + 2M
  # Total writes: 3MN + 2M
  return result

@triton.jit
def softmax_kernel_parallelize_rows(X, Y, stride_xm, stride_xn, stride_ym, stride_yn, n_elements: tl.constexpr):
    # stride_xm helps us get to the right row using strides (where is row 3? in pid * stride_xm)
    # stride_xn helps us to get the correct elements in the row once we have identified it
     # (where is element 1-4 of row 3? in [0, 1, 2, 3] if rows stored contigously; if columns stored contiguously then it could be [0, 2, 4, 6] if we have 2 columns)
     # Typically, stride_xn == 1, but not necessarily depending on how data is stored


    # get row index: grid is flat
    pid = tl.program_id(0)

    # find input indexes to use for this program
        # n_elements is the number of columns to sum over for this row, but since the input is flattened we need to stride appropriately
        # example: pid=1, n_elements = 6 => [6, 7, 8, 9, 10, 11]
        # typically, if the matrix is stored in row-major order (rows are stored in contiguous memory), then stride_xn is 1 (i.e. each row element is next to each other)
    offsets = tl.arange(0, n_elements) * stride_xn + pid * stride_xm

    # no need for masking since we are reading full row of data
    x = tl.load(X + offsets)

    # Compute softmax within the block
    row_max = tl.max(x, axis=0) # axis=0 because input is flat (a row)
    exp_row = tl.exp(x - row_max)
    sum_exp = tl.sum(exp_row, axis=0) # axis=0 because input is flat (a row)

    # Normalize the output
    result = exp_row / sum_exp

    # Store the result in global memory
    output_offsets = tl.arange(0, n_elements) * stride_yn + pid * stride_ym
    tl.store(Y + output_offsets, result)
1
def softmax_triton_parallelize_rows(X):
    # Define the input tensor and the output
    Y = torch.empty_like(X)

    # Kernel launch parameters, flattened matrix
    n_rows, n_cols = X.shape
    grid = (n_rows,)

    # Launch the Triton kernel
    # stride says how many indexes you have to jump to get to the next element along the dimension.
    # E.g. to get to the next element if flatting it as a row (dim 0), we need to jump the entire length of one row to get to the next row.
    #   stride(1) is then typically 1, since it requires one step in the columns to go to the next column.
    # Y pointer will be populated
    softmax_kernel_parallelize_rows[grid](
        X,
        Y,
        X.stride(0),
        X.stride(1),
        Y.stride(0),
        Y.stride(1),
        # length of each input for each kernel instance (program)
        n_cols,
    )
    return Y

# Usage example
size = 128
x = torch.rand((size, size), device='cuda')

print(f"diff naive vs torch: {torch.max(torch.abs(naive_softmax(x) - torch.softmax(x, 1)))}")
print(f"diff naive vs triton parallelize rows: {torch.max(torch.abs(naive_softmax(x) - softmax_triton_parallelize_rows(x)))}")
print(f"diff torch vs triton parallelize rows: {torch.max(torch.abs(torch.softmax(x, 1) - softmax_triton_parallelize_rows(x)))}")
#print(softmax_triton(x))

diff naive vs torch: 1.862645149230957e-09
diff naive vs triton parallelize rows: 2.7939677238464355e-09
diff torch vs triton parallelize rows: 2.7939677238464355e-09


In [None]:
# TODO: implement parallelized on columns as well, if possible. Or fused kernel

In [None]:
from triton.runtime import driver
device = torch.cuda.current_device()
properties = driver.active.utils.get_device_properties(device)
NUM_SM = properties["multiprocessor_count"]
NUM_REGS = properties["max_num_regs"]
SIZE_SMEM = properties["max_shared_mem"]
WARP_SIZE = properties["warpSize"]
target = triton.runtime.driver.active.get_current_target()