Skip to content

Releases: patrick-toulme/pyptx

pyptx v0.1.1

03 May 17:12

Choose a tag to compare

pyptx v0.1.1

First PyPI release. pip install pyptx==0.1.1.

What's in this release

  • First-class Ampere support (sm_80) — mma.sync and cp.async ISA wrappers, ldmatrix, full
    GEMM + RMSNorm / LayerNorm / Softmax / SwiGLU example suite tuned for A100.
  • Cross-arch portabilitypyptx.detect_arch() and @kernel(arch="auto") resolve at decorator
    time. Datacenter Hopper / Blackwell get the a suffix automatically; workstation Blackwell (sm_120)
    returns plain sm_120 with PTX 8.7.
  • Multi-card bring-up validated on real hardware (see below).
  • Manylinux x86_64 + aarch64 wheels for Python 3.10–3.13. Pre-built shim included; no compile step at
    install time.

Validated hardware

Every kernel below was run end-to-end (JAX + Torch backends, numerics checked against fp32 reference)
on the listed card before this release:

Arch Card Status
Turing (sm_75) T4 Framework works for user-written sm_75 kernels. Bundled examples/ampere/*
require sm_80 features (cp.async, bf16) and fail cleanly with CUDA_ERROR_INVALID_PTX.
Ampere (sm_80) A100-SXM4-80GB All 7 Ampere examples ✅
Ada (sm_89) L4 All 7 Ampere examples ✅ via sm_80 forward-compat
Hopper (sm_90a) H100 80GB HBM3 All 6 Hopper examples ✅
Datacenter Blackwell (sm_100a) B200 All headline Blackwell examples ✅ (rms_norm, layer_norm,
Workstation Blackwell (sm_120) RTX Pro 6000 Blackwell All 7 Ampere examples ✅ via sm_80
forward-compat

Performance highlights

Validated on B200 (driver 595.45.04, CUDA 13, jax 0.10):

  • gemm_highperf_blackwell (1SM, bf16): 1329 TFLOPS @ 8192³ — 83% of cuBLAS 13 in a hand-written
    PTX kernel.

Known issues

  • Turing reference kernels not yet shipped. The framework runs on T4; the bundled examples don't.
    Bring your own sm_75 kernel for now.
  • examples/blackwell/grouped_gemm.py MoE-scale (G=4, M=2048, N=256, K=2048) JAX path
    intermittently fails when run as part of the script's full validation loop. The same case passes when
    called in isolation. Pre-existing cross-call state issue in the FFI path — fix tracked for v0.1.2.
    Workaround: call the kernel once per program invocation.

Install

pip install pyptx==0.1.1               # CPU + GPU (manylinux x86_64 / aarch64)
pip install "pyptx[jax]==0.1.1"        # adds jax[cuda12]                      
pip install "pyptx[torch]==0.1.1"      # adds torch + cuda-python                                     
                                                                                                      
Links                                                                                                 
                                                                                                      
- Docs: https://pyptx.dev (or docs/ in the repo)                                                      
- Examples: https://github.com/patrick-toulme/pyptx/tree/v0.1.1/examples
- Comparison vs Pallas / Triton / CuTe DSL: docs/comparison.md  

Initial Release of PyPTX

25 Apr 19:25

Choose a tag to compare

Today I'm open-sourcing a project I've been building on personal time: pyptx, a Python DSL where the function body is the PTX instruction stream. One PTX instruction = one Python call. No optimizer, no autotuner, no tile IR between you and the hardware.

Why? Because the newest GPU features — Hopper's wgmma, TMA multicast, mbarrier-based pipelines, Blackwell's tcgen05.mma + TMEM + cooperative 2-SM MMA often only exist at the PTX level. For developers chasing peak performance, that has historically meant writing inline PTX inside CUDA C++.

pyptx brings that whole path into Python. Callable from JAX (via typed XLA FFI) and PyTorch (eager, torch.compile, and a C++ extension fast path).

A few numbers from real silicon:

• H100 bf16 GEMM: 815 TFLOPS, competitive with cuBLAS at matrix sizes ≥ 6K
• B200 bf16 GEMM: 1240 TFLOPS on the 1SM kernel
• RMSNorm: 2.6 TB/s (88% of HBM3 peak, 3.9× PyTorch eager)
• SwiGLU: 2.8 TB/s (94% HBM3)

The other half of the project is a transpiler in the opposite direction:

python -m pyptx.codegen kernel.ptx --sugar

takes PTX from anywhere — nvcc, Triton, CUTLASS output, DeepGEMM kernels — and emits editable pyptx Python. The parser/emitter round-trips byte-identical on 218+ real-world kernels. So you can read someone else's kernel as Python, modify it, and ship the result.

Built end-to-end: parser, IR, emitter, transpiler, JAX integration, PyTorch integration, full Hopper + Blackwell ISA coverage, multi-arch wheels published to PyPI. ~17K lines of Python total.

Ships with maintained GEMM, grouped GEMM, RMSNorm, LayerNorm, and SwiGLU kernels for both Hopper and Blackwell, plus the PTX → Python transpiler.

pip install pyptx[torch] # for PyTorch
pip install pyptx[jax] # for JAX
pip install pyptx[all] # both