Skip to content

v0.4.3 — correct silent NKI fallback across v0.4.x

Choose a tag to compare

@scttfrdmn scttfrdmn released this 13 Apr 20:57
· 69 commits to main since this release

Correction release. Every "trn1 NKI" perf number published in v0.4.0 / v0.4.1 / v0.4.2 was actually trn1's 8-vCPU Xeon running torch.matmul, not the Trainium Tensor Engine.

What went wrong

Our SSM runners launched the Neuron venv's python directly without prepending its bin/ to $PATH. torch_neuronx.initializer calls subprocess.run(["libneuronpjrt-path"]) to locate the PJRT plugin library — that binary lives in the venv's bin/ and was unresolvable. Every NKI dispatch raised FileNotFoundError, and our _nki_*_impl try/except wrappers swallowed the exception and fell back to torch.matmul.

Correctness tests kept passing because torch.matmul gives the same answer as nki_gemm; only perf attribution was wrong. The v0.4.2 cross-vendor comparison was A10G's Ampere GPU vs trn1's Xeon, not vs trn1's Tensor Engine.

What's fixed

  • scripts/run_neuron_tests.sh, scripts/run_df_mp2_bench.sh — prepend $NEURON_VENV/bin to $PATH + set TRNBLAS_REQUIRE_NKI=1 in the test runner.
  • trnblas.nki.NkiFallbackWarning — emitted once per distinct error when the fallback triggers. Makes future misconfigurations visible.
  • tests/test_nki_really_runs.py — anti-regression test that forces TRNBLAS_REQUIRE_NKI=1 and asserts a GEMM dispatch completes.
  • Re-measured trn1 numbers on docs/benchmarks.md with retraction banner.

Side finding

trnblas.nki.nki_mp2_energy kernel tests had a partition-limit bug that was masked by the silent fallback (nl.load(eps_vir[0:NVIR]) exceeds 128 partitions for nvir > 128). Tests skipped pending kernel rewrite under #15. Not in the production DF-MP2 path (examples/df_mp2.py uses the torch reduction).

Re-measured (trn1.2xlarge, real NKI, warm cache)

Op Shape v0.4.x (was CPU) v0.4.3 (real NKI)
GEMM 1024³ 4.5 ms 2.3 ms
SYRK 1024² 7.91 ms 5.71 ms
TRSM 2048×512 27.75 ms 35.82 ms
DF-MP2 medium warm 9.77 s 9.91 s

Relative A10G vs trn1 ratios land in the same 19–45× range; cross-vendor story is unchanged, attribution is correct.

See the CHANGELOG.