Port padding, RoPE, router, recipe, and fp8_partial_cast to stable ABI#2
Open
pstjohn wants to merge 2 commits intopstjohn/stable-abi-v2from
Open
Port padding, RoPE, router, recipe, and fp8_partial_cast to stable ABI#2pstjohn wants to merge 2 commits intopstjohn/stable-abi-v2from
pstjohn wants to merge 2 commits intopstjohn/stable-abi-v2from
Conversation
Proof of concept for migrating pybind11 functions to the PyTorch stable ABI. Ports all 8 scaled softmax functions: - Add stable_common.h with stable ABI helpers (tensor allocation, TensorWrapper construction, CUDA stream, dtype converters) - Add registration.cpp with STABLE_TORCH_LIBRARY schema definitions - Rewrite softmax.cpp: at::Tensor -> torch::stable::Tensor, use stable allocation and stream APIs, TORCH_BOX() for impl registration - Remove softmax registrations from pybind.cpp - Update Python callers to use torch.ops.transformer_engine_stable The pattern is mechanical (API translation, no logic changes) and establishes the template for porting the remaining ~70 Category A functions that have no py::handle/py::object dependencies. Signed-off-by: Peter St. John <pstjohn@nvidia.com>
56fde89 to
a235d44
Compare
Port 18 additional pure Category A functions to the libtorch stable ABI, following the same pattern established by the softmax port: - padding.cpp: fused_multi_row_padding, fused_multi_row_unpadding - apply_rope.cpp: fused_rope_forward/backward, fused_qkv_rope_forward/backward - router.cpp: fused_topk_with_score_function_fwd/bwd, fused_score_for_moe_aux_loss_fwd/bwd, fused_moe_aux_loss_fwd/bwd - recipe.cpp: compute_amax, fused_amax_and_scale_update_after_reduction - fp8_partial_cast.cpp: fp8_block_scaling_compute_partial_amax, fp8_block_scaling_partial_cast, mxfp8_scaling_compute_partial_amax, mxfp8_scaling_partial_cast All functions ported in-place with minimal diffs. Schemas added to registration.cpp, pybind registrations removed, Python callers updated to use torch.ops.transformer_engine. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
a235d44 to
4b905fb
Compare
5611150 to
e0bae00
Compare
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Port 18 additional simple functions to the libtorch stable ABI, following the same pattern established by the softmax port:
All functions ported in-place with minimal diffs. Schemas added to registration.cpp, pybind registrations removed, Python callers updated to use torch.ops.transformer_engine.