Skip to content

[AMD][Gfx950] Add the support of 160K LDS and copy.async #2058

Merged
LeiWang1999 merged 5 commits into
tile-ai:mainfrom
zhangnju:gfx950_support
Apr 23, 2026
Merged

[AMD][Gfx950] Add the support of 160K LDS and copy.async #2058
LeiWang1999 merged 5 commits into
tile-ai:mainfrom
zhangnju:gfx950_support

Conversation

@zhangnju
Copy link
Copy Markdown
Collaborator

@zhangnju zhangnju commented Apr 17, 2026

HI TileLang Team

This PR adds hardware-specific optimizations for the AMD gfx950 (CDNA4 /MI350) GPU architecture, targeting two key improvements:

  1. 160KB LDS capacity support — gfx950 doubles the LDS (Local Data Share / shared memory) capacity compared to gfx942 (64KB → 160KB). The arch carver now correctly reports this limit and includes a safety-net override in case an older driver under-reports it.
  2. 128-bit direct-to-LDS async copy (buffer_load_dwordx4 lds) — On gfx950, the new cp_async_gs path uses buffer_load_dwordx4 with the lds modifier for 128-bit (16-byte) loads. This instruction bypasses VGPRs entirely, reducing register pressure and enabling overlap with MFMA computation for higher memory bandwidth.

Changes:
src/tl_templates/hip/copy.h : 1) Added async_buffer_load_dwordx4_v() — a gfx950-specific inline assembly function that issues buffer_load_dwordx4 ... lds for direct global-to-LDS async transfer. 2) Updated cp_async_gs<16> and cp_async_gs_conditional<16> with #if defined(gfx950) guards to dispatch to the new 128-bit path on gfx950, falling back to the existing uint4 pointer-copy path on all other targets.

tilelang/carver/arch/cdna.py: 1) Added _GFX950_LDS_SIZE = 160 * 1024 constant for gfx950's expanded LDS capacity. 2) Updated CDNA.init() to detect gfx950 from the target's mcpu attribute and override smem_cap when the driver-reported value is below 160KB.

Impact

  • No behavior change on non-gfx950 hardware (all new code is gated by #if defined(gfx950) or runtime mcpu checks).
  • On gfx950, tilelang kernels can now utilize the full 160KB shared memory and benefit from higher-throughput async copies, which is especially impactful for memory-bound operations like GEMM and Flash Attention.

Thanks

Summary by CodeRabbit

  • New Features

    • Faster async memory transfer on gfx950 GPUs via a 128-bit direct-to-shared-memory path for certain large copies.
  • Bug Fixes

    • Improved shared-memory capacity detection on gfx950 so the runtime reports and can use the larger 160KB LDS when appropriate.
  • Tests

    • Added gfx950-specific tests validating code generation, shared-memory capacity reporting, and numerical correctness for pipelined and non‑pipelined GEMM.

@github-actions
Copy link
Copy Markdown

👋 Hi! Thank you for contributing to the TileLang project.

Please remember to run pre-commit run --all-files in the root directory of the project to ensure your changes are properly linted and formatted. This will help ensure your contribution passes the format check.

We appreciate you taking this step! Our team will review your contribution, and we look forward to your awesome work! 🚀

@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai Bot commented Apr 17, 2026

📝 Walkthrough

Walkthrough

Added a gfx950-only 128-bit async buffer→LDS load helper and switched cp_async_gs<16> / cp_async_gs_conditional<16> to use it on gfx950; non-gfx950 paths retain prior 4×32-bit global→LDS copies. CDNA target setup conditionally raises shared-memory-per-block to 163840 bytes for gfx950 and added gfx950-specific tests.

Changes

Cohort / File(s) Summary
Async LDS load helper & cp_async_gs updates
src/tl_templates/hip/copy.h
Added CK_TILE_DEVICE void async_buffer_load_dwordx4_v(void *smem, int32x4_t rsrc, index_t voffset); (compiled only for __gfx950__). Updated cp_async_gs<16> and cp_async_gs_conditional<16> to use 128-bit direct-to-LDS async load on gfx950; non-gfx950 retains previous uint4 (4×32-bit) load/conditional-zero behavior.
CDNA target SMEM override
tilelang/carver/arch/cdna.py
Introduced _GFX950_LDS_SIZE = 163840 and set self.smem_cap to this value when target mcpu contains "gfx950" and reported shared-memory-per-block is smaller; otherwise use reported value.
Tests: gfx950 copy.async
testing/python/amd/test_tilelang_gfx950_copy_async.py
Added tests that assert codegen emits cp_async_gs<16> for coalesced_width=8, validate CDNA reports 160KB smem cap on gfx950, and check numerical correctness for pipelined and non-pipelined GEMM variants.

Sequence Diagram(s)

sequenceDiagram
    participant Kernel
    participant AsyncASM as Async ASM
    participant GlobalMem as Global Memory
    participant LDS
    Kernel->>AsyncASM: call async_buffer_load_dwordx4_v(smem, rsrc, voffset)
    AsyncASM->>GlobalMem: buffer_load_dwordx4 ... offen (m0 set, ptr from smem)
    GlobalMem-->>AsyncASM: 128-bit data
    AsyncASM->>LDS: write 128-bit data directly to LDS (smem pointer)
    Kernel->>LDS: subsequent reads use loaded data
Loading

Estimated code review effort

🎯 3 (Moderate) | ⏱️ ~25 minutes

Possibly related PRs

Suggested reviewers

  • LeiWang1999
  • Gongen-Ali

Poem

🐰
I hopped through lanes of dword and thread,
Packed four at once, straight into shared,
For gfx950 I set the pace,
Async whispers, LDS embrace,
A tiny hop — a faster race. 🎉

🚥 Pre-merge checks | ✅ 4 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 33.33% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (4 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The title clearly and specifically describes the main changes: adding support for gfx950's 160K LDS capacity and copy.async optimizations.
Linked Issues check ✅ Passed Check skipped because no linked issues were found for this pull request.
Out of Scope Changes check ✅ Passed Check skipped because no linked issues were found for this pull request.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🧹 Nitpick comments (1)
src/tl_templates/hip/copy.h (1)

79-89: Consider matching async_buffer_load_dword_v's pre_nop template for consistency.

async_buffer_load_dword_v is defined with template <bool pre_nop = false> (Line 64). The new async_buffer_load_dwordx4_v omits this. If not needed, fine; otherwise adding the same template shell keeps both helpers symmetric for future s_nop insertion needs.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@src/tl_templates/hip/copy.h` around lines 79 - 89, The helper
async_buffer_load_dwordx4_v is missing the template<bool pre_nop = false> used
by async_buffer_load_dword_v; update async_buffer_load_dwordx4_v to the same
template signature and, where async_buffer_load_dword_v conditionally emits an
s_nop when pre_nop is true, add the same conditional s_nop insertion before the
asm in async_buffer_load_dwordx4_v so both helpers are symmetric (refer to
async_buffer_load_dword_v, async_buffer_load_dwordx4_v, pre_nop and s_nop to
locate the changes).
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@tilelang/carver/arch/cdna.py`:
- Line 26: Fix the typo in the comment that reads "TVM runtime should orrectly
report 160 KB (163840 B) for gfx950; the" by changing "orrectly" to "correctly"
so the comment reads "TVM runtime should correctly report 160 KB (163840 B) for
gfx950; the". Reference the exact comment string containing "orrectly" to locate
the change in cdna.py.

---

Nitpick comments:
In `@src/tl_templates/hip/copy.h`:
- Around line 79-89: The helper async_buffer_load_dwordx4_v is missing the
template<bool pre_nop = false> used by async_buffer_load_dword_v; update
async_buffer_load_dwordx4_v to the same template signature and, where
async_buffer_load_dword_v conditionally emits an s_nop when pre_nop is true, add
the same conditional s_nop insertion before the asm in
async_buffer_load_dwordx4_v so both helpers are symmetric (refer to
async_buffer_load_dword_v, async_buffer_load_dwordx4_v, pre_nop and s_nop to
locate the changes).
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: 0a1088c0-9643-438b-aa90-b12f49002ea2

📥 Commits

Reviewing files that changed from the base of the PR and between 04468a3 and 703c121.

📒 Files selected for processing (2)
  • src/tl_templates/hip/copy.h
  • tilelang/carver/arch/cdna.py

Comment thread tilelang/carver/arch/cdna.py Outdated
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 3

🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@testing/python/amd/test_tilelang_gfx950_copy_async.py`:
- Around line 89-109: The test test_gfx950_cp_async_gs_16_in_codegen currently
only asserts the generic wrapper "cp_async_gs<16>" is present; update it to
assert the gfx950-specific lowering is emitted by checking
kernel.get_kernel_source() for the direct-to-LDS instruction(s) introduced by
the PR (e.g., the "buffer_load_dwordx4" pattern and usage of "lds" in the
emitted HIP) or, if the PR adds a helper symbol, assert that helper symbol name
appears; specifically, replace or add to the existing assert that searches for
"cp_async_gs<16>" with an assertion that the source contains the gfx950
instruction sequence (e.g., "buffer_load_dwordx4" and "lds") so the test
verifies the gfx950 path is used.
- Around line 141-169: The parametrization for
test_gfx950_copy_async_gemm_pipelined only uses k_pack=1 so coalesced_width is
always 4 and the 16-byte gfx950 path (cp_async_gs<16>) is never exercised;
update the param list in the test (the pytest.mark.parametrize tuple for
"trans_A, trans_B, k_pack") to include at least one case with k_pack=2 (which
yields coalesced_width=8) so the call to _matmul_kernel(...) will exercise the
128-bit/cp_async_gs<16> path and catch any data corruption there.
- Around line 23-30: Replace the manual mcpu detection in _is_gfx950() with the
existing helper target_is_gfx950() from tilelang.utils.target (import and call
target_is_gfx950()) or, if you prefer explicit targets, construct
tvm.target.Target("rocm", options="-mcpu=gfx950") when detecting/creating
targets used by test_gfx950_smem_cap_160kb; also remove the broad "except
Exception" so errors surface instead of being swallowed. Use the symbol
target_is_gfx950 (or explicit Target(...) with options) and update any calls
that currently use _is_gfx950 or Target("rocm") in this file.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: 4dbae2b9-21c7-419b-b185-9efa681037c5

📥 Commits

Reviewing files that changed from the base of the PR and between 772cc56 and 1607dff.

📒 Files selected for processing (1)
  • testing/python/amd/test_tilelang_gfx950_copy_async.py

Comment on lines +23 to +30
def _is_gfx950() -> bool:
try:
from tilelang import tvm

mcpu = str(tvm.target.Target("rocm").attrs.get("mcpu", ""))
return "gfx950" in mcpu
except Exception:
return False
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

🧩 Analysis chain

🏁 Script executed:

# First, let's read the test file to understand the full context
cat -n testing/python/amd/test_tilelang_gfx950_copy_async.py

Repository: tile-ai/tilelang

Length of output: 8695


🏁 Script executed:

# Check if there are other device detection helpers in the codebase
fd -t f -i "device\|detect\|target" testing/python | head -20
rg -i "is_gfx\|device_detection\|get.*mcpu" testing/python -A 2 -B 2

Repository: tile-ai/tilelang

Length of output: 42


🏁 Script executed:

# Look at the CDNA implementation mentioned in the comment
cat -n tilelang/carver/arch/cdna.py | head -60

Repository: tile-ai/tilelang

Length of output: 2179


🏁 Script executed:

# Search for how Target is used elsewhere in tests
rg "Target\(.*rocm" testing/ -A 2 -B 2 | head -40

Repository: tile-ai/tilelang

Length of output: 888


🏁 Script executed:

# Search for Target usage with explicit -mcpu in tests
rg "Target\(" testing/ -A 2 | grep -i "mcpu\|rocm" -B 2 -A 2

Repository: tile-ai/tilelang

Length of output: 621


🏁 Script executed:

# Look for TVM documentation or examples of device detection
find . -name "*.py" -type f -exec grep -l "target.*mcpu\|mcpu.*target" {} \; | head -10

Repository: tile-ai/tilelang

Length of output: 173


🏁 Script executed:

# Check if there are any device query mechanisms that explicitly report gfx950
rg "gfx950\|MI350" . -t py --iglob "*.py" | grep -v test | head -20

Repository: tile-ai/tilelang

Length of output: 42


🏁 Script executed:

# Look at how the requires_rocm decorator works
rg "requires_rocm" . -B 5 -A 10 | head -60

Repository: tile-ai/tilelang

Length of output: 4109


🏁 Script executed:

# Let's check the target utilities file
cat -n tilelang/utils/target.py

Repository: tile-ai/tilelang

Length of output: 10181


🏁 Script executed:

# Search for how TVM Target auto-detects attributes
rg "rocm\|mcpu" tilelang/utils/ -B 2 -A 2

Repository: tile-ai/tilelang

Length of output: 42


🏁 Script executed:

# Check if there are any explicit examples with -mcpu in the codebase
rg "\-mcpu=\|mcpu.*=" . -t py | head -20

Repository: tile-ai/tilelang

Length of output: 42


🏁 Script executed:

# Check if target_is_gfx950 is used elsewhere in tests
rg "target_is_gfx950" testing/ -B 2 -A 2

Repository: tile-ai/tilelang

Length of output: 42


🏁 Script executed:

# Check what the FFI API returns
rg "TargetIsGfx950" . -B 2 -A 2

Repository: tile-ai/tilelang

Length of output: 1169


🏁 Script executed:

# Look at the imports in the test file more carefully
head -20 testing/python/amd/test_tilelang_gfx950_copy_async.py

Repository: tile-ai/tilelang

Length of output: 796


🏁 Script executed:

# Check if there are any other examples of hardware detection in AMD tests
ls -la testing/python/amd/

Repository: tile-ai/tilelang

Length of output: 515


🏁 Script executed:

# Search for explicit target string construction in other AMD tests
rg "Target\(" testing/python/amd/ -B 1 -A 1

Repository: tile-ai/tilelang

Length of output: 551


🏁 Script executed:

# Check the C++ implementation to understand TargetIsGfx950 behavior
cat -n src/target/utils.cc | grep -A 10 "TargetIsGfx950"

Repository: tile-ai/tilelang

Length of output: 1153


🏁 Script executed:

# Verify if target_is_gfx950 is exported from tilelang
rg "target_is_gfx950|from.*target import" tilelang/__init__.py -B 2 -A 2

Repository: tile-ai/tilelang

Length of output: 42


🏁 Script executed:

# Check if the test file ever uses the built-in target_is_gfx950 function
rg "target_is_gfx950\|from tilelang.utils.target" testing/python/amd/test_tilelang_gfx950_copy_async.py

Repository: tile-ai/tilelang

Length of output: 42


🏁 Script executed:

# Check what happens when Target is created without explicit options - look at TVM docs or comments
rg "Target.*rocm\|Target.*hip" tilelang/ -B 1 -A 3 | head -40

Repository: tile-ai/tilelang

Length of output: 42


🏁 Script executed:

# Search for how other tests use gfx950 detection or target_is_gfx950
rg "gfx950\|target_is_gfx950\|TargetIsGfx950" testing/ -B 2 -A 2

Repository: tile-ai/tilelang

Length of output: 42


🏁 Script executed:

# Check if there are any explicit mcpu examples in other files
rg "\-mcpu\|mcpu=" . --type py | head -15

Repository: tile-ai/tilelang

Length of output: 42


🏁 Script executed:

# Check what attributes Target("rocm") actually gets
python3 << 'EOF'
try:
    # Simple test to understand TVM behavior
    from tvm.target import Target
    t = Target("rocm")
    print("Target attrs keys:", list(t.attrs.keys()))
    print("mcpu value:", t.attrs.get("mcpu", "NOT_PRESENT"))
    print("Full attrs:", dict(t.attrs))
except Exception as e:
    print(f"Error: {e}")
    print("TVM may not be available in sandbox")
EOF

Repository: tile-ai/tilelang

Length of output: 124


🏁 Script executed:

# Look at the tilelang/__init__.py to see what's exported
grep -n "target_is_gfx950\|from.*utils.target" tilelang/__init__.py

Repository: tile-ai/tilelang

Length of output: 42


Use target_is_gfx950() helper instead of manual mcpu detection.

Lines 27 and 123 construct Target("rocm") without explicit -mcpu=gfx950. When mcpu is not present in target attributes, _is_gfx950() returns False and test_gfx950_smem_cap_160kb takes the non-gfx950 branch, bypassing validation of the 160 KB override logic. Use the existing target_is_gfx950() helper from tilelang.utils.target instead, or explicitly construct targets with -mcpu=gfx950. Also remove the broad except Exception to avoid silently hiding target-detection failures.

🧰 Tools
🪛 Ruff (0.15.10)

[warning] 29-29: Do not catch blind exception: Exception

(BLE001)

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@testing/python/amd/test_tilelang_gfx950_copy_async.py` around lines 23 - 30,
Replace the manual mcpu detection in _is_gfx950() with the existing helper
target_is_gfx950() from tilelang.utils.target (import and call
target_is_gfx950()) or, if you prefer explicit targets, construct
tvm.target.Target("rocm", options="-mcpu=gfx950") when detecting/creating
targets used by test_gfx950_smem_cap_160kb; also remove the broad "except
Exception" so errors surface instead of being swallowed. Use the symbol
target_is_gfx950 (or explicit Target(...) with options) and update any calls
that currently use _is_gfx950 or Target("rocm") in this file.

Comment on lines +89 to +109
@tilelang.testing.requires_rocm
def test_gfx950_cp_async_gs_16_in_codegen():
"""coalesced_width=8 (16 bytes) must emit cp_async_gs<16> in generated HIP source."""
prog = _matmul_kernel(
256,
256,
256,
128,
128,
32,
False,
True,
T.float16,
T.float32,
T.float32,
num_stages=2,
coalesced_width=8, # 8 fp16 = 16 bytes → cp_async_gs<16>
)
kernel = tl.compile(prog, out_idx=[2])
src = kernel.get_kernel_source()
assert "cp_async_gs<16>" in src, "Expected cp_async_gs<16> in generated HIP source for 128-bit async copy path"
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Assert the gfx950-specific async load, not just the wrapper call.

cp_async_gs<16> only proves the generic 16-byte wrapper was emitted. It does not prove the gfx950 path lowers to buffer_load_dwordx4 ... lds, which is the behavior this PR adds. On gfx950, add an assertion against the generated/emitted source that contains the direct-to-LDS instruction, or add a lower-level codegen check for the helper implementation.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@testing/python/amd/test_tilelang_gfx950_copy_async.py` around lines 89 - 109,
The test test_gfx950_cp_async_gs_16_in_codegen currently only asserts the
generic wrapper "cp_async_gs<16>" is present; update it to assert the
gfx950-specific lowering is emitted by checking kernel.get_kernel_source() for
the direct-to-LDS instruction(s) introduced by the PR (e.g., the
"buffer_load_dwordx4" pattern and usage of "lds" in the emitted HIP) or, if the
PR adds a helper symbol, assert that helper symbol name appears; specifically,
replace or add to the existing assert that searches for "cp_async_gs<16>" with
an assertion that the source contains the gfx950 instruction sequence (e.g.,
"buffer_load_dwordx4" and "lds") so the test verifies the gfx950 path is used.

Comment on lines +141 to +169
@pytest.mark.parametrize(
"trans_A, trans_B, k_pack",
[
(False, False, 1),
(False, True, 1),
(True, True, 1),
(True, False, 1),
],
)
@tilelang.testing.requires_rocm
def test_gfx950_copy_async_gemm_pipelined(trans_A, trans_B, k_pack):
"""Pipelined GEMM (num_stages=2) with gfx950 copy.async must be numerically correct."""
prog = _matmul_kernel(
512,
512,
512,
128,
128,
32,
trans_A,
trans_B,
T.float16,
T.float32,
T.float32,
num_stages=2,
threads=128,
k_pack=k_pack,
coalesced_width=4 * k_pack,
)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Exercise the 16-byte path in the correctness test.

All k_pack values are 1, so Line 168 always sets coalesced_width=4 and never reaches the cp_async_gs<16> path. Add at least one k_pack=2 / coalesced_width=8 correctness case so data corruption in the new 128-bit gfx950 copy path is caught.

Proposed test coverage adjustment
     [
         (False, False, 1),
         (False, True, 1),
         (True, True, 1),
         (True, False, 1),
+        # Exercise coalesced_width=8 -> cp_async_gs<16>.
+        (False, False, 2),
     ],
 )
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@testing/python/amd/test_tilelang_gfx950_copy_async.py` around lines 141 -
169, The parametrization for test_gfx950_copy_async_gemm_pipelined only uses
k_pack=1 so coalesced_width is always 4 and the 16-byte gfx950 path
(cp_async_gs<16>) is never exercised; update the param list in the test (the
pytest.mark.parametrize tuple for "trans_A, trans_B, k_pack") to include at
least one case with k_pack=2 (which yields coalesced_width=8) so the call to
_matmul_kernel(...) will exercise the 128-bit/cp_async_gs<16> path and catch any
data corruption there.

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

♻️ Duplicate comments (3)
testing/python/amd/test_tilelang_gfx950_copy_async.py (3)

141-169: ⚠️ Potential issue | 🟠 Major

Add a correctness case for the 16-byte async-copy path.

All current cases use k_pack=1, so Line 168 keeps coalesced_width=4 and never exercises cp_async_gs<16>. Add at least one k_pack=2 case.

Proposed coverage addition
     [
         (False, False, 1),
         (False, True, 1),
         (True, True, 1),
         (True, False, 1),
+        # coalesced_width=8 -> cp_async_gs<16>.
+        (False, False, 2),
     ],
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@testing/python/amd/test_tilelang_gfx950_copy_async.py` around lines 141 -
169, The test only uses k_pack=1 so coalesced_width=4 never exercises the
16-byte async-copy path; update the parameterization for
test_gfx950_copy_async_gemm_pipelined to include at least one case with k_pack=2
(for example add a tuple (False, False, 2) or add 2 to the k_pack choices) so
that _matmul_kernel(...) is invoked with coalesced_width=8 and triggers the
cp_async_gs<16> code path.

23-30: ⚠️ Potential issue | 🟠 Major

Use a real gfx950 target detector for the smem-cap assertion.

Target("rocm") can lack mcpu, so _is_gfx950() returns False and the gfx950 branch is bypassed on the hardware this test is meant to validate. Prefer the existing target utility or pass an explicit -mcpu=gfx950 target; also avoid swallowing detection failures with broad except Exception.

Run this read-only check to confirm the helper signature and remaining manual mcpu probes:

#!/bin/bash
set -euo pipefail

# Expect: target_is_gfx950 helper definition/callable wrapper is visible.
rg -n -C4 '\btarget_is_gfx950\b|TargetIsGfx950' tilelang src testing

# Expect: this test no longer relies on Target("rocm").attrs["mcpu"] for gfx950 detection.
rg -n -C4 'Target\("rocm"\)|attrs\.get\("mcpu"' testing/python/amd/test_tilelang_gfx950_copy_async.py tilelang/carver/arch/cdna.py

Also applies to: 123-127

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@testing/python/amd/test_tilelang_gfx950_copy_async.py` around lines 23 - 30,
The _is_gfx950() helper uses Target("rocm").attrs.get("mcpu", "") and swallows
all errors with a broad except, causing false negatives; replace its logic to
call the existing target utility (e.g. target_is_gfx950) or detect via an
explicit target string/mcpu flag (e.g. parse "-mcpu=gfx950"), and remove the
broad except by either letting errors propagate or catching only specific
attribute/key errors (AttributeError/KeyError) so genuine failures are not
hidden; update references inside _is_gfx950 to avoid using
Target("rocm").attrs.get("mcpu") directly and ensure the new implementation
reliably returns True for gfx950 hardware.

107-109: ⚠️ Potential issue | 🟠 Major

Assert the gfx950 direct-to-LDS load, not only the wrapper.

cp_async_gs<16> is emitted from the copy size alone, so this can pass even if the gfx950 buffer_load_dwordx4 ... lds path is broken or removed. Add an assertion for the emitted helper/instruction pattern, e.g. buffer_load_dwordx4 plus lds, when compiling for gfx950.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@testing/python/amd/test_tilelang_gfx950_copy_async.py` around lines 107 -
109, The test currently only asserts the wrapper "cp_async_gs<16>" was emitted
but not the actual gfx950 direct-to-LDS load; update the test around kernel =
tl.compile(...) / src = kernel.get_kernel_source() to also assert the
gfx950-specific helper/instruction pattern when compiling for gfx950 by checking
src contains the direct-load tokens (e.g. "buffer_load_dwordx4" and "lds" or the
combined "buffer_load_dwordx4 ... lds") in addition to "cp_async_gs<16>", so the
test fails if the buffer_load_dwordx4 to LDS path is missing or removed. Ensure
the check is gated to the gfx950 compilation case.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Duplicate comments:
In `@testing/python/amd/test_tilelang_gfx950_copy_async.py`:
- Around line 141-169: The test only uses k_pack=1 so coalesced_width=4 never
exercises the 16-byte async-copy path; update the parameterization for
test_gfx950_copy_async_gemm_pipelined to include at least one case with k_pack=2
(for example add a tuple (False, False, 2) or add 2 to the k_pack choices) so
that _matmul_kernel(...) is invoked with coalesced_width=8 and triggers the
cp_async_gs<16> code path.
- Around line 23-30: The _is_gfx950() helper uses
Target("rocm").attrs.get("mcpu", "") and swallows all errors with a broad
except, causing false negatives; replace its logic to call the existing target
utility (e.g. target_is_gfx950) or detect via an explicit target string/mcpu
flag (e.g. parse "-mcpu=gfx950"), and remove the broad except by either letting
errors propagate or catching only specific attribute/key errors
(AttributeError/KeyError) so genuine failures are not hidden; update references
inside _is_gfx950 to avoid using Target("rocm").attrs.get("mcpu") directly and
ensure the new implementation reliably returns True for gfx950 hardware.
- Around line 107-109: The test currently only asserts the wrapper
"cp_async_gs<16>" was emitted but not the actual gfx950 direct-to-LDS load;
update the test around kernel = tl.compile(...) / src =
kernel.get_kernel_source() to also assert the gfx950-specific helper/instruction
pattern when compiling for gfx950 by checking src contains the direct-load
tokens (e.g. "buffer_load_dwordx4" and "lds" or the combined
"buffer_load_dwordx4 ... lds") in addition to "cp_async_gs<16>", so the test
fails if the buffer_load_dwordx4 to LDS path is missing or removed. Ensure the
check is gated to the gfx950 compilation case.

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: 774178c1-8e8d-4f55-9032-b5ae6745ffb4

📥 Commits

Reviewing files that changed from the base of the PR and between 1607dff and acf714e.

📒 Files selected for processing (1)
  • testing/python/amd/test_tilelang_gfx950_copy_async.py

@LeiWang1999 LeiWang1999 merged commit 6d0bffb into tile-ai:main Apr 23, 2026
5 of 6 checks passed
LeiWang1999 pushed a commit that referenced this pull request Apr 26, 2026
)

* Fix HIP codegen for sync_warp, sync_grid, and local.var initialisation

* [AMD/HIP] Fix warp_reduce VGPR bug, ShuffleNode packing, and Pipelined LDS overflow

Extends PR #2096 with three additional fixes for CDNA (MI350) targets:

Fix 1 — src/tl_templates/hip/reduce.h: warp_reduce width=32
  The old 6-step butterfly called __shfl_xor(value, 32) without a width
  argument. On CDNA (wave64) with 32 active threads, lanes 32-63 are
  inactive and hold uninitialised VGPRs, producing NaN in reduce_max /
  reduce_sum / AllReduce. Fix: remove the step-32 shuffle; pass width=32
  to all remaining 5 steps so every shuffle stays within the [0,31] group.

Fix 2 — src/target/codegen_hip.cc + src/tl_templates/hip/common.h:
         ShuffleNode bfloat16x2 / float16x2 packing
  CodeGenC emitted `uint1(a, b)` for bfloat16x2 construction, which is an
  invalid HIP constructor call. Fix: override VisitExpr_(ShuffleNode) in
  CodeGenTileLangHIP to emit `uint1{__pack_bfloat162(a, b)}` / `uint1{
  __pack_half2(a, b)}` using aggregate initialisation. Also add five
  bfloat16x2 math overloads for uint1 carrier (abs2/max2/min2/add2/mul2).

Fix 3 — src/transform/pipeline_planning.cc: skip T.Pipelined(num_stages>1)
  Double-buffering doubled LDS per loop-body buffer. On CDNA (≤128 KB LDS
  per workgroup), this caused hipModuleLaunchKernel EINVAL. Fix: when
  TargetIsRocm() && num_stages > 1, skip pipeline planning and fall back
  to a plain sequential loop with synchronous T.copy.

Also: fix __habs(hip_bfloat16) and __habs(float16_t) in common.h to use
__builtin_memcpy instead of reinterpret_cast to avoid strict-aliasing UB
(as flagged by CodeRabbit on PR #2096).

Tests: 19 new cases added to testing/python/amd/test_tilelang_hip_codegen.py
covering all three fixes. All 42 tests pass on MI350 (gfx950).

* [AMD/HIP] Merge test_tilelang_hip_bugfixes.py into test_tilelang_hip_codegen.py

Consolidate all HIP regression tests into a single file.  The merged file
covers all six fixes with 32 tests total (previously split across two files
with duplicated test cases for warp_reduce, pipelined GEMM, and ShuffleNode).

Changes versus the two individual files:
- Deduplicated test_warp_reduce_no_nan (identical in both files)
- Deduplicated test_pipelined_no_lds_overflow / test_pipelined_shared_mem_no_launch_error
- Deduplicated test_pipelined_multi_stage_fp16_gemm
- Merged bfloat16 shuffle tests: source check + runtime correctness in one function
- Kept PR #2096 source-inspection tests (alloc_var, sync_warp, sync_grid)
- Added runtime tests from bugfixes: inf init, serial loop accumulation,
  float scalar readback, two-group wave64 reduce, float16 shuffle

* fixup: correct LDS size comment — gfx950 has 160 KB, not 128 KB

gfx942 (CDNA3 / MI300X) has 64 KB LDS per workgroup.
gfx950 (CDNA4 / MI350)  has 160 KB LDS per workgroup (see PR #2058).

The old comment said '128 KB' which is wrong for both generations.
Updated pipeline_planning.cc and the test docstrings to reflect the
correct per-architecture limits.

* update for format checking
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants