From c40fb26b358692d7eb66b6b06f8d925803bb2241 Mon Sep 17 00:00:00 2001 From: Sibylau Date: Wed, 3 Sep 2025 17:43:18 -0700 Subject: [PATCH 1/3] [Benchmark] add addmm example and test --- benchmarks/run.py | 21 +++++++--- examples/addmm.py | 83 +++++++++++++++++++++++++++++++++++++ test/test_examples.expected | 61 +++++++++++++++++++++++++++ test/test_examples.py | 20 ++++++++- 4 files changed, 177 insertions(+), 8 deletions(-) create mode 100644 examples/addmm.py diff --git a/benchmarks/run.py b/benchmarks/run.py index 10f93862c..c7b7f70a5 100644 --- a/benchmarks/run.py +++ b/benchmarks/run.py @@ -61,6 +61,11 @@ class RunResult: KERNEL_MAPPINGS: dict[str, tuple[str, ...]] = { # pyright: ignore[reportAssignmentType] # : (, , ) "vector_add": ("tritonbench.operators.vector_add.operator", "examples.add", "add"), + "addmm": ( + "tritonbench.operators.addmm.operator", + "examples.addmm", + "addmm", + ), "embedding": ( "tritonbench.operators.embedding.operator", "examples.embedding", @@ -89,9 +94,11 @@ class RunResult: "tritonbench.operators.jagged_mean.operator", "examples.jagged_mean", "jagged_mean_tritonbench", - {"B": 32, "M": 8, "seqlen": 64} - if os.environ.get("HELION_DEV_LOW_VRAM", "0") == "1" - else {}, + ( + {"B": 32, "M": 8, "seqlen": 64} + if os.environ.get("HELION_DEV_LOW_VRAM", "0") == "1" + else {} + ), ), "fp8_gemm": ( "tritonbench.operators.fp8_gemm.fp8_gemm", @@ -110,9 +117,11 @@ class RunResult: "tritonbench.operators.cross_entropy.operator", "examples.cross_entropy", "cross_entropy", - {"B": 4, "T": 512, "v_range": "10,15"} - if os.environ.get("HELION_DEV_LOW_VRAM", "0") == "1" - else {}, + ( + {"B": 4, "T": 512, "v_range": "10,15"} + if os.environ.get("HELION_DEV_LOW_VRAM", "0") == "1" + else {} + ), ), "fp8_attention": ( "tritonbench.operators.fp8_attention.operator", diff --git a/examples/addmm.py b/examples/addmm.py new file mode 100644 index 000000000..1ec7c81fe --- /dev/null +++ b/examples/addmm.py @@ -0,0 +1,83 @@ +""" +Helion Addmm Kernel Example +============================ +This example demonstrates a Helion kernel implementation of matrix multiplication followed by an addition +It includes correctness checks against the PyTorch baseline and integration with tritonbench. +""" + +# %% +# Imports +# ------- +from __future__ import annotations + +import torch +from torch import Tensor + +import helion +from helion._testing import run_example +import helion.language as hl + + +# %% +# Addmm Kernel +# -------------- +@helion.kernel() +def addmm(a: Tensor, mat1: Tensor, mat2: Tensor) -> Tensor: + """ + Performs a matrix multiplication of the matrices `mat1` and `mat2` and adds `a` to the result. + Args: + a (Tensor): Input tensor of shape [m, k]. + mat1 (Tensor): Input tensor of shape [k, n]. + mat2 (Tensor): Input tensor of shape [m, n]. + Returns: + Tensor: Resulting tensor of shape [m, n]. + """ + m, k = mat1.size() + k2, n = mat2.size() + assert k == k2, f"size mismatch {k} != {k2}" + out = torch.empty( + [m, n], dtype=torch.promote_types(mat1.dtype, mat2.dtype), device=mat1.device + ) + a = torch.broadcast_to(a, (m, n)) + for tile_m, tile_n in hl.tile([m, n]): + acc = hl.zeros([tile_m, tile_n], dtype=torch.float32) + for tile_k in hl.tile(k): + acc = torch.addmm( + acc, + mat1[tile_m, tile_k], + mat2[tile_k, tile_n], + ) + out[tile_m, tile_n] = acc + a[tile_m, tile_n] + return out + + +# %% +# Verification Function +# ------------------- +def check(m: int, n: int, k: int) -> None: + """ + Verify the add kernel implementation against PyTorch's native addmm function. + + Args: + m (int): Number of rows in matrix x. + n (int): Number of columns in matrix y. + k (int): Number of columns in matrix x and rows in matrix y. + """ + a = torch.randn([m], device="cuda", dtype=torch.float16) + mat1 = torch.randn([m, k], device="cuda", dtype=torch.float16) + mat2 = torch.randn([k, n], device="cuda", dtype=torch.float16) + run_example(addmm, torch.addmm, (a, mat1, mat2)) + + +# %% +# Main Function +# ----------- +def main() -> None: + """ + Main entry point that runs the add kernel verification with a 1024x512 tensor and a 512x1024 tensor. + """ + check(1024, 1024, 512) + + +if __name__ == "__main__": + main() diff --git a/test/test_examples.expected b/test/test_examples.expected index 86a394f2d..a742c1b9e 100644 --- a/test/test_examples.expected +++ b/test/test_examples.expected @@ -38,6 +38,67 @@ def add(x: torch.Tensor, y: torch.Tensor, *, _launcher=_default_launcher): _launcher(_helion_add, (triton.cdiv(x.size(0) * x.size(1), _BLOCK_SIZE_0_1), 1, 1), x, y, out, x.size(0), x.size(1), out.stride(0), out.stride(1), x.stride(0), x.stride(1), y.stride(0), y.stride(1), _BLOCK_SIZE_0_1, num_warps=4, num_stages=3) return out +--- assertExpectedJournal(TestExamples.test_addmm) +from __future__ import annotations + +import torch +import triton +import triton.language as tl +from helion.runtime import default_launcher as _default_launcher + +@triton.jit +def _helion_addmm(mat1, mat2, a, out, a_stride_0, a_stride_1, mat1_stride_0, mat1_stride_1, mat2_stride_0, mat2_stride_1, out_stride_0, out_stride_1, m, n, k, _BLOCK_SIZE_0: tl.constexpr, _BLOCK_SIZE_1: tl.constexpr, _BLOCK_SIZE_2: tl.constexpr): + num_pid_m = tl.cdiv(m, _BLOCK_SIZE_0) + num_pid_n = tl.cdiv(n, _BLOCK_SIZE_1) + inner_2d_pid = tl.program_id(0) + num_pid_in_group = 4 * num_pid_n + group_id = inner_2d_pid // num_pid_in_group + first_pid_m = group_id * 4 + group_size_m = min(num_pid_m - first_pid_m, 4) + pid_0 = first_pid_m + inner_2d_pid % num_pid_in_group % group_size_m + pid_1 = inner_2d_pid % num_pid_in_group // group_size_m + offset_0 = pid_0 * _BLOCK_SIZE_0 + indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32) + mask_0 = indices_0 < m + offset_1 = pid_1 * _BLOCK_SIZE_1 + indices_1 = (offset_1 + tl.arange(0, _BLOCK_SIZE_1)).to(tl.int32) + mask_1 = indices_1 < n + acc = tl.full([_BLOCK_SIZE_0, _BLOCK_SIZE_1], 0.0, tl.float32) + for offset_2 in tl.range(0, k.to(tl.int32), _BLOCK_SIZE_2): + indices_2 = offset_2 + tl.arange(0, _BLOCK_SIZE_2).to(tl.int32) + mask_2 = indices_2 < k + acc_copy = acc + acc_copy_0 = acc_copy + load = tl.load(mat1 + (indices_0[:, None] * mat1_stride_0 + indices_2[None, :] * mat1_stride_1), mask_0[:, None] & mask_2[None, :], other=0) + load_1 = tl.load(mat2 + (indices_2[:, None] * mat2_stride_0 + indices_1[None, :] * mat2_stride_1), mask_2[:, None] & mask_1[None, :], other=0) + acc = acc_copy_0 + tl.cast(tl.dot(tl.cast(load, tl.float16), tl.cast(load_1, tl.float16), input_precision='tf32'), tl.float32) + load_2 = tl.load(a + (indices_0[:, None] * a_stride_0 + indices_1[None, :] * a_stride_1), mask_0[:, None] & mask_1[None, :], other=0) + v_0 = tl.cast(load_2, tl.float32) + v_1 = acc + v_0 + v_2 = tl.cast(v_1, tl.float16) + tl.store(out + (indices_0[:, None] * out_stride_0 + indices_1[None, :] * out_stride_1), v_2, mask_0[:, None] & mask_1[None, :]) + +def addmm(a: Tensor, mat1: Tensor, mat2: Tensor, *, _launcher=_default_launcher): + """ + Performs a matrix multiplication of the matrices `mat1` and `mat2` and adds `a` to the result. + Args: + a (Tensor): Input tensor of shape [m, k]. + mat1 (Tensor): Input tensor of shape [k, n]. + mat2 (Tensor): Input tensor of shape [m, n]. + Returns: + Tensor: Resulting tensor of shape [m, n]. + """ + m, k = mat1.size() + k2, n = mat2.size() + assert k == k2, f'size mismatch {k} != {k2}' + out = torch.empty([m, n], dtype=torch.promote_types(mat1.dtype, mat2.dtype), device=mat1.device) + a = torch.broadcast_to(a, (m, n)) + _BLOCK_SIZE_0 = 16 + _BLOCK_SIZE_1 = 16 + _BLOCK_SIZE_2 = 16 + _launcher(_helion_addmm, (triton.cdiv(m, _BLOCK_SIZE_0) * triton.cdiv(n, _BLOCK_SIZE_1),), mat1, mat2, a, out, a.stride(0), a.stride(1), mat1.stride(0), mat1.stride(1), mat2.stride(0), mat2.stride(1), out.stride(0), out.stride(1), m, n, k, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, num_warps=4, num_stages=3) + return out + --- assertExpectedJournal(TestExamples.test_attention_block_pointer) from __future__ import annotations diff --git a/test/test_examples.py b/test/test_examples.py index c8ddf07a9..d3ddfd5fd 100644 --- a/test/test_examples.py +++ b/test/test_examples.py @@ -14,8 +14,8 @@ from helion._testing import skipIfRefEager from helion._testing import skipIfRocm -torch.backends.cuda.matmul.fp32_precision = "tf32" -torch.backends.cudnn.conv.fp32_precision = "tf32" +# torch.backends.cuda.matmul.fp32_precision = "tf32" +# torch.backends.cudnn.conv.fp32_precision = "tf32" class TestExamples(RefEagerTestBase, TestCase): @@ -30,6 +30,22 @@ def test_add(self): ) ) + def test_addmm(self): + args = ( + torch.randn((1), device=DEVICE, dtype=torch.float16), + torch.randn([512, 256], device=DEVICE, dtype=torch.float16), + torch.randn([256, 512], device=DEVICE, dtype=torch.float16), + ) + self.assertExpectedJournal( + check_example( + "addmm", + args, + torch.addmm(*args), + block_sizes=[16, 16, 16], + l2_grouping=4, + ) + ) + def test_matmul(self): args = ( torch.randn([128, 128], device=DEVICE, dtype=torch.float32), From c0dc1e9b7935d0c27f44ab4d13ea6b83f885bd90 Mon Sep 17 00:00:00 2001 From: Jie Liu Date: Wed, 3 Sep 2025 17:54:32 -0700 Subject: [PATCH 2/3] Uncomment tf32 precision for CUDA matmul and conv --- test/test_examples.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/test_examples.py b/test/test_examples.py index d3ddfd5fd..bca061dd0 100644 --- a/test/test_examples.py +++ b/test/test_examples.py @@ -14,8 +14,8 @@ from helion._testing import skipIfRefEager from helion._testing import skipIfRocm -# torch.backends.cuda.matmul.fp32_precision = "tf32" -# torch.backends.cudnn.conv.fp32_precision = "tf32" +torch.backends.cuda.matmul.fp32_precision = "tf32" +torch.backends.cudnn.conv.fp32_precision = "tf32" class TestExamples(RefEagerTestBase, TestCase): From b1e7ccbe294e746a612755351b3cdd3a5107cffc Mon Sep 17 00:00:00 2001 From: Sibylau Date: Thu, 4 Sep 2025 12:15:08 -0700 Subject: [PATCH 3/3] [Benchmark] add addmm tritonbench wrapper using Helion matmul example --- benchmarks/run.py | 20 ++++----- examples/addmm.py | 83 ------------------------------------- examples/matmul.py | 27 ++++++++++++ test/test_examples.expected | 61 --------------------------- test/test_examples.py | 16 ------- 5 files changed, 35 insertions(+), 172 deletions(-) delete mode 100644 examples/addmm.py diff --git a/benchmarks/run.py b/benchmarks/run.py index c7b7f70a5..463fcfa20 100644 --- a/benchmarks/run.py +++ b/benchmarks/run.py @@ -63,8 +63,8 @@ class RunResult: "vector_add": ("tritonbench.operators.vector_add.operator", "examples.add", "add"), "addmm": ( "tritonbench.operators.addmm.operator", - "examples.addmm", - "addmm", + "examples.matmul", + "addmm_tritonbench", ), "embedding": ( "tritonbench.operators.embedding.operator", @@ -94,11 +94,9 @@ class RunResult: "tritonbench.operators.jagged_mean.operator", "examples.jagged_mean", "jagged_mean_tritonbench", - ( - {"B": 32, "M": 8, "seqlen": 64} - if os.environ.get("HELION_DEV_LOW_VRAM", "0") == "1" - else {} - ), + {"B": 32, "M": 8, "seqlen": 64} + if os.environ.get("HELION_DEV_LOW_VRAM", "0") == "1" + else {}, ), "fp8_gemm": ( "tritonbench.operators.fp8_gemm.fp8_gemm", @@ -117,11 +115,9 @@ class RunResult: "tritonbench.operators.cross_entropy.operator", "examples.cross_entropy", "cross_entropy", - ( - {"B": 4, "T": 512, "v_range": "10,15"} - if os.environ.get("HELION_DEV_LOW_VRAM", "0") == "1" - else {} - ), + {"B": 4, "T": 512, "v_range": "10,15"} + if os.environ.get("HELION_DEV_LOW_VRAM", "0") == "1" + else {}, ), "fp8_attention": ( "tritonbench.operators.fp8_attention.operator", diff --git a/examples/addmm.py b/examples/addmm.py deleted file mode 100644 index 1ec7c81fe..000000000 --- a/examples/addmm.py +++ /dev/null @@ -1,83 +0,0 @@ -""" -Helion Addmm Kernel Example -============================ -This example demonstrates a Helion kernel implementation of matrix multiplication followed by an addition -It includes correctness checks against the PyTorch baseline and integration with tritonbench. -""" - -# %% -# Imports -# ------- -from __future__ import annotations - -import torch -from torch import Tensor - -import helion -from helion._testing import run_example -import helion.language as hl - - -# %% -# Addmm Kernel -# -------------- -@helion.kernel() -def addmm(a: Tensor, mat1: Tensor, mat2: Tensor) -> Tensor: - """ - Performs a matrix multiplication of the matrices `mat1` and `mat2` and adds `a` to the result. - Args: - a (Tensor): Input tensor of shape [m, k]. - mat1 (Tensor): Input tensor of shape [k, n]. - mat2 (Tensor): Input tensor of shape [m, n]. - Returns: - Tensor: Resulting tensor of shape [m, n]. - """ - m, k = mat1.size() - k2, n = mat2.size() - assert k == k2, f"size mismatch {k} != {k2}" - out = torch.empty( - [m, n], dtype=torch.promote_types(mat1.dtype, mat2.dtype), device=mat1.device - ) - a = torch.broadcast_to(a, (m, n)) - for tile_m, tile_n in hl.tile([m, n]): - acc = hl.zeros([tile_m, tile_n], dtype=torch.float32) - for tile_k in hl.tile(k): - acc = torch.addmm( - acc, - mat1[tile_m, tile_k], - mat2[tile_k, tile_n], - ) - out[tile_m, tile_n] = acc + a[tile_m, tile_n] - return out - - -# %% -# Verification Function -# ------------------- -def check(m: int, n: int, k: int) -> None: - """ - Verify the add kernel implementation against PyTorch's native addmm function. - - Args: - m (int): Number of rows in matrix x. - n (int): Number of columns in matrix y. - k (int): Number of columns in matrix x and rows in matrix y. - """ - a = torch.randn([m], device="cuda", dtype=torch.float16) - mat1 = torch.randn([m, k], device="cuda", dtype=torch.float16) - mat2 = torch.randn([k, n], device="cuda", dtype=torch.float16) - run_example(addmm, torch.addmm, (a, mat1, mat2)) - - -# %% -# Main Function -# ----------- -def main() -> None: - """ - Main entry point that runs the add kernel verification with a 1024x512 tensor and a 512x1024 tensor. - """ - check(1024, 1024, 512) - - -if __name__ == "__main__": - main() diff --git a/examples/matmul.py b/examples/matmul.py index bffbde7ee..2e9d8c563 100644 --- a/examples/matmul.py +++ b/examples/matmul.py @@ -90,9 +90,19 @@ def check(m: int, k: int, n: int) -> None: x = torch.randn([m, k], device="cuda", dtype=torch.float16) y = torch.randn([k, n], device="cuda", dtype=torch.float16) bias = torch.randn([n], device="cuda", dtype=torch.float16) + bias_scalar = torch.randn([1], device="cuda", dtype=torch.float16) # Test without bias run_example(matmul, torch.matmul, (x, y)) + # Test for addmm with scalar bias + def addmm(bias: Tensor, mat1: Tensor, mat2: Tensor) -> Tensor: + m, k = mat1.size() + k2, n = mat2.size() + bias = torch.broadcast_to(bias, [m, n]) + return matmul(mat1, mat2, lambda acc, tile: acc + bias[tile[0], tile[1]]) + + run_example(addmm, torch.addmm, (bias_scalar, x, y)) + # Test with bias def helion_linear(x: Tensor, y: Tensor, bias: Tensor) -> Tensor: return matmul(x, y, lambda acc, tile: acc + bias[tile[1]]) @@ -138,6 +148,23 @@ def matmul_tritonbench( return lambda: matmul(a, b) +def addmm_tritonbench(bias: Tensor, mat1: Tensor, mat2: Tensor) -> Callable: + """ + Wrapper for tritonbench that performs a matrix multiplication of the matrices + `mat1` and `mat2` followed by adding `bias` to the result. + Args: + bias (torch.Tensor): Bias to add in the epilogue. + mat1 (torch.Tensor): Left matrix. + mat2 (torch.Tensor): Right matrix. + Returns: + Callable: A callable that runs the matmul kernel with bias. + """ + m, k = mat1.size() + k2, n = mat2.size() + bias = torch.broadcast_to(bias, [m, n]) + return lambda: matmul(mat1, mat2, lambda acc, tile: acc + bias[tile[0], tile[1]]) + + # %% def main() -> None: """ diff --git a/test/test_examples.expected b/test/test_examples.expected index a742c1b9e..86a394f2d 100644 --- a/test/test_examples.expected +++ b/test/test_examples.expected @@ -38,67 +38,6 @@ def add(x: torch.Tensor, y: torch.Tensor, *, _launcher=_default_launcher): _launcher(_helion_add, (triton.cdiv(x.size(0) * x.size(1), _BLOCK_SIZE_0_1), 1, 1), x, y, out, x.size(0), x.size(1), out.stride(0), out.stride(1), x.stride(0), x.stride(1), y.stride(0), y.stride(1), _BLOCK_SIZE_0_1, num_warps=4, num_stages=3) return out ---- assertExpectedJournal(TestExamples.test_addmm) -from __future__ import annotations - -import torch -import triton -import triton.language as tl -from helion.runtime import default_launcher as _default_launcher - -@triton.jit -def _helion_addmm(mat1, mat2, a, out, a_stride_0, a_stride_1, mat1_stride_0, mat1_stride_1, mat2_stride_0, mat2_stride_1, out_stride_0, out_stride_1, m, n, k, _BLOCK_SIZE_0: tl.constexpr, _BLOCK_SIZE_1: tl.constexpr, _BLOCK_SIZE_2: tl.constexpr): - num_pid_m = tl.cdiv(m, _BLOCK_SIZE_0) - num_pid_n = tl.cdiv(n, _BLOCK_SIZE_1) - inner_2d_pid = tl.program_id(0) - num_pid_in_group = 4 * num_pid_n - group_id = inner_2d_pid // num_pid_in_group - first_pid_m = group_id * 4 - group_size_m = min(num_pid_m - first_pid_m, 4) - pid_0 = first_pid_m + inner_2d_pid % num_pid_in_group % group_size_m - pid_1 = inner_2d_pid % num_pid_in_group // group_size_m - offset_0 = pid_0 * _BLOCK_SIZE_0 - indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32) - mask_0 = indices_0 < m - offset_1 = pid_1 * _BLOCK_SIZE_1 - indices_1 = (offset_1 + tl.arange(0, _BLOCK_SIZE_1)).to(tl.int32) - mask_1 = indices_1 < n - acc = tl.full([_BLOCK_SIZE_0, _BLOCK_SIZE_1], 0.0, tl.float32) - for offset_2 in tl.range(0, k.to(tl.int32), _BLOCK_SIZE_2): - indices_2 = offset_2 + tl.arange(0, _BLOCK_SIZE_2).to(tl.int32) - mask_2 = indices_2 < k - acc_copy = acc - acc_copy_0 = acc_copy - load = tl.load(mat1 + (indices_0[:, None] * mat1_stride_0 + indices_2[None, :] * mat1_stride_1), mask_0[:, None] & mask_2[None, :], other=0) - load_1 = tl.load(mat2 + (indices_2[:, None] * mat2_stride_0 + indices_1[None, :] * mat2_stride_1), mask_2[:, None] & mask_1[None, :], other=0) - acc = acc_copy_0 + tl.cast(tl.dot(tl.cast(load, tl.float16), tl.cast(load_1, tl.float16), input_precision='tf32'), tl.float32) - load_2 = tl.load(a + (indices_0[:, None] * a_stride_0 + indices_1[None, :] * a_stride_1), mask_0[:, None] & mask_1[None, :], other=0) - v_0 = tl.cast(load_2, tl.float32) - v_1 = acc + v_0 - v_2 = tl.cast(v_1, tl.float16) - tl.store(out + (indices_0[:, None] * out_stride_0 + indices_1[None, :] * out_stride_1), v_2, mask_0[:, None] & mask_1[None, :]) - -def addmm(a: Tensor, mat1: Tensor, mat2: Tensor, *, _launcher=_default_launcher): - """ - Performs a matrix multiplication of the matrices `mat1` and `mat2` and adds `a` to the result. - Args: - a (Tensor): Input tensor of shape [m, k]. - mat1 (Tensor): Input tensor of shape [k, n]. - mat2 (Tensor): Input tensor of shape [m, n]. - Returns: - Tensor: Resulting tensor of shape [m, n]. - """ - m, k = mat1.size() - k2, n = mat2.size() - assert k == k2, f'size mismatch {k} != {k2}' - out = torch.empty([m, n], dtype=torch.promote_types(mat1.dtype, mat2.dtype), device=mat1.device) - a = torch.broadcast_to(a, (m, n)) - _BLOCK_SIZE_0 = 16 - _BLOCK_SIZE_1 = 16 - _BLOCK_SIZE_2 = 16 - _launcher(_helion_addmm, (triton.cdiv(m, _BLOCK_SIZE_0) * triton.cdiv(n, _BLOCK_SIZE_1),), mat1, mat2, a, out, a.stride(0), a.stride(1), mat1.stride(0), mat1.stride(1), mat2.stride(0), mat2.stride(1), out.stride(0), out.stride(1), m, n, k, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, num_warps=4, num_stages=3) - return out - --- assertExpectedJournal(TestExamples.test_attention_block_pointer) from __future__ import annotations diff --git a/test/test_examples.py b/test/test_examples.py index bca061dd0..c8ddf07a9 100644 --- a/test/test_examples.py +++ b/test/test_examples.py @@ -30,22 +30,6 @@ def test_add(self): ) ) - def test_addmm(self): - args = ( - torch.randn((1), device=DEVICE, dtype=torch.float16), - torch.randn([512, 256], device=DEVICE, dtype=torch.float16), - torch.randn([256, 512], device=DEVICE, dtype=torch.float16), - ) - self.assertExpectedJournal( - check_example( - "addmm", - args, - torch.addmm(*args), - block_sizes=[16, 16, 16], - l2_grouping=4, - ) - ) - def test_matmul(self): args = ( torch.randn([128, 128], device=DEVICE, dtype=torch.float32),