diff --git a/docs/triton.md b/docs/triton.md new file mode 100644 index 000000000000..d744abb6ce32 --- /dev/null +++ b/docs/triton.md @@ -0,0 +1,66 @@ +# Custom GPU Kernels via Triton + +PyTorch/XLA now supports [Triton](https://openai.com/research/triton) kernels, enabling high-performance deep learning model execution on GPUs. Triton, a specialized language and compiler for GPU programming, empowers developers to write custom kernels that leverage the full potential of GPUs for various operations in deep learning models. + +Given a Triton kernel defined as follows: +```python3 +@triton.jit +def add_kernel( + x_ptr, # *Pointer* to first input vector. + y_ptr, # *Pointer* to second input vector. + output_ptr, # *Pointer* to output vector. + n_elements, # Size of the vector. + BLOCK_SIZE: tl.constexpr, # Number of elements each program should process. + # NOTE: `constexpr` so it can be used as a shape value. +): + # Triton add kernel from https://github.com/openai/triton/blob/main/python/tutorials/01-vector-add.py#L28 + pid = tl.program_id(axis=0) + block_start = pid * BLOCK_SIZE + offsets = block_start + tl.arange(0, BLOCK_SIZE) + mask = offsets < n_elements + x = tl.load(x_ptr + offsets, mask=mask) + y = tl.load(y_ptr + offsets, mask=mask) + output = x + y + tl.store(output_ptr + offsets, output, mask=mask) + +``` + +We can run make this kernel a part of the PyTorch/XLA execution graph as follows: + +```python3 +import torch + +import torch_xla.experimental.triton as xla_triton +import torch_xla + +import triton +import triton.language as tl + +size = 16 +x = torch.arange(size, dtype=torch.int64).to("xla") +y = torch.arange(size, dtype=torch.int64).to("xla") +output = torch.empty_like(x) +block_size = 8 +grid = (triton.cdiv(size, block_size),) + +# triton_call takes the same arguments as the triton.jit function, in addition +to the kernel itself and the grid that is used to execute the kernel. +All the tl.constexpr terms are passed as kwargs at the end. +payload = xla_triton.triton_call( + x, y, output, size, kernel=add_kernel, grid=grid, BLOCK_SIZE=block_size) + +# To make the triton kernel, a part of the PyTorch/XLA graph, we create a +# custom call node with the expected inputs, payload from triton_call, +# the output shapes and output dtypes. The payload already contains information +# regarding how the GPU buffers will be loaded when this node is executed. +output = torch_xla._XLAC._xla_gpu_custom_call([x, y], payload, + [output.shape], [torch.int64]) + +``` + +For more complex kernels, you can also refer to the Triton Flash Attention kernel test in PyTorch/XLA. + +## Dependencies +The Triton integration depends on the `triton` package to function. This code is tested with `triton==2.3.0`. To install: +```bash +pip install --no-deps triton==2.3.0 \ No newline at end of file