In [3]:
%pip install triton

Collecting triton
  Downloading triton-3.6.0-cp312-cp312-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl.metadata (1.7 kB)
Downloading triton-3.6.0-cp312-cp312-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl (188.3 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m188.3/188.3 MB[0m [31m6.6 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: triton
Successfully installed triton-3.6.0


In [4]:
import torch

import triton
import triton.language as tl

In [5]:
!nvidia-smi

/bin/bash: line 1: nvidia-smi: command not found


In [6]:
DEVICE = triton.runtime.driver.active.get_active_torch_device()

RuntimeError: 0 active drivers ([]). There should only be one.

In [None]:
@triton.jit
def add_kernel(x_ptr,
               y_ptr,
               output_ptr,
               n_elements,
               BLOCK_SIZE: tl.constexpr):

  pid = tl.program_id(axis=0) # this is like x.blockIdx * x.blockDim + x.threadIdx ?
  block_start= pid * BLOCK_SIZE
  offsets = block_start + tl.arange(0,BLOCK_SIZE)
  # we create an entire guard array
  mask = offsets < n_elements
  x = tl.load(x_ptr + offsets,mask=mask)
  y = tl.load(y_ptr + offsets,mask=mask)
  output = x +y
  tl.store(output_ptr + offsets, output, mask=mask)


def add(x:torch.Tensor,y:torch.Tensor) -> torch.Tensor:

  output = torch.empty_like(x)

  assert x.device == DEVICE and y.device == DEVICE and output.device == DEVICE

  n_elements = output.numel()
  #cdiv gives ceiling division e.g int(n_elements / block_size) +1
  grid = lambda meta: (triton.cdiv(n_elements,meta['BLOCK_SIZE']), )
  add_kernel[grid](x,y,output,n_elements,BLOCK_SIZE=1024)
  return output


In [None]:
torch.manual_seed(0)

size = 76783

x = torch.rand(size,device=DEVICE)
y = torch.rand(size,device=DEVICE)

output_torch = x + y
output_triton = add(x,y)

max(torch.abs(output_torch - output_triton))



In [None]:
x.element_size()

In [None]:
@triton.testing.perf_report(
    triton.testing.Benchmark(
        x_names=['size'],
        x_vals=[2**i for i in range (12,28,1)],
        x_log = True,
        line_arg ='provider',
        line_vals = ['triton','torch'],
        line_names = ['Triton','Torch'],
        styles= [('blue','-'),('green','-')],
        ylabel='GB/s',
        plot_name='vector-add-performance',
        args={}
    )
)
def benchmark(size,provider):
  x = torch.randn(size,device=DEVICE,dtype=torch.float32)
  y = torch.randn(size,device=DEVICE,dtype=torch.float32)
  quantiles=[0.5,0.2,0.8]
  if provider=='torch':
    ms,min_ms,max_ms = triton.testing.do_bench(lambda: x+y, quantiles=quantiles)
  if provider=='triton':
    ms,min_ms,max_ms = triton.testing.do_bench(lambda: add(x,y), quantiles=quantiles)
  gbps = lambda ms: 3 * x.numel() * x.element_size() * 1e-9 / (ms* 1e-3)
  return gbps(ms),gbps(max_ms),gbps(min_ms)

benchmark.run(print_data=True,show_plots=True)