Skip to content

pyptx v0.1.1

Latest

Choose a tag to compare

@patrick-toulme patrick-toulme released this 03 May 17:12
· 4 commits to main since this release

pyptx v0.1.1

First PyPI release. pip install pyptx==0.1.1.

What's in this release

  • First-class Ampere support (sm_80) — mma.sync and cp.async ISA wrappers, ldmatrix, full
    GEMM + RMSNorm / LayerNorm / Softmax / SwiGLU example suite tuned for A100.
  • Cross-arch portabilitypyptx.detect_arch() and @kernel(arch="auto") resolve at decorator
    time. Datacenter Hopper / Blackwell get the a suffix automatically; workstation Blackwell (sm_120)
    returns plain sm_120 with PTX 8.7.
  • Multi-card bring-up validated on real hardware (see below).
  • Manylinux x86_64 + aarch64 wheels for Python 3.10–3.13. Pre-built shim included; no compile step at
    install time.

Validated hardware

Every kernel below was run end-to-end (JAX + Torch backends, numerics checked against fp32 reference)
on the listed card before this release:

Arch Card Status
Turing (sm_75) T4 Framework works for user-written sm_75 kernels. Bundled examples/ampere/*
require sm_80 features (cp.async, bf16) and fail cleanly with CUDA_ERROR_INVALID_PTX.
Ampere (sm_80) A100-SXM4-80GB All 7 Ampere examples ✅
Ada (sm_89) L4 All 7 Ampere examples ✅ via sm_80 forward-compat
Hopper (sm_90a) H100 80GB HBM3 All 6 Hopper examples ✅
Datacenter Blackwell (sm_100a) B200 All headline Blackwell examples ✅ (rms_norm, layer_norm,
Workstation Blackwell (sm_120) RTX Pro 6000 Blackwell All 7 Ampere examples ✅ via sm_80
forward-compat

Performance highlights

Validated on B200 (driver 595.45.04, CUDA 13, jax 0.10):

  • gemm_highperf_blackwell (1SM, bf16): 1329 TFLOPS @ 8192³ — 83% of cuBLAS 13 in a hand-written
    PTX kernel.

Known issues

  • Turing reference kernels not yet shipped. The framework runs on T4; the bundled examples don't.
    Bring your own sm_75 kernel for now.
  • examples/blackwell/grouped_gemm.py MoE-scale (G=4, M=2048, N=256, K=2048) JAX path
    intermittently fails when run as part of the script's full validation loop. The same case passes when
    called in isolation. Pre-existing cross-call state issue in the FFI path — fix tracked for v0.1.2.
    Workaround: call the kernel once per program invocation.

Install

pip install pyptx==0.1.1               # CPU + GPU (manylinux x86_64 / aarch64)
pip install "pyptx[jax]==0.1.1"        # adds jax[cuda12]                      
pip install "pyptx[torch]==0.1.1"      # adds torch + cuda-python                                     
                                                                                                      
Links                                                                                                 
                                                                                                      
- Docs: https://pyptx.dev (or docs/ in the repo)                                                      
- Examples: https://github.com/patrick-toulme/pyptx/tree/v0.1.1/examples
- Comparison vs Pallas / Triton / CuTe DSL: docs/comparison.md