# Triton Compilation Stages — Hands-On Notebook

[Triton](https://openai.com/index/triton/) offers a high-level, Python-based way to write efficient GPU code.  
In this notebook, we walk through how a Triton program is compiled, focusing on the intermediate representations (IR) you can inspect along the way.


## Triton Language

As a running example, we follow (with minor tweaks) the Triton vector-add tutorial.  
The kernel and a small helper are defined below.

 ```
import torch
import triton
import triton.language as tl

def is_cuda():
    return triton.runtime.driver.active.get_current_target().backend == "cuda"


@triton.jit
def add_kernel(
    x_ptr, y_ptr, z_ptr,
    N: tl.constexpr,
    BLOCK_SIZE: tl.constexpr
):
    blockidx = tl.program_id(axis=0)
    offsets = blockidx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
    mask = offsets < N
    x = tl.load(x_ptr + offsets, mask=mask, other=0.0)
    y = tl.load(y_ptr + offsets, mask=mask, other=0.0)

    z = x + y

    tl.store(z_ptr + offsets, z, mask=mask)


def add(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
    assert x.shape == y.shape, "Input tensors must have the same shape"
    assert x.is_cuda and y.is_cuda, "Input tensors must be on CUDA device"
    assert x.dtype == y.dtype, "Input tensors must have the same dtype"

    z = torch.empty_like(x)
    N = x.numel()
    BLOCK_SIZE = 1024
    grid = (N + BLOCK_SIZE - 1) // BLOCK_SIZE
    grid = (grid,)

    triton_kernel = add_kernel[grid](
        x_ptr=x,
        y_ptr=y,
        z_ptr=z,
        N=N,
        BLOCK_SIZE=BLOCK_SIZE,
    )
    
    # Save compilation stages - some of the stages identified here are specific to NVIDIA devices:
    with open('triton_IR.txt', 'w') as f:
        print(triton_kernel.asm['ttir'], file=f)
    with open('triton_TTGIR.txt', 'w') as f:
        print(triton_kernel.asm['ttgir'], file=f)
    with open('triton_LLVMIR.txt', 'w') as f:
        print(triton_kernel.asm['llir'], file=f)
    if is_cuda():
        with open('triton_PTX.ptx', 'w') as f:
            print(triton_kernel.asm['ptx'], file=f)
        with open('triton_cubin.txt', 'w') as f:
            print(triton_kernel.asm['cubin'], file=f)
    else:
        with open('triton_AMDGCN.ptx', 'w') as f:
            print(triton_kernel.asm['amdgcn'], file=f)
        with open('triton_hsaco.txt', 'w') as f:
            print(triton_kernel.asm['hsaco'], file=f)
    
    return z
```

Key points:

1. The vector-add kernel is decorated with `@triton.jit`. Functions marked with `@triton.jit` are compiled by Triton and lowered through multiple stages.
2. The helper function `add` allocates the output tensor, computes an appropriate GPU grid, and captures intermediate IR artifacts during compilation.

## Triton Compiler

We’ll focus on how a Triton kernel is lowered to device-specific assembly through a sequence of stages (see the figure below).

![Triton Compilation Stages](./imgs/triton_compilation_stages.png)

**High level flow**

1. A Python Triton kernel (your `@triton.jit` function) enters the compiler.

**Top box — “MLIR” (Triton front/middle end)**

2. **Triton-IR (TTIR)** — Produced by walking the kernel’s AST. It’s unoptimized, machine-independent, tile-oriented, and implemented as an MLIR dialect.  
3. **Triton-GPU-IR (TTGIR)** — Still MLIR, but GPU-aware (NVIDIA or AMD). Triton applies GPU-specific optimizations here.  
4. **LLVM IR** — TTGIR is lowered to standard LLVM IR.

**Bottom box — “LLVM” (back end)**

5. **Device Assembly** — LLVM’s targets emit ISA text: PTX (NVIDIA) or AMDGCN (AMD).  
6. **Loadable Code Objects** — PTX → assembled by `ptxas` into a **cubin**; AMDGCN → linked into an **hsaco** (HSA code object).


## Vector Addition - TTIR

Let's try to ingest Vector Addition's TTIR. Here's the Triton IR:
```
module {
  tt.func public @add_kernel(
    %arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32} ..., 
    %arg1: !tt.ptr<f32> {tt.divisibility = 16 : i32} ..., 
    %arg2: !tt.ptr<f32> {tt.divisibility = 16 : i32} ...
    ) 
  
  attributes {noinline = false} {
    %cst = arith.constant dense<0.000000e+00> : tensor<1024xf32> loc(#loc1)
    %cst_0 = arith.constant dense<1024> : tensor<1024xi32> loc(#loc1)
    %c1024_i32 = arith.constant 1024 : i32 loc(#loc1)
    %0 = tt.get_program_id x : i32 loc(#loc2)
    %1 = arith.muli %0, %c1024_i32 : i32 loc(#loc3)
    %2 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32> loc(#loc4)
    %3 = tt.splat %1 : i32 -> tensor<1024xi32> loc(#loc5)
    %4 = arith.addi %3, %2 : tensor<1024xi32> loc(#loc5)
    %5 = arith.cmpi slt, %4, %cst_0 : tensor<1024xi32> loc(#loc6)
    %6 = tt.splat %arg0 : !tt.ptr<f32> -> tensor<1024x!tt.ptr<f32>> loc(#loc7)
    %7 = tt.addptr %6, %4 : tensor<1024x!tt.ptr<f32>>, tensor<1024xi32> loc(#loc7)
    %8 = tt.load %7, %5, %cst : tensor<1024x!tt.ptr<f32>> loc(#loc8)
    %9 = tt.splat %arg1 : !tt.ptr<f32> -> tensor<1024x!tt.ptr<f32>> loc(#loc9)
    %10 = tt.addptr %9, %4 : tensor<1024x!tt.ptr<f32>>, tensor<1024xi32> loc(#loc9)
    %11 = tt.load %10, %5, %cst : tensor<1024x!tt.ptr<f32>> loc(#loc10)
    %12 = arith.addf %8, %11 : tensor<1024xf32> loc(#loc11)
    %13 = tt.splat %arg2 : !tt.ptr<f32> -> tensor<1024x!tt.ptr<f32>> loc(#loc12)
    %14 = tt.addptr %13, %4 : tensor<1024x!tt.ptr<f32>>, tensor<1024xi32> loc(#loc12)
    tt.store %14, %12, %5 : tensor<1024x!tt.ptr<f32>> loc(#loc13)
    tt.return loc(#loc14)
  } loc(#loc)
} loc(#loc)
```
Mapping from Triton Kernel to TTIR:
1. Constants for Block:
```
    %cst = arith.constant dense<0.000000e+00> : tensor<1024xf32> loc(#loc1)
    %cst_0 = arith.constant dense<1024> : tensor<1024xi32> loc(#loc1)
    %c1024_i32 = arith.constant 1024 : i32 loc(#loc1)
```
2. Which Block am I?
```
    %0 = tt.get_program_id x : i32 loc(#loc2)
    %1 = arith.muli %0, %c1024_i32 : i32 loc(#loc3) # base = pid_x * 1024
```
3. Per-lane indices inside the CTA:
```
    %2 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32> loc(#loc4)     # [0,1,...,1023] in blocked layout
    %3 = tt.splat %1 : i32 -> tensor<1024xi32> loc(#loc5)                                    # broadcast base to all lanes
    %4 = arith.addi %3, %2 : tensor<1024xi32> loc(#loc5)                                     # global indices for this CTA’s 1024 elems
    %5 = arith.cmpi slt, %4, %cst_0 : tensor<1024xi32> loc(#loc6)                            # mask = (idx < N)             
```
4. Compute x, y addresses then load with mask and elementwise add
```
    %6 = tt.splat %arg0 : !tt.ptr<f32> -> tensor<1024x!tt.ptr<f32>> loc(#loc7)              # broadcast x base ptr
    %7 = tt.addptr %6, %4 : tensor<1024x!tt.ptr<f32>>, tensor<1024xi32> loc(#loc7)          # x_ptr + idx
    %8 = tt.load %7, %5, %cst : tensor<1024x!tt.ptr<f32>> loc(#loc8)                        # masked load x (else 0.0)                    


    %9 = tt.splat %arg1 : !tt.ptr<f32> -> tensor<1024x!tt.ptr<f32>> loc(#loc9)
    %10 = tt.addptr %9, %4 : tensor<1024x!tt.ptr<f32>>, tensor<1024xi32> loc(#loc9)
    %11 = tt.load %10, %5, %cst : tensor<1024x!tt.ptr<f32>> loc(#loc10)                     # masked load y

    %12 = arith.addf %8, %11 : tensor<1024xf32> loc(#loc11)                # x + y
```
5.  Compute out addresses and store with mask
```
    %13 = tt.splat %arg2 : !tt.ptr<f32> -> tensor<1024x!tt.ptr<f32>> loc(#loc12)
    %14 = tt.addptr %13, %4 : tensor<1024x!tt.ptr<f32>>, tensor<1024xi32> loc(#loc12)
    tt.store %14, %12, %5 : tensor<1024x!tt.ptr<f32>> loc(#loc13)                          # masked store
```

## Vector Addition — TTGIR

**Environment**

1. triton 3.4.0  
2. ROCm 6.2.0  
3. AMD MI300X  
4. torch 2.4.1+rocm6.0

With the above setup, the simple vector addition has the following Triton GPU IR snippet with lines omitted for clarity:

```
#blocked = #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [64], warpsPerCTA = [4], order = [0]}>
module attributes {
  "triton_gpu.num-ctas" = 1 : i32, 
  "triton_gpu.num-warps" = 4 : i32, 
  triton_gpu.target = "hip:gfx942", 
  "triton_gpu.threads-per-warp" = 64 : i32
} {
  tt.func public @add_kernel(
    %arg0: !tt.ptr<f32> 
    {tt.divisibility = 16 : i32} 
    %arg1: !tt.ptr<f32> {tt.divisibility = 16 : i32} 
    %arg2: !tt.ptr<f32> {tt.divisibility = 16 : i32} 
  ) 
    attributes {noinline = false} {
    ...
    %8 = tt.load %7, %5, %cst : tensor<1024x!tt.ptr<f32>, #blocked> loc(#loc8)
    ...
    %11 = tt.load %10, %5, %cst : tensor<1024x!tt.ptr<f32>, #blocked> loc(#loc10)
    %12 = arith.addf %8, %11 : tensor<1024xf32, #blocked> loc(#loc11)
    ...
    tt.store %14, %12, %5 : tensor<1024x!tt.ptr<f32>, #blocked> loc(#loc13)
    ...
  } loc(#loc)
} loc(#loc)
```

**Header & layout breakdown**


With `warpsPerCTA = [4]` -> `numwarps=4` -> Each block has 4 * 64 (warpSize/threadsPerWarp) = 256 threads. And each thread handle 4 elements (sizePerThread) -> BlockSize = 256 * 4 = 1024 (`blocked` is a type of triton layout, we have blocked, sliced, shared, mfma, but we will talk about it later in Optimizing Triton Kernel notebook). 


At this stage, some of the hardware specific information is included. For example, the compute capability is included along with details on how the tensors are distributed to cores and warps. 
In this example, the tensors are represented as a `blocked` layout. In this encoding, each warp owns a contiguous portion of the tensor. Currently, other possible memory optimizations include layouts such as:
1. slice - restructures and distributes a tensor along a dimension
2. dot_op - optimized layout for block matrix product
3. shared - indicates GPU shared memory
4. nvidia_mma - produced by NVIDIA tensor cores
5. amd_mfma - produced by AMD MFMA matrix core
6. amd_wmma - produced by AMD WMMA matrix core

Note: As announced at the recent Triton conference, this layout representation will transition to a new linear layout to unify layouts within and across backends. 

## Vector Addition — LLVM IR

TTGIR is lowered to LLVM IR, LLVM’s standard representation.  
Today, Triton supports NVIDIA and AMD through third-party backends; support for other devices is being actively developed by the OSS community.

Below is an LLVM IR snippet for the vector-add kernel (abridged):
```
define amdgpu_kernel void @add_kernel(
    ptr addrspace(1) nocapture readonly %0, 
    ptr addrspace(1) nocapture readonly %1, 
    ptr addrspace(1) nocapture writeonly %2
  )
 
local_unnamed_addr #0 !dbg !4 {
  %4 = tail call i32 @llvm.amdgcn.workgroup.id.x(), !dbg !7
  %5 = shl i32 %4, 10, !dbg !8
  %6 = tail call i32 @llvm.amdgcn.workitem.id.x(), !dbg !9
  %7 = shl i32 %6, 2, !dbg !9
  %8 = and i32 %7, 1020, !dbg !9
  %9 = or disjoint i32 %8, %5, !dbg !10
  %10 = icmp slt i32 %9, 1024, !dbg !11
  br i1 %10, label %.critedge, label %.critedge2, !dbg !12

.critedge:                                        ; preds = %3
  %11 = or disjoint i32 %9, 3, !dbg !10
  %12 = or disjoint i32 %9, 2, !dbg !10
  %13 = or disjoint i32 %9, 1, !dbg !10
  %14 = sext i32 %9 to i64, !dbg !13
  %15 = getelementptr float, ptr addrspace(1) %0, i64 %14, !dbg !13
  %16 = addrspacecast ptr addrspace(1) %15 to ptr, !dbg !12
  %17 = load float, ptr %16, align 16, !dbg !12
  %18 = getelementptr inbounds i8, ptr %16, i64 4, !dbg !12
  %19 = load float, ptr %18, align 4, !dbg !12
  %20 = getelementptr inbounds i8, ptr %16, i64 8, !dbg !12
  %21 = load float, ptr %20, align 8, !dbg !12
  %22 = getelementptr inbounds i8, ptr %16, i64 12, !dbg !12
  %23 = load float, ptr %22, align 4, !dbg !12
  %24 = getelementptr float, ptr addrspace(1) %1, i64 %14, !dbg !14
  %25 = addrspacecast ptr addrspace(1) %24 to ptr, !dbg !15
  %26 = sext i32 %11 to i64, !dbg !16
  %27 = getelementptr float, ptr addrspace(1) %2, i64 %26, !dbg !16
  %28 = sext i32 %12 to i64, !dbg !16
  %29 = getelementptr float, ptr addrspace(1) %2, i64 %28, !dbg !16
  %30 = sext i32 %13 to i64, !dbg !16
  %31 = getelementptr float, ptr addrspace(1) %2, i64 %30, !dbg !16
  %32 = getelementptr inbounds i8, ptr %25, i64 12, !dbg !15
  %33 = load float, ptr %32, align 4, !dbg !15
  %34 = fadd float %23, %33, !dbg !17
  %35 = getelementptr inbounds i8, ptr %25, i64 8, !dbg !15
  %36 = load float, ptr %35, align 8, !dbg !15
  %37 = fadd float %21, %36, !dbg !17
  %38 = getelementptr inbounds i8, ptr %25, i64 4, !dbg !15
  %39 = load float, ptr %38, align 4, !dbg !15
  %40 = fadd float %19, %39, !dbg !17
  %41 = load float, ptr %25, align 16, !dbg !15
  %42 = fadd float %17, %41, !dbg !17
  %43 = sext i32 %9 to i64, !dbg !16
  %44 = getelementptr float, ptr addrspace(1) %2, i64 %43, !dbg !16
  store float %42, ptr addrspace(1) %44, align 4, !dbg !18
  store float %40, ptr addrspace(1) %31, align 4, !dbg !18
  store float %37, ptr addrspace(1) %29, align 4, !dbg !18
  store float %34, ptr addrspace(1) %27, align 4, !dbg !18
  br label %.critedge2, !dbg !18

.critedge2:                                       ; preds = %3, %.critedge
  ret void, !dbg !19
}

attributes #0 = { mustprogress nofree norecurse nosync nounwind willreturn 
                  memory(argmem: readwrite) 
                  "amdgpu-flat-work-group-size"="1,256" 
                  "amdgpu-waves-per-eu"="1" 
                  "denormal-fp-math-f32"="ieee"
                }
```

### How to read this IR

#### Header
* Kernel entry:
```
define amdgpu_kernel void @add_kernel(
    ptr addrspace(1) nocapture readonly %0,   ; x
    ptr addrspace(1) nocapture readonly %1,   ; y
    ptr addrspace(1) nocapture writeonly %2   ; out
) 
```
`addrspace(1)` is global Mem
* **Attributes include**
  * `"amdgpu-flat-work-group-size"="1,256"` -> 256 threads per group (numwarps=4)
  * `"amdgpu-waves-per-eu"="1"` -> min 1 warp per Execution Unit ~ SIMD (occupancy hint) - not sure, I will verify this information later.

#### Thread/Block indexing
```
  %4 = tail call i32 @llvm.amdgcn.workgroup.id.x(), !dbg !7   ; blockIdx.x
  %5 = shl i32 %4, 10, !dbg !8                                ; base = blockIdx.x * 1024 
  %6 = tail call i32 @llvm.amdgcn.workitem.id.x(), !dbg !9    ; threadIdx.x in range (0,255)
  %7 = shl i32 %6, 2, !dbg !9                                 ; threadIdx.x * 4 (4 elements/thread)
  %8 = and i32 %7, 1020, !dbg !9                              ; clamp to 1020
  %9 = or disjoint i32 %8, %5, !dbg !10                       ; global index start for this thread [1024K + 0, 1024K + 4, ..., 1024K + 255 * 4]
  %10 = icmp slt i32 %9, 1024, !dbg !11                       ; masking id < N (N=1024)
  br i1 %10, label %.critedge, label %.critedge2, !dbg !12    ; Execution only does loads/stores when the mask is true
```

#### Addressing & loads

```
%14 = sext i32 %9 to i64
%15 = getelementptr float, ptr addrspace(1) %0, i64 %14   ; &x[i]
%16 = addrspacecast ptr addrspace(1) %15 to ptr           ; cast to flat for byte GEPs

; Load 4 consecutive floats from x:
%17 = load float, ptr %16,        align 16    ; x[i]
%18 = getelementptr i8, ptr %16, i64 4
%19 = load float, ptr %18,        align 4     ; x[i+1]
%20 = getelementptr i8, ptr %16, i64 8
%21 = load float, ptr %20,        align 8     ; x[i+2]
%22 = getelementptr i8, ptr %16, i64 12
%23 = load float, ptr %22,        align 4     ; x[i+3]

; Load 4 consecutive floats from y:
%24 = getelementptr float, ptr addrspace(1) %1, i64 %14   ; &y[i] in AS1
%25 = addrspacecast ptr addrspace(1) %24 to ptr
%41 = load float, ptr %25,        align 16    ; y[i]
%39 = load float, ptr %38,        align 4     ; y[i+1]
%36 = load float, ptr %35,        align 8     ; y[i+2]
%33 = load float, ptr %32,        align 4     ; y[i+3]
```

#### Adds and stores

```
%42 = fadd float %17, %41   ; x[i]   + y[i]
%40 = fadd float %19, %39   ; x[i+1] + y[i+1]
%37 = fadd float %21, %36   ; x[i+2] + y[i+2]
%34 = fadd float %23, %33   ; x[i+3] + y[i+3]

; Compute output addresses (again using lane indices {i, i+1, i+2, i+3}) and store:
%43 = sext i32 %9  to i64
%44 = getelementptr float, ptr addrspace(1) %2, i64 %43   ; &out[i]
store float %42, ptr addrspace(1) %44, align 4

; %31 -> &out[i+1], %29 -> &out[i+2], %27 -> &out[i+3]
store float %40, ptr addrspace(1) %31, align 4
store float %37, ptr addrspace(1) %29, align 4
store float %34, ptr addrspace(1) %27, align 4
```

After LLVM IR, Triton (via the backend toolchain) lowers to device assembly and then to a loadable binary:

* **NVIDIA**: LLVM → PTX → `ptxas` → **cubin**  
* **AMD**: LLVM → AMDGCN → linker → **hsaco**

At this point the kernel is ready to run. For most kernel-level optimization, understanding LLVM IR is usually sufficient.

# Reference
1. [Triton kernel compilation stages by Pytorch](https://pytorch.org/blog/triton-kernel-compilation-stages/)