Skip to content

Port padding, RoPE, router, recipe, and fp8_partial_cast to stable ABI#2

Open
pstjohn wants to merge 2 commits intopstjohn/stable-abi-v2from
pstjohn/stable-abi-v2-step2
Open

Port padding, RoPE, router, recipe, and fp8_partial_cast to stable ABI#2
pstjohn wants to merge 2 commits intopstjohn/stable-abi-v2from
pstjohn/stable-abi-v2-step2

Conversation

@pstjohn
Copy link
Copy Markdown
Owner

@pstjohn pstjohn commented Apr 3, 2026

Port 18 additional simple functions to the libtorch stable ABI, following the same pattern established by the softmax port:

  • padding.cpp: fused_multi_row_padding, fused_multi_row_unpadding
  • apply_rope.cpp: fused_rope_forward/backward, fused_qkv_rope_forward/backward
  • router.cpp: fused_topk_with_score_function_fwd/bwd, fused_score_for_moe_aux_loss_fwd/bwd, fused_moe_aux_loss_fwd/bwd
  • recipe.cpp: compute_amax, fused_amax_and_scale_update_after_reduction
  • fp8_partial_cast.cpp: fp8_block_scaling_compute_partial_amax, fp8_block_scaling_partial_cast, mxfp8_scaling_compute_partial_amax, mxfp8_scaling_partial_cast

All functions ported in-place with minimal diffs. Schemas added to registration.cpp, pybind registrations removed, Python callers updated to use torch.ops.transformer_engine.

Proof of concept for migrating pybind11 functions to the PyTorch
stable ABI. Ports all 8 scaled softmax functions:

- Add stable_common.h with stable ABI helpers (tensor allocation,
  TensorWrapper construction, CUDA stream, dtype converters)
- Add registration.cpp with STABLE_TORCH_LIBRARY schema definitions
- Rewrite softmax.cpp: at::Tensor -> torch::stable::Tensor, use
  stable allocation and stream APIs, TORCH_BOX() for impl registration
- Remove softmax registrations from pybind.cpp
- Update Python callers to use torch.ops.transformer_engine_stable

The pattern is mechanical (API translation, no logic changes) and
establishes the template for porting the remaining ~70 Category A
functions that have no py::handle/py::object dependencies.

Signed-off-by: Peter St. John <pstjohn@nvidia.com>
@pstjohn pstjohn force-pushed the pstjohn/stable-abi-v2-step2 branch from 56fde89 to a235d44 Compare April 3, 2026 19:50
Port 18 additional pure Category A functions to the libtorch stable
ABI, following the same pattern established by the softmax port:

- padding.cpp: fused_multi_row_padding, fused_multi_row_unpadding
- apply_rope.cpp: fused_rope_forward/backward, fused_qkv_rope_forward/backward
- router.cpp: fused_topk_with_score_function_fwd/bwd,
  fused_score_for_moe_aux_loss_fwd/bwd, fused_moe_aux_loss_fwd/bwd
- recipe.cpp: compute_amax, fused_amax_and_scale_update_after_reduction
- fp8_partial_cast.cpp: fp8_block_scaling_compute_partial_amax,
  fp8_block_scaling_partial_cast, mxfp8_scaling_compute_partial_amax,
  mxfp8_scaling_partial_cast

All functions ported in-place with minimal diffs. Schemas added to
registration.cpp, pybind registrations removed, Python callers
updated to use torch.ops.transformer_engine.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant