Paper artifact and reference implementation for:
TiledAttention: a CUDA Tile SDPA Kernel for PyTorch
Author: Taimur Khan (taimur.khan@ufz.de)
TiledAttention is a scaled dot-product attention (SDPA) forward operator for SDPA research on NVIDIA GPUs. Implemented in cuTile Python (TileIR) and exposed as a PyTorch-callable function, it is easier to modify than low-level CUDA templates while retaining realistic behavior via online softmax and tiled
flowchart LR
A["Load Q tile"] --> B["Stream K,V tiles + score/mask"]
B --> C["Update online softmax state (m, l, o)"]
C --> D["Normalize and store output tile"]
- SDPA API:
sdpa(q, k, v, causal=False, scale=None) -> o - cuTile forward kernels:
- tiled streaming over
K/V - online softmax (
m,l,orunning state) - no
S x Sscore materialization
- tiled streaming over
- compile cache:
- in-memory and disk-backed helpers in
src/tiledattention/kernels/compile_cache.py
- in-memory and disk-backed helpers in
- reproducible study + profiling scripts:
benchmark-gb10/run_study.pybenchmark-gb10/run_ncu_profile.py
- NVIDIA GPU with Blackwell-class capability supported by this repo runtime checks (
10.xor12.x). - CUDA 13.1 + must for
cuda-tilelibrary.
- Linux
- Python
>=3.10 - NVIDIA driver compatible with CUDA 13.x
- CUDA Toolkit
13.1+with:nvcctileirasncu(for Nsight Compute runs)nsys(optional, for Nsight Systems traces)
torch==2.10.0+cu130torchvision==0.25.0+cu130cupy-cuda13x==13.6.0cuda-tile==1.1.0matplotlib(figure generation)
cd TiledAttention
python3 -m venv .venv
source .venv/bin/activate
python -m pip install --upgrade pip setuptools wheel
# Install from requirements.txt (single source of dependency truth)
pip install -r requirements.txt
# Install project in editable mode
pip install -e . --no-build-isolationRun this before benchmarking:
python - <<'PY'
import shutil
import torch, cupy, cuda.tile as ct
print("torch:", torch.__version__, "torch.cuda:", torch.version.cuda)
print("cupy:", cupy.__version__)
print("cuda.tile:", ct.__version__)
print("cuda available:", torch.cuda.is_available())
if torch.cuda.is_available():
print("device:", torch.cuda.get_device_name(0))
print("capability:", torch.cuda.get_device_capability(0))
for cmd in ("nvcc", "tileiras"):
print(cmd, "->", shutil.which(cmd))
PYimport torch
from tiledattention import sdpa
q = torch.randn(1, 2, 64, 64, device="cuda", dtype=torch.float16)
k = torch.randn(1, 2, 64, 64, device="cuda", dtype=torch.float16)
v = torch.randn(1, 2, 64, 64, device="cuda", dtype=torch.float16)
o = sdpa(q, k, v, causal=True)
print(o.shape, o.dtype)# Optional cleanup
rm -rf benchmark-gb10/results benchmark-gb10/figures
mkdir -p benchmark-gb10/results benchmark-gb10/figures
# Recommended paper run configuration
export TILEDATTN_SYNC_MODE=async
export TILEDATTN_TILE_M=64
export TILEDATTN_TILE_N=64
export TILEDATTN_ACCUM_MODE=fp32
python benchmark-gb10/run_study.py --warmup 5 --iters 15 --batch 1 --heads 8 --disable-flashattentionOutputs:
benchmark-gb10/results/benchmark_results.csvbenchmark-gb10/results/tuning_results.csvbenchmark-gb10/results/table3_reproducibility.mdbenchmark-gb10/results/table4_tiling_sensitivity.mdbenchmark-gb10/results/study_summary.mdbenchmark-gb10/figures/figure3_throughput_vs_s.pngbenchmark-gb10/figures/figure4_regime_map.pngbenchmark-gb10/figures/figure5_bw_proxy.pngbenchmark-gb10/figures/figure_fa_style_tflops_fp16.pngbenchmark-gb10/figures/figure6_explicit_baselines_tflops_fp16.png
Recommended focused profiles:
# Non-causal, D=128
sudo -E TiledAttention/.venv/bin/python \
benchmark-gb10/run_ncu_profile.py \
--ncu-path /usr/local/cuda/bin/ncu \
--output-dir benchmark-gb10/results \
--batch 1 --heads 8 --seq-len 4096 --head-dim 128 \
--dtype float16 --accum-mode fp32 --warmup 5 --repeats 1 \
--sync-mode async --tile-m 64 --tile-n 64
# Causal, D=128
sudo -E TiledAttention/.venv/bin/python \
benchmark-gb10/run_ncu_profile.py \
--ncu-path /usr/local/cuda/bin/ncu \
--output-dir benchmark-gb10/results \
--batch 1 --heads 8 --seq-len 4096 --head-dim 128 \
--dtype float16 --accum-mode fp32 --causal --warmup 5 --repeats 1 \
--sync-mode async --tile-m 64 --tile-n 64
# Non-causal stress point, D=64
sudo -E TiledAttention/.venv/bin/python \
benchmark-gb10/run_ncu_profile.py \
--ncu-path /usr/local/cuda/bin/ncu \
--output-dir benchmark-gb10/results \
--batch 1 --heads 8 --seq-len 2048 --head-dim 64 \
--dtype float16 --accum-mode fp32 --warmup 5 --repeats 1 \
--sync-mode async --tile-m 64 --tile-n 64Nsight artifacts:
.ncu-repreports- raw CSV exports (
*_raw.csv) - summary markdown (
ncu_profile_summary_*.md)
TILEDATTN_SYNC_MODE=async | post | strictTILEDATTN_ACCUM_MODE=auto | fp32 | fp16TILEDATTN_TILE_M,TILEDATTN_TILE_NTILEDATTN_KERNEL_OPT_LEVELTILEDATTN_KERNEL_OCCUPANCYTILEDATTN_KERNEL_NUM_CTASTILEDATTN_CHUNKED_HEAD_DIMS(experimental)TILEDATTN_DISABLE_ALIGNED_FASTPATH(debug/ablation)
- If Nsight is run with
sudo, output files may be owned byroot. Fix ownership:sudo chown -R $USER:$USER benchmark-gb10/results
- PyTorch
cu130with host CUDA13.1can emit a minor-version warning; this is expected in this setup. - This repository currently targets forward-pass SDPA only.
Apache-2.0. See LICENSE.
