Releases: patrick-toulme/pyptx
pyptx v0.1.1
pyptx v0.1.1
First PyPI release. pip install pyptx==0.1.1.
What's in this release
- First-class Ampere support (sm_80) —
mma.syncandcp.asyncISA wrappers,ldmatrix, full
GEMM + RMSNorm / LayerNorm / Softmax / SwiGLU example suite tuned for A100. - Cross-arch portability —
pyptx.detect_arch()and@kernel(arch="auto")resolve at decorator
time. Datacenter Hopper / Blackwell get theasuffix automatically; workstation Blackwell (sm_120)
returns plainsm_120with PTX 8.7. - Multi-card bring-up validated on real hardware (see below).
- Manylinux x86_64 + aarch64 wheels for Python 3.10–3.13. Pre-built shim included; no compile step at
install time.
Validated hardware
Every kernel below was run end-to-end (JAX + Torch backends, numerics checked against fp32 reference)
on the listed card before this release:
| Arch | Card | Status |
|---|---|---|
| Turing (sm_75) | T4 | Framework works for user-written sm_75 kernels. Bundled examples/ampere/* |
require sm_80 features (cp.async, bf16) and fail cleanly with CUDA_ERROR_INVALID_PTX. |
||
| Ampere (sm_80) | A100-SXM4-80GB | All 7 Ampere examples ✅ |
| Ada (sm_89) | L4 | All 7 Ampere examples ✅ via sm_80 forward-compat |
| Hopper (sm_90a) | H100 80GB HBM3 | All 6 Hopper examples ✅ |
| Datacenter Blackwell (sm_100a) | B200 | All headline Blackwell examples ✅ (rms_norm, layer_norm, |
| Workstation Blackwell (sm_120) | RTX Pro 6000 Blackwell | All 7 Ampere examples ✅ via sm_80 |
| forward-compat |
Performance highlights
Validated on B200 (driver 595.45.04, CUDA 13, jax 0.10):
gemm_highperf_blackwell(1SM, bf16): 1329 TFLOPS @ 8192³ — 83% of cuBLAS 13 in a hand-written
PTX kernel.
Known issues
- Turing reference kernels not yet shipped. The framework runs on T4; the bundled examples don't.
Bring your own sm_75 kernel for now. examples/blackwell/grouped_gemm.pyMoE-scale (G=4, M=2048, N=256, K=2048) JAX path
intermittently fails when run as part of the script's full validation loop. The same case passes when
called in isolation. Pre-existing cross-call state issue in the FFI path — fix tracked for v0.1.2.
Workaround: call the kernel once per program invocation.
Install
pip install pyptx==0.1.1 # CPU + GPU (manylinux x86_64 / aarch64)
pip install "pyptx[jax]==0.1.1" # adds jax[cuda12]
pip install "pyptx[torch]==0.1.1" # adds torch + cuda-python
Links
- Docs: https://pyptx.dev (or docs/ in the repo)
- Examples: https://github.com/patrick-toulme/pyptx/tree/v0.1.1/examples
- Comparison vs Pallas / Triton / CuTe DSL: docs/comparison.md Initial Release of PyPTX
Today I'm open-sourcing a project I've been building on personal time: pyptx, a Python DSL where the function body is the PTX instruction stream. One PTX instruction = one Python call. No optimizer, no autotuner, no tile IR between you and the hardware.
Why? Because the newest GPU features — Hopper's wgmma, TMA multicast, mbarrier-based pipelines, Blackwell's tcgen05.mma + TMEM + cooperative 2-SM MMA often only exist at the PTX level. For developers chasing peak performance, that has historically meant writing inline PTX inside CUDA C++.
pyptx brings that whole path into Python. Callable from JAX (via typed XLA FFI) and PyTorch (eager, torch.compile, and a C++ extension fast path).
A few numbers from real silicon:
• H100 bf16 GEMM: 815 TFLOPS, competitive with cuBLAS at matrix sizes ≥ 6K
• B200 bf16 GEMM: 1240 TFLOPS on the 1SM kernel
• RMSNorm: 2.6 TB/s (88% of HBM3 peak, 3.9× PyTorch eager)
• SwiGLU: 2.8 TB/s (94% HBM3)
The other half of the project is a transpiler in the opposite direction:
python -m pyptx.codegen kernel.ptx --sugar
takes PTX from anywhere — nvcc, Triton, CUTLASS output, DeepGEMM kernels — and emits editable pyptx Python. The parser/emitter round-trips byte-identical on 218+ real-world kernels. So you can read someone else's kernel as Python, modify it, and ship the result.
Built end-to-end: parser, IR, emitter, transpiler, JAX integration, PyTorch integration, full Hopper + Blackwell ISA coverage, multi-arch wheels published to PyPI. ~17K lines of Python total.
Ships with maintained GEMM, grouped GEMM, RMSNorm, LayerNorm, and SwiGLU kernels for both Hopper and Blackwell, plus the PTX → Python transpiler.
pip install pyptx[torch] # for PyTorch
pip install pyptx[jax] # for JAX
pip install pyptx[all] # both