
# Optimizing Triton Kernels on AMD MI300X (ROCm)

This notebook is a hands‑on companion for optimizing Triton kernels on AMD Instinct™ MI300X GPUs under ROCm.
It follows the themes from AMD’s official guidance on **Triton kernel optimization** and adapts them into practical, reproducible experiments you can run on your MI300X box.

"*Broadly, Triton kernel optimization is similar to HIP and CUDA kernel optimization.*" - AMD ROCm Doc


## 0. Environment check

1. rocm version: 6.2.0
2. torch version: 2.4.1+rocm6.0
3. triton version: 3.4.0
4. HIP version: 6.2.41133-dd7f95766
5. Hardware: MI300X

## 1. Memory Access Efficiency

Modern accelerators/GPU cores expose a memory hierarchy:

- **Global memory** — large capacity, **high latency**.
- **Local Data Share (LDS) / Shared memory** — much **lower latency**, but **limited size**.
- **Registers** — **fastest**, **smallest**.

**Guidelines**
- Minimize round-trips to **global memory** (load/store as few times as possible).
- When multiple threads in a workgroup need the same data, **stage it in LDS** once, then let threads reuse it from there.

![GPU Memory Hierarchy](./imgs/mem_hierarchy.png)

*Image source: [Simon Oz — GPU Programming Ep. 6](https://www.youtube.com/watch?v=Zrbw0zajhJM).*


## 2. Know Your Hardware, Win Your Kernel

A solid grasp of the hardware leads to better heuristics for maximizing utilization.  
For **AMD MI300X**, here’s a practical rule-of-thumb:

![MI300X-CU](./imgs/mi300x_CU.png)

According to the [hardware spec](https://rocm.docs.amd.com/en/docs-6.1.1/reference/gpu-arch-specs.html), **MI300X** has **4 SIMD units per CU** and a **wavefront (warp) size of 64**.  
Implications:

- **Block size ≥ 256 threads (4 wavefronts)** → you can keep all 4 SIMDs in a CU busy at once.
- With **304 CUs**, aim for a grid with **≥ 1024 workgroups** (≈ **3–4 workgroups per CU**) to increase latency hiding and keep the device saturated.

To further raise parallelism and utilization, design algorithms that expose more independent work.  
For GEMM, a common tactic is **larger split-K**: partition the K dimension so more partial products can run across additional CUs, improving concurrency and throughput (with a final reduction to combine partial results).

## 3. IR analysis

"*In Triton, there are several layouts including blocked, shared, sliced, and MFMA.*"

From Triton GPU IR (TTGIR) you can infer **where** a value lives (global → registers → LDS/shared) and **how** it’s laid out.  
If you’re new to TTGIR, see my notebook that **demystifies Triton’s compilation stages**: [`triton_compilation_stages.ipynb`](./triton_compilation_stages.ipynb).

Below is a snippet from a FlashAttention **decode** kernel that **dequantizes int4 key/value** to `f16`.
```
    %190 = tt.load %189 {cache = 1 : i32, evict = 1 : i32, isVolatile =
    false} : tensor<1x64xi32, #blocked6> loc(#loc159)

    %266 = arith.andi %190, %cst_28 : tensor<1x64xi32, #blocked6>
    loc(#loc250)

    %267 = arith.trunci %266 : tensor<1x64xi32, #blocked6> to
    tensor<1x64xi16, #blocked6> loc(#loc251)

    %268 = tt.bitcast %267 : tensor<1x64xi16, #blocked6> -> tensor<1x64xf16,
    #blocked6> loc(#loc252)

    %269 = triton_gpu.convert_layout %268 : (tensor<1x64xf16, #blocked6>) ->
    tensor<1x64xf16, #shared1> loc(#loc252)

    %270 = tt.trans %269 : (tensor<1x64xf16, #shared1>) -> tensor<64x1xf16,
    #shared2> loc(#loc194)

    %276 = triton_gpu.convert_layout %270 : (tensor<64x1xf16, #shared2>) ->
    tensor<64x1xf16, #blocked5> loc(#loc254)

    %293 = arith.mulf %276, %cst_30 : tensor<64x1xf16, #blocked5>
    loc(#loc254)

    %295 = arith.mulf %292, %294 : tensor<64x32xf16, #blocked5> loc(#loc264)

    %297 = arith.addf %295, %296 : tensor<64x32xf16, #blocked5> loc(#loc255)

    %298 = triton_gpu.convert_layout %297 : (tensor<64x32xf16, #blocked5>)
    -> tensor<64x32xf16, #shared1> loc(#loc255)

    %299 = tt.trans %298 : (tensor<64x32xf16, #shared1>) ->
    tensor<32x64xf16, #shared2> loc(#loc196)

    %300 = triton_gpu.convert_layout %299 : (tensor<32x64xf16, #shared2>) ->
    tensor<32x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mfma, kWidth
    = 4}>> loc(#loc197)
```

### What this IR does

1) **Load + scalar ops in registers**  
   Load `i32` from **global** into registers; mask, truncate, and reinterpret to `f16`:
    ```
    %190, %266, %267, %268
    ```

2) **Stage to LDS for cross-thread transpose**  
   Move to **shared (LDS)** and **transpose** so subsequent accesses are coalesced and bank-friendly:
    ```
    %269, %270
    ```

3) **Compute in registers, restage to LDS**  
   Bring back to a blocked/register view, apply elementwise mul/add, then **store back to LDS**:
    ```
    %276, %293, %295, %297 -> %298
    ```

   This is the classic “**stage in shared, transpose for coalescing/bank access**” pattern.  
   More: [Memory Coalescing](https://modal.com/gpu-glossary/perf/memory-coalescing), [Bank Conflict](https://modal.com/gpu-glossary/perf/bank-conflict)

4) **Prepare dot-operand layout (MFMA)**  
   Transpose again in **shared**, then convert to the **dot-op (MFMA) layout** for matrix ops:
    ```
    %299 -> %300
    ```

**Takeaway:** LDS is used **twice**—first to enable a safe/efficient **transpose**, then to bridge from a **blocked** layout into an **MFMA-friendly dot operand** layout, minimizing global traffic while enabling coalesced loads and conflict-free shared accesses.

## 4. Assembly Analysis

### Global memory: vectorized loads
- Prefer **`global_load_dwordx4`** (per-lane 16 B = 4×32-bit) — especially inside loops.  
  Vectorizing reduces LD instruction count while remaining coalesced across lanes.

### LDS (shared) width: b128 vs b64
- **`ds_read_b128` / `ds_write_b128`**: 16 B per lane (preferred).
- **`ds_read_b64` / `ds_write_b64`**: 8 B per lane (fallback when alignment/aliasing/layout prevents b128).
- **Rule of thumb:** design shared layouts so **b128** is *legal and bank-friendly*; expect ~2× fewer LDS ops vs b64.

### `s_waitcnt`: fencing only when needed
- `s_waitcnt lgkmcnt(n)`: waits on **LDS/GDS/const/msg** ops.  
  `lgkmcnt(0)` ⇒ all such accesses complete; `1` ⇒ allow one still in flight, etc.
- `s_waitcnt vmcnt(n)`: waits on **vector (global) memory** ops.  
  Same semantics: fence only as much as required before first use.

**Latency hiding pattern (skeleton):** issue memory early → do unrelated compute while it’s in flight → `s_waitcnt` right before first use.  
    Skeleton in AMD-ish pseudocode:

```
    ; ---- Prefetch next tile from HBM (global) ----
    global_load_dwordx4 v[a0:a3], [rptrA]     ; A_next
    global_load_dwordx4 v[b0:b3], [rptrB]     ; B_next

    ; ---- Do math on the current tile while the loads are in flight ----
    v_fma_f32 acc0, a_cur0, b_cur0, acc0
    v_fma_f32 acc1, a_cur1, b_cur1, acc1
    ; ... more independent VALU / MFMA ...

    ; ---- Only now, fence the loads before consuming them ----
    s_waitcnt vmcnt(0)                         ; A_next/B_next ready

    ; Optionally stage into LDS using wide ops
    ds_write_b128 [ldsA + off], v[a0:a3]
    ds_write_b128 [ldsB + off], v[b0:b3]
    s_waitcnt lgkmcnt(0)                       ; if you must read back immediately

    ; Read back (conflict-free layout) and continue compute
    ds_read_b128  v[a_cur0:a_cur3], [ldsA + off2]
    ds_read_b128  v[b_cur0:b_cur3], [ldsB + off3]
    s_waitcnt lgkmcnt(0)                       ; before first use
```

---

### Practical checklist
- **Vectorize** global loads/stores (target **`dwordx4` / 16 B per lane**).
- **Prefer LDS b128**; if you see **b64**, investigate alignment, aliasing, or bank mapping.
- **Fence late**: place `s_waitcnt` immediately before data consumption, not earlier.
- **Trace inefficiencies**: if codegen looks suboptimal, walk back through **LLVM IR → TTGIR → TTIR**.  
  Enable MLIR/LLVM dumps to spot which pass/layout choice caused the issue.

## 5. Kernel Occupancy

Occupancy = how many **workgroups** (and thus **wavefronts**) can reside on a **Compute Unit (CU)** concurrently, given limits from **LDS**, **wavefront slots**, and **VGPRs**. Below is a practical, reproducible way to estimate it for MI300-class GPUs.

1. LDS-limited workgroups per CU
    * Get the allocated LDS used per workgroup following the steps (for example, L for the kernel).
        * export MLIR_ENABLE_DUMP=1
        * rm -rf ~/.triton/cache
        * python kernel.py | | grep "triton_gpu.shared = " | tail -n 1
        * You should see something like triton_gpu.shared = 65536, indicating 65536 bytes of LDS are allocated for the kernel.
    *  A CU has 64 KiB LDS total. If each workgroup needs L bytes: `occ_lds = floor(65536 / L)`

2. Warp-limites workgroups per CU
    * Get number of waves per workgroup using the following steps (for example, nW).
        * export MLIR_ENABLE_DUMP=1
        * rm -rf ~/.triton/cache
        * python kernel.py | | grep "triton_gpu.num-warps " | tail -n 1
        * You should see something like “triton_gpu.num-warps" = 8, indicating 8 waves per workgroup.
    * A CU has 32 warps total. If each workgroup need nW warps: `occ_warp = floor(32 / nW)`

3. VGPR-limited workgroups per CU
    * Get the VGPR count per thread, search for `.vgpr_count` in the ISA (for example, N).
    * Compute occupancy limited by VGPR based on N according to the following table. For example, waves per EU (SIMD) as `occ_vgpr_eu`

    ![VGPRs Occupancy Table](./imgs/vgpr_occ_table.png)
    
    * The `occ_vgpr_cu = occ_vgpr_eu * 4`, in CDNA2/3 we have 4 SIMD per CU. A CU can only contains up to `occ_vgpr_cu` due to VGPR pressure, a workgroup consume `nW` warps -> `occ_vgpr = floor(occ_vgpr_cu / nW)`

The true `occupancy = min(occ_lds, occ_vgpr)`. In practice, the `occ_warp` is always >= `occ_vgpr`.

Optional read: The VGPRs Occupancy Table is provided by AMD, but you can calculate it by hands. According to MI300X spec, we have 512 KiB VGPR file in total -> 128 KiB each SIMD. We have the `vgpr_count` per thread -> each warp consume `vgpr_used = vgpr_count * warp_size * 4` bytes (each register is 32 bit). -> `occ_vgpr_eu = floor (128K / vgpr_used)`

## 6. Auto-tunable Kernel Configurations and Environment Variables

### Triton reserved keywords

These are common Triton meta-parameters you’ll want to tune.


#### 1. `BLOCK_M`, `BLOCK_N`, `BLOCK_K`

Tile sizes for GEMM-like workloads. They control the **memory-to-compute ratio** and **grid-level parallelism**:

- Make tiles **large enough** to:
  - Improve arithmetic intensity.
  - Reuse data in registers/LDS efficiently.
- But also **small enough** to:
  - Launch many workgroups.
  - Keep all CUs busy and hide latency.

You typically explore a grid of `(BLOCK_M, BLOCK_N, BLOCK_K)` configs in the autotuner.


#### 2. `num_stages = n`

Controls the number of pipeline stages (software pipelining / prefetching behavior). AMD’s guidelines:

- Kernels with **a single GEMM** → `num_stages = 0`.
- Kernels with **two GEMMs fused** (e.g., Flash Attention, or any 2-GEMM fusion) → `num_stages = 1`.
- Kernels with **one GEMM fused with a non-GEMM op** (e.g., GEMM + ReLU) → `num_stages = 0`.
- Kernels with **no GEMM** → `num_stages = 1`.

These are good defaults; fine-tuning is still workload-dependent.


#### 3. `matrix_instr_nonkdim`

Experimental parameter for FlashAttention-like kernels that selects the MFMA instruction shape:

- `matrix_instr_nonkdim = 16` → use `mfma_16x16`.
- `matrix_instr_nonkdim = 32` → use `mfma_32x32`.

On AMD MI300X, **`mfma_16x16` usually outperforms `mfma_32x32`** for GEMM kernels, even with large tiles.


#### 4. `waves_per_eu = n`

Hints the backend to **cap VGPR usage** so that a desired number of waves per EU (SIMD) can be resident.

Useful when:

- **Occupancy is VGPR-limited**, and
- Your current VGPR usage is just **slightly above** a VGPR occupancy boundary (see the VGPRs Occupancy Table).

Example:

- Each EU (SIMD) has **512 VGPRs**, allocated in **chunks of 16**.
- Suppose `vgpr_count = 170` is reported:
  - Actual allocation is rounded to 176.
  - `176 × 3 > 512` → only **2 waves per EU** fit.
- If you set `waves_per_eu = 3`, the LLVM backend tries to reduce VGPR usage enough to allow **3 waves per EU**, potentially improving occupancy.


### Environment variable: `OPTIMIZE_EPILOGUE`

#### What is the “epilogue” in this context?

For an MFMA GEMM kernel, the high-level flow:

1. **Main loop**
   - Load A/B tiles.
   - Perform MFMA → accumulate into registers in **MFMA layout** (`dot_op`).
2. **Epilogue**
   - Take the accumulator tile in MFMA’s per-lane layout.
   - Optionally apply **bias / activation / etc.**
   - **Repack** into a normal (blocked) layout.
   - Store the final result to **global memory**.

The **repacking** usually happens via `convert_layout` (often using LDS as a staging buffer):  
`dot_op → [LDS] → blocked`.  
Padding is added to avoid LDS bank conflicts, but this **increases LDS usage** and can **reduce occupancy**.


#### Behavior of `OPTIMIZE_EPILOGUE`

- **`OPTIMIZE_EPILOGUE=0` (default)**  
  - Convert MFMA accumulators to a blocked layout.
  - Enables fully vectorized stores, e.g. `global_store_dwordx4`.
  - Cost: extra LDS usage + extra data movement/ops.

- **`OPTIMIZE_EPILOGUE=1`**  
  - Store MFMA results **directly in MFMA layout**.
  - Global stores may be less pretty / less vectorized.
  - Benefit: lower LDS usage, fewer instructions, often **better occupancy**.
  - In practice, the impact on runtime is usually **small or positive** despite less perfect stores.

**In short:**

- `OPTIMIZE_EPILOGUE=0`:  
  - Pay LDS + extra ops to get **ideal, vectorized global stores**.

- `OPTIMIZE_EPILOGUE=1`:  
  - Skip the expensive repack;  
  - **Store MFMA layout directly**, accept slightly worse store patterns,  
  - Gain **occupancy** and reduce instruction count.


## References
- AMD ROCm Docs — [Optimizing Triton kernels](https://rocm.docs.amd.com/en/docs-6.1.1/how-to/llm-fine-tuning-optimization/optimizing-triton-kernel.html)
- GPU Glossary - [Performance](https://modal.com/gpu-glossary/perf)