Small GPU kernel lab for LLM inference primitives in Python, Triton, PyTorch, and CUDA. The repo focuses on readable kernels and reproducible validation: row-wise softmax, FP16 GEMM, and a FlashAttention-style fused attention forward kernel using tiled online softmax.
This is a cleaned-up reconstruction of earlier kernel experiments. Benchmark numbers are intentionally generated by the local harness instead of being baked into the README, because GPU model, driver, Triton version, and tensor shapes change the results materially.
row_softmax: one Triton program per row, numerically stable max-subtract softmax, intended for bandwidth-bound reductions where a row fits in SRAM.fp16_matmul: tiled FP16 GEMM with FP32 accumulation, block-leveltl.dot, grouped program ordering for better L2 reuse, and configurableBLOCK_M/BLOCK_N/BLOCK_K.flash_attention_forward: fused attention forward for prefill and simple decode-style shapes, using tiled QK and PV blocks plus the online softmax recurrence to avoid materializing the full attention matrix.
src/triton_llm_kernel_lab/
bench.py # CLI benchmark harness
configs.py # kernel configs and LLM-like benchmark shapes
reference.py # PyTorch reference implementations
runtime.py # CUDA/Triton availability checks
kernels/
attention.py # FlashAttention-style fused attention forward
gemm.py # FP16 GEMM kernel
softmax.py # row-wise fused softmax kernel
tests/
test_references.py # CPU-safe reference and config checks
test_gpu_kernels.py # CUDA/Triton correctness checks, skipped otherwise
docs/
profiling.md # Nsight Compute workflow and metrics
tradeoffs.md # prefill vs decode kernel selection notes
Use a Linux environment with an NVIDIA GPU for the Triton kernels.
python -m venv .venv
source .venv/bin/activate
pip install -U pip
pip install -e ".[gpu,dev]"On a CPU-only machine, install only the test/dev path:
pip install -e ".[dev]"
pytest tests/test_references.pyThe GPU tests compare every custom kernel against a PyTorch reference and report the max absolute error. On a CUDA machine:
pytest tests/test_gpu_kernels.py -qThe test coverage includes:
- softmax rows with non-power-of-two column counts
- FP16 GEMM with masked edge tiles
- causal and non-causal fused attention forward
The harness uses 50 warmup iterations and 200 timed iterations by default. It prints latency, estimated TFLOPS, estimated memory bandwidth, and max error.
python -m triton_llm_kernel_lab.bench --kernel all
python -m triton_llm_kernel_lab.bench --kernel attention --warmup 50 --iters 200 --csv results/attention.csvRepresentative shape groups are defined in configs.py:
- prefill: longer query/key lengths where QK and PV dominate arithmetic
- decode: short query length with long KV cache where memory traffic dominates
- GEMM: common projection and MLP matrix sizes
- softmax: row lengths that stress SRAM fit and reduction behavior
Use Nsight Compute for detailed GPU metrics:
bash scripts/profile_ncu.sh attentionThe profiling notes in docs/profiling.md track the metrics that matter for
this lab: achieved occupancy, memory throughput, L2 hit rate, warp stalls,
tensor core utilization, and DRAM read/write transactions.
The implementation style was informed by the public Triton tutorials and FlashAttention papers, but the code in this repository is written as a compact teaching/lab version rather than a copy of those tutorials.
- Triton fused softmax tutorial: https://triton-lang.org/main/getting-started/tutorials/02-fused-softmax.html
- Triton matrix multiplication tutorial: https://github.com/triton-lang/triton/blob/main/python/tutorials/03-matrix-multiplication.py
- Triton fused attention tutorial: https://triton-lang.org/main/getting-started/tutorials/06-fused-attention.html
- FlashAttention: https://arxiv.org/abs/2205.14135
- FlashAttention-2: https://tridao.me/publications/flash2/flash2.pdf