Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions benchmarks/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,11 @@ class RunResult:
"examples.welford",
"welford",
),
"gather_gemv": (
"tritonbench.operators.gather_gemv.operator",
"examples.gather_gemv",
"gather_gemv_tritonbench",
),
"int4_gemm": (
"tritonbench.operators.int4_gemm.int4_gemm",
"examples.int4_gemm",
Expand Down Expand Up @@ -271,6 +276,14 @@ class RunResult:
"helion_kl_div_tritonbench-speedup": "helion_speedup",
"helion_kl_div_tritonbench-accuracy": "helion_accuracy",
},
"gather_gemv": {
"test_0-speedup": "triton_speedup",
"test_0-accuracy": "triton_accuracy",
"test_inductor-speedup": "torch_compile_speedup",
"test_inductor-accuracy": "torch_compile_accuracy",
"helion_gather_gemv_tritonbench-speedup": "helion_speedup",
"helion_gather_gemv_tritonbench-accuracy": "helion_accuracy",
},
"int4_gemm": {
"triton_int4_gemm-speedup": "triton_speedup",
"triton_int4_gemm-accuracy": "triton_accuracy",
Expand Down
140 changes: 140 additions & 0 deletions examples/gather_gemv.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,140 @@
"""
Helion Gather GEMV Kernel Example
=================================
This example demonstrates a Helion kernel implementation of a gather operation
followed by general matrix-vector multiplication (GEMV). The operation is:
w[idx].to(x.dtype) @ x, where w is a 3D tensor, idx contains indices to gather,
and x is a vector.

Based on the tritonbench gather_gemv operator that is motivated by Mixtral performance
where gather + gemv is the primary kernel.
"""

# %%
# Imports
# -------
from __future__ import annotations

from typing import TYPE_CHECKING

import torch
from torch import Tensor

import helion
from helion._testing import run_example
import helion.language as hl

if TYPE_CHECKING:
from collections.abc import Callable


# %%
# Gather GEMV Kernel
# ------------------
@helion.kernel(ignore_warnings=[helion.exc.TensorOperationInWrapper])
def gather_gemv(w: Tensor, idx: Tensor, x: Tensor) -> Tensor:
"""
Performs a gather operation on w using idx, then matrix-vector multiplication with x.

Args:
w (Tensor): Weight matrix of shape [B, S, S] where B is batch size, S is sequence length.
idx (Tensor): Index tensor of shape [N] containing indices to gather from dimension 0 of w.
x (Tensor): Vector of shape [S] to multiply with the gathered matrices.

Returns:
Tensor: Result of shape [N, S] where each row i is w[idx[i]] @ x.
"""
B, S1, S2 = w.size()
N = idx.size(0)
S = x.size(0)
assert S1 == S2, f"Weight matrix must be square, got {S1} != {S2}"
assert S == S1, f"Vector size {S} must match matrix size {S1}"

# Rearrange shapes for matrix-vector multiplication
w_view = w.contiguous().view(B * S, S).to(x.dtype) # Shape: [N, S, S]
x = x.view(S, 1)

# Create output tensor
out = torch.empty([N * S, 1], dtype=x.dtype, device=x.device)

# Perform matrix-vector multiplication for each gathered matrix
for tile_n_s in hl.tile(N * S):
acc = hl.zeros([tile_n_s, 1], dtype=torch.float32)
idx_id = tile_n_s.index // S
idx_gather = idx[idx_id]
for tile_k in hl.tile(S):
# Matrix-vector multiplication
gathered = w_view[idx_gather * S + tile_n_s.index % S, tile_k]
acc += hl.dot(gathered, x[tile_k, :])
out[tile_n_s, :] = acc

return out.contiguous().view(N, S)


# %%
# Verification Function
# ---------------------
def check(B: int, S: int, N: int) -> None:
"""
Verify the gather_gemv kernel implementation against PyTorch's baseline.

Args:
B (int): Batch size for weight matrix.
S (int): Sequence length (matrix size).
N (int): Number of indices to gather.
"""
# Create test tensors matching tritonbench format
w = torch.randn((B, S, S), device="cuda:0", dtype=torch.float16)
idx = torch.randint(0, B, [N], device="cuda:0", dtype=torch.int32)
x = torch.randn((S), device="cuda:0", dtype=torch.float16)

def baseline_gather_gemv(w: Tensor, idx: Tensor, x: Tensor) -> Tensor:
"""PyTorch baseline implementation."""
outputs = []
for idx_val in idx.tolist():
outputs.append(w[idx_val].to(x.dtype) @ x)
return torch.stack(outputs, dim=0)

run_example(gather_gemv, baseline_gather_gemv, (w, idx, x))


# %%
# Tritonbench Integration
# -----------------------
def gather_gemv_tritonbench(
tb_op: object, w: Tensor, idx: Tensor, x: Tensor
) -> Callable:
"""
Wrapper for tritonbench that matches its interface.

Args:
w (Tensor): Weight matrix of shape [B, S, S].
idx (Tensor): Index tensor of shape [N].
x (Tensor): Vector of shape [S].

Returns:
Callable: A callable that runs the gather_gemv kernel.
"""
return lambda: gather_gemv(w, idx, x)


# %%
# Main Function
# -------------
def main() -> None:
"""
Main entry point that runs the gather_gemv kernel verification.
Uses sizes similar to tritonbench for consistency.
"""
# Test with sizes from tritonbench
B = 8 # Batch size, could be number of experts in MoE
N = 2 # Number of indices, experts selected
for i in range(11, 15):
S = 2**i
print(f"Testing with B={B}, S={S}, N={N}")
check(B, S, N)


# %%
if __name__ == "__main__":
main()
70 changes: 70 additions & 0 deletions test/test_examples.expected
Original file line number Diff line number Diff line change
Expand Up @@ -841,6 +841,76 @@ def fp8_gemm(x: torch.Tensor, y: torch.Tensor, *, _launcher=_default_launcher):
_launcher(_helion_fp8_gemm, (triton.cdiv(256, _BLOCK_SIZE_0) * triton.cdiv(256, _BLOCK_SIZE_1),), x, y, out, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, num_warps=4, num_stages=3)
return out

--- assertExpectedJournal(TestExamples.test_gather_gemv)
from __future__ import annotations

import torch
import triton
import triton.language as tl
from torch._inductor.runtime.triton_compat import libdevice
from helion.runtime import default_launcher as _default_launcher

@triton.jit
def _helion_gather_gemv(out, idx, w_view, x, out_size_0, idx_stride_0, out_stride_0, w_view_stride_0, w_view_stride_1, x_stride_0, S1, _BLOCK_SIZE_0: tl.constexpr, _BLOCK_SIZE_1: tl.constexpr):
pid_0 = tl.program_id(0)
offset_0 = pid_0 * _BLOCK_SIZE_0
indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32)
mask_0 = indices_0 < out_size_0
acc = tl.full([_BLOCK_SIZE_0, 1], 0.0, tl.float32)
v_0 = tl.cast(S1, tl.int32)
v_1 = tl.where((indices_0 < 0) != (v_0 < 0), tl.where(indices_0 % v_0 != 0, indices_0 // v_0 - 1, indices_0 // v_0), indices_0 // v_0)
idx_gather = tl.load(idx + v_1 * idx_stride_0, mask_0, other=0)
for offset_1 in tl.range(0, S1.to(tl.int32), _BLOCK_SIZE_1):
indices_1 = offset_1 + tl.arange(0, _BLOCK_SIZE_1).to(tl.int32)
mask_1 = indices_1 < S1
idx_gather_copy = idx_gather
acc_copy = acc
idx_gather_copy_0 = idx_gather_copy
acc_copy_0 = acc_copy
v_2 = tl.cast(S1, tl.int32)
v_3 = idx_gather_copy_0 * v_2
v_4 = tl.cast(S1, tl.int32)
v_5 = indices_0 % v_4
v_6 = tl.full([], 0, tl.int32)
v_7 = v_5 != v_6
v_8 = libdevice.signbit(v_5) != 0 if v_5.dtype is tl.float32 else v_5 < 0
v_9 = libdevice.signbit(v_4) != 0 if v_4.dtype is tl.float32 else v_4 < 0
v_10 = v_8 != v_9
v_11 = v_7 & v_10
v_12 = v_5 + v_4
v_13 = tl.where(v_11, v_12, v_5)
v_14 = v_3 + v_13
gathered = tl.load(w_view + (v_14[:, None] * w_view_stride_0 + indices_1[None, :] * w_view_stride_1), mask_0[:, None] & mask_1[None, :], other=0)
load_1 = tl.load(x + indices_1[:, None] * x_stride_0, mask_1[:, None], other=0)
dot = tl.split(tl.permute(tl.reshape(tl.split(tl.permute(tl.reshape(tl.split(tl.permute(tl.reshape(tl.split(tl.permute(tl.reshape(tl.dot(tl.cast(gathered, tl.float32), tl.reshape(tl.permute(tl.join(tl.reshape(tl.permute(tl.join(tl.reshape(tl.permute(tl.join(tl.reshape(tl.permute(tl.join(tl.cast(load_1, tl.float32), tl.zeros_like(tl.cast(load_1, tl.float32))), [0, 2, 1]), [16, 2]), tl.zeros_like(tl.reshape(tl.permute(tl.join(tl.cast(load_1, tl.float32), tl.zeros_like(tl.cast(load_1, tl.float32))), [0, 2, 1]), [16, 2]))), [0, 2, 1]), [16, 4]), tl.zeros_like(tl.reshape(tl.permute(tl.join(tl.reshape(tl.permute(tl.join(tl.cast(load_1, tl.float32), tl.zeros_like(tl.cast(load_1, tl.float32))), [0, 2, 1]), [16, 2]), tl.zeros_like(tl.reshape(tl.permute(tl.join(tl.cast(load_1, tl.float32), tl.zeros_like(tl.cast(load_1, tl.float32))), [0, 2, 1]), [16, 2]))), [0, 2, 1]), [16, 4]))), [0, 2, 1]), [16, 8]), tl.zeros_like(tl.reshape(tl.permute(tl.join(tl.reshape(tl.permute(tl.join(tl.reshape(tl.permute(tl.join(tl.cast(load_1, tl.float32), tl.zeros_like(tl.cast(load_1, tl.float32))), [0, 2, 1]), [16, 2]), tl.zeros_like(tl.reshape(tl.permute(tl.join(tl.cast(load_1, tl.float32), tl.zeros_like(tl.cast(load_1, tl.float32))), [0, 2, 1]), [16, 2]))), [0, 2, 1]), [16, 4]), tl.zeros_like(tl.reshape(tl.permute(tl.join(tl.reshape(tl.permute(tl.join(tl.cast(load_1, tl.float32), tl.zeros_like(tl.cast(load_1, tl.float32))), [0, 2, 1]), [16, 2]), tl.zeros_like(tl.reshape(tl.permute(tl.join(tl.cast(load_1, tl.float32), tl.zeros_like(tl.cast(load_1, tl.float32))), [0, 2, 1]), [16, 2]))), [0, 2, 1]), [16, 4]))), [0, 2, 1]), [16, 8]))), [0, 2, 1]), [16, 16]), input_precision='tf32', out_dtype=tl.float32), [16, 2, 8]), [0, 2, 1]))[0], [16, 2, 4]), [0, 2, 1]))[0], [16, 2, 2]), [0, 2, 1]))[0], [16, 2, 1]), [0, 2, 1]))[0]
acc = acc_copy_0 + dot
tl.store(out + indices_0[:, None] * out_stride_0, acc, mask_0[:, None])

def gather_gemv(w: Tensor, idx: Tensor, x: Tensor, *, _launcher=_default_launcher):
"""
Performs a gather operation on w using idx, then matrix-vector multiplication with x.

Args:
w (Tensor): Weight matrix of shape [B, S, S] where B is batch size, S is sequence length.
idx (Tensor): Index tensor of shape [N] containing indices to gather from dimension 0 of w.
x (Tensor): Vector of shape [S] to multiply with the gathered matrices.

Returns:
Tensor: Result of shape [N, S] where each row i is w[idx[i]] @ x.
"""
B, S1, S2 = w.size()
N = idx.size(0)
S = x.size(0)
assert S1 == S2, f'Weight matrix must be square, got {S1} != {S2}'
assert S == S1, f'Vector size {S} must match matrix size {S1}'
w_view = w.contiguous().view(B * S, S).to(x.dtype)
x = x.view(S, 1)
out = torch.empty([N * S, 1], dtype=x.dtype, device=x.device)
_BLOCK_SIZE_0 = 16
_BLOCK_SIZE_1 = 16
_launcher(_helion_gather_gemv, (triton.cdiv(out.size(0), _BLOCK_SIZE_0),), out, idx, w_view, x, out.size(0), idx.stride(0), out.stride(0), w_view.stride(0), w_view.stride(1), x.stride(0), S1, _BLOCK_SIZE_0, _BLOCK_SIZE_1, num_warps=8, num_stages=1)
return out.contiguous().view(N, S)

--- assertExpectedJournal(TestExamples.test_geglu)
from __future__ import annotations

Expand Down
24 changes: 24 additions & 0 deletions test/test_examples.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from helion._testing import TestCase
from helion._testing import check_example
from helion._testing import import_path
from helion._testing import is_cuda
from helion._testing import skipIfRefEager
from helion._testing import skipIfRocm

Expand Down Expand Up @@ -1112,6 +1113,29 @@ def test_kl_div(self):
)
)

def test_gather_gemv(self):
args = (
torch.randn([8, 1024, 1024], device=DEVICE, dtype=torch.float32),
torch.randint(0, 8, [2], device=DEVICE, dtype=torch.int32),
torch.randn([1024], device=DEVICE, dtype=torch.float32),
)

def expected(w, idx, x):
return w[idx].to(x.dtype) @ x

code = check_example(
"gather_gemv",
args,
expected(*args),
fn_name="gather_gemv",
block_sizes=[16, 16],
num_warps=8,
num_stages=1,
)

if is_cuda():
self.assertExpectedJournal(code)

def test_int4_gemm(self):
# Matrix dimensions
M, K, N = 256, 512, 256
Expand Down
Loading