Skip to content

Commit c91ace4

Browse files
NikhilAPateletaf
authored andcommitted
[Inductor][Grouped Gemm] Add Blackwell CuTeDSL Kernel (#165036)
Make sure you're on cutlass 4.2.0+ Test Plan: Tritonbench(oss): `clear; CUDA_VISIBLE_DEVICES=2 TRITON_PRINT_AUTOTUNING=1 TRITON_ALWAYS_COMPILE=1 TORCH_LOGS=+inductor TORCHINDUCTOR_FORCE_DISABLE_CACHES=1 TORCHINDUCTOR_MAX_AUTOTUNE_GEMM=1 python run.py --op grouped_gemm --only aten_grouped_mm,preprocessed_pt2_triton_grouped_mm --precision bf16 --num-inputs 1 --metrics tflops,accuracy` Unit Tests(oss): `clear; python test/inductor/test_cutedsl_grouped_mm.py` Differential Revision: D82010227 Pull Request resolved: #165036 Approved by: https://github.com/alexsamardzic, https://github.com/drisspg, https://github.com/mlazos
1 parent eeab794 commit c91ace4

File tree

10 files changed

+807
-33
lines changed

10 files changed

+807
-33
lines changed

.ci/pytorch/test.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -337,7 +337,7 @@ test_python() {
337337

338338
test_python_smoke() {
339339
# Smoke tests for H100/B200
340-
time python test/run_test.py --include test_matmul_cuda test_scaled_matmul_cuda inductor/test_fp8 inductor/test_max_autotune $PYTHON_TEST_EXTRA_OPTION --upload-artifacts-while-running
340+
time python test/run_test.py --include test_matmul_cuda test_scaled_matmul_cuda inductor/test_fp8 inductor/test_max_autotune inductor/test_cutedsl_grouped_mm $PYTHON_TEST_EXTRA_OPTION --upload-artifacts-while-running
341341
assert_git_not_dirty
342342
}
343343

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -127,6 +127,7 @@ torch/test/
127127
torch/utils/benchmark/utils/valgrind_wrapper/callgrind.h
128128
torch/utils/benchmark/utils/valgrind_wrapper/valgrind.h
129129
torch/version.py
130+
torch/_inductor/kernel/vendored_templates/*
130131
minifier_launcher.py
131132
aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_fwd_d*
132133
aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_bwd_d*

setup.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -630,6 +630,37 @@ def mirror_files_into_torchgen() -> None:
630630
raise RuntimeError("Check the file paths in `mirror_files_into_torchgen()`")
631631

632632

633+
def mirror_inductor_external_kernels() -> None:
634+
"""
635+
Copy external kernels into Inductor so they are importable.
636+
"""
637+
paths = [
638+
(
639+
CWD / "torch/_inductor/kernel/vendored_templates/cutedsl_grouped_gemm.py",
640+
CWD
641+
/ "third_party/cutlass/examples/python/CuTeDSL/blackwell/grouped_gemm.py",
642+
),
643+
]
644+
for new_path, orig_path in paths:
645+
# Create the dirs involved in new_path if they don't exist
646+
if not new_path.exists():
647+
new_path.parent.mkdir(parents=True, exist_ok=True)
648+
649+
# Copy the files from the orig location to the new location
650+
if orig_path.is_file():
651+
shutil.copyfile(orig_path, new_path)
652+
continue
653+
if orig_path.is_dir():
654+
if new_path.exists():
655+
# copytree fails if the tree exists already, so remove it.
656+
shutil.rmtree(new_path)
657+
shutil.copytree(orig_path, new_path)
658+
continue
659+
raise RuntimeError(
660+
"Check the file paths in `mirror_inductor_external_kernels()`"
661+
)
662+
663+
633664
# ATTENTION: THIS IS AI SLOP
634665
def extract_variant_from_version(version: str) -> str:
635666
"""Extract variant from version string, defaulting to 'cpu'."""
@@ -1616,6 +1647,8 @@ def main() -> None:
16161647
if RUN_BUILD_DEPS:
16171648
build_deps()
16181649

1650+
mirror_inductor_external_kernels()
1651+
16191652
(
16201653
ext_modules,
16211654
cmdclass,
@@ -1649,6 +1682,7 @@ def main() -> None:
16491682
"_inductor/codegen/aoti_runtime/*.cpp",
16501683
"_inductor/script.ld",
16511684
"_inductor/kernel/flex/templates/*.jinja",
1685+
"_inductor/kernel/templates/*.jinja",
16521686
"_export/serde/*.yaml",
16531687
"_export/serde/*.thrift",
16541688
"share/cmake/ATen/*.cmake",
Lines changed: 154 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,154 @@
1+
# Owner(s): ["module: inductor"]
2+
3+
4+
import unittest
5+
6+
import torch
7+
from torch import Tensor
8+
from torch._inductor import config
9+
from torch._inductor.codegen.cuda.cuda_env import is_datacenter_blackwell_arch
10+
from torch._inductor.test_case import run_tests, TestCase as InductorTestCase
11+
from torch._inductor.utils import ensure_cute_available
12+
from torch.testing._internal.common_utils import (
13+
instantiate_parametrized_tests,
14+
parametrize,
15+
)
16+
17+
18+
@unittest.skipIf(
19+
not (ensure_cute_available() and is_datacenter_blackwell_arch()),
20+
"CuTeDSL library or Blackwell device not available",
21+
)
22+
@instantiate_parametrized_tests
23+
class TestCuTeDSLGroupedGemm(InductorTestCase):
24+
def _get_inputs(
25+
self,
26+
group_size: int,
27+
M_hint: int,
28+
K: int,
29+
N: int,
30+
device: str,
31+
dtype: torch.dtype,
32+
alignment: int = 16,
33+
) -> tuple[Tensor, Tensor, Tensor]:
34+
# --- Random, tile-aligned M sizes ---
35+
M_sizes = (
36+
torch.randint(1, (M_hint // alignment) + 1, (group_size,), dtype=torch.int)
37+
* alignment
38+
)
39+
40+
M_total = torch.sum(M_sizes).item()
41+
42+
# --- Construct input tensors ---
43+
A = torch.randn(int(M_total), K, dtype=dtype, device=device) * 0.1
44+
B = torch.randn((group_size, K, N), dtype=dtype, device=device) * 0.01
45+
46+
# --- Build offsets (no leading zero, strictly increasing) ---
47+
offsets = torch.cumsum(M_sizes, dim=0).to(dtype=torch.int32, device=device)
48+
49+
return (A, B, offsets)
50+
51+
@parametrize("group_size", (2, 8))
52+
@parametrize("M_hint", (256, 1024))
53+
@parametrize("K", (64, 128))
54+
@parametrize("N", (128, 256))
55+
def test_grouped_gemm_basic(self, group_size: int, M_hint: int, K: int, N: int):
56+
device = "cuda"
57+
dtype = torch.bfloat16
58+
59+
A, B, offsets = self._get_inputs(group_size, M_hint, K, N, device, dtype)
60+
61+
def grouped_gemm_fn(A_packed, B_batched, offs):
62+
return torch._grouped_mm(A_packed, B_batched, offs=offs)
63+
64+
# Eager execution
65+
c_eager = grouped_gemm_fn(A, B, offsets)
66+
67+
# Test with Cute backend
68+
with config.patch(
69+
{
70+
"max_autotune": True,
71+
"max_autotune_gemm_backends": "CUTEDSL",
72+
"test_configs.autotune_choice_name_regex": "cutedsl",
73+
"autotune_fallback_to_aten": False,
74+
}
75+
):
76+
grouped_gemm_compiled = torch.compile(
77+
grouped_gemm_fn, backend="inductor", dynamic=False
78+
)
79+
c_compiled = grouped_gemm_compiled(A, B, offsets)
80+
81+
self.assertEqual(c_eager.dtype, dtype)
82+
self.assertEqual(c_compiled.dtype, dtype)
83+
torch.testing.assert_close(c_eager, c_compiled)
84+
85+
@parametrize("layout_A", ("contiguous", "offset", "padded", "view"))
86+
@parametrize("layout_B", ("contiguous", "broadcasted"))
87+
def test_grouped_gemm_assorted_layouts(
88+
self,
89+
layout_A: str,
90+
layout_B: str,
91+
):
92+
device = "cuda"
93+
dtype = torch.bfloat16
94+
95+
G, K, N = 8, 64, 128
96+
M_sizes = [128] * G
97+
sum_M = sum(M_sizes)
98+
offsets = torch.tensor(
99+
[sum(M_sizes[: i + 1]) for i in range(G)], dtype=torch.int32, device=device
100+
)
101+
102+
A_base = torch.randn(sum_M, K, device=device, dtype=dtype)
103+
A = A_base
104+
105+
if layout_A == "offset":
106+
# allocate bigger buffer than needed, use nonzero storage offset
107+
storage = torch.randn(sum_M * K + 512, device=device, dtype=dtype)
108+
offset = 128 # skip first 128 elements
109+
A = torch.as_strided(storage[offset:], (sum_M, K), (K, 1))
110+
elif layout_A == "padded":
111+
# simulate row pitch > K (row_stride = K + pad)
112+
row_pitch = K + 8
113+
storage = torch.randn(sum_M * row_pitch, device=device, dtype=dtype)
114+
A = torch.as_strided(storage, (sum_M, K), (row_pitch, 1))
115+
elif layout_A == "view":
116+
A_storage = torch.randn(sum_M * K, device=device, dtype=dtype)
117+
A = A_storage.view(sum_M, K)
118+
assert A._base is not None
119+
assert A.shape == (sum_M, K)
120+
121+
B = torch.randn((G, K, N), dtype=dtype, device=device) * 0.01
122+
123+
if layout_B == "broadcasted":
124+
# Broadcast B across groups (zero stride along G)
125+
B = B[0].expand(G, K, N)
126+
assert B.stride(0) == 0
127+
128+
def grouped_gemm_fn(A_packed, B_batched, offs):
129+
return torch._grouped_mm(A_packed, B_batched, offs=offs)
130+
131+
# --- eager ---
132+
c_eager = grouped_gemm_fn(A, B, offsets)
133+
134+
# --- compiled (CUTE backend) ---
135+
with config.patch(
136+
{
137+
"max_autotune": True,
138+
"max_autotune_gemm_backends": "CUTEDSL",
139+
"test_configs.autotune_choice_name_regex": "cutedsl",
140+
"autotune_fallback_to_aten": False,
141+
}
142+
):
143+
grouped_gemm_compiled = torch.compile(
144+
grouped_gemm_fn, backend="inductor", dynamic=False
145+
)
146+
c_compiled = grouped_gemm_compiled(A, B, offsets)
147+
148+
self.assertEqual(c_eager.dtype, dtype)
149+
self.assertEqual(c_compiled.dtype, dtype)
150+
torch.testing.assert_close(c_eager, c_compiled)
151+
152+
153+
if __name__ == "__main__":
154+
run_tests()

torch/_inductor/config.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -546,6 +546,10 @@ def prologue_fusion_enabled() -> bool:
546546
"TORCHINDUCTOR_MAX_AUTOTUNE_FLEX_SEARCH_SPACE", "DEFAULT"
547547
).upper() # type: ignore[assignment]
548548

549+
cutedsl_enable_autotuning: bool = (
550+
os.environ.get("CUTEDSL_ENABLE_AUTOTUNING", "0") == "1"
551+
)
552+
549553
# DEPRECATED. This setting is ignored.
550554
autotune_fallback_to_aten = False
551555

torch/_inductor/kernel/mm_common.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
# mypy: allow-untyped-defs
22
import logging
33
from collections.abc import Sequence
4+
from functools import partial
5+
from pathlib import Path
46
from typing import Any
57

68
import torch
@@ -12,6 +14,7 @@
1214
from .. import config
1315
from ..codegen.wrapper import PythonWrapperCodegen
1416
from ..ir import _IntLike, Layout, TensorBox
17+
from ..utils import load_template
1518

1619

1720
log = logging.getLogger(__name__)
@@ -254,3 +257,7 @@ def is_batch_stride_largest_or_zero(mat1, mat2, layout) -> bool:
254257
return False
255258

256259
return True
260+
261+
262+
_KERNEL_TEMPLATE_DIR = Path(__file__).parent / "templates"
263+
load_kernel_template = partial(load_template, template_dir=_KERNEL_TEMPLATE_DIR)

0 commit comments

Comments
 (0)