[ROCm][Quantization][2/N] Refactor quark_moe w4a8 w/ oracle #39136
[ROCm][Quantization][2/N] Refactor quark_moe w4a8 w/ oracle #39136BowenBao wants to merge 2 commits intovllm-project:mainfrom
Conversation
There was a problem hiding this comment.
Code Review
This pull request introduces support for ROCm-specific MXFP4 MoE backends, specifically targeting GFX950 architectures using AITER triton kernels. Key changes include the addition of the AiterW4A8ExpertsMonolithic class for W4A8 (MXFP4 weights with static FP8 activations) and the expansion of the oracle system to handle backend selection and weight conversion for these new schemes. Additionally, the QuarkOCP_MX_MoEMethod has been refactored to unify backend selection through the oracle, allowing for the removal of the redundant QuarkOCP_MX_MoEMethod_OSS class. Review feedback suggests enhancing the descriptiveness of error messages regarding missing input scales and simplifying the complex emulation mode logic to improve code maintainability.
AndreasKaratzas
left a comment
There was a problem hiding this comment.
LGTM. Some minor questions only.
|
This pull request has merge conflicts that must be resolved before it can be |
1a54e39 to
e2573f7
Compare
|
Hi @BowenBao, the pre-commit checks have failed. Please run: uv pip install pre-commit>=4.5.1
pre-commit install
pre-commit run --all-filesThen, commit the changes and push to your branch. For future commits, Tip Is
|
1655483 to
1c160f4
Compare
- Add oracle backend selection for MXFP4 MOE - Add unittest cases, fix w4a8 weight re-assign - Refactor kernel selection and move out aiter kernel Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> Signed-off-by: Bowen Bao <bowenbao@amd.com>
1c160f4 to
7a6e293
Compare
Signed-off-by: Bowen Bao <bowenbao@amd.com>
| MARLIN = "MARLIN" | ||
| # ROCm AITER | ||
| AITER = "AITER" | ||
| # ROCm AITER backends |
There was a problem hiding this comment.
I think we should rename to AITER_CK and AITER_TRITON_MXFP4_FP8 for clarity ... AITER_CK supports both W4A4 and W4A16, and experimenting with W4A8 rn
There was a problem hiding this comment.
In general that makes sense, it's either we keep a single backend enum & expert class for AITER_CK, or we keep separate wrapper for each quant config like in #41436.
With AITER_CK the only issue I see so far is that you can't immediately tell what config combo it uses / supports. Current weight postprocessing (shuffling etc) is branched on backends, with AITER_CK we are introducing more complex logic to further distinguish between configs there. That is if we assume w4a16, w4a4 and potentially w4a8 does different postprocessing logic, which seems like the case from existing code.
| # TODO: Remove once all OCP MX schemes use the kernel abstraction | ||
| _AITER_NATIVE_OCP_MX_SCHEMES = ("w_mxfp4", "w_mxfp4_a_mxfp4") | ||
| _AITER_NATIVE_OCP_MX_SCHEMES = ("w_mxfp4", "w_mxfp4_a_mxfp4", "w_mxfp4_a_fp8") | ||
| self.emulate = ( |
There was a problem hiding this comment.
Shouldn't we remove this flag/override entirely and just let the oracle set the backend?
| @@ -392,10 +430,15 @@ def _return_or_raise( | |||
| ) | |||
|
|
|||
| for backend in AVAILABLE_BACKENDS: | |||
There was a problem hiding this comment.
this should be _get_priority_backends if we're changing the names from select_gpt_oss_mxfp4_moe_backend -> select_mxfp4_moe_backend like this ... let me discuss offline
There was a problem hiding this comment.
good catch, let's resolve in follow-ups.
QuarkOCP_MX_MoEMethod_OSSand add aiter w4a8 backend.pytest -s -v tests/evals/gpt_oss/test_gpqa_correctness.py \ --config-list-file=tests/evals/gpt_oss/configs/models-gfx950.txt pytest -s -v tests/evals/gsm8k/test_gsm8k_correctness.py \ --config-list-file=tests/evals/gsm8k/configs/models-qwen35-mi355.txt