-
Notifications
You must be signed in to change notification settings - Fork 77
Closed
Labels
matmulmatmul / gemm / mm / bmm / tl.dot / hl.dot related issuesmatmul / gemm / mm / bmm / tl.dot / hl.dot related issues
Description
Is your feature request related to a problem? Please describe.
We are implementing LCE helion kernel to compute (A.T @ B).T = B.T @ A + bias, where A has tensor shape [B, K, M] and B has tensor shape [K, N]. The initial implementation is as follows. But we have the error "NotImplementedError: torch.matmul with different input tensor dims is not supported in Helion kernel".
"""
Helion implementation of matrix multiplication with broadcast support for LCE operations.
"""
from typing import Optional
import helion
import helion.language as hl
import torch
from torch import Tensor
@helion.kernel(
static_shapes=True,
)
def helion_lce_fwd(
A: Tensor,
B: Tensor,
bias: Optional[Tensor] = None,
) -> Tensor:
"""
Forward to calculate LCE (A.T @ B).T = B.T @ A + bias if bias is not None
Args:
A: Input tensor of shape [batch, K, M]
B: Weight matrix of shape [K, N] to be broadcast across batches
bias: Optional bias vector of shape [N, 1]
Returns:
Output tensor of shape [batch, N, M]
"""
batch, K, M = A.size()
K2, N = B.size()
assert K == K2, f"Inner dimension mismatch: A.size(1)={K}, B.size(0)={K2}"
out = torch.empty(
[batch, N, M],
dtype=torch.promote_types(A.dtype, B.dtype),
device=A.device,
)
for tile_b, tile_m, tile_n in hl.tile([batch, M, N]):
acc = hl.zeros([tile_b, tile_n, tile_m], dtype=torch.float32)
for tile_k in hl.tile(K):
a_tile = A[tile_b, tile_k, tile_m] # 3D indexing: [tile_b, tile_k, tile_m]
b_tile = B[tile_k, tile_n].T # 2D: [tile_n, tile_k]
acc = torch.baddbmm(acc, b_tile, a_tile)
if bias is not None:
acc += bias[None, tile_n, :]
out[tile_b, tile_n, tile_m] = acc.to(out.dtype)
return out
def main() -> None:
"""
Main function to test and benchmark the Helion matmul_broadcast kernel.
"""
B, M, N, K = 1152, 1152, 768, 256
dtype = torch.bfloat16
device = "cuda"
print("===== Helion MatMul Broadcast Test =====")
A = torch.randn(B, K, M, device=device, dtype=dtype, requires_grad=False)
weight = torch.randn(K, N, device=device, dtype=dtype, requires_grad=False)
bias = torch.randn(N, 1, device=device, dtype=dtype, requires_grad=False)
print(f"Input shapes: A={A.shape}, weight={weight.shape}, bias={bias.shape}")
C1 = helion_lce_fwd(A, weight, bias)
print(f"Output shape: {C1.shape}")
expected_shape = (B, N, M)
assert C1.shape == expected_shape, f"Expected {expected_shape}, got {C1.shape}"
C1_ref = torch.matmul(weight.T, A) + bias
torch.testing.assert_close(C1, C1_ref, atol=1e-2, rtol=1e-2)
print("✓ Correctness check passed")
if __name__ == "__main__":
main()
Metadata
Metadata
Assignees
Labels
matmulmatmul / gemm / mm / bmm / tl.dot / hl.dot related issuesmatmul / gemm / mm / bmm / tl.dot / hl.dot related issues