-
Couldn't load subscription status.
- Fork 280
[Language] Expose T.warpgroup_fence_operand for nvcc code motion
#986
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
add8cca
8783cd9
3a51846
b7fa37a
bac37ae
4a96032
a60701b
c8eec62
9fd79eb
73fa0af
3e90be7
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -13,6 +13,7 @@ | |
|
|
||
| #include "../target/utils.h" | ||
| #include "tvm/ffi/string.h" | ||
| #include "tcgen5_meta.h" | ||
|
|
||
| namespace tvm { | ||
| namespace tl { | ||
|
|
@@ -76,6 +77,20 @@ GemmPy::GemmPy(Array<PrimExpr> args, BufferMap vmap) { | |
| if (args.size() > 15) { | ||
| node->wg_wait = args[15].as<IntImm>().value()->value; | ||
| } | ||
| if (args.size() > 16) { | ||
| node->mbarptr = args[16]; | ||
| } else { | ||
| node->mbarptr = IntImm(DataType::UInt(32), 0); | ||
| } | ||
| if (args.size() > 18) { | ||
| node->C_coords = Array<PrimExpr>({args[17], args[18]}); | ||
| } else if (args.size() > 17) { | ||
| node->C_coords = | ||
| Array<PrimExpr>({args[17], IntImm(DataType::Int(32), 0)}); | ||
| } else { | ||
| node->C_coords = Array<PrimExpr>( | ||
| {IntImm(DataType::Int(32), 0), IntImm(DataType::Int(32), 0)}); | ||
| } | ||
| data_ = std::move(node); | ||
| } | ||
|
|
||
|
|
@@ -92,16 +107,37 @@ TileOperator GemmPyNode::Clone() const { | |
| return GemmPy(op); | ||
| } | ||
|
|
||
| GemmInst GemmPyNode::GetGemmInst(int block_size, Target target) const { | ||
| bool GemmPyNode::AllowTCGEN5MMA(Target target) const { | ||
| return TargetIsSm100(target) && | ||
| ((A.scope() == "shared.dyn" || A.scope() == "shared" || | ||
| A.scope() == "shared.tmem") && | ||
| (B.scope() == "shared.dyn" || B.scope() == "shared") && | ||
| C.scope() == "shared.tmem") && | ||
| GetTCGEN5MMAMeta(M, N, K, A->dtype, C->dtype).first; | ||
| } | ||
|
Comment on lines
+110
to
+117
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 🧩 Analysis chainAllowTCGEN5MMA gate LGTM; verify TargetIsSm100 logic Condition looks right. However, TargetIsSm100 in src/target/utils.cc appears to use bitwise '&' instead of '&&', which can mis-gate SM100 detection. If it shows 🏁 Script executed: #!/bin/bash
# Verify TargetIsSm100 implementation uses logical &&, not bitwise &.
rg -nC2 'bool\s+TargetIsSm100' src | sed -n '1,120p'Length of output: 489 🏁 Script executed: #!/bin/bash
sed -n '56,90p' src/target/utils.ccLength of output: 1002 Replace bitwise & with logical && in TargetIsSm100 🤖 Prompt for AI Agents |
||
|
|
||
| bool GemmPyNode::AllowWGMMA(int block_size, Target target) const { | ||
| tvm::transform::PassContext ctxt = tvm::transform::PassContext::Current(); | ||
|
|
||
| int warp_size = TargetGetWarpSize(target); | ||
| int num_warps = block_size / warp_size; | ||
| bool allow_wgmma = TargetIsHopper(target) && (this->M >= 64) && | ||
| (num_warps % 4 == 0) && CheckWGMMA(); | ||
| if (allow_wgmma) { | ||
| return !ctxt->GetConfig(kDisableWGMMA, Optional<Bool>()).value_or(false) && | ||
| TargetIsHopper(target) && (this->M >= 64) && (num_warps % 4 == 0) && | ||
| CheckWGMMA(); | ||
| } | ||
|
|
||
| GemmInst GemmPyNode::GetGemmInst(int block_size, Target target) const { | ||
| bool allow_tcgen5mma = AllowTCGEN5MMA(target); | ||
| bool allow_wgmma = AllowWGMMA(block_size, target); | ||
| if (allow_tcgen5mma) { | ||
| return GemmInst::kTCGEN5MMA; | ||
| } else if (allow_wgmma) { | ||
| return GemmInst::kWGMMA; | ||
| } else if (TargetIsCDNA(target)) { | ||
| return GemmInst::kMFMA; | ||
| } else if (TargetIsCuda(target)) { | ||
| } else if (TargetIsVolta(target) || TargetIsAmpere(target) || | ||
| TargetIsTuring(target) || TargetIsHopper(target) || | ||
| TargetIsSm100(target)) { | ||
| return GemmInst::kMMA; | ||
| } else { | ||
| ICHECK(0) << "Unsupported target for gemm: " << target->str(); | ||
|
|
@@ -289,5 +325,31 @@ TVM_FFI_STATIC_INIT_BLOCK({ | |
| }); | ||
| }); | ||
|
|
||
| TVM_FFI_STATIC_INIT_BLOCK({ | ||
| namespace refl = tvm::ffi::reflection; | ||
| refl::GlobalDef().def( | ||
| "tl.get_tcgen5_mma_meta", | ||
| [](int M, int N, int K, DataType ab_dtype, DataType c_dtype) { | ||
| auto [success, meta] = GetTCGEN5MMAMeta(M, N, K, ab_dtype, c_dtype); | ||
| Array<Integer> result; | ||
| if (success) { | ||
| result.push_back(Integer(meta.atom_m)); | ||
| result.push_back(Integer(meta.atom_n)); | ||
| result.push_back(Integer(meta.atom_k)); | ||
| } | ||
| return result; | ||
| }); | ||
| refl::GlobalDef().def( | ||
| "tl.get_tcgen5_instr_desc", | ||
| [](int atom_m, int atom_n, int atom_k, DataType ab_dtype, | ||
| DataType c_dtype, bool a_is_k_major, bool b_is_k_major, | ||
| int scale_in_a, int scale_in_b) { | ||
| uint32_t desc = GetTCGEN5InstrDesc(atom_m, atom_n, atom_k, ab_dtype, | ||
| c_dtype, a_is_k_major, b_is_k_major, | ||
| scale_in_a, scale_in_b); | ||
| return Integer(static_cast<int64_t>(desc)); | ||
| }); | ||
| }); | ||
|
|
||
| } // namespace tl | ||
| } // namespace tvm | ||
| Original file line number | Diff line number | Diff line change | ||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -19,6 +19,8 @@ using namespace tir; | |||||||||||||
| class GemmPyNode : public TileOperatorNode { | ||||||||||||||
| public: | ||||||||||||||
| bool CheckWGMMA() const; | ||||||||||||||
| bool AllowTCGEN5MMA(Target target) const; | ||||||||||||||
| bool AllowWGMMA(int block_size, Target target) const; | ||||||||||||||
| tir::Buffer A, B, C; | ||||||||||||||
| // pointer to the A, B, C | ||||||||||||||
| PrimExpr Aptr, Bptr, Cptr; | ||||||||||||||
|
|
@@ -27,6 +29,8 @@ class GemmPyNode : public TileOperatorNode { | |||||||||||||
| int stride_A, stride_B; | ||||||||||||||
| int offset_A, offset_B; | ||||||||||||||
| PrimExpr clear_accum = const_false(); | ||||||||||||||
| PrimExpr mbarptr; | ||||||||||||||
| Array<PrimExpr> C_coords; | ||||||||||||||
| // k_pack please ref to bitblas/tl/mfma_macro_generator.py::k_pack | ||||||||||||||
| // only will be enabled under cdna mfma instructions | ||||||||||||||
| int kPack = 1; | ||||||||||||||
|
|
@@ -55,6 +59,8 @@ class GemmPyNode : public TileOperatorNode { | |||||||||||||
| .def_ro("offset_A", &GemmPyNode::offset_A) | ||||||||||||||
| .def_ro("offset_B", &GemmPyNode::offset_B) | ||||||||||||||
| .def_ro("clear_accum", &GemmPyNode::clear_accum) | ||||||||||||||
| .def_ro("mbarptr", &GemmPyNode::mbarptr) | ||||||||||||||
| .def_ro("C_coords", &GemmPyNode::C_coords) | ||||||||||||||
| .def_ro("kPack", &GemmPyNode::kPack) | ||||||||||||||
| .def_ro("wg_wait", &GemmPyNode::wg_wait) | ||||||||||||||
| .def_ro("policy", &GemmPyNode::policy); | ||||||||||||||
|
|
@@ -71,6 +77,8 @@ class GemmPyNode : public TileOperatorNode { | |||||||||||||
| equal(offset_A, other->offset_B) && | ||||||||||||||
| equal(offset_B, other->offset_B) && | ||||||||||||||
| equal(clear_accum, other->clear_accum) && | ||||||||||||||
|
Comment on lines
77
to
79
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Correct SEqualReduce comparisons for offsets. Currently compares - equal(offset_A, other->offset_B) &&
- equal(offset_B, other->offset_B) &&
+ equal(offset_A, other->offset_A) &&
+ equal(offset_B, other->offset_B) &&📝 Committable suggestion
Suggested change
🤖 Prompt for AI Agents |
||||||||||||||
| equal(mbarptr, other->mbarptr) && | ||||||||||||||
| equal(C_coords, other->C_coords) && | ||||||||||||||
| equal(kPack, other->kPack) && equal(wg_wait, other->wg_wait) && | ||||||||||||||
| equal(policy, other->policy); | ||||||||||||||
| } | ||||||||||||||
|
|
@@ -92,6 +100,8 @@ class GemmPyNode : public TileOperatorNode { | |||||||||||||
| hash_reduce(offset_A); | ||||||||||||||
| hash_reduce(offset_B); | ||||||||||||||
| hash_reduce(clear_accum); | ||||||||||||||
| hash_reduce(mbarptr); | ||||||||||||||
| hash_reduce(C_coords); | ||||||||||||||
| hash_reduce(kPack); | ||||||||||||||
| hash_reduce(wg_wait); | ||||||||||||||
| hash_reduce(policy); | ||||||||||||||
|
|
@@ -122,4 +132,4 @@ class GemmPy : public TileOperator { | |||||||||||||
| } // namespace tl | ||||||||||||||
| } // namespace tvm | ||||||||||||||
|
|
||||||||||||||
| #endif // TVM_TL_OP_GEMM_PY_H_ | ||||||||||||||
| #endif // TVM_TL_OP_GEMM_PY_H_ | ||||||||||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fix arity:
ptx_tcgen05_mma_ssrequires 14 inputs (not 13).Python/TIR wrapper passes 14 args; registering 13 will error at call time. Update set_num_inputs to 14.
Reference: tilelang/language/tir/op.py expects 14 args for tl.ptx_tcgen05_mma_ss. Based on provided snippets.
📝 Committable suggestion
🤖 Prompt for AI Agents