[Test] Add 1D TMA regression test for issue #1842#2005
Conversation
Add a regression test covering 1D single-dimension tensor TMA copy (global -> shared -> global) with warp specialization disabled. The underlying bug was fixed in #1840, but the test suite only covered 2D descriptor-based TMA paths. This test ensures the 1D bulk copy path (cp.async.bulk) also works correctly with proper mbarrier allocation. Closes #1842 Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
|
👋 Hi! Thank you for contributing to the TileLang project. Please remember to run We appreciate you taking this step! Our team will review your contribution, and we look forward to your awesome work! 🚀 |
📝 WalkthroughWalkthroughA new regression test ( Changes
Estimated code review effort🎯 2 (Simple) | ⏱️ ~10 minutes Possibly related PRs
Suggested reviewers
Poem
🚥 Pre-merge checks | ✅ 4 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (4 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches📝 Generate docstrings
🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Actionable comments posted: 2
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@testing/python/issue/test_tilelang_issue_tma_no_ws.py`:
- Around line 81-99: The test only runs one configuration; extend it to iterate
the intended regression matrix by parameterizing length, dtype
(float32/float16/bfloat16), and warp-specialized (WS) enabled/disabled; modify
the test around the tma_copy_1d definition and the pass_configs block to loop
over a list of sizes, dtypes, and WS boolean values, for each build a matching
pass_configs (toggling tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED) and
call _compile_tvm_ffi(tma_copy_1d, pass_configs, out_idx=[1]) for each
combination, and assert/verify outputs per combination so the test covers all
sizes, fp32/fp16/bf16, and WS on/off as intended.
- Around line 100-104: The test currently only checks that "mbarrier_mem" is
declared but not that tl::tma_load actually uses it; update the assertion after
src = kernel.get_kernel_source() to confirm the tl::tma_load invocation includes
the mbarrier reference (e.g., ensure the generated source contains a
tl::tma_load call with "mbarrier" as the barrier argument such as
"tl::tma_load(...mbarrier" or "mbarrier[") so we catch the case where the load
was passed 0 instead of the mbarrier buffer.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: 1b23d1ac-65fd-42ff-a402-575278361712
📒 Files selected for processing (1)
testing/python/issue/test_tilelang_issue_tma_no_ws.py
| length = 7168 | ||
|
|
||
| @T.prim_func | ||
| def tma_copy_1d( | ||
| a: T.Tensor((length,), T.float32), | ||
| out: T.Tensor((length,), T.float32), | ||
| ): | ||
| with T.Kernel(1, threads=256): | ||
| a_shared = T.alloc_shared((length,), T.float32) | ||
| T.copy(a, a_shared) | ||
| T.copy(a_shared, out) | ||
|
|
||
| pass_configs = { | ||
| tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: False, | ||
| tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: False, | ||
| tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, | ||
| } | ||
| kernel = _compile_tvm_ffi(tma_copy_1d, pass_configs, out_idx=[1]) | ||
|
|
There was a problem hiding this comment.
Coverage is narrower than the stated regression matrix.
This test currently exercises only one shape/dtype/WS setting. To lock issue #1842 down, please cover the intended matrix (sizes, fp32/fp16/bf16, WS on/off).
Proposed refactor (matrix coverage)
- length = 7168
-
- `@T.prim_func`
- def tma_copy_1d(
- a: T.Tensor((length,), T.float32),
- out: T.Tensor((length,), T.float32),
- ):
- with T.Kernel(1, threads=256):
- a_shared = T.alloc_shared((length,), T.float32)
- T.copy(a, a_shared)
- T.copy(a_shared, out)
-
- pass_configs = {
- tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: False,
- tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: False,
- tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True,
- }
- kernel = _compile_tvm_ffi(tma_copy_1d, pass_configs, out_idx=[1])
+ lengths = [100, 256, 1024, 7168, 32768]
+ dtypes = [
+ (T.float32, torch.float32),
+ (T.float16, torch.float16),
+ (T.bfloat16, torch.bfloat16),
+ ]
+ for length in lengths:
+ for tl_dtype, torch_dtype in dtypes:
+ for disable_ws in [False, True]:
+ `@T.prim_func`
+ def tma_copy_1d(
+ a: T.Tensor((length,), tl_dtype),
+ out: T.Tensor((length,), tl_dtype),
+ ):
+ with T.Kernel(1, threads=256):
+ a_shared = T.alloc_shared((length,), tl_dtype)
+ T.copy(a, a_shared)
+ T.copy(a_shared, out)
+
+ pass_configs = {
+ tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: False,
+ tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: False,
+ tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: disable_ws,
+ }
+ kernel = _compile_tvm_ffi(tma_copy_1d, pass_configs, out_idx=[1])
- src = kernel.get_kernel_source()
- assert "tl::tma_load" in src
- assert "mbarrier_mem" in src
- assert "tl::tma_store" in src
+ src = kernel.get_kernel_source()
+ assert "tl::tma_load" in src
+ assert "mbarrier_mem" in src
+ assert "tl::tma_store" in src
- t = torch.randn((length,), device="cuda", dtype=torch.float32)
- out = kernel(t)
- torch.testing.assert_close(out, t)
+ t = torch.randn((length,), device="cuda", dtype=torch_dtype)
+ out = kernel(t)
+ torch.testing.assert_close(out, t)
torch.cuda.synchronize()Also applies to: 105-107
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@testing/python/issue/test_tilelang_issue_tma_no_ws.py` around lines 81 - 99,
The test only runs one configuration; extend it to iterate the intended
regression matrix by parameterizing length, dtype (float32/float16/bfloat16),
and warp-specialized (WS) enabled/disabled; modify the test around the
tma_copy_1d definition and the pass_configs block to loop over a list of sizes,
dtypes, and WS boolean values, for each build a matching pass_configs (toggling
tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED) and call
_compile_tvm_ffi(tma_copy_1d, pass_configs, out_idx=[1]) for each combination,
and assert/verify outputs per combination so the test covers all sizes,
fp32/fp16/bf16, and WS on/off as intended.
| src = kernel.get_kernel_source() | ||
| assert "tl::tma_load" in src | ||
| assert "mbarrier_mem" in src | ||
| assert "tl::tma_store" in src | ||
|
|
There was a problem hiding this comment.
Assert mbarrier is passed into tl::tma_load, not just declared.
Line 102 only proves barrier storage exists; it doesn’t catch the original failure mode (tl::tma_load(..., 0, ...)). Please assert the load call uses mbarrier[...] directly.
Proposed fix
src = kernel.get_kernel_source()
assert "tl::tma_load" in src
assert "mbarrier_mem" in src
assert "tl::tma_store" in src
+ flat_src = " ".join(src.split())
+ assert re.search(r"tl::tma_load\([^,]+,\s*mbarrier\[[0-9]+\]", flat_src)
+ assert not re.search(r"tl::tma_load\([^,]+,\s*0\b", flat_src)Based on learnings: validating transformation behavior via generated kernel source pattern checks is an expected and appropriate testing strategy.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@testing/python/issue/test_tilelang_issue_tma_no_ws.py` around lines 100 -
104, The test currently only checks that "mbarrier_mem" is declared but not that
tl::tma_load actually uses it; update the assertion after src =
kernel.get_kernel_source() to confirm the tl::tma_load invocation includes the
mbarrier reference (e.g., ensure the generated source contains a tl::tma_load
call with "mbarrier" as the barrier argument such as "tl::tma_load(...mbarrier"
or "mbarrier[") so we catch the case where the load was passed 0 instead of the
mbarrier buffer.
Summary
Context
Issue #1842 reported that
T.copyon a 1D tensor withTL_DISABLE_WARP_SPECIALIZED: Truefailed with:The root cause was that without WS, barrier allocation passes were skipped, so the mbarrier argument was emitted as literal
0instead of a propermbarrier[0]reference. PR #1840 fixed this by running TMA barrier passes regardless of WS setting.The existing tests only covered 2D descriptor-based TMA paths. This PR adds coverage for the 1D bulk copy path (
cp.async.bulk).Closes #1842
Test plan
test_tma_lower_1d_no_warp_specializedpasses on SM100 (GB200)test_tilelang_issue_tma_no_ws.pystill pass (1 pre-existing skip on SM100 for sparse gemm)🤖 Generated with Claude Code
Summary by CodeRabbit