Skip to content

NotImplementedError: torch.matmul with different input tensor dims is not supported in Helion kernel #1109

@mengluy0125

Description

@mengluy0125

Is your feature request related to a problem? Please describe.

We are implementing LCE helion kernel to compute (A.T @ B).T = B.T @ A + bias, where A has tensor shape [B, K, M] and B has tensor shape [K, N]. The initial implementation is as follows. But we have the error "NotImplementedError: torch.matmul with different input tensor dims is not supported in Helion kernel".

"""
Helion implementation of matrix multiplication with broadcast support for LCE operations.
"""

from typing import Optional

import helion
import helion.language as hl

import torch
from torch import Tensor


@helion.kernel(
    static_shapes=True,
)
def helion_lce_fwd(
    A: Tensor,
    B: Tensor,
    bias: Optional[Tensor] = None,
) -> Tensor:
    """
    Forward to calculate LCE (A.T @ B).T = B.T @ A + bias if bias is not None

    Args:
        A: Input tensor of shape [batch, K, M]
        B: Weight matrix of shape [K, N] to be broadcast across batches
        bias: Optional bias vector of shape [N, 1]

    Returns:
        Output tensor of shape [batch, N, M]
    """

    batch, K, M = A.size()
    K2, N = B.size()
    assert K == K2, f"Inner dimension mismatch: A.size(1)={K}, B.size(0)={K2}"

    out = torch.empty(
        [batch, N, M],
        dtype=torch.promote_types(A.dtype, B.dtype),
        device=A.device,
    )

    for tile_b, tile_m, tile_n in hl.tile([batch, M, N]):
        acc = hl.zeros([tile_b, tile_n, tile_m], dtype=torch.float32)

        for tile_k in hl.tile(K):
            a_tile = A[tile_b, tile_k, tile_m]  # 3D indexing: [tile_b, tile_k, tile_m]
            b_tile = B[tile_k, tile_n].T  # 2D: [tile_n, tile_k]
            acc = torch.baddbmm(acc, b_tile, a_tile)

        if bias is not None:
            acc += bias[None, tile_n, :]

        out[tile_b, tile_n, tile_m] = acc.to(out.dtype)

    return out


def main() -> None:
    """
    Main function to test and benchmark the Helion matmul_broadcast kernel.
    """
    B, M, N, K = 1152, 1152, 768, 256
    dtype = torch.bfloat16
    device = "cuda"

    print("===== Helion MatMul Broadcast Test =====")

    A = torch.randn(B, K, M, device=device, dtype=dtype, requires_grad=False)
    weight = torch.randn(K, N, device=device, dtype=dtype, requires_grad=False)
    bias = torch.randn(N, 1, device=device, dtype=dtype, requires_grad=False)

    print(f"Input shapes: A={A.shape}, weight={weight.shape}, bias={bias.shape}")

    C1 = helion_lce_fwd(A, weight, bias)
    print(f"Output shape: {C1.shape}")
    expected_shape = (B, N, M)
    assert C1.shape == expected_shape, f"Expected {expected_shape}, got {C1.shape}"

    C1_ref = torch.matmul(weight.T, A) + bias
    torch.testing.assert_close(C1, C1_ref, atol=1e-2, rtol=1e-2)
    print("✓ Correctness check passed")


if __name__ == "__main__":
    main()

Metadata

Metadata

Assignees

Labels

matmulmatmul / gemm / mm / bmm / tl.dot / hl.dot related issues

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions