diff --git a/benchmarks/run.py b/benchmarks/run.py index 745414e27..c695dfade 100644 --- a/benchmarks/run.py +++ b/benchmarks/run.py @@ -61,6 +61,11 @@ class RunResult: KERNEL_MAPPINGS: dict[str, tuple[str, ...]] = { # pyright: ignore[reportAssignmentType] # : (, , ) "vector_add": ("tritonbench.operators.vector_add.operator", "examples.add", "add"), + "addmm": ( + "tritonbench.operators.addmm.operator", + "examples.matmul", + "addmm_tritonbench", + ), "ragged_attention": ( "tritonbench.operators.ragged_attention.operator", "examples.jagged_hstu_attn", diff --git a/examples/matmul.py b/examples/matmul.py index bffbde7ee..2e9d8c563 100644 --- a/examples/matmul.py +++ b/examples/matmul.py @@ -90,9 +90,19 @@ def check(m: int, k: int, n: int) -> None: x = torch.randn([m, k], device="cuda", dtype=torch.float16) y = torch.randn([k, n], device="cuda", dtype=torch.float16) bias = torch.randn([n], device="cuda", dtype=torch.float16) + bias_scalar = torch.randn([1], device="cuda", dtype=torch.float16) # Test without bias run_example(matmul, torch.matmul, (x, y)) + # Test for addmm with scalar bias + def addmm(bias: Tensor, mat1: Tensor, mat2: Tensor) -> Tensor: + m, k = mat1.size() + k2, n = mat2.size() + bias = torch.broadcast_to(bias, [m, n]) + return matmul(mat1, mat2, lambda acc, tile: acc + bias[tile[0], tile[1]]) + + run_example(addmm, torch.addmm, (bias_scalar, x, y)) + # Test with bias def helion_linear(x: Tensor, y: Tensor, bias: Tensor) -> Tensor: return matmul(x, y, lambda acc, tile: acc + bias[tile[1]]) @@ -138,6 +148,23 @@ def matmul_tritonbench( return lambda: matmul(a, b) +def addmm_tritonbench(bias: Tensor, mat1: Tensor, mat2: Tensor) -> Callable: + """ + Wrapper for tritonbench that performs a matrix multiplication of the matrices + `mat1` and `mat2` followed by adding `bias` to the result. + Args: + bias (torch.Tensor): Bias to add in the epilogue. + mat1 (torch.Tensor): Left matrix. + mat2 (torch.Tensor): Right matrix. + Returns: + Callable: A callable that runs the matmul kernel with bias. + """ + m, k = mat1.size() + k2, n = mat2.size() + bias = torch.broadcast_to(bias, [m, n]) + return lambda: matmul(mat1, mat2, lambda acc, tile: acc + bias[tile[0], tile[1]]) + + # %% def main() -> None: """