From 738c2908b3d821dc623e0552570ad35fb23372e0 Mon Sep 17 00:00:00 2001 From: Will Feng Date: Wed, 17 Sep 2025 19:25:07 -0700 Subject: [PATCH] [Example] grouped_gemm kernel example and tritonbench integration stack-info: PR: https://github.com/pytorch/helion/pull/620, branch: yf225/stack/58 --- benchmarks/run.py | 15 ++ examples/grouped_gemm.py | 313 +++++++++++++++++++++++++++ helion/_compiler/type_propagation.py | 25 ++- helion/language/loops.py | 38 +++- test/test_examples.expected | 217 +++++++++++++++++++ test/test_examples.py | 72 ++++++ test/test_type_propagation.expected | 109 +++++++++- test/test_type_propagation.py | 39 ++++ 8 files changed, 806 insertions(+), 22 deletions(-) create mode 100644 examples/grouped_gemm.py diff --git a/benchmarks/run.py b/benchmarks/run.py index c0b763ab0..05589f2c2 100644 --- a/benchmarks/run.py +++ b/benchmarks/run.py @@ -180,6 +180,13 @@ class RunResult: "examples.jagged_softmax", "jagged_softmax_tritonbench", ), + "grouped_gemm": ( + "tritonbench.operators.grouped_gemm.operator", + [ + ("examples.grouped_gemm", "grouped_gemm_jagged_tritonbench"), + ("examples.grouped_gemm", "grouped_gemm_jagged_persistent_tritonbench"), + ], + ), # Multiple kernel variants: "gemm": ( "tritonbench.operators.gemm.operator", @@ -306,6 +313,14 @@ class RunResult: "helion_int4_gemm_tritonbench-speedup": "helion_speedup", "helion_int4_gemm_tritonbench-accuracy": "helion_accuracy", }, + "grouped_gemm": { + "triton-speedup": "triton_speedup", + "triton-accuracy": "triton_accuracy", + "pt2_triton_grouped_mm-speedup": "torch_compile_speedup", + "pt2_triton_grouped_mm-accuracy": "torch_compile_accuracy", + "helion_grouped_gemm_jagged_persistent_tritonbench-speedup": "helion_speedup", + "helion_grouped_gemm_jagged_persistent_tritonbench-accuracy": "helion_accuracy", + }, } diff --git a/examples/grouped_gemm.py b/examples/grouped_gemm.py new file mode 100644 index 000000000..d89623844 --- /dev/null +++ b/examples/grouped_gemm.py @@ -0,0 +1,313 @@ +""" +Grouped GEMM Example +==================== + +This example demonstrates grouped matrix multiplication (GEMM) where multiple +input matrices ``A_i`` (with potentially different numbers of rows ``M_i``) +are multiplied against a single shared weight matrix ``B``. The results are +concatenated in the original group order. + +Key ideas used in this implementation: +- Pack all groups' rows into one contiguous tensor ``A_packed`` with shape + ``[sum(M_i), K]``. This improves memory locality and simplifies indexing. +- Represent group boundaries with ``group_offsets`` (size ``G+1``), so that + rows for group ``g`` live in ``A_packed[group_offsets[g]:group_offsets[g+1]]``. +- Use data-dependent tiling over the concatenated row dimension to efficiently + support jagged group sizes (different ``M_i`` per group) without padding. + +Two kernels are provided: +1) ``grouped_gemm_jagged`` - a simple kernel that iterates groups and tiles + dynamically. +2) ``grouped_gemm_jagged_persistent`` - a persistent kernel with dynamic tile + assignment for better load balancing across streaming multiprocessors (SMs). +""" + +# %% +# Imports and Dependencies +# ------------------------ +from __future__ import annotations + +from typing import Callable + +import torch + +import helion +import helion.language as hl + + +# %% +# Grouped GEMM Kernel - Basic Implementation +# ------------------------------------------- +@helion.kernel(static_shapes=False) +def grouped_gemm_jagged( + A_packed: torch.Tensor, # [total_M, K], where total_M == sum(M_i) + B: torch.Tensor, # [K, N] shared across all groups + group_offsets: torch.Tensor, # [G+1], int32/int64, row offsets into A_packed +) -> torch.Tensor: # [total_M, N] concatenated outputs for all groups + """ + Perform grouped GEMM on jagged inputs using row offsets. + + Args: + A_packed: Row-wise concatenation of per-group inputs ``A_i``, + shape ``[sum(M_i), K]``. + B: Shared weight matrix, shape ``[K, N]``. + group_offsets: Row offsets delimiting each group within ``A_packed``, + shape ``[G+1]``. For group ``g``: rows are + ``start = group_offsets[g]`` to ``end = group_offsets[g+1]``. + + Returns: + Output tensor of shape ``[sum(M_i), N]`` equal to + ``torch.cat([A_i @ B for i in groups], dim=0)``. + """ + total_M, K = A_packed.shape + K2, N = B.shape + assert K == K2, "K dimension mismatch between A_packed and B" + + out = torch.empty( + total_M, + N, + dtype=torch.promote_types(A_packed.dtype, B.dtype), + device=A_packed.device, + ) + + G = group_offsets.size(0) - 1 + + # Process each group independently, tiling over its specific M_g dimension + for g in hl.grid(G): + start = group_offsets[g] + end = group_offsets[g + 1] + M_g = end - start + if M_g != 0: + # Create 2D tiling pattern over output dimensions (M_g x N) for current group + for tile_m, tile_n in hl.tile([M_g, N]): + acc = hl.zeros([tile_m, tile_n], dtype=torch.float32) + # K-reduction loop: multiply tiles along K dimension + for tile_k in hl.tile(K): + a_blk = A_packed[start + tile_m.index, tile_k] + b_blk = B[tile_k, tile_n] + # Perform fused multiply-add with FP32 accumulation for numerical stability + acc = torch.addmm(acc, a_blk, b_blk) + # Convert accumulator to output dtype and store result + out[start + tile_m.index, tile_n] = acc.to(out.dtype) + + return out + + +# %% +# Grouped GEMM Kernel - Persistent Implementation +# ------------------------------------------------ +@helion.kernel(static_shapes=False) +def grouped_gemm_jagged_persistent( + A_packed: torch.Tensor, # [total_M, K] + B: torch.Tensor, # [K, N] + group_offsets: torch.Tensor, # [G+1], row offsets into A_packed +) -> torch.Tensor: + """ + Persistent grouped GEMM with dynamic tile metadata computation. + + This variant computes tile assignments dynamically in the kernel, + similar to TritonBench's WS variant. + + Args: + A_packed: Packed A, concatenated by rows across groups, ``[sum(M_i), K]``. + B: Shared weight matrix, ``[K, N]``. + group_offsets: Row offsets delimiting each group within ``A_packed``. + + Returns: + Output tensor of shape ``[sum(M_i), N]``. + """ + # Set worker count to match GPU streaming multiprocessor count + device = A_packed.device + num_workers = torch.cuda.get_device_properties(device).multi_processor_count # type: ignore[arg-type] + + # Define tunable block sizes for M, N dimensions (auto-tuned at runtime) + BLOCK_M = hl.register_block_size(32, 128) + BLOCK_N = hl.register_block_size(32, 128) + total_M, K = A_packed.shape + K2, N = B.shape + assert K == K2 + + out = torch.zeros( + total_M, + N, + dtype=torch.promote_types(A_packed.dtype, B.dtype), + device=A_packed.device, + ) + + G = group_offsets.size(0) - 1 + + for worker_id in hl.grid(num_workers): + # Persistent thread pattern: each worker processes tiles across all groups + # using strided/interleaved assignment for load balancing. + # (i.e. each worker takes every num_workers-th tile. e.g., worker 0 takes tiles 0, N, 2N, ...) + for g in hl.grid(G): + group_start = group_offsets[g] + group_end = group_offsets[g + 1] + m_size = group_end - group_start + + if m_size > 0: + # Compute tile grid dimensions for current group + num_m_tiles = (m_size + BLOCK_M - 1) // BLOCK_M + # Calculate number of N tiles (shared across all groups) + num_n_tiles = (N + BLOCK_N - 1) // BLOCK_N + num_group_tiles = num_m_tiles * num_n_tiles + + # Distribute tiles among workers using strided access pattern + for local_tile in hl.grid(num_group_tiles): + tile_in_group = local_tile * num_workers + worker_id + if tile_in_group < num_group_tiles: + # Convert linear tile index to 2D (M, N) tile coordinates + m_tile_idx = tile_in_group % num_m_tiles # pyright: ignore[reportOperatorIssue] + n_tile_idx = tile_in_group // num_m_tiles + + # Compute global memory indices for current tile + base_row = group_start + m_tile_idx * BLOCK_M + base_col = n_tile_idx * BLOCK_N # pyright: ignore[reportOperatorIssue] + + # Generate row and column index ranges for tile access + row_idx = base_row + hl.arange(BLOCK_M) + col_idx = base_col + hl.arange(BLOCK_N) + + # Apply boundary masks to handle partial tiles at edges + rows_valid = row_idx < group_end + cols_valid = col_idx < N + + # Initialize FP32 accumulator for numerical precision + acc = hl.zeros([BLOCK_M, BLOCK_N], dtype=torch.float32) + + # Iterate over K dimension in blocks for matrix multiplication + for k_tile in hl.tile(K): + k_idx = k_tile.index + + # Load tiles from A_packed and B with boundary checking + a_blk = hl.load( + A_packed, + [row_idx, k_idx], + extra_mask=rows_valid[:, None], + ) + b_blk = hl.load( + B, + [k_idx, col_idx], + extra_mask=cols_valid[None, :], + ) + + # Perform tile-level matrix multiplication and accumulate + acc = torch.addmm(acc, a_blk, b_blk) + + # Write accumulated result to output with boundary masking + valid_2d = rows_valid[:, None] & cols_valid[None, :] + hl.store( + out, + [row_idx, col_idx], + acc.to(out.dtype), + extra_mask=valid_2d, + ) + + return out + + +# %% +# Data Preparation Utilities +# -------------------------- +def _pack_group_inputs( + group_A: list[torch.Tensor], group_B: list[torch.Tensor] +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Build ``A_packed``, shared ``B``, and ``group_offsets`` from grouped inputs. + + Expectations: + - All ``A_i`` share the same ``K`` and dtype/device. + - All groups share the same ``B`` (as produced by TritonBench inputs). + + Returns ``(A_packed, B_shared, group_offsets)`` where + ``group_offsets`` has length ``G+1`` with ``group_offsets[0] == 0`` and + ``group_offsets[g+1] - group_offsets[g] == M_g``. + """ + assert len(group_A) > 0 + device = group_A[0].device + dtype = group_A[0].dtype + + # Extract shared weight matrix B (same for all groups in TritonBench) + B_shared = group_B[0] + + # Compute group offsets and concatenate all A matrices row-wise + M_sizes = [int(a.size(0)) for a in group_A] + starts = [0] + for m in M_sizes: + starts.append(starts[-1] + m) + group_offsets = torch.tensor(starts, device=device, dtype=torch.int32) + A_packed = torch.cat(group_A, dim=0).to(device=device, dtype=dtype).contiguous() + return A_packed, B_shared, group_offsets + + +# %% +# TritonBench Integration Wrappers +# --------------------------------- +def grouped_gemm_jagged_tritonbench( + tb_op: object, group_A: list[torch.Tensor], group_B: list[torch.Tensor] +) -> Callable[[], torch.Tensor]: + """Adapter for basic grouped GEMM kernel to work with TritonBench benchmark suite.""" + A_packed, B_shared, group_offsets = _pack_group_inputs(group_A, group_B) + return lambda: grouped_gemm_jagged(A_packed, B_shared, group_offsets) + + +def grouped_gemm_jagged_persistent_tritonbench( + tb_op: object, group_A: list[torch.Tensor], group_B: list[torch.Tensor] +) -> Callable[[], torch.Tensor]: + """Adapter for persistent grouped GEMM kernel with dynamic work distribution for TritonBench.""" + A_packed, B_shared, group_offsets = _pack_group_inputs(group_A, group_B) + + return lambda: grouped_gemm_jagged_persistent( + A_packed, + B_shared, + group_offsets, + ) + + +# %% +# Reference Implementation for Validation +# --------------------------------------- +def _reference_grouped_gemm( + group_A: list[torch.Tensor], group_B: list[torch.Tensor] +) -> torch.Tensor: + B_shared = group_B[0] + outs = [a @ B_shared for a in group_A] + return torch.cat(outs, dim=0) + + +# %% +# Test Harness and Validation +# --------------------------- +def main() -> None: + torch.manual_seed(0) # Ensure reproducible test results + device = "cuda" + dtype = torch.bfloat16 + G = 4 # Number of groups to test + K, N = 256, 128 # Shared dimensions: K (reduction), N (output columns) + # Create test data with varying group sizes (M_i = 64, 128, 192, 256) + group_A = [ + torch.randn(64 * (i + 1), K, device=device, dtype=dtype).contiguous() + for i in range(G) + ] + # Shared weight matrix B replicated for each group (as per TritonBench convention) + group_B = [torch.randn(K, N, device=device, dtype=dtype).contiguous()] * G + + ref = _reference_grouped_gemm(group_A, group_B) + + print("Testing grouped GEMM kernels...") + + # Test basic jagged kernel correctness + out = grouped_gemm_jagged_tritonbench(None, group_A, group_B)() + torch.testing.assert_close(out.float(), ref.float(), atol=1e-2, rtol=1e-2) + print("✓ Non-persistent kernel passed") + + # Test persistent kernel with dynamic tiling + out_p = grouped_gemm_jagged_persistent_tritonbench(None, group_A, group_B)() + torch.testing.assert_close(out_p.float(), ref.float(), atol=1e-2, rtol=1e-2) + print("✓ Persistent kernel passed") + + print("\nAll tests passed!") + + +if __name__ == "__main__": + main() diff --git a/helion/_compiler/type_propagation.py b/helion/_compiler/type_propagation.py index e0e1aca5c..41b2d1fb4 100644 --- a/helion/_compiler/type_propagation.py +++ b/helion/_compiler/type_propagation.py @@ -271,6 +271,21 @@ def from_example(cls, value: object, origin: Origin) -> TypeInfo: # This allows zip to work in list comprehensions zipped_tuples = tuple(tuple(items) for items in value) return cls.from_example(zipped_tuples, origin) + if isinstance(value, torch.cuda._CudaDeviceProperties): + attrs = {} + env = CompileEnvironment.current() + + # Only `multi_processor_count` attribute is supported for now + # TODO(yf225): support other torch.cuda._CudaDeviceProperties attributes + attr_origin = AttributeOrigin(origin, "multi_processor_count") + # Create a symbolic integer that can be passed as kernel argument + sym = env.create_unbacked_symint() + HostFunction.current().expr_to_origin[sym._sympy_()] = SymbolOrigin( + origin=attr_origin + ) + attrs["multi_processor_count"] = SymIntType(attr_origin, sym) + + return ClassType(origin, attrs) raise exc.UnsupportedPythonType(type(value).__name__) @staticmethod @@ -1306,7 +1321,15 @@ def tree_map(self, fn: Callable[[TypeInfo], object]) -> dict[str | int, object]: class ClassType(DictType): def propagate_attribute(self, attr: str, origin: AttributeOrigin) -> TypeInfo: - return self.element_types[attr] + try: + return self.element_types[attr] + except KeyError: + desc = str( + getattr(origin.value, "location", origin.value.__class__.__name__) + ) + raise exc.TypeInferenceError( + f"Attribute '{attr}' is not supported on {desc}" + ) from None class StackTensorType(ClassType): diff --git a/helion/language/loops.py b/helion/language/loops.py index 75ee71268..5ba7dab9a 100644 --- a/helion/language/loops.py +++ b/helion/language/loops.py @@ -50,6 +50,7 @@ from collections.abc import Sequence from .._compiler.inductor_lowering import CodegenState + from .constexpr import ConstExpr __all__ = ["grid", "static_range", "tile"] @@ -572,10 +573,14 @@ def _codegen_loop_helper( is_device_loop=True, is_device_only=False, cache_type=True, tiles_as_sizes=True ) def grid( - begin_or_end: int | torch.Tensor, - end_or_none: int | torch.Tensor | None = None, + begin_or_end: int | torch.Tensor | ConstExpr, + end_or_none: int | torch.Tensor | ConstExpr | None = None, /, - step: int | torch.Tensor | Sequence[int | torch.Tensor] | None = None, + step: int + | torch.Tensor + | ConstExpr + | Sequence[int | torch.Tensor | ConstExpr] + | None = None, ) -> Iterator[torch.SymInt]: ... @@ -585,10 +590,14 @@ def grid( is_device_loop=True, is_device_only=False, cache_type=True, tiles_as_sizes=True ) def grid( - begin_or_end: Sequence[int | torch.Tensor], - end_or_none: Sequence[int | torch.Tensor] | None = None, + begin_or_end: Sequence[int | torch.Tensor | ConstExpr], + end_or_none: Sequence[int | torch.Tensor | ConstExpr] | None = None, /, - step: int | torch.Tensor | Sequence[int | torch.Tensor] | None = None, + step: int + | torch.Tensor + | ConstExpr + | Sequence[int | torch.Tensor | ConstExpr] + | None = None, ) -> Iterator[Sequence[torch.SymInt]]: ... @@ -597,10 +606,21 @@ def grid( is_device_loop=True, is_device_only=False, cache_type=True, tiles_as_sizes=True ) def grid( - begin_or_end: int | torch.Tensor | Sequence[int | torch.Tensor], - end_or_none: int | torch.Tensor | Sequence[int | torch.Tensor] | None = None, + begin_or_end: int + | torch.Tensor + | ConstExpr + | Sequence[int | torch.Tensor | ConstExpr], + end_or_none: int + | torch.Tensor + | ConstExpr + | Sequence[int | torch.Tensor | ConstExpr] + | None = None, /, - step: int | torch.Tensor | Sequence[int | torch.Tensor] | None = None, + step: int + | torch.Tensor + | ConstExpr + | Sequence[int | torch.Tensor | ConstExpr] + | None = None, ) -> Iterator[torch.SymInt] | Iterator[Sequence[torch.SymInt]]: # type: ignore[type-arg] """Iterate over individual indices of the given iteration space. diff --git a/test/test_examples.expected b/test/test_examples.expected index 1499ddb2b..6afb94238 100644 --- a/test/test_examples.expected +++ b/test/test_examples.expected @@ -900,6 +900,223 @@ def geglu(a: Tensor, b: Tensor, *, _launcher=_default_launcher): _launcher(_helion_geglu, (triton.cdiv(total_elements, _BLOCK_SIZE_0),), a_flat, b_flat, out_flat, a_flat.stride(0), b_flat.stride(0), out_flat.stride(0), total_elements, _BLOCK_SIZE_0, num_warps=4, num_stages=3) return out +--- assertExpectedJournal(TestExamples.test_grouped_gemm_jagged) +from __future__ import annotations + +import torch +import triton +import triton.language as tl +from helion.runtime import default_launcher as _default_launcher + +@triton.jit +def _helion_grouped_gemm_jagged(group_offsets, A_packed, B, out, A_packed_stride_0, A_packed_stride_1, B_stride_0, B_stride_1, group_offsets_stride_0, out_stride_0, out_stride_1, N, K, _BLOCK_SIZE_2: tl.constexpr, _BLOCK_SIZE_1: tl.constexpr, _BLOCK_SIZE_3: tl.constexpr): + pid_0 = tl.program_id(0) + offset_0 = pid_0 + start = tl.load(group_offsets + offset_0 * group_offsets_stride_0, None) + add = 1 + offset_0 + end = tl.load(group_offsets + add * group_offsets_stride_0, None) + v_0 = end - start + v_1 = tl.full([], 0, tl.int32) + v_2 = v_0 != v_1 + if v_2: + v_0_copy = v_0 + start_copy = start + v_0_copy_0 = v_0_copy + start_copy_0 = start_copy + for offset_1 in tl.range(0, v_0_copy_0.to(tl.int32), _BLOCK_SIZE_1): + indices_1 = offset_1 + tl.arange(0, _BLOCK_SIZE_1).to(tl.int32) + mask_1 = indices_1 < v_0_copy_0 + for offset_2 in tl.range(0, N.to(tl.int32), _BLOCK_SIZE_2): + indices_2 = offset_2 + tl.arange(0, _BLOCK_SIZE_2).to(tl.int32) + mask_2 = indices_2 < N + start_copy_0_copy = start_copy_0 + start_copy_0_copy_0 = start_copy_0_copy + acc = tl.full([_BLOCK_SIZE_1, _BLOCK_SIZE_2], 0.0, tl.float32) + for offset_3 in tl.range(0, K.to(tl.int32), _BLOCK_SIZE_3): + indices_3 = offset_3 + tl.arange(0, _BLOCK_SIZE_3).to(tl.int32) + mask_3 = indices_3 < K + start_copy_0_copy_0_copy = start_copy_0_copy_0 + acc_copy = acc + start_copy_0_copy_0_copy_0 = start_copy_0_copy_0_copy + acc_copy_0 = acc_copy + v_3 = start_copy_0_copy_0_copy_0[None] + v_4 = v_3 + indices_1 + a_blk = tl.load(A_packed + (v_4[:, None] * A_packed_stride_0 + indices_3[None, :] * A_packed_stride_1), mask_1[:, None] & mask_3[None, :], other=0) + b_blk = tl.load(B + (indices_3[:, None] * B_stride_0 + indices_2[None, :] * B_stride_1), mask_3[:, None] & mask_2[None, :], other=0) + acc = tl.dot(tl.cast(a_blk, tl.bfloat16), tl.cast(b_blk, tl.bfloat16), acc=acc_copy_0, input_precision='tf32', out_dtype=tl.float32) + v_5 = tl.cast(acc, tl.bfloat16) + v_6 = start_copy_0_copy_0[None] + v_7 = v_6 + indices_1 + tl.store(out + (v_7[:, None] * out_stride_0 + indices_2[None, :] * out_stride_1), v_5, mask_1[:, None] & mask_2[None, :]) + +def grouped_gemm_jagged(A_packed: torch.Tensor, B: torch.Tensor, group_offsets: torch.Tensor, *, _launcher=_default_launcher): + """ + Perform grouped GEMM on jagged inputs using row offsets. + + Args: + A_packed: Row-wise concatenation of per-group inputs ``A_i``, + shape ``[sum(M_i), K]``. + B: Shared weight matrix, shape ``[K, N]``. + group_offsets: Row offsets delimiting each group within ``A_packed``, + shape ``[G+1]``. For group ``g``: rows are + ``start = group_offsets[g]`` to ``end = group_offsets[g+1]``. + + Returns: + Output tensor of shape ``[sum(M_i), N]`` equal to + ``torch.cat([A_i @ B for i in groups], dim=0)``. + """ + total_M, K = A_packed.shape + K2, N = B.shape + assert K == K2, 'K dimension mismatch between A_packed and B' + out = torch.empty(total_M, N, dtype=torch.promote_types(A_packed.dtype, B.dtype), device=A_packed.device) + G = group_offsets.size(0) - 1 + _BLOCK_SIZE_2 = 16 + _BLOCK_SIZE_1 = 16 + _BLOCK_SIZE_3 = 16 + _launcher(_helion_grouped_gemm_jagged, (G,), group_offsets, A_packed, B, out, A_packed.stride(0), A_packed.stride(1), B.stride(0), B.stride(1), group_offsets.stride(0), out.stride(0), out.stride(1), N, K, _BLOCK_SIZE_2, _BLOCK_SIZE_1, _BLOCK_SIZE_3, num_warps=4, num_stages=3) + return out + +--- assertExpectedJournal(TestExamples.test_grouped_gemm_jagged_persistent) +from __future__ import annotations + +import torch +import triton +import triton.language as tl +from torch._inductor.runtime import triton_helpers +from torch._inductor.runtime.triton_compat import libdevice +from helion.runtime import default_launcher as _default_launcher + +@triton.jit +def _helion_grouped_gemm_jagged_persistent(group_offsets, A_packed, B, out, A_packed_stride_0, A_packed_stride_1, B_stride_0, B_stride_1, group_offsets_stride_0, out_stride_0, out_stride_1, num_workers, G, N, K, _BLOCK_SIZE_0: tl.constexpr, _BLOCK_SIZE_1: tl.constexpr, _BLOCK_SIZE_5: tl.constexpr): + pid_0 = tl.program_id(0) + offset_2 = pid_0 + for offset_3 in tl.range(0, G.to(tl.int32)): + group_start = tl.load(group_offsets + offset_3 * group_offsets_stride_0, None) + add = 1 + offset_3 + group_end = tl.load(group_offsets + add * group_offsets_stride_0, None) + v_0 = group_end - group_start + v_1 = tl.full([], 0, tl.int32) + v_2 = v_0 > v_1 + if v_2: + v_0_copy = v_0 + group_start_copy = group_start + group_end_copy = group_end + v_0_copy_0 = v_0_copy + group_start_copy_0 = group_start_copy + group_end_copy_0 = group_end_copy + _BLOCK_SIZE_0_ = _BLOCK_SIZE_0 + v_3 = tl.cast(v_0_copy_0, tl.int64) + v_4 = v_3 + _BLOCK_SIZE_0_ + v_5 = tl.full([], 1, tl.int32) + v_6 = v_4 - v_5 + _BLOCK_SIZE_0__1 = _BLOCK_SIZE_0 + v_7 = tl.cast(v_6, tl.int64) + v_8 = tl.where((v_7 < 0) != (_BLOCK_SIZE_0__1 < 0), tl.where(v_7 % _BLOCK_SIZE_0__1 != 0, v_7 // _BLOCK_SIZE_0__1 - 1, v_7 // _BLOCK_SIZE_0__1), v_7 // _BLOCK_SIZE_0__1) + add_1 = N + _BLOCK_SIZE_1 + sub_1 = -1 + N + _BLOCK_SIZE_1 + floordiv = triton_helpers.div_floor_integer(-1 + N + _BLOCK_SIZE_1, _BLOCK_SIZE_1) + v_9 = tl.cast(v_8, tl.int64) + v_10 = v_9 * floordiv + for offset_4 in tl.range(0, v_10.to(tl.int32)): + v_10_copy = v_10 + v_8_copy = v_8 + group_start_copy_0_copy = group_start_copy_0 + group_end_copy_0_copy = group_end_copy_0 + v_10_copy_0 = v_10_copy + v_8_copy_0 = v_8_copy + group_start_copy_0_copy_0 = group_start_copy_0_copy + group_end_copy_0_copy_0 = group_end_copy_0_copy + mul = num_workers * offset_4 + add_2 = offset_2 + num_workers * offset_4 + v_11 = tl.cast(v_10_copy_0, tl.int64) + v_12 = v_11 > add_2 + if v_12: + v_8_copy_0_copy = v_8_copy_0 + group_start_copy_0_copy_0_copy = group_start_copy_0_copy_0 + group_end_copy_0_copy_0_copy = group_end_copy_0_copy_0 + v_8_copy_0_copy_0 = v_8_copy_0_copy + group_start_copy_0_copy_0_copy_0 = group_start_copy_0_copy_0_copy + group_end_copy_0_copy_0_copy_0 = group_end_copy_0_copy_0_copy + v_13 = tl.cast(v_8_copy_0_copy_0, tl.int64) + v_14 = add_2 % v_13 + v_15 = tl.full([], 0, tl.int32) + v_16 = v_14 != v_15 + v_17 = libdevice.signbit(v_14) != 0 if v_14.dtype is tl.float32 else v_14 < 0 + v_18 = libdevice.signbit(v_13) != 0 if v_13.dtype is tl.float32 else v_13 < 0 + v_19 = v_17 != v_18 + v_20 = v_16 & v_19 + v_21 = v_14 + v_13 + v_22 = tl.where(v_20, v_21, v_14) + v_23 = tl.cast(v_8_copy_0_copy_0, tl.int64) + v_24 = tl.where((add_2 < 0) != (v_23 < 0), tl.where(add_2 % v_23 != 0, add_2 // v_23 - 1, add_2 // v_23), add_2 // v_23) + _BLOCK_SIZE_0__2 = _BLOCK_SIZE_0 + v_25 = tl.cast(v_22, tl.int64) + v_26 = v_25 * _BLOCK_SIZE_0__2 + v_27 = group_start_copy_0_copy_0_copy_0 + v_26 + _BLOCK_SIZE_1_ = _BLOCK_SIZE_1 + v_28 = tl.cast(v_24, tl.int64) + v_29 = v_28 * _BLOCK_SIZE_1_ + iota = tl.arange(0, _BLOCK_SIZE_0) + v_30 = v_27[None] + v_31 = v_30 + iota + iota_1 = tl.arange(0, _BLOCK_SIZE_1) + v_32 = v_29[None] + v_33 = v_32 + iota_1 + v_34 = group_end_copy_0_copy_0_copy_0[None] + v_35 = v_31 < v_34 + v_36 = tl.cast(N, tl.int32) + v_37 = v_33 < v_36 + acc = tl.full([_BLOCK_SIZE_0, _BLOCK_SIZE_1], 0.0, tl.float32) + for offset_5 in tl.range(0, K.to(tl.int32), _BLOCK_SIZE_5): + indices_5 = offset_5 + tl.arange(0, _BLOCK_SIZE_5).to(tl.int32) + mask_5 = indices_5 < K + v_31_copy = v_31 + v_35_copy = v_35 + v_33_copy = v_33 + v_37_copy = v_37 + acc_copy = acc + v_31_copy_0 = v_31_copy + v_35_copy_0 = v_35_copy + v_33_copy_0 = v_33_copy + v_37_copy_0 = v_37_copy + acc_copy_0 = acc_copy + subscript = v_35_copy_0[:, None] + a_blk = tl.load(A_packed + (v_31_copy_0[:, None] * A_packed_stride_0 + indices_5[None, :] * A_packed_stride_1), mask_5[None, :] & subscript, other=0) + subscript_1 = v_37_copy_0[None, :] + b_blk = tl.load(B + (indices_5[:, None] * B_stride_0 + v_33_copy_0[None, :] * B_stride_1), mask_5[:, None] & subscript_1, other=0) + acc = tl.dot(tl.cast(a_blk, tl.bfloat16), tl.cast(b_blk, tl.bfloat16), acc=acc_copy_0, input_precision='tf32', out_dtype=tl.float32) + subscript_2 = v_35[:, None] + subscript_3 = v_37[None, :] + v_38 = subscript_2 & subscript_3 + v_39 = tl.cast(acc, tl.bfloat16) + tl.store(out + (v_31[:, None] * out_stride_0 + v_33[None, :] * out_stride_1), v_39, v_38) + +def grouped_gemm_jagged_persistent(A_packed: torch.Tensor, B: torch.Tensor, group_offsets: torch.Tensor, *, _launcher=_default_launcher): + """ + Persistent grouped GEMM with dynamic tile metadata computation. + + This variant computes tile assignments dynamically in the kernel, + similar to TritonBench's WS variant. + + Args: + A_packed: Packed A, concatenated by rows across groups, ``[sum(M_i), K]``. + B: Shared weight matrix, ``[K, N]``. + group_offsets: Row offsets delimiting each group within ``A_packed``. + + Returns: + Output tensor of shape ``[sum(M_i), N]``. + """ + device = A_packed.device + num_workers = torch.cuda.get_device_properties(device).multi_processor_count + total_M, K = A_packed.shape + K2, N = B.shape + assert K == K2 + out = torch.zeros(total_M, N, dtype=torch.promote_types(A_packed.dtype, B.dtype), device=A_packed.device) + G = group_offsets.size(0) - 1 + _BLOCK_SIZE_5 = 16 + _launcher(_helion_grouped_gemm_jagged_persistent, (num_workers,), group_offsets, A_packed, B, out, A_packed.stride(0), A_packed.stride(1), B.stride(0), B.stride(1), group_offsets.stride(0), out.stride(0), out.stride(1), num_workers, G, N, K, 32, 32, _BLOCK_SIZE_5, num_warps=4, num_stages=3) + return out + --- assertExpectedJournal(TestExamples.test_int4_gemm) from __future__ import annotations diff --git a/test/test_examples.py b/test/test_examples.py index bb35334b2..be385c1d4 100644 --- a/test/test_examples.py +++ b/test/test_examples.py @@ -1030,6 +1030,78 @@ def test_jagged_hstu_attn(self): ) ) + def test_grouped_gemm_jagged(self): + # Build small jagged grouped GEMM inputs + torch.manual_seed(0) + G = 3 + K, N = 64, 64 + dtype = torch.bfloat16 + group_A = [ + torch.randn(32 * (i + 1), K, device=DEVICE, dtype=dtype).contiguous() + for i in range(G) + ] + B_shared = torch.randn(K, N, device=DEVICE, dtype=dtype).contiguous() + + # Pack A and offsets + M_sizes = [int(a.size(0)) for a in group_A] + starts = [0] + for m in M_sizes: + starts.append(starts[-1] + m) + group_offsets = torch.tensor(starts, device=DEVICE, dtype=torch.int32) + A_packed = torch.cat(group_A, dim=0).contiguous() + + # Reference result + expected = torch.cat([a @ B_shared for a in group_A], dim=0) + + # Run kernel and check + args = (A_packed, B_shared, group_offsets) + self.assertExpectedJournal( + check_example( + "grouped_gemm", + args, + expected, + fn_name="grouped_gemm_jagged", + ) + ) + + def test_grouped_gemm_jagged_persistent(self): + # Build small jagged grouped GEMM inputs + torch.manual_seed(0) + G = 3 + K, N = 64, 64 + dtype = torch.bfloat16 + group_A = [ + torch.randn(32 * (i + 1), K, device=DEVICE, dtype=dtype).contiguous() + for i in range(G) + ] + B_shared = torch.randn(K, N, device=DEVICE, dtype=dtype).contiguous() + + # Pack A and offsets + M_sizes = [int(a.size(0)) for a in group_A] + starts = [0] + for m in M_sizes: + starts.append(starts[-1] + m) + group_offsets = torch.tensor(starts, device=DEVICE, dtype=torch.int32) + A_packed = torch.cat(group_A, dim=0).contiguous() + + # Reference result + expected = torch.cat([a @ B_shared for a in group_A], dim=0) + + # Run kernel and check + args = ( + A_packed, + B_shared, + group_offsets, + ) + self.assertExpectedJournal( + check_example( + "grouped_gemm", + args, + expected, + fn_name="grouped_gemm_jagged_persistent", + ) + ) + def test_geglu(self): args = ( torch.randn([1024, 1024], device=DEVICE, dtype=torch.float16), diff --git a/test/test_type_propagation.expected b/test/test_type_propagation.expected index 1f058e2be..2a103e16f 100644 --- a/test/test_type_propagation.expected +++ b/test/test_type_propagation.expected @@ -452,6 +452,91 @@ def root_graph_0(): store = helion_language_memory_ops_store(out, [block_size_0, block_size_1], add, None); out = block_size_0 = block_size_1 = add = store = None return None +--- assertExpectedJournal(TestTypePropagation.test_cuda_device_properties) +def use_device_properties(x: torch.Tensor): + # Attribute: LiteralType(device(type='cuda', index=0)) AttributeOrigin(value=ArgumentOrigin(name='x'), key='device') + # Name: TensorType([x_size0], torch.float32) ArgumentOrigin(name='x') + device = x.device + # Call: ClassType({'multi_processor_count': SymIntType(u0)}) SourceOrigin(location=) + # Attribute: CallableType(get_device_properties) AttributeOrigin(value=AttributeOrigin(value=GlobalOrigin(name='torch'), key='cuda'), key='get_device_properties') + # Attribute: PythonModuleType(torch.cuda) AttributeOrigin(value=GlobalOrigin(name='torch'), key='cuda') + # Name: PythonModuleType(torch) GlobalOrigin(name='torch') + # Name: LiteralType(device(type='cuda', index=0)) AttributeOrigin(value=ArgumentOrigin(name='x'), key='device') + props = torch.cuda.get_device_properties(device) + # Attribute: SymIntType(u0) AttributeOrigin(value=SourceOrigin(location=), key='multi_processor_count') + # Name: ClassType({'multi_processor_count': SymIntType(u0)}) SourceOrigin(location=) + sm_count = props.multi_processor_count + # Subscript: SymIntType(s77) GetItemOrigin(value=AttributeOrigin(value=ArgumentOrigin(name='x'), key='shape'), key=0) + # Attribute: SequenceType((SymIntType(s77), )) AttributeOrigin(value=ArgumentOrigin(name='x'), key='shape') + # Name: TensorType([x_size0], torch.float32) ArgumentOrigin(name='x') + # Constant: LiteralType(0) SourceOrigin(location=) + n = x.shape[0] + # Call: TensorType([x_size0], torch.float32) SourceOrigin(location=) + # Attribute: CallableType(_VariableFunctionsClass.zeros_like) AttributeOrigin(value=GlobalOrigin(name='torch'), key='zeros_like') + # Name: PythonModuleType(torch) GlobalOrigin(name='torch') + # Name: TensorType([x_size0], torch.float32) ArgumentOrigin(name='x') + # For: loop_type=GRID + out = torch.zeros_like(x) + # Call: IterType(GridIndexType(0)) SourceOrigin(location=) + # Attribute: CallableType(grid) AttributeOrigin(value=GlobalOrigin(name='hl'), key='grid') + # Name: PythonModuleType(helion.language) GlobalOrigin(name='hl') + # Name: SymIntType(u0) AttributeOrigin(value=SourceOrigin(location=), key='multi_processor_count') + # For: loop_type=DEVICE + for worker_id in hl.grid(sm_count): + # Call: IterType(GridIndexType(1)) DeviceOrigin(location=) + # Attribute: CallableType(grid) AttributeOrigin(value=GlobalOrigin(name='hl'), key='grid') + # Name: PythonModuleType(helion.language) GlobalOrigin(name='hl') + # Name: SymIntType(s77) GetItemOrigin(value=AttributeOrigin(value=ArgumentOrigin(name='x'), key='shape'), key=0) + for i in hl.grid(n): + # BinOp: SymIntType(u0*u4 + u2) DeviceOrigin(location=) + # Name: GridIndexType(0) SourceOrigin(location=) + # BinOp: SymIntType(u0*u4) DeviceOrigin(location=) + # Name: GridIndexType(1) DeviceOrigin(location=) + # Name: SymIntType(u0) AttributeOrigin(value=SourceOrigin(location=), key='multi_processor_count') + idx = worker_id + i * sm_count + # Compare: SymBoolType(Eq(u11, 1)) DeviceOrigin(location=) + # Name: SymIntType(u0*u4 + u2) DeviceOrigin(location=) + # Name: SymIntType(s77) GetItemOrigin(value=AttributeOrigin(value=ArgumentOrigin(name='x'), key='shape'), key=0) + if idx < n: + # Subscript: TensorType([], torch.float32) DeviceOrigin(location=) + # Name: TensorType([x_size0], torch.float32) SourceOrigin(location=) + # Name: SymIntType(u0*u4 + u2) DeviceOrigin(location=) + # Subscript: TensorType([], torch.float32) DeviceOrigin(location=) + # Name: TensorType([x_size0], torch.float32) ArgumentOrigin(name='x') + # Name: SymIntType(u0*u4 + u2) DeviceOrigin(location=) + out[idx] = x[idx] + # Name: TensorType([x_size0], torch.float32) SourceOrigin(location=) + return out + +def if_else_graph_0(): + # File: .../test_type_propagation.py:111 in use_device_properties, code: out[idx] = x[idx] + x: "f32[s77]" = helion_language__tracing_ops__host_tensor('x') + symnode: "Sym(u0*u4 + u2)" = helion_language__tracing_ops__get_symnode('u0*u4 + u2') + load: "f32[]" = helion_language_memory_ops_load(x, [symnode], None); x = None + out: "f32[s77]" = helion_language__tracing_ops__host_tensor('out') + store = helion_language_memory_ops_store(out, [symnode], load, None); out = symnode = load = store = None + return [] + +def for_loop_1(): + # File: .../test_type_propagation.py:109 in use_device_properties, code: idx = worker_id + i * sm_count + u4: "Sym(u4)" = helion_language__tracing_ops__get_symnode('u4') + u0: "Sym(u0)" = helion_language__tracing_ops__get_symnode('u0') + mul: "Sym(u0*u4)" = u4 * u0; u4 = u0 = None + u2: "Sym(u2)" = helion_language__tracing_ops__get_symnode('u2') + add: "Sym(u0*u4 + u2)" = u2 + mul; u2 = mul = None + + # File: .../test_type_propagation.py:110 in use_device_properties, code: if idx < n: + x_size0: "Sym(s77)" = helion_language__tracing_ops__get_symnode('x_size0') + lt: "Sym(u0*u4 + u2 < s77)" = add < x_size0; add = x_size0 = None + _if = helion_language__tracing_ops__if(lt, 0, []); lt = _if = None + return [] + +def root_graph_2(): + # File: .../test_type_propagation.py:108 in use_device_properties, code: for i in hl.grid(n): + x_size0: "Sym(s77)" = helion_language__tracing_ops__get_symnode('x_size0') + _for_loop = helion_language__tracing_ops__for_loop(1, [0], [x_size0], []); x_size0 = _for_loop = None + return None + --- assertExpectedJournal(TestTypePropagation.test_hl_full_usage) def hl_full_usage(x: torch.Tensor): # Call: TensorType([x_size0, x_size1], torch.int32) SourceOrigin(location=) @@ -693,33 +778,33 @@ def root_graph_1(): --- assertExpectedJournal(TestTypePropagation.test_method_call) def fn(x): - # Call: TensorType([x_size0, x_size1], torch.int32) SourceOrigin(location=) + # Call: TensorType([x_size0, x_size1], torch.int32) SourceOrigin(location=) # Attribute: CallableType(_VariableFunctionsClass.empty_like) AttributeOrigin(value=GlobalOrigin(name='torch'), key='empty_like') # Name: PythonModuleType(torch) GlobalOrigin(name='torch') # Name: TensorType([x_size0, x_size1], torch.int32) ArgumentOrigin(name='x') # For: loop_type=GRID out = torch.empty_like(x) - # Call: IterType(SequenceType((TileIndexType(0), TileIndexType(1)))) SourceOrigin(location=) + # Call: IterType(SequenceType((TileIndexType(0), TileIndexType(1)))) SourceOrigin(location=) # Attribute: CallableType(tile) AttributeOrigin(value=GlobalOrigin(name='hl'), key='tile') # Name: PythonModuleType(helion.language) GlobalOrigin(name='hl') - # Call: SequenceType((SymIntType(s77), SymIntType(s27))) SourceOrigin(location=) + # Call: SequenceType((SymIntType(s77), SymIntType(s27))) SourceOrigin(location=) # Attribute: TensorAttributeType AttributeOrigin(value=ArgumentOrigin(name='x'), key='size') # Name: TensorType([x_size0, x_size1], torch.int32) ArgumentOrigin(name='x') for tile in hl.tile(x.size()): - # Subscript: TensorType([block_size_0, block_size_1], torch.int32) DeviceOrigin(location=) - # Name: TensorType([x_size0, x_size1], torch.int32) SourceOrigin(location=) - # Name: SequenceType((TileIndexType(0), TileIndexType(1))) SourceOrigin(location=) - # Call: TensorType([block_size_0, block_size_1], torch.float32) DeviceOrigin(location=) - # Attribute: TensorAttributeType AttributeOrigin(value=DeviceOrigin(location=), key='sin') - # Subscript: TensorType([block_size_0, block_size_1], torch.int32) DeviceOrigin(location=) + # Subscript: TensorType([block_size_0, block_size_1], torch.int32) DeviceOrigin(location=) + # Name: TensorType([x_size0, x_size1], torch.int32) SourceOrigin(location=) + # Name: SequenceType((TileIndexType(0), TileIndexType(1))) SourceOrigin(location=) + # Call: TensorType([block_size_0, block_size_1], torch.float32) DeviceOrigin(location=) + # Attribute: TensorAttributeType AttributeOrigin(value=DeviceOrigin(location=), key='sin') + # Subscript: TensorType([block_size_0, block_size_1], torch.int32) DeviceOrigin(location=) # Name: TensorType([x_size0, x_size1], torch.int32) ArgumentOrigin(name='x') - # Name: SequenceType((TileIndexType(0), TileIndexType(1))) SourceOrigin(location=) + # Name: SequenceType((TileIndexType(0), TileIndexType(1))) SourceOrigin(location=) out[tile] = x[tile].sin() - # Name: TensorType([x_size0, x_size1], torch.int32) SourceOrigin(location=) + # Name: TensorType([x_size0, x_size1], torch.int32) SourceOrigin(location=) return out def root_graph_0(): - # File: .../test_type_propagation.py:79 in fn, code: out[tile] = x[tile].sin() + # File: .../test_type_propagation.py:80 in fn, code: out[tile] = x[tile].sin() x: "i32[s77, s27]" = helion_language__tracing_ops__host_tensor('x') block_size_0: "Sym(u0)" = helion_language__tracing_ops__get_symnode('block_size_0') block_size_1: "Sym(u1)" = helion_language__tracing_ops__get_symnode('block_size_1') diff --git a/test/test_type_propagation.py b/test/test_type_propagation.py index 45a89c26c..79fa3a6b6 100644 --- a/test/test_type_propagation.py +++ b/test/test_type_propagation.py @@ -7,6 +7,7 @@ import torch import helion +from helion import exc from helion._testing import RefEagerTestDisabled from helion._testing import TestCase from helion._testing import import_path @@ -93,6 +94,44 @@ def test_matmul(self): ) self.assertExpectedJournal(output) + def test_cuda_device_properties(self): + @helion.kernel + def use_device_properties(x: torch.Tensor) -> torch.Tensor: + device = x.device + props = torch.cuda.get_device_properties(device) + sm_count = props.multi_processor_count + + n = x.shape[0] + out = torch.zeros_like(x) + + for worker_id in hl.grid(sm_count): + for i in hl.grid(n): + idx = worker_id + i * sm_count + if idx < n: + out[idx] = x[idx] + return out + + x = torch.ones([128], device="cuda") + output = type_propagation_report(use_device_properties, x) + self.assertExpectedJournal(output) + + def test_cuda_device_properties_unsupported_attribute(self): + @helion.kernel + def use_unsupported_property(x: torch.Tensor) -> torch.Tensor: + device = x.device + props = torch.cuda.get_device_properties(device) + for i in hl.grid(x.shape[0]): + unsupported = props.total_memory # attribute not supported yet + x[i] = unsupported + return x + + x = torch.ones([16], device="cuda") + with self.assertRaisesRegex( + exc.TypeInferenceError, + r"Attribute 'total_memory' is not supported on .*test_type_propagation.py", + ): + type_propagation_report(use_unsupported_property, x) + if __name__ == "__main__": unittest.main()