Skip to content

feat: support cdna4 v_mfma_i32_16x16x64_i8 & v_mfma_i32_32x32x32_i8#2097

Merged
LeiWang1999 merged 1 commit into
tile-ai:mainfrom
Paran0idy:jx/cdna4
Apr 26, 2026
Merged

feat: support cdna4 v_mfma_i32_16x16x64_i8 & v_mfma_i32_32x32x32_i8#2097
LeiWang1999 merged 1 commit into
tile-ai:mainfrom
Paran0idy:jx/cdna4

Conversation

@Paran0idy
Copy link
Copy Markdown
Contributor

@Paran0idy Paran0idy commented Apr 24, 2026

  • CDNA4 v_mfma_i32_16x16x64_i8 & v_mfma_i32_32x32x32_i8
  • Matrix A preshuffle
  • TODO: use async copy

Summary by CodeRabbit

Release Notes

  • New Features

    • Added support for MFMA instruction shape (32,32,32) for AMD architectures
    • Extended data type mappings for int8 and int32 vector operations
    • New preshuffle configurations for matrix multiplication optimization
    • Added configurable autotuning parameters for GEMM kernel tuning
  • Tests

    • Expanded test coverage for preshuffle configurations
    • Added dedicated test suite for extended int8 MFMA operations

@github-actions
Copy link
Copy Markdown

👋 Hi! Thank you for contributing to the TileLang project.

Please remember to run pre-commit run --all-files in the root directory of the project to ensure your changes are properly linted and formatted. This will help ensure your contribution passes the format check.

We appreciate you taking this step! Our team will review your contribution, and we look forward to your awesome work! 🚀

@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai Bot commented Apr 24, 2026

📝 Walkthrough

Walkthrough

This pull request extends MFMA (Matrix Fused Multiply-Add) support for AMD GPUs by introducing the (32,32,32) MFMA instruction shape and an a_preshuffle feature for matrix A. It adds new HIP data type mappings (int8x16, int32x16), layout transformation helpers for 64x16-to-32x32 thread access patterns, and refactors the macro generator to parameterize MFMA shapes with corresponding codegen, intrinsics, and test harness updates.

Changes

Cohort / File(s) Summary
HIP Type System Extensions
src/tl_templates/hip/common.h, src/target/codegen_hip.cc
Added new int32x16 vector type alias and extended HIP codegen dtype mappings to recognize int8x16→int32x4 and int32x16→int32x16 operand type conversions for MFMA emission.
MFMA Layout Transformation Helpers
tilelang/intrinsics/mfma_layout.py
Introduced six new functions (shared_32x32_to_local_64x16_layout_* and thread_id_shared_access_64x16_to_32x32_layout_*) that define bidirectional thread-id/local-id mappings for accessing 32x32 shared tiles through 64x16 local layout across A, B, and C matrices.
MFMA Macro Generator & Shape Support
tilelang/intrinsics/mfma_macro_generator.py, tilelang/intrinsics/utils.py
Extended MatrixCoreIntrinEmitter and MatrixCorePreshuffleIntrinEmitter to parameterize MFMA shapes; added (32,32,32) shape-specific LDMATRIX and store index mapping logic; refactored preshuffle emitter to support a_preshuffle alongside b_preshuffle with updated global-load conditions and thread-binding access patterns.
Preshuffle GEMM Test Suite
testing/python/amd/test_tilelang_gemm_mfma_preshuffle.py
Refactored kernel builder to expose mfma_shape selection and extended autotuning controls (block_*_warps, warp_*_tiles, chunk, num_stages, panel_size); added a_preshuffle parameter with packed A tensor layouts and preshuffled load paths; extended test coverage to include both A-only and A+B preshuffle configurations, plus dedicated test suite for CDNA4 int8 MFMA shapes (16,16,64) and (32,32,32).

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~60 minutes

Possibly related PRs

Suggested reviewers

  • LeiWang1999
  • Gongen-Ali

Poem

🐰 A tile of shapes now thirty-two by two,
With preshuffle magic for the A's debut,
Layout maps dancing through local and thread,
New MFMA whispers—int8 wisdom spread!

🚥 Pre-merge checks | ✅ 4 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 15.63% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (4 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The PR title accurately describes the main feature additions: support for two specific CDNA4 MFMA instructions (v_mfma_i32_16x16x64_i8 and v_mfma_i32_32x32x32_i8), which aligns with the detailed changes across multiple files extending MFMA dtype mappings, layout functions, and test coverage for these new instruction shapes.
Linked Issues check ✅ Passed Check skipped because no linked issues were found for this pull request.
Out of Scope Changes check ✅ Passed Check skipped because no linked issues were found for this pull request.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
tilelang/intrinsics/mfma_macro_generator.py (1)

105-112: ⚠️ Potential issue | 🟡 Minor

_normalize_gfx950_f16_bf16_kpack can silently override an explicit mfma_shape.

If a future caller passes mfma_shape=(16, 16, 16) together with k_pack=2 on gfx950 f16/bf16, the subsequent _normalize_gfx950_f16_bf16_kpack call will rewrite k_dim/k_pack and effectively change the emitted instruction shape away from what the user asked for. Not exercised by this PR (int8-only tests) but worth guarding against.

🛡️ Suggested guard
     def _normalize_gfx950_f16_bf16_kpack(self):
+        if getattr(self, "_mfma_shape_explicit", False):
+            return
         is_f16_or_bf16 = self.a_dtype in {T.float16, T.bfloat16} and self.b_dtype in {T.float16, T.bfloat16}

and set self._mfma_shape_explicit = mfma_shape is not None inside _initialize_mfma_shape.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tilelang/intrinsics/mfma_macro_generator.py` around lines 105 - 112, The
_normalize_gfx950_f16_bf16_kpack method can silently override an
explicitly-provided mfma_shape; to fix, record whether mfma_shape was
user-supplied by setting self._mfma_shape_explicit = (mfma_shape is not None)
inside _initialize_mfma_shape, and then modify _normalize_gfx950_f16_bf16_kpack
to early-return (or skip changing k_pack/k_dim) when self._mfma_shape_explicit
is true so an explicit mfma_shape and k_pack are not overwritten.
🧹 Nitpick comments (5)
testing/python/amd/test_tilelang_gemm_mfma_preshuffle.py (2)

390-401: Nice coverage — optionally cross a_preshuffle=True with the extended shapes.

The existing test_assert_tl_matmul already exercises a_preshuffle + a_preshuffle+b_preshuffle on the default (16,16,32) shape, and test_assert_tl_matmul_extended_mfma exercises (16,16,64) / (32,32,32) with b_preshuffle only. Adding one or two rows that set a_preshuffle=True on the extended shapes would close the matrix and specifically verify the 32x32 reverse-map path through the preshuffle ldmatrix_a codepath.

🧪 Suggested extra cases
     [
         # v_mfma_i32_16x16x64_i8 — doubled K throughput (kp=1 only, micro_k=64)
         (256, 256, 512, T.int8, T.int32, T.int32, True, 1, False, (16, 16, 64)),
         (256, 256, 512, T.int8, T.int32, T.int32, True, 1, True, (16, 16, 64)),
         # v_mfma_i32_32x32x32_i8 — doubled MN throughput
         (256, 256, 512, T.int8, T.int32, T.int32, True, 1, False, (32, 32, 32)),
         (256, 256, 512, T.int8, T.int32, T.int32, True, 1, True, (32, 32, 32)),
+        # A-preshuffle on extended shapes (requires threading a_preshuffle through the helper)
+        # (256, 256, 512, T.int8, T.int32, T.int32, True, 1, True, True, (32, 32, 32)),
     ],

Threading this through requires parametrizing a_preshuffle in test_assert_tl_matmul_extended_mfma and forwarding it to assert_tl_matmul_correctness(..., a_preshuffle=...).

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@testing/python/amd/test_tilelang_gemm_mfma_preshuffle.py` around lines 390 -
401, Parametrize test_assert_tl_matmul_extended_mfma to include a_preshuffle
(True) for the extended MFMA shapes and forward that argument into the call to
assert_tl_matmul_correctness; specifically add entries where a_preshuffle=True
alongside the existing (16,16,64) and (32,32,32) cases and update the test
signature/param list to accept a_preshuffle so the call to
assert_tl_matmul_correctness(..., a_preshuffle=...) exercises the a_preshuffle +
b_preshuffle reverse-map path through the preshuffle ldmatrix_a codepath.

300-301: Weak smoke assert — consider strengthening.

assert kernel.get_kernel_source() is not None really only verifies the compile call returned something. Since numerical correctness is already checked by assert_close below, this line is effectively redundant; if the goal is to gate on successful codegen you could drop it or assert that the source contains the expected MFMA builtin suffix (e.g., i32_32x32x32_i8 for the new shape), which would catch silent fall-through to the wrong intrinsic.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@testing/python/amd/test_tilelang_gemm_mfma_preshuffle.py` around lines 300 -
301, The current smoke assert only checks that tilelang.compile(matmul) returned
something; strengthen it to verify the generated kernel source actually uses the
expected MFMA intrinsic by inspecting kernel.get_kernel_source(). Instead of (or
in addition to) assert kernel.get_kernel_source() is not None, assert that the
returned source string contains the expected MFMA builtin suffix (e.g.,
"i32_32x32x32_i8") so that kernel.get_kernel_source() and
tilelang.compile(matmul) are both validated to have produced the correct MFMA
code path (reference symbols: tilelang.compile, kernel.get_kernel_source,
expected MFMA suffix).
tilelang/intrinsics/mfma_macro_generator.py (3)

292-320: Path logic for 32x32 LDMATRIX is correct; consider a small robustness note.

The A/B/transposed routing mirrors the existing 16x16 code and is consistent with the forward/reverse maps in mfma_layout.py. Two small nits that you can pick up later:

  • The ValueError at Line 302 is informative but long; consider wrapping it onto two lines to help readers.
  • Ruff flags the × / characters in the docstring/comment at Lines 296 and 300 (RUF002/RUF003). Replacing with plain x and <-> silences the warning.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tilelang/intrinsics/mfma_macro_generator.py` around lines 292 - 320, In
_get_ldmatrix_index_map_32, shorten/wrap the long ValueError message for
readability (replace the single long f-string with a shorter message like
"unsupported k_dim; only k_dim=32 implemented" or split into two concatenated
strings) and replace non-ASCII symbols in the docstring/comments (change `×` to
`x` and `↔` to `<->`) to satisfy Ruff warnings; update the docstring lines and
the ValueError text in the _get_ldmatrix_index_map_32 function accordingly.

521-526: Consolidate the 32x32 C-map import with the module-level imports.

Lines 36–37 already pull thread_id_shared_access_64x16_to_32x32_layout_A/_B from .mfma_layout. Adding _C_n_m to that block (and dropping the inline from .mfma_layout import … here) would be slightly cleaner and match the existing style. Same observation for the mfma_store_index_map_32x32 local import at Line 325.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tilelang/intrinsics/mfma_macro_generator.py` around lines 521 - 526, Move the
two inline imports into the module-level import block from .mfma_layout: add
thread_id_shared_access_64x16_to_32x32_layout_C_n_m and
mfma_store_index_map_32x32 to the existing imports (the same place that already
imports thread_id_shared_access_64x16_to_32x32_layout_A/_B), then remove the
local "from .mfma_layout import ..." statements in mfma_macro_generator.py and
use the module-level symbols directly; specifically update the M_DIM conditional
to set _store_map = thread_id_shared_access_64x16_to_32x32_layout_C_n_m when
M_DIM == 32 and replace the local mfma_store_index_map_32x32 import/usage with
the module-level mfma_store_index_map_32x32, keeping the fallback to
mfma_store_index_map unchanged.

599-608: Use N_DIM for B-side layouts to keep this resilient to future asymmetric shapes.

mn_dim = self.M_DIM (with the "M_DIM == N_DIM for all supported shapes" comment) works today but will quietly misroute if a shape with M_DIM ≠ N_DIM is ever added (e.g. a hypothetical 16×32×K variant). Since the function already branches on matrix_is_a below, it's cheap to pick the correct axis here too.

♻️ Suggested change
-        k_dim = self.k_dim * self.k_pack
-        mn_dim = self.M_DIM  # M_DIM == N_DIM for all supported shapes
+        k_dim = self.k_dim * self.k_pack
+        mn_dim = self.M_DIM if matrix_is_a else self.N_DIM
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tilelang/intrinsics/mfma_macro_generator.py` around lines 599 - 608, The code
currently sets mn_dim = self.M_DIM for both A and B paths which will misroute
B-side layouts if M_DIM != N_DIM later; change the logic to pick the correct
dimension based on matrix_is_a (e.g., use self.M_DIM when matrix_is_a is True
and self.N_DIM when False) so that when selecting transform_func_sr_a and
transform_func_sr_b (and any subsequent branches that use mn_dim) the B-side
uses N_DIM instead of M_DIM.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Outside diff comments:
In `@tilelang/intrinsics/mfma_macro_generator.py`:
- Around line 105-112: The _normalize_gfx950_f16_bf16_kpack method can silently
override an explicitly-provided mfma_shape; to fix, record whether mfma_shape
was user-supplied by setting self._mfma_shape_explicit = (mfma_shape is not
None) inside _initialize_mfma_shape, and then modify
_normalize_gfx950_f16_bf16_kpack to early-return (or skip changing k_pack/k_dim)
when self._mfma_shape_explicit is true so an explicit mfma_shape and k_pack are
not overwritten.

---

Nitpick comments:
In `@testing/python/amd/test_tilelang_gemm_mfma_preshuffle.py`:
- Around line 390-401: Parametrize test_assert_tl_matmul_extended_mfma to
include a_preshuffle (True) for the extended MFMA shapes and forward that
argument into the call to assert_tl_matmul_correctness; specifically add entries
where a_preshuffle=True alongside the existing (16,16,64) and (32,32,32) cases
and update the test signature/param list to accept a_preshuffle so the call to
assert_tl_matmul_correctness(..., a_preshuffle=...) exercises the a_preshuffle +
b_preshuffle reverse-map path through the preshuffle ldmatrix_a codepath.
- Around line 300-301: The current smoke assert only checks that
tilelang.compile(matmul) returned something; strengthen it to verify the
generated kernel source actually uses the expected MFMA intrinsic by inspecting
kernel.get_kernel_source(). Instead of (or in addition to) assert
kernel.get_kernel_source() is not None, assert that the returned source string
contains the expected MFMA builtin suffix (e.g., "i32_32x32x32_i8") so that
kernel.get_kernel_source() and tilelang.compile(matmul) are both validated to
have produced the correct MFMA code path (reference symbols: tilelang.compile,
kernel.get_kernel_source, expected MFMA suffix).

In `@tilelang/intrinsics/mfma_macro_generator.py`:
- Around line 292-320: In _get_ldmatrix_index_map_32, shorten/wrap the long
ValueError message for readability (replace the single long f-string with a
shorter message like "unsupported k_dim; only k_dim=32 implemented" or split
into two concatenated strings) and replace non-ASCII symbols in the
docstring/comments (change `×` to `x` and `↔` to `<->`) to satisfy Ruff
warnings; update the docstring lines and the ValueError text in the
_get_ldmatrix_index_map_32 function accordingly.
- Around line 521-526: Move the two inline imports into the module-level import
block from .mfma_layout: add thread_id_shared_access_64x16_to_32x32_layout_C_n_m
and mfma_store_index_map_32x32 to the existing imports (the same place that
already imports thread_id_shared_access_64x16_to_32x32_layout_A/_B), then remove
the local "from .mfma_layout import ..." statements in mfma_macro_generator.py
and use the module-level symbols directly; specifically update the M_DIM
conditional to set _store_map =
thread_id_shared_access_64x16_to_32x32_layout_C_n_m when M_DIM == 32 and replace
the local mfma_store_index_map_32x32 import/usage with the module-level
mfma_store_index_map_32x32, keeping the fallback to mfma_store_index_map
unchanged.
- Around line 599-608: The code currently sets mn_dim = self.M_DIM for both A
and B paths which will misroute B-side layouts if M_DIM != N_DIM later; change
the logic to pick the correct dimension based on matrix_is_a (e.g., use
self.M_DIM when matrix_is_a is True and self.N_DIM when False) so that when
selecting transform_func_sr_a and transform_func_sr_b (and any subsequent
branches that use mn_dim) the B-side uses N_DIM instead of M_DIM.

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: 35764cea-e8ad-4140-be74-b13cb14bc510

📥 Commits

Reviewing files that changed from the base of the PR and between 264efe2 and 6ceec32.

📒 Files selected for processing (6)
  • src/target/codegen_hip.cc
  • src/tl_templates/hip/common.h
  • testing/python/amd/test_tilelang_gemm_mfma_preshuffle.py
  • tilelang/intrinsics/mfma_layout.py
  • tilelang/intrinsics/mfma_macro_generator.py
  • tilelang/intrinsics/utils.py

@LeiWang1999 LeiWang1999 merged commit 6a29c76 into tile-ai:main Apr 26, 2026
7 of 8 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants