[TIR] Add cooperative_tensor builtins and metal.cooperative_tensor storage scope#34
[TIR] Add cooperative_tensor builtins and metal.cooperative_tensor storage scope#34oraluben wants to merge 4 commits intotile-ai:tilelang_mainfrom
Conversation
d1bc7ee to
afe168b
Compare
…orage scope Add TIR builtins for Metal cooperative_tensor operations (MetalPerformancePrimitives): - cooperative_tensor_fill: fill a cooperative_tensor with a value - cooperative_tensor_load: load from device/threadgroup memory - cooperative_tensor_store: store to device/threadgroup memory - cooperative_tensor_multiply_accumulate: matrix multiply-accumulate via matmul2d Add metal.cooperative_tensor storage scope (StorageRank::kMetalCooperativeTensor) for buffers backed by MPP cooperative_tensor registers, analogous to the existing metal.simdgroup scope but targeting the Metal 4 tensor operations API. These primitives enable code generation for MetalPerformancePrimitives matmul2d, which routes to NAX tensor cores on Apple M5 and falls back to simdgroup matrix instructions on M1-M4.
MTLLanguageVersion4_0 is only available in macOS 26+ SDK. Fall back to 3_1 (macOS 14+) or 3_0 for older SDKs to fix CI builds.
2afb490 to
29e5bc8
Compare
|
Downstream PR: tile-ai/tilelang#1869 ([Metal] Add Metal GEMM support with cooperative_tensor MMA) This PR adds the TIR builtins ( Regarding testing: the TIR builtins are exercised end-to-end through tilelang's Metal GEMM tests (18 tests covering correctness across various tile configurations). We're not sure what additional TVM-level tests would be appropriate for these builtins — if there's a specific test pattern you'd like us to add (e.g. TIR script roundtrip tests, or builtin registration checks), we're happy to do so. |
|
Upstream PR: apache#19423 (same change targeting apache/tvm main) |
The cython backend uses strict __slots__ on ObjectBase, which prevents setting _inst on PyFunctionPass without declaring it. This works with ctypes backend (which has __dict__) but fails on CI with cython.
This reverts commit d7158e9.
Summary
cooperative_tensor_fillcooperative_tensor_loadcooperative_tensor_storecooperative_tensor_multiply_accumulatemetal.cooperative_tensorstorage scope (StorageRank::kMetalCooperativeTensor)Motivation
MetalPerformancePrimitives (MPP) provides
matmul2dwithcooperative_tensoroperands that route to NAX tensor cores on Apple M5 and fall back to simdgroup matrix on M1-M4. These TIR builtins enable Metal backend codegen to emit MPP calls, analogous to the existingsimdgroup_*builtins for the older Metal simdgroup matrix API.Changes
include/tvm/tir/builtin.h— 4 new Op declarationssrc/tir/op/builtin.cc— 4 new Op registrationspython/tvm/tir/op.py— Python wrapper functionspython/tvm/script/ir_builder/tir/ir.py— Script parser exports +__all__src/runtime/thread_storage_scope.h—kMetalCooperativeTensorStorageRank + scope string parsingCompanion tilelang PR uses these builtins for Metal GEMM codegen targeting MPP
matmul2d.