Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions docs/compiler_internals/inject_fence_proxy.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ The pass is conservative: unknown extern calls are treated as async so that the
### Timeline View

```
generic initialize_descriptor → generic shared-store → async wgmma
generic initialize_wgmma_descriptor → generic shared-store → async wgmma
│ │ │
└─ generic proxy ┴─ generic proxy ┴─ async proxy
│ fence inserted here ↑
Expand Down Expand Up @@ -53,7 +53,7 @@ def kernel():
with T.Kernel(1):
desc = T.decl_buffer((1,), "uint64", scope="local.descriptor")
smem = T.decl_buffer((128,), "float16", scope="shared")
T.initialize_descriptor(desc, T.uint64(0), 2, 1, 32)
T.initialize_wgmma_descriptor(desc, T.uint64(0), 2, 1, 32)
smem[0] = T.float16(0)
T.ptx_wgmma_ss(
"float16",
Expand Down Expand Up @@ -83,7 +83,7 @@ def kernel():
with T.Kernel(1):
desc = T.decl_buffer((1,), "uint64", scope="local.descriptor")
smem = T.decl_buffer((128,), "float16", scope="shared")
T.initialize_descriptor(desc, T.uint64(0), 2, 1, 32)
T.initialize_wgmma_descriptor(desc, T.uint64(0), 2, 1, 32)
smem[0] = T.float16(0)
T.fence_proxy_async()
T.ptx_wgmma_ss(
Expand Down
6 changes: 6 additions & 0 deletions src/layout/layout.cc
Original file line number Diff line number Diff line change
Expand Up @@ -535,6 +535,12 @@ TVM_FFI_STATIC_INIT_BLOCK({
return makeGemmABLayoutHopper(stride, mat_continuous, continuity,
element_size, k_inner);
})
.def("tl.make_tcgen05mma_swizzled_layout",
[](int stride, int mat_continuous, int continuity, int element_size,
bool k_inner) {
return makeGemmABLayoutSm100(stride, mat_continuous, continuity,
element_size, k_inner);
})
.def("tl.make_full_bank_swizzled_layout",
[](int stride, int continuous, int element_size) {
return makeFullBankSwizzleLayout(stride, continuous, element_size);
Expand Down
17 changes: 16 additions & 1 deletion src/op/builtin.cc
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,11 @@ TIR_DEFINE_TL_BUILTIN(ptx_wgmma_rs)
.set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kOpaque));

TIR_DEFINE_TL_BUILTIN(ptx_tcgen05_mma_ss)
.set_num_inputs(13)
.set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kOpaque));

Comment on lines +157 to +161
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

Fix arity: ptx_tcgen05_mma_ss requires 14 inputs (not 13).

Python/TIR wrapper passes 14 args; registering 13 will error at call time. Update set_num_inputs to 14.

 TIR_DEFINE_TL_BUILTIN(ptx_tcgen05_mma_ss)
-    .set_num_inputs(13)
+    .set_num_inputs(14)
     .set_attr<TCallEffectKind>("TCallEffectKind",
                                Integer(CallEffectKind::kOpaque));

Reference: tilelang/language/tir/op.py expects 14 args for tl.ptx_tcgen05_mma_ss. Based on provided snippets.

📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
TIR_DEFINE_TL_BUILTIN(ptx_tcgen05_mma_ss)
.set_num_inputs(13)
.set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kOpaque));
TIR_DEFINE_TL_BUILTIN(ptx_tcgen05_mma_ss)
.set_num_inputs(14)
.set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kOpaque));
🤖 Prompt for AI Agents
In src/op/builtin.cc around lines 157 to 161, the registration for
TIR_DEFINE_TL_BUILTIN(ptx_tcgen05_mma_ss) incorrectly sets .set_num_inputs(13);
update this to .set_num_inputs(14) so the C++ op registration matches the
Python/TIR wrapper which passes 14 arguments, and verify no other registrations
for this op remain with the old arity.

TIR_DEFINE_TL_BUILTIN(ptx_init_tensor_memory)
.set_num_inputs(2)
.set_attr<TCallEffectKind>("TCallEffectKind",
Expand Down Expand Up @@ -218,6 +223,11 @@ TIR_DEFINE_TL_BUILTIN(warpgroup_wait)
.set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kOpaque));

TIR_DEFINE_TL_BUILTIN(warpgroup_fence_operand)
.set_num_inputs(4)
.set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kOpaque));

TIR_DEFINE_TL_BUILTIN(wait_wgmma)
.set_num_inputs(1)
.set_attr<TCallEffectKind>("TCallEffectKind",
Expand Down Expand Up @@ -265,11 +275,16 @@ TIR_DEFINE_TL_BUILTIN(tl_shuffle_elect)
.set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kPure));

TIR_DEFINE_TL_BUILTIN(initialize_descriptor)
TIR_DEFINE_TL_BUILTIN(initialize_wgmma_descriptor)
.set_num_inputs(5)
.set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kOpaque));

TIR_DEFINE_TL_BUILTIN(initialize_tcgen05_descriptor)
.set_num_inputs(7)
.set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kOpaque));

TIR_DEFINE_TL_BUILTIN(increase_descriptor_offset)
.set_num_inputs(2)
.set_attr<TCallEffectKind>("TCallEffectKind",
Expand Down
31 changes: 25 additions & 6 deletions src/op/builtin.h
Original file line number Diff line number Diff line change
Expand Up @@ -238,14 +238,19 @@ TVM_DLL const Op &ptx_wgmma_ss();
/*!
* \brief tvm intrinsics for ptx tensor core wgmma instructions.
*
* void ptx_wgmma_rs(StringImm accum_dtype, StringImm wgmma_prefix, bool
* a_is_k_major, bool b_is_k_major, StringImm a_dtype_abbrv, StringImm
* b_dtype_abbrv, StringImm accum_dtype_abbrv, Var A_descriptor, PrimExpr
* A_offset, Var B_descriptor, Var B_offset, Var C_data, Var C_offset, bool
* scale_out, bool scale_in_a, bool scale_in_b);
* void ptx_wgmma_rs(StringImm accum_dtype, StringImm wgmma_prefix,
* bool b_is_k_major, StringImm a_dtype_abbrv, StringImm b_dtype_abbrv,
* StringImm accum_dtype_abbrv, Var A_descriptor, PrimExpr A_offset, Var
* B_descriptor, Var B_offset, Var C_data, Var C_offset, bool scale_out,
* bool scale_in_a, bool scale_in_b);
*/
TVM_DLL const Op &ptx_wgmma_rs();

/*!
* \brief tvm intrinsic for tcgen05 mma shared-shared instructions.
*/
TVM_DLL const Op &ptx_tcgen05_mma_ss();

/*!
* \brief tvm intrinsics for initializing tensor memory
*
Expand Down Expand Up @@ -358,6 +363,14 @@ TVM_DLL const Op &warpgroup_commit_batch();
*/
TVM_DLL const Op &warpgroup_wait();

/*!
* \brief Fence accumulator operand registers for upcoming WGMMA operations
*
* warpgroup_fence_operand(dtype, ptr, offset, num_regs)
*
*/
TVM_DLL const Op &warpgroup_fence_operand();

/*!
* \brief Wait the previous wgmma to finish
*
Expand Down Expand Up @@ -459,7 +472,13 @@ TVM_DLL const Op &tl_shuffle_elect();
* This op is used to represent a descriptor initialization operation in
* tilelang.
*/
TVM_DLL const Op &initialize_descriptor();
TVM_DLL const Op &initialize_wgmma_descriptor();

/*!
* \brief tilelang intrinsic for initializing a descriptor buffer for
* tcgen05 mma.
*/
TVM_DLL const Op &initialize_tcgen05_descriptor();

/*!
* \brief tilelang intrinsic for setting the start address of a descriptor
Expand Down
72 changes: 4 additions & 68 deletions src/op/gemm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -12,79 +12,13 @@
#include <tvm/tir/transform.h>

#include "../target/utils.h"
#include "tcgen5_meta.h"

namespace tvm {
namespace tl {

using namespace tir;

struct TCGEN5MMAMeta {
int atom_m, atom_n, atom_k;
};

// Return {is_success, meta}
static inline std::pair<bool, TCGEN5MMAMeta>
GetTCGEN5MMAMeta(int M, int N, int K, DataType ab_dtype, DataType c_dtype) {
// TODO (lei) Currently not all shapes / dtypes are supported for TCGEN5MMA.
#define FAIL \
return { \
false, TCGEN5MMAMeta { 0, 0, 0 } \
}
#define SUCCESS(atom_m, atom_n, atom_k) \
return { \
true, TCGEN5MMAMeta { atom_m, atom_n, atom_k } \
}
std::vector<int> ws_valid_atom_ns = {256, 128, 64};
if ((ab_dtype.is_bfloat16() || ab_dtype.is_float16()) &&
(c_dtype.is_float() && c_dtype.bits() == 32)) {
if (K % 16 != 0)
FAIL;
if (M % 128 == 0) {
for (int atom_n = 256; atom_n >= 16; atom_n -= 16)
if (N % atom_n == 0)
SUCCESS(128, atom_n, 16);
FAIL;
} else if (M % 64 == 0) {
for (int atom_n : ws_valid_atom_ns)
if (N % atom_n == 0)
SUCCESS(64, atom_n, 16);
FAIL;
} else if (M % 32 == 0) {
for (int atom_n : ws_valid_atom_ns)
if (N % atom_n == 0)
SUCCESS(32, atom_n, 16);
FAIL;
} else {
FAIL;
}
} else if ((ab_dtype.is_float8_e4m3fn() || ab_dtype.is_float8_e5m2()) &&
(c_dtype.is_float() && c_dtype.bits() == 32)) {
if (K % 32 != 0)
FAIL;
if (M % 128 == 0) {
for (int atom_n = 256; atom_n >= 16; atom_n -= 16)
if (N % atom_n == 0)
SUCCESS(128, atom_n, 32);
FAIL;
} else if (M % 64 == 0) {
for (int atom_n : ws_valid_atom_ns)
if (N % atom_n == 0)
SUCCESS(64, atom_n, 32);
FAIL;
} else if (M % 32 == 0) {
for (int atom_n : ws_valid_atom_ns)
if (N % atom_n == 0)
SUCCESS(32, atom_n, 32);
FAIL;
} else {
FAIL;
}
}
FAIL;
#undef FAIL
#undef SUCCESS
}

/**
* @brief Construct a Gemm operator from serialized TL arguments and a buffer
* map.
Expand Down Expand Up @@ -199,7 +133,7 @@ GemmInst GemmNode::GetGemmInst(int block_size, Target target) const {
TargetIsSm100(target)) {
return GemmInst::kMMA;
} else {
ICHECK(0) << "Unsupported target for gemm: " << target->str();
ICHECK(0) << "Unsupported target for gemm: " << target;
}
}

Expand Down Expand Up @@ -582,6 +516,8 @@ Stmt GemmNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const {

if (A.scope() == "local.fragment") {
ICHECK(B.scope() != "local.fragment");
ICHECK(!trans_A)
<< "gemm_rs requires the A operand to be in non-transposed layout.";
op_name = "tl::gemm_rs";
} else if (B.scope() == "local.fragment") {
op_name = "tl::gemm_sr";
Expand Down
72 changes: 67 additions & 5 deletions src/op/gemm_py.cc
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@

#include "../target/utils.h"
#include "tvm/ffi/string.h"
#include "tcgen5_meta.h"

namespace tvm {
namespace tl {
Expand Down Expand Up @@ -76,6 +77,20 @@ GemmPy::GemmPy(Array<PrimExpr> args, BufferMap vmap) {
if (args.size() > 15) {
node->wg_wait = args[15].as<IntImm>().value()->value;
}
if (args.size() > 16) {
node->mbarptr = args[16];
} else {
node->mbarptr = IntImm(DataType::UInt(32), 0);
}
if (args.size() > 18) {
node->C_coords = Array<PrimExpr>({args[17], args[18]});
} else if (args.size() > 17) {
node->C_coords =
Array<PrimExpr>({args[17], IntImm(DataType::Int(32), 0)});
} else {
node->C_coords = Array<PrimExpr>(
{IntImm(DataType::Int(32), 0), IntImm(DataType::Int(32), 0)});
}
data_ = std::move(node);
}

Expand All @@ -92,16 +107,37 @@ TileOperator GemmPyNode::Clone() const {
return GemmPy(op);
}

GemmInst GemmPyNode::GetGemmInst(int block_size, Target target) const {
bool GemmPyNode::AllowTCGEN5MMA(Target target) const {
return TargetIsSm100(target) &&
((A.scope() == "shared.dyn" || A.scope() == "shared" ||
A.scope() == "shared.tmem") &&
(B.scope() == "shared.dyn" || B.scope() == "shared") &&
C.scope() == "shared.tmem") &&
GetTCGEN5MMAMeta(M, N, K, A->dtype, C->dtype).first;
}
Comment on lines +110 to +117
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

🧩 Analysis chain

AllowTCGEN5MMA gate LGTM; verify TargetIsSm100 logic

Condition looks right. However, TargetIsSm100 in src/target/utils.cc appears to use bitwise '&' instead of '&&', which can mis-gate SM100 detection.

If it shows arch >= 100 & arch <= 110, change to arch >= 100 && arch <= 110 in src/target/utils.cc. This affects AllowTCGEN5MMA().


🏁 Script executed:

#!/bin/bash
# Verify TargetIsSm100 implementation uses logical &&, not bitwise &.
rg -nC2 'bool\s+TargetIsSm100' src | sed -n '1,120p'

Length of output: 489


🏁 Script executed:

#!/bin/bash
sed -n '56,90p' src/target/utils.cc

Length of output: 1002


Replace bitwise & with logical && in TargetIsSm100
In src/target/utils.cc:58 change arch >= 100 & arch <= 110 to use && for correct SM100 detection.

🤖 Prompt for AI Agents
In src/op/gemm_py.cc around lines 110 to 117, the review points out that
TargetIsSm100 uses a bitwise & instead of logical && when checking architecture
range; update TargetIsSm100 in src/target/utils.cc (line ~58) to use `&&` (arch
>= 100 && arch <= 110) so the SM100 detection is correct, then rebuild and run
tests that exercise GemmPyNode::AllowTCGEN5MMA to ensure the change fixes
selection logic.


bool GemmPyNode::AllowWGMMA(int block_size, Target target) const {
tvm::transform::PassContext ctxt = tvm::transform::PassContext::Current();

int warp_size = TargetGetWarpSize(target);
int num_warps = block_size / warp_size;
bool allow_wgmma = TargetIsHopper(target) && (this->M >= 64) &&
(num_warps % 4 == 0) && CheckWGMMA();
if (allow_wgmma) {
return !ctxt->GetConfig(kDisableWGMMA, Optional<Bool>()).value_or(false) &&
TargetIsHopper(target) && (this->M >= 64) && (num_warps % 4 == 0) &&
CheckWGMMA();
}

GemmInst GemmPyNode::GetGemmInst(int block_size, Target target) const {
bool allow_tcgen5mma = AllowTCGEN5MMA(target);
bool allow_wgmma = AllowWGMMA(block_size, target);
if (allow_tcgen5mma) {
return GemmInst::kTCGEN5MMA;
} else if (allow_wgmma) {
return GemmInst::kWGMMA;
} else if (TargetIsCDNA(target)) {
return GemmInst::kMFMA;
} else if (TargetIsCuda(target)) {
} else if (TargetIsVolta(target) || TargetIsAmpere(target) ||
TargetIsTuring(target) || TargetIsHopper(target) ||
TargetIsSm100(target)) {
return GemmInst::kMMA;
} else {
ICHECK(0) << "Unsupported target for gemm: " << target->str();
Expand Down Expand Up @@ -289,5 +325,31 @@ TVM_FFI_STATIC_INIT_BLOCK({
});
});

TVM_FFI_STATIC_INIT_BLOCK({
namespace refl = tvm::ffi::reflection;
refl::GlobalDef().def(
"tl.get_tcgen5_mma_meta",
[](int M, int N, int K, DataType ab_dtype, DataType c_dtype) {
auto [success, meta] = GetTCGEN5MMAMeta(M, N, K, ab_dtype, c_dtype);
Array<Integer> result;
if (success) {
result.push_back(Integer(meta.atom_m));
result.push_back(Integer(meta.atom_n));
result.push_back(Integer(meta.atom_k));
}
return result;
});
refl::GlobalDef().def(
"tl.get_tcgen5_instr_desc",
[](int atom_m, int atom_n, int atom_k, DataType ab_dtype,
DataType c_dtype, bool a_is_k_major, bool b_is_k_major,
int scale_in_a, int scale_in_b) {
uint32_t desc = GetTCGEN5InstrDesc(atom_m, atom_n, atom_k, ab_dtype,
c_dtype, a_is_k_major, b_is_k_major,
scale_in_a, scale_in_b);
return Integer(static_cast<int64_t>(desc));
});
});

} // namespace tl
} // namespace tvm
12 changes: 11 additions & 1 deletion src/op/gemm_py.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ using namespace tir;
class GemmPyNode : public TileOperatorNode {
public:
bool CheckWGMMA() const;
bool AllowTCGEN5MMA(Target target) const;
bool AllowWGMMA(int block_size, Target target) const;
tir::Buffer A, B, C;
// pointer to the A, B, C
PrimExpr Aptr, Bptr, Cptr;
Expand All @@ -27,6 +29,8 @@ class GemmPyNode : public TileOperatorNode {
int stride_A, stride_B;
int offset_A, offset_B;
PrimExpr clear_accum = const_false();
PrimExpr mbarptr;
Array<PrimExpr> C_coords;
// k_pack please ref to bitblas/tl/mfma_macro_generator.py::k_pack
// only will be enabled under cdna mfma instructions
int kPack = 1;
Expand Down Expand Up @@ -55,6 +59,8 @@ class GemmPyNode : public TileOperatorNode {
.def_ro("offset_A", &GemmPyNode::offset_A)
.def_ro("offset_B", &GemmPyNode::offset_B)
.def_ro("clear_accum", &GemmPyNode::clear_accum)
.def_ro("mbarptr", &GemmPyNode::mbarptr)
.def_ro("C_coords", &GemmPyNode::C_coords)
.def_ro("kPack", &GemmPyNode::kPack)
.def_ro("wg_wait", &GemmPyNode::wg_wait)
.def_ro("policy", &GemmPyNode::policy);
Expand All @@ -71,6 +77,8 @@ class GemmPyNode : public TileOperatorNode {
equal(offset_A, other->offset_B) &&
equal(offset_B, other->offset_B) &&
equal(clear_accum, other->clear_accum) &&
Comment on lines 77 to 79
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

Correct SEqualReduce comparisons for offsets.

Currently compares offset_A to other->offset_B twice. This breaks structural equality and can cause caching bugs.

-           equal(offset_A, other->offset_B) &&
-           equal(offset_B, other->offset_B) &&
+           equal(offset_A, other->offset_A) &&
+           equal(offset_B, other->offset_B) &&
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
equal(offset_A, other->offset_B) &&
equal(offset_B, other->offset_B) &&
equal(clear_accum, other->clear_accum) &&
equal(offset_A, other->offset_A) &&
equal(offset_B, other->offset_B) &&
equal(clear_accum, other->clear_accum) &&
🤖 Prompt for AI Agents
In src/op/gemm_py.h around lines 77 to 79, the SEqualReduce comparison
mistakenly compares offset_A to other->offset_B; change the comparisons so
offset_A is compared to other->offset_A and offset_B to other->offset_B (i.e.,
replace the first equal(offset_A, other->offset_B) with equal(offset_A,
other->offset_A) and keep/verify the second is equal(offset_B, other->offset_B))
to restore correct structural equality and avoid caching bugs.

equal(mbarptr, other->mbarptr) &&
equal(C_coords, other->C_coords) &&
equal(kPack, other->kPack) && equal(wg_wait, other->wg_wait) &&
equal(policy, other->policy);
}
Expand All @@ -92,6 +100,8 @@ class GemmPyNode : public TileOperatorNode {
hash_reduce(offset_A);
hash_reduce(offset_B);
hash_reduce(clear_accum);
hash_reduce(mbarptr);
hash_reduce(C_coords);
hash_reduce(kPack);
hash_reduce(wg_wait);
hash_reduce(policy);
Expand Down Expand Up @@ -122,4 +132,4 @@ class GemmPy : public TileOperator {
} // namespace tl
} // namespace tvm

#endif // TVM_TL_OP_GEMM_PY_H_
#endif // TVM_TL_OP_GEMM_PY_H_
Loading