diff --git a/benchmarks/run.py b/benchmarks/run.py index ed712aeae..a9b9b16ff 100644 --- a/benchmarks/run.py +++ b/benchmarks/run.py @@ -166,6 +166,11 @@ class RunResult: "examples.welford", "welford", ), + "gather_gemv": ( + "tritonbench.operators.gather_gemv.operator", + "examples.gather_gemv", + "gather_gemv_tritonbench", + ), "int4_gemm": ( "tritonbench.operators.int4_gemm.int4_gemm", "examples.int4_gemm", @@ -271,6 +276,14 @@ class RunResult: "helion_kl_div_tritonbench-speedup": "helion_speedup", "helion_kl_div_tritonbench-accuracy": "helion_accuracy", }, + "gather_gemv": { + "test_0-speedup": "triton_speedup", + "test_0-accuracy": "triton_accuracy", + "test_inductor-speedup": "torch_compile_speedup", + "test_inductor-accuracy": "torch_compile_accuracy", + "helion_gather_gemv_tritonbench-speedup": "helion_speedup", + "helion_gather_gemv_tritonbench-accuracy": "helion_accuracy", + }, "int4_gemm": { "triton_int4_gemm-speedup": "triton_speedup", "triton_int4_gemm-accuracy": "triton_accuracy", diff --git a/examples/gather_gemv.py b/examples/gather_gemv.py new file mode 100644 index 000000000..aead1d315 --- /dev/null +++ b/examples/gather_gemv.py @@ -0,0 +1,140 @@ +""" +Helion Gather GEMV Kernel Example +================================= +This example demonstrates a Helion kernel implementation of a gather operation +followed by general matrix-vector multiplication (GEMV). The operation is: +w[idx].to(x.dtype) @ x, where w is a 3D tensor, idx contains indices to gather, +and x is a vector. + +Based on the tritonbench gather_gemv operator that is motivated by Mixtral performance +where gather + gemv is the primary kernel. +""" + +# %% +# Imports +# ------- +from __future__ import annotations + +from typing import TYPE_CHECKING + +import torch +from torch import Tensor + +import helion +from helion._testing import run_example +import helion.language as hl + +if TYPE_CHECKING: + from collections.abc import Callable + + +# %% +# Gather GEMV Kernel +# ------------------ +@helion.kernel(ignore_warnings=[helion.exc.TensorOperationInWrapper]) +def gather_gemv(w: Tensor, idx: Tensor, x: Tensor) -> Tensor: + """ + Performs a gather operation on w using idx, then matrix-vector multiplication with x. + + Args: + w (Tensor): Weight matrix of shape [B, S, S] where B is batch size, S is sequence length. + idx (Tensor): Index tensor of shape [N] containing indices to gather from dimension 0 of w. + x (Tensor): Vector of shape [S] to multiply with the gathered matrices. + + Returns: + Tensor: Result of shape [N, S] where each row i is w[idx[i]] @ x. + """ + B, S1, S2 = w.size() + N = idx.size(0) + S = x.size(0) + assert S1 == S2, f"Weight matrix must be square, got {S1} != {S2}" + assert S == S1, f"Vector size {S} must match matrix size {S1}" + + # Rearrange shapes for matrix-vector multiplication + w_view = w.contiguous().view(B * S, S).to(x.dtype) # Shape: [N, S, S] + x = x.view(S, 1) + + # Create output tensor + out = torch.empty([N * S, 1], dtype=x.dtype, device=x.device) + + # Perform matrix-vector multiplication for each gathered matrix + for tile_n_s in hl.tile(N * S): + acc = hl.zeros([tile_n_s, 1], dtype=torch.float32) + idx_id = tile_n_s.index // S + idx_gather = idx[idx_id] + for tile_k in hl.tile(S): + # Matrix-vector multiplication + gathered = w_view[idx_gather * S + tile_n_s.index % S, tile_k] + acc += hl.dot(gathered, x[tile_k, :]) + out[tile_n_s, :] = acc + + return out.contiguous().view(N, S) + + +# %% +# Verification Function +# --------------------- +def check(B: int, S: int, N: int) -> None: + """ + Verify the gather_gemv kernel implementation against PyTorch's baseline. + + Args: + B (int): Batch size for weight matrix. + S (int): Sequence length (matrix size). + N (int): Number of indices to gather. + """ + # Create test tensors matching tritonbench format + w = torch.randn((B, S, S), device="cuda:0", dtype=torch.float16) + idx = torch.randint(0, B, [N], device="cuda:0", dtype=torch.int32) + x = torch.randn((S), device="cuda:0", dtype=torch.float16) + + def baseline_gather_gemv(w: Tensor, idx: Tensor, x: Tensor) -> Tensor: + """PyTorch baseline implementation.""" + outputs = [] + for idx_val in idx.tolist(): + outputs.append(w[idx_val].to(x.dtype) @ x) + return torch.stack(outputs, dim=0) + + run_example(gather_gemv, baseline_gather_gemv, (w, idx, x)) + + +# %% +# Tritonbench Integration +# ----------------------- +def gather_gemv_tritonbench( + tb_op: object, w: Tensor, idx: Tensor, x: Tensor +) -> Callable: + """ + Wrapper for tritonbench that matches its interface. + + Args: + w (Tensor): Weight matrix of shape [B, S, S]. + idx (Tensor): Index tensor of shape [N]. + x (Tensor): Vector of shape [S]. + + Returns: + Callable: A callable that runs the gather_gemv kernel. + """ + return lambda: gather_gemv(w, idx, x) + + +# %% +# Main Function +# ------------- +def main() -> None: + """ + Main entry point that runs the gather_gemv kernel verification. + Uses sizes similar to tritonbench for consistency. + """ + # Test with sizes from tritonbench + B = 8 # Batch size, could be number of experts in MoE + N = 2 # Number of indices, experts selected + for i in range(11, 15): + S = 2**i + print(f"Testing with B={B}, S={S}, N={N}") + check(B, S, N) + + +# %% +if __name__ == "__main__": + main() diff --git a/test/test_examples.expected b/test/test_examples.expected index 51ca95eb1..ee0be2738 100644 --- a/test/test_examples.expected +++ b/test/test_examples.expected @@ -841,6 +841,76 @@ def fp8_gemm(x: torch.Tensor, y: torch.Tensor, *, _launcher=_default_launcher): _launcher(_helion_fp8_gemm, (triton.cdiv(256, _BLOCK_SIZE_0) * triton.cdiv(256, _BLOCK_SIZE_1),), x, y, out, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, num_warps=4, num_stages=3) return out +--- assertExpectedJournal(TestExamples.test_gather_gemv) +from __future__ import annotations + +import torch +import triton +import triton.language as tl +from torch._inductor.runtime.triton_compat import libdevice +from helion.runtime import default_launcher as _default_launcher + +@triton.jit +def _helion_gather_gemv(out, idx, w_view, x, out_size_0, idx_stride_0, out_stride_0, w_view_stride_0, w_view_stride_1, x_stride_0, S1, _BLOCK_SIZE_0: tl.constexpr, _BLOCK_SIZE_1: tl.constexpr): + pid_0 = tl.program_id(0) + offset_0 = pid_0 * _BLOCK_SIZE_0 + indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32) + mask_0 = indices_0 < out_size_0 + acc = tl.full([_BLOCK_SIZE_0, 1], 0.0, tl.float32) + v_0 = tl.cast(S1, tl.int32) + v_1 = tl.where((indices_0 < 0) != (v_0 < 0), tl.where(indices_0 % v_0 != 0, indices_0 // v_0 - 1, indices_0 // v_0), indices_0 // v_0) + idx_gather = tl.load(idx + v_1 * idx_stride_0, mask_0, other=0) + for offset_1 in tl.range(0, S1.to(tl.int32), _BLOCK_SIZE_1): + indices_1 = offset_1 + tl.arange(0, _BLOCK_SIZE_1).to(tl.int32) + mask_1 = indices_1 < S1 + idx_gather_copy = idx_gather + acc_copy = acc + idx_gather_copy_0 = idx_gather_copy + acc_copy_0 = acc_copy + v_2 = tl.cast(S1, tl.int32) + v_3 = idx_gather_copy_0 * v_2 + v_4 = tl.cast(S1, tl.int32) + v_5 = indices_0 % v_4 + v_6 = tl.full([], 0, tl.int32) + v_7 = v_5 != v_6 + v_8 = libdevice.signbit(v_5) != 0 if v_5.dtype is tl.float32 else v_5 < 0 + v_9 = libdevice.signbit(v_4) != 0 if v_4.dtype is tl.float32 else v_4 < 0 + v_10 = v_8 != v_9 + v_11 = v_7 & v_10 + v_12 = v_5 + v_4 + v_13 = tl.where(v_11, v_12, v_5) + v_14 = v_3 + v_13 + gathered = tl.load(w_view + (v_14[:, None] * w_view_stride_0 + indices_1[None, :] * w_view_stride_1), mask_0[:, None] & mask_1[None, :], other=0) + load_1 = tl.load(x + indices_1[:, None] * x_stride_0, mask_1[:, None], other=0) + dot = tl.split(tl.permute(tl.reshape(tl.split(tl.permute(tl.reshape(tl.split(tl.permute(tl.reshape(tl.split(tl.permute(tl.reshape(tl.dot(tl.cast(gathered, tl.float32), tl.reshape(tl.permute(tl.join(tl.reshape(tl.permute(tl.join(tl.reshape(tl.permute(tl.join(tl.reshape(tl.permute(tl.join(tl.cast(load_1, tl.float32), tl.zeros_like(tl.cast(load_1, tl.float32))), [0, 2, 1]), [16, 2]), tl.zeros_like(tl.reshape(tl.permute(tl.join(tl.cast(load_1, tl.float32), tl.zeros_like(tl.cast(load_1, tl.float32))), [0, 2, 1]), [16, 2]))), [0, 2, 1]), [16, 4]), tl.zeros_like(tl.reshape(tl.permute(tl.join(tl.reshape(tl.permute(tl.join(tl.cast(load_1, tl.float32), tl.zeros_like(tl.cast(load_1, tl.float32))), [0, 2, 1]), [16, 2]), tl.zeros_like(tl.reshape(tl.permute(tl.join(tl.cast(load_1, tl.float32), tl.zeros_like(tl.cast(load_1, tl.float32))), [0, 2, 1]), [16, 2]))), [0, 2, 1]), [16, 4]))), [0, 2, 1]), [16, 8]), tl.zeros_like(tl.reshape(tl.permute(tl.join(tl.reshape(tl.permute(tl.join(tl.reshape(tl.permute(tl.join(tl.cast(load_1, tl.float32), tl.zeros_like(tl.cast(load_1, tl.float32))), [0, 2, 1]), [16, 2]), tl.zeros_like(tl.reshape(tl.permute(tl.join(tl.cast(load_1, tl.float32), tl.zeros_like(tl.cast(load_1, tl.float32))), [0, 2, 1]), [16, 2]))), [0, 2, 1]), [16, 4]), tl.zeros_like(tl.reshape(tl.permute(tl.join(tl.reshape(tl.permute(tl.join(tl.cast(load_1, tl.float32), tl.zeros_like(tl.cast(load_1, tl.float32))), [0, 2, 1]), [16, 2]), tl.zeros_like(tl.reshape(tl.permute(tl.join(tl.cast(load_1, tl.float32), tl.zeros_like(tl.cast(load_1, tl.float32))), [0, 2, 1]), [16, 2]))), [0, 2, 1]), [16, 4]))), [0, 2, 1]), [16, 8]))), [0, 2, 1]), [16, 16]), input_precision='tf32', out_dtype=tl.float32), [16, 2, 8]), [0, 2, 1]))[0], [16, 2, 4]), [0, 2, 1]))[0], [16, 2, 2]), [0, 2, 1]))[0], [16, 2, 1]), [0, 2, 1]))[0] + acc = acc_copy_0 + dot + tl.store(out + indices_0[:, None] * out_stride_0, acc, mask_0[:, None]) + +def gather_gemv(w: Tensor, idx: Tensor, x: Tensor, *, _launcher=_default_launcher): + """ + Performs a gather operation on w using idx, then matrix-vector multiplication with x. + + Args: + w (Tensor): Weight matrix of shape [B, S, S] where B is batch size, S is sequence length. + idx (Tensor): Index tensor of shape [N] containing indices to gather from dimension 0 of w. + x (Tensor): Vector of shape [S] to multiply with the gathered matrices. + + Returns: + Tensor: Result of shape [N, S] where each row i is w[idx[i]] @ x. + """ + B, S1, S2 = w.size() + N = idx.size(0) + S = x.size(0) + assert S1 == S2, f'Weight matrix must be square, got {S1} != {S2}' + assert S == S1, f'Vector size {S} must match matrix size {S1}' + w_view = w.contiguous().view(B * S, S).to(x.dtype) + x = x.view(S, 1) + out = torch.empty([N * S, 1], dtype=x.dtype, device=x.device) + _BLOCK_SIZE_0 = 16 + _BLOCK_SIZE_1 = 16 + _launcher(_helion_gather_gemv, (triton.cdiv(out.size(0), _BLOCK_SIZE_0),), out, idx, w_view, x, out.size(0), idx.stride(0), out.stride(0), w_view.stride(0), w_view.stride(1), x.stride(0), S1, _BLOCK_SIZE_0, _BLOCK_SIZE_1, num_warps=8, num_stages=1) + return out.contiguous().view(N, S) + --- assertExpectedJournal(TestExamples.test_geglu) from __future__ import annotations diff --git a/test/test_examples.py b/test/test_examples.py index bb35334b2..cbd5ebbd5 100644 --- a/test/test_examples.py +++ b/test/test_examples.py @@ -12,6 +12,7 @@ from helion._testing import TestCase from helion._testing import check_example from helion._testing import import_path +from helion._testing import is_cuda from helion._testing import skipIfRefEager from helion._testing import skipIfRocm @@ -1112,6 +1113,29 @@ def test_kl_div(self): ) ) + def test_gather_gemv(self): + args = ( + torch.randn([8, 1024, 1024], device=DEVICE, dtype=torch.float32), + torch.randint(0, 8, [2], device=DEVICE, dtype=torch.int32), + torch.randn([1024], device=DEVICE, dtype=torch.float32), + ) + + def expected(w, idx, x): + return w[idx].to(x.dtype) @ x + + code = check_example( + "gather_gemv", + args, + expected(*args), + fn_name="gather_gemv", + block_sizes=[16, 16], + num_warps=8, + num_stages=1, + ) + + if is_cuda(): + self.assertExpectedJournal(code) + def test_int4_gemm(self): # Matrix dimensions M, K, N = 256, 512, 256