feat: support cdna4 v_mfma_i32_16x16x64_i8 & v_mfma_i32_32x32x32_i8#2097
Conversation
|
👋 Hi! Thank you for contributing to the TileLang project. Please remember to run We appreciate you taking this step! Our team will review your contribution, and we look forward to your awesome work! 🚀 |
📝 WalkthroughWalkthroughThis pull request extends MFMA (Matrix Fused Multiply-Add) support for AMD GPUs by introducing the Changes
Estimated code review effort🎯 4 (Complex) | ⏱️ ~60 minutes Possibly related PRs
Suggested reviewers
Poem
🚥 Pre-merge checks | ✅ 4 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (4 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
tilelang/intrinsics/mfma_macro_generator.py (1)
105-112:⚠️ Potential issue | 🟡 Minor
_normalize_gfx950_f16_bf16_kpackcan silently override an explicitmfma_shape.If a future caller passes
mfma_shape=(16, 16, 16)together withk_pack=2on gfx950 f16/bf16, the subsequent_normalize_gfx950_f16_bf16_kpackcall will rewritek_dim/k_packand effectively change the emitted instruction shape away from what the user asked for. Not exercised by this PR (int8-only tests) but worth guarding against.🛡️ Suggested guard
def _normalize_gfx950_f16_bf16_kpack(self): + if getattr(self, "_mfma_shape_explicit", False): + return is_f16_or_bf16 = self.a_dtype in {T.float16, T.bfloat16} and self.b_dtype in {T.float16, T.bfloat16}and set
self._mfma_shape_explicit = mfma_shape is not Noneinside_initialize_mfma_shape.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tilelang/intrinsics/mfma_macro_generator.py` around lines 105 - 112, The _normalize_gfx950_f16_bf16_kpack method can silently override an explicitly-provided mfma_shape; to fix, record whether mfma_shape was user-supplied by setting self._mfma_shape_explicit = (mfma_shape is not None) inside _initialize_mfma_shape, and then modify _normalize_gfx950_f16_bf16_kpack to early-return (or skip changing k_pack/k_dim) when self._mfma_shape_explicit is true so an explicit mfma_shape and k_pack are not overwritten.
🧹 Nitpick comments (5)
testing/python/amd/test_tilelang_gemm_mfma_preshuffle.py (2)
390-401: Nice coverage — optionally crossa_preshuffle=Truewith the extended shapes.The existing
test_assert_tl_matmulalready exercisesa_preshuffle+a_preshuffle+b_preshuffleon the default(16,16,32)shape, andtest_assert_tl_matmul_extended_mfmaexercises(16,16,64)/(32,32,32)withb_preshuffleonly. Adding one or two rows that seta_preshuffle=Trueon the extended shapes would close the matrix and specifically verify the 32x32 reverse-map path through the preshuffleldmatrix_acodepath.🧪 Suggested extra cases
[ # v_mfma_i32_16x16x64_i8 — doubled K throughput (kp=1 only, micro_k=64) (256, 256, 512, T.int8, T.int32, T.int32, True, 1, False, (16, 16, 64)), (256, 256, 512, T.int8, T.int32, T.int32, True, 1, True, (16, 16, 64)), # v_mfma_i32_32x32x32_i8 — doubled MN throughput (256, 256, 512, T.int8, T.int32, T.int32, True, 1, False, (32, 32, 32)), (256, 256, 512, T.int8, T.int32, T.int32, True, 1, True, (32, 32, 32)), + # A-preshuffle on extended shapes (requires threading a_preshuffle through the helper) + # (256, 256, 512, T.int8, T.int32, T.int32, True, 1, True, True, (32, 32, 32)), ],Threading this through requires parametrizing
a_preshuffleintest_assert_tl_matmul_extended_mfmaand forwarding it toassert_tl_matmul_correctness(..., a_preshuffle=...).🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@testing/python/amd/test_tilelang_gemm_mfma_preshuffle.py` around lines 390 - 401, Parametrize test_assert_tl_matmul_extended_mfma to include a_preshuffle (True) for the extended MFMA shapes and forward that argument into the call to assert_tl_matmul_correctness; specifically add entries where a_preshuffle=True alongside the existing (16,16,64) and (32,32,32) cases and update the test signature/param list to accept a_preshuffle so the call to assert_tl_matmul_correctness(..., a_preshuffle=...) exercises the a_preshuffle + b_preshuffle reverse-map path through the preshuffle ldmatrix_a codepath.
300-301: Weak smoke assert — consider strengthening.
assert kernel.get_kernel_source() is not Nonereally only verifies the compile call returned something. Since numerical correctness is already checked byassert_closebelow, this line is effectively redundant; if the goal is to gate on successful codegen you could drop it or assert that the source contains the expected MFMA builtin suffix (e.g.,i32_32x32x32_i8for the new shape), which would catch silent fall-through to the wrong intrinsic.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@testing/python/amd/test_tilelang_gemm_mfma_preshuffle.py` around lines 300 - 301, The current smoke assert only checks that tilelang.compile(matmul) returned something; strengthen it to verify the generated kernel source actually uses the expected MFMA intrinsic by inspecting kernel.get_kernel_source(). Instead of (or in addition to) assert kernel.get_kernel_source() is not None, assert that the returned source string contains the expected MFMA builtin suffix (e.g., "i32_32x32x32_i8") so that kernel.get_kernel_source() and tilelang.compile(matmul) are both validated to have produced the correct MFMA code path (reference symbols: tilelang.compile, kernel.get_kernel_source, expected MFMA suffix).tilelang/intrinsics/mfma_macro_generator.py (3)
292-320: Path logic for 32x32 LDMATRIX is correct; consider a small robustness note.The A/B/transposed routing mirrors the existing 16x16 code and is consistent with the forward/reverse maps in
mfma_layout.py. Two small nits that you can pick up later:
- The
ValueErrorat Line 302 is informative but long; consider wrapping it onto two lines to help readers.- Ruff flags the
×/↔characters in the docstring/comment at Lines 296 and 300 (RUF002/RUF003). Replacing with plainxand<->silences the warning.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tilelang/intrinsics/mfma_macro_generator.py` around lines 292 - 320, In _get_ldmatrix_index_map_32, shorten/wrap the long ValueError message for readability (replace the single long f-string with a shorter message like "unsupported k_dim; only k_dim=32 implemented" or split into two concatenated strings) and replace non-ASCII symbols in the docstring/comments (change `×` to `x` and `↔` to `<->`) to satisfy Ruff warnings; update the docstring lines and the ValueError text in the _get_ldmatrix_index_map_32 function accordingly.
521-526: Consolidate the 32x32 C-map import with the module-level imports.Lines 36–37 already pull
thread_id_shared_access_64x16_to_32x32_layout_A/_Bfrom.mfma_layout. Adding_C_n_mto that block (and dropping the inlinefrom .mfma_layout import …here) would be slightly cleaner and match the existing style. Same observation for themfma_store_index_map_32x32local import at Line 325.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tilelang/intrinsics/mfma_macro_generator.py` around lines 521 - 526, Move the two inline imports into the module-level import block from .mfma_layout: add thread_id_shared_access_64x16_to_32x32_layout_C_n_m and mfma_store_index_map_32x32 to the existing imports (the same place that already imports thread_id_shared_access_64x16_to_32x32_layout_A/_B), then remove the local "from .mfma_layout import ..." statements in mfma_macro_generator.py and use the module-level symbols directly; specifically update the M_DIM conditional to set _store_map = thread_id_shared_access_64x16_to_32x32_layout_C_n_m when M_DIM == 32 and replace the local mfma_store_index_map_32x32 import/usage with the module-level mfma_store_index_map_32x32, keeping the fallback to mfma_store_index_map unchanged.
599-608: UseN_DIMfor B-side layouts to keep this resilient to future asymmetric shapes.
mn_dim = self.M_DIM(with the "M_DIM == N_DIM for all supported shapes" comment) works today but will quietly misroute if a shape with M_DIM ≠ N_DIM is ever added (e.g. a hypothetical 16×32×K variant). Since the function already branches onmatrix_is_abelow, it's cheap to pick the correct axis here too.♻️ Suggested change
- k_dim = self.k_dim * self.k_pack - mn_dim = self.M_DIM # M_DIM == N_DIM for all supported shapes + k_dim = self.k_dim * self.k_pack + mn_dim = self.M_DIM if matrix_is_a else self.N_DIM🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tilelang/intrinsics/mfma_macro_generator.py` around lines 599 - 608, The code currently sets mn_dim = self.M_DIM for both A and B paths which will misroute B-side layouts if M_DIM != N_DIM later; change the logic to pick the correct dimension based on matrix_is_a (e.g., use self.M_DIM when matrix_is_a is True and self.N_DIM when False) so that when selecting transform_func_sr_a and transform_func_sr_b (and any subsequent branches that use mn_dim) the B-side uses N_DIM instead of M_DIM.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Outside diff comments:
In `@tilelang/intrinsics/mfma_macro_generator.py`:
- Around line 105-112: The _normalize_gfx950_f16_bf16_kpack method can silently
override an explicitly-provided mfma_shape; to fix, record whether mfma_shape
was user-supplied by setting self._mfma_shape_explicit = (mfma_shape is not
None) inside _initialize_mfma_shape, and then modify
_normalize_gfx950_f16_bf16_kpack to early-return (or skip changing k_pack/k_dim)
when self._mfma_shape_explicit is true so an explicit mfma_shape and k_pack are
not overwritten.
---
Nitpick comments:
In `@testing/python/amd/test_tilelang_gemm_mfma_preshuffle.py`:
- Around line 390-401: Parametrize test_assert_tl_matmul_extended_mfma to
include a_preshuffle (True) for the extended MFMA shapes and forward that
argument into the call to assert_tl_matmul_correctness; specifically add entries
where a_preshuffle=True alongside the existing (16,16,64) and (32,32,32) cases
and update the test signature/param list to accept a_preshuffle so the call to
assert_tl_matmul_correctness(..., a_preshuffle=...) exercises the a_preshuffle +
b_preshuffle reverse-map path through the preshuffle ldmatrix_a codepath.
- Around line 300-301: The current smoke assert only checks that
tilelang.compile(matmul) returned something; strengthen it to verify the
generated kernel source actually uses the expected MFMA intrinsic by inspecting
kernel.get_kernel_source(). Instead of (or in addition to) assert
kernel.get_kernel_source() is not None, assert that the returned source string
contains the expected MFMA builtin suffix (e.g., "i32_32x32x32_i8") so that
kernel.get_kernel_source() and tilelang.compile(matmul) are both validated to
have produced the correct MFMA code path (reference symbols: tilelang.compile,
kernel.get_kernel_source, expected MFMA suffix).
In `@tilelang/intrinsics/mfma_macro_generator.py`:
- Around line 292-320: In _get_ldmatrix_index_map_32, shorten/wrap the long
ValueError message for readability (replace the single long f-string with a
shorter message like "unsupported k_dim; only k_dim=32 implemented" or split
into two concatenated strings) and replace non-ASCII symbols in the
docstring/comments (change `×` to `x` and `↔` to `<->`) to satisfy Ruff
warnings; update the docstring lines and the ValueError text in the
_get_ldmatrix_index_map_32 function accordingly.
- Around line 521-526: Move the two inline imports into the module-level import
block from .mfma_layout: add thread_id_shared_access_64x16_to_32x32_layout_C_n_m
and mfma_store_index_map_32x32 to the existing imports (the same place that
already imports thread_id_shared_access_64x16_to_32x32_layout_A/_B), then remove
the local "from .mfma_layout import ..." statements in mfma_macro_generator.py
and use the module-level symbols directly; specifically update the M_DIM
conditional to set _store_map =
thread_id_shared_access_64x16_to_32x32_layout_C_n_m when M_DIM == 32 and replace
the local mfma_store_index_map_32x32 import/usage with the module-level
mfma_store_index_map_32x32, keeping the fallback to mfma_store_index_map
unchanged.
- Around line 599-608: The code currently sets mn_dim = self.M_DIM for both A
and B paths which will misroute B-side layouts if M_DIM != N_DIM later; change
the logic to pick the correct dimension based on matrix_is_a (e.g., use
self.M_DIM when matrix_is_a is True and self.N_DIM when False) so that when
selecting transform_func_sr_a and transform_func_sr_b (and any subsequent
branches that use mn_dim) the B-side uses N_DIM instead of M_DIM.
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: 35764cea-e8ad-4140-be74-b13cb14bc510
📒 Files selected for processing (6)
src/target/codegen_hip.ccsrc/tl_templates/hip/common.htesting/python/amd/test_tilelang_gemm_mfma_preshuffle.pytilelang/intrinsics/mfma_layout.pytilelang/intrinsics/mfma_macro_generator.pytilelang/intrinsics/utils.py
Summary by CodeRabbit
Release Notes
New Features
Tests