Skip to content

zengxiao-he/triton-llm-kernel-lab

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

31 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Triton LLM Inference Kernel Lab

Small GPU kernel lab for LLM inference primitives in Python, Triton, PyTorch, and CUDA. The repo focuses on readable kernels and reproducible validation: row-wise softmax, FP16 GEMM, and a FlashAttention-style fused attention forward kernel using tiled online softmax.

This is a cleaned-up reconstruction of earlier kernel experiments. Benchmark numbers are intentionally generated by the local harness instead of being baked into the README, because GPU model, driver, Triton version, and tensor shapes change the results materially.

Kernels

  • row_softmax: one Triton program per row, numerically stable max-subtract softmax, intended for bandwidth-bound reductions where a row fits in SRAM.
  • fp16_matmul: tiled FP16 GEMM with FP32 accumulation, block-level tl.dot, grouped program ordering for better L2 reuse, and configurable BLOCK_M/BLOCK_N/BLOCK_K.
  • flash_attention_forward: fused attention forward for prefill and simple decode-style shapes, using tiled QK and PV blocks plus the online softmax recurrence to avoid materializing the full attention matrix.

Project Layout

src/triton_llm_kernel_lab/
  bench.py              # CLI benchmark harness
  configs.py            # kernel configs and LLM-like benchmark shapes
  reference.py          # PyTorch reference implementations
  runtime.py            # CUDA/Triton availability checks
  kernels/
    attention.py        # FlashAttention-style fused attention forward
    gemm.py             # FP16 GEMM kernel
    softmax.py          # row-wise fused softmax kernel
tests/
  test_references.py    # CPU-safe reference and config checks
  test_gpu_kernels.py   # CUDA/Triton correctness checks, skipped otherwise
docs/
  profiling.md          # Nsight Compute workflow and metrics
  tradeoffs.md          # prefill vs decode kernel selection notes

Install

Use a Linux environment with an NVIDIA GPU for the Triton kernels.

python -m venv .venv
source .venv/bin/activate
pip install -U pip
pip install -e ".[gpu,dev]"

On a CPU-only machine, install only the test/dev path:

pip install -e ".[dev]"
pytest tests/test_references.py

Correctness

The GPU tests compare every custom kernel against a PyTorch reference and report the max absolute error. On a CUDA machine:

pytest tests/test_gpu_kernels.py -q

The test coverage includes:

  • softmax rows with non-power-of-two column counts
  • FP16 GEMM with masked edge tiles
  • causal and non-causal fused attention forward

Benchmarking

The harness uses 50 warmup iterations and 200 timed iterations by default. It prints latency, estimated TFLOPS, estimated memory bandwidth, and max error.

python -m triton_llm_kernel_lab.bench --kernel all
python -m triton_llm_kernel_lab.bench --kernel attention --warmup 50 --iters 200 --csv results/attention.csv

Representative shape groups are defined in configs.py:

  • prefill: longer query/key lengths where QK and PV dominate arithmetic
  • decode: short query length with long KV cache where memory traffic dominates
  • GEMM: common projection and MLP matrix sizes
  • softmax: row lengths that stress SRAM fit and reduction behavior

Profiling

Use Nsight Compute for detailed GPU metrics:

bash scripts/profile_ncu.sh attention

The profiling notes in docs/profiling.md track the metrics that matter for this lab: achieved occupancy, memory throughput, L2 hit rate, warp stalls, tensor core utilization, and DRAM read/write transactions.

Open-Source References

The implementation style was informed by the public Triton tutorials and FlashAttention papers, but the code in this repository is written as a compact teaching/lab version rather than a copy of those tutorials.

About

No description, website, or topics provided.

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors