From 9cc96bd95b1de574a67b9ee1c2158a16c41122fd Mon Sep 17 00:00:00 2001 From: Will Feng Date: Tue, 23 Sep 2025 14:23:01 -0700 Subject: [PATCH 1/2] Fix misaligned address error for matmul --- helion/_compiler/indexing_strategy.py | 6 ++- test/test_tensor_descriptor.py | 65 +++++++++++++++++++++++++++ 2 files changed, 69 insertions(+), 2 deletions(-) diff --git a/helion/_compiler/indexing_strategy.py b/helion/_compiler/indexing_strategy.py index f26cc333f..9e53d0559 100644 --- a/helion/_compiler/indexing_strategy.py +++ b/helion/_compiler/indexing_strategy.py @@ -216,8 +216,10 @@ def valid_block_size( if fake_tensor.ndim == 2 and block_size < threshold: return False - # was getting some IMAs with small block sizes even in non-stride 1 dims - return block_size * element_size >= 16 or (block_size == 1 and stride != 1) + # Tensor-descriptor path (TMA + WGMMA / stmatrix writes) + # moves data in 16-byte chunks. Enforce a 16-byte minimum so the + # generated stores stay aligned and avoid misaligned-address errors. + return block_size * element_size >= 16 # 4) Check minimum 16 bytes in each dimension sizes = fake_tensor.size() diff --git a/test/test_tensor_descriptor.py b/test/test_tensor_descriptor.py index 39b0fdba0..54d0d0f1f 100644 --- a/test/test_tensor_descriptor.py +++ b/test/test_tensor_descriptor.py @@ -198,6 +198,71 @@ def kernel_different_blocks(x: torch.Tensor) -> torch.Tensor: # The block sizes should also be permuted in the tensor descriptor # This is important for correctness + @unittest.skipUnless( + supports_tensor_descriptor(), "Tensor descriptor support is required" + ) + def test_tiny_matmul_tile_fallback(self) -> None: + """Tensor descriptor indexing should be rejected when the tile is too small.""" + + @helion.kernel( + config=helion.Config( + block_sizes=[1, 16, 16], + indexing="tensor_descriptor", + l2_groupings=[2], + loop_orders=[[0, 1]], + num_stages=4, + num_warps=1, + pid_type="persistent_blocked", + range_flattens=[True, True], + range_multi_buffers=[False, True], + range_num_stages=[0, 1], + range_unroll_factors=[0, 4], + ), + static_shapes=True, + ) + def matmul(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: + m, k = x.size() + k2, n = y.size() + assert k == k2 + out = torch.empty( + [m, n], + dtype=torch.promote_types(x.dtype, y.dtype), + device=x.device, + ) + for tile_m, tile_n in hl.tile([m, n]): + acc = hl.zeros([tile_m, tile_n], dtype=torch.float32) + for tile_k in hl.tile(k): + acc = torch.addmm(acc, x[tile_m, tile_k], y[tile_k, tile_n]) + out[tile_m, tile_n] = acc.to(out.dtype) + return out + + x = torch.randn((64, 64), device=DEVICE, dtype=torch.float16) + y = torch.randn((64, 64), device=DEVICE, dtype=torch.float16) + + code, result = code_and_output(matmul, (x, y)) + torch.cuda.synchronize() + expected = torch.matmul(x, y) + torch.testing.assert_close(result, expected, atol=1e-2, rtol=1e-2) + + # Ensure we fall back to pointer indexing for accesses that would use the + # 1x16 tile - there should be no tensor descriptor for the x or out tensors. + self.assertNotIn("x_desc = tl.make_tensor_descriptor", code) + self.assertNotIn("out_desc = tl.make_tensor_descriptor", code) + # The K dimension still has a valid tile size, so the column operand can + # keep using tensor descriptors. + self.assertIn("y_desc = tl.make_tensor_descriptor", code) + + # A larger tile should still be able to use tensor descriptors + code_large, result_large = code_and_output( + matmul, + (x, y), + block_sizes=[16, 16, 16], + indexing="tensor_descriptor", + ) + torch.cuda.synchronize() + torch.testing.assert_close(result_large, expected, atol=1e-2, rtol=1e-2) + self.assertIn(get_tensor_descriptor_fn_name(), code_large) + @unittest.skipUnless( supports_tensor_descriptor(), "Tensor descriptor support is required" ) From 0a54b5aefd688f8bc491bdac40ff075995b863c8 Mon Sep 17 00:00:00 2001 From: Will Feng Date: Tue, 23 Sep 2025 14:59:11 -0700 Subject: [PATCH 2/2] update expected --- test/test_tensor_descriptor.expected | 46 ++++++++++++---------------- 1 file changed, 20 insertions(+), 26 deletions(-) diff --git a/test/test_tensor_descriptor.expected b/test/test_tensor_descriptor.expected index 3a1bbbb90..63f2aea9d 100644 --- a/test/test_tensor_descriptor.expected +++ b/test/test_tensor_descriptor.expected @@ -5,32 +5,27 @@ Update expected outputs by running tests with the EXPECTTEST_ACCEPT=1 environmen from __future__ import annotations import torch -import helion import triton import triton.language as tl from torch._inductor.runtime import triton_helpers from torch._inductor.runtime.triton_compat import libdevice from helion.runtime import default_launcher as _default_launcher -helion.runtime.set_triton_allocator() - @triton.jit -def _helion_attention(q_view, k_view, v_view, out, k_view_size_0, k_view_size_2, out_size_0, out_size_1, q_in_size_1, q_view_size_0, q_view_size_1, v_view_size_0, v_view_size_1, k_view_stride_0, k_view_stride_1, k_view_stride_2, out_stride_0, out_stride_1, out_stride_2, q_view_stride_0, q_view_stride_1, q_view_stride_2, v_view_stride_0, v_view_stride_1, v_view_stride_2, m_dim, n_dim, _BLOCK_SIZE_1: tl.constexpr, _BLOCK_SIZE_0: tl.constexpr, _BLOCK_SIZE_3: tl.constexpr): - q_view_desc = tl.make_tensor_descriptor(q_view, [q_view_size_0, q_view_size_1, 64], [q_view_stride_0, q_view_stride_1, q_view_stride_2], [_BLOCK_SIZE_0, _BLOCK_SIZE_1, 64]) - k_view_desc = tl.make_tensor_descriptor(k_view, [k_view_size_0, k_view_size_2, 64], [k_view_stride_0, k_view_stride_2, k_view_stride_1], [_BLOCK_SIZE_0, _BLOCK_SIZE_3, 64]) - v_view_desc = tl.make_tensor_descriptor(v_view, [v_view_size_0, v_view_size_1, 64], [v_view_stride_0, v_view_stride_1, v_view_stride_2], [_BLOCK_SIZE_0, _BLOCK_SIZE_3, 64]) - out_desc = tl.make_tensor_descriptor(out, [out_size_0, out_size_1, 64], [out_stride_0, out_stride_1, out_stride_2], [_BLOCK_SIZE_0, _BLOCK_SIZE_1, 64]) +def _helion_attention(q_view, k_view, v_view, out, q_in_size_1, k_view_stride_0, k_view_stride_1, k_view_stride_2, out_stride_0, out_stride_1, out_stride_2, q_view_stride_0, q_view_stride_1, q_view_stride_2, v_view_stride_0, v_view_stride_1, v_view_stride_2, m_dim, n_dim, _BLOCK_SIZE_1: tl.constexpr, _RDIM_SIZE_2: tl.constexpr, _BLOCK_SIZE_0: tl.constexpr, _BLOCK_SIZE_3: tl.constexpr): num_blocks_0 = q_in_size_1 pid_0 = tl.program_id(0) % num_blocks_0 pid_1 = tl.program_id(0) // num_blocks_0 offset_0 = pid_0 + indices_0 = offset_0 + tl.zeros([1], tl.int32) offset_1 = pid_1 * _BLOCK_SIZE_1 indices_1 = (offset_1 + tl.arange(0, _BLOCK_SIZE_1)).to(tl.int32) mask_1 = indices_1 < m_dim + indices_4 = tl.arange(0, _RDIM_SIZE_2).to(tl.int32) m_i = tl.full([_BLOCK_SIZE_0, _BLOCK_SIZE_1], float('-inf'), tl.float32) l_i = tl.full([_BLOCK_SIZE_0, _BLOCK_SIZE_1], 1.0, tl.float32) acc = tl.full([_BLOCK_SIZE_0, _BLOCK_SIZE_1, 64], 0.0, tl.float32) - q = q_view_desc.load([offset_0, offset_1, 0]) + q = tl.load(q_view + (indices_0[:, None, None] * q_view_stride_0 + indices_1[None, :, None] * q_view_stride_1 + indices_4[None, None, :] * q_view_stride_2), mask_1[None, :, None], other=0) for offset_2 in tl.range(0, n_dim.to(tl.int32), _BLOCK_SIZE_3): indices_2 = offset_2 + tl.arange(0, _BLOCK_SIZE_3).to(tl.int32) mask_3 = indices_2 < n_dim @@ -42,7 +37,7 @@ def _helion_attention(q_view, k_view, v_view, out, k_view_size_0, k_view_size_2, m_i_copy_0 = m_i_copy l_i_copy_0 = l_i_copy acc_copy_0 = acc_copy - k = tl.permute(k_view_desc.load([offset_0, offset_2, 0]), [0, 2, 1]) + k = tl.load(k_view + (indices_0[:, None, None] * k_view_stride_0 + indices_4[None, :, None] * k_view_stride_1 + indices_2[None, None, :] * k_view_stride_2), mask_3[None, None, :], other=0) qk = tl.reshape(tl.dot(tl.reshape(tl.cast(q_copy_0, tl.float32), [_BLOCK_SIZE_1, 64]), tl.reshape(tl.cast(k, tl.float32), [64, _BLOCK_SIZE_3]), input_precision='tf32', out_dtype=tl.float32), [_BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_3]) _mask_to_2 = tl.where(tl.broadcast_to(mask_1[None, :, None] & mask_3[None, None, :], [_BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_3]), qk, tl.full([], float('-inf'), tl.float32)) amax = tl.cast(tl.max(_mask_to_2, 2), tl.float32) @@ -62,12 +57,12 @@ def _helion_attention(q_view, k_view, v_view, out, k_view_size_0, k_view_size_2, l_i = v_9 + l_ij subscript_1 = v_8[:, :, None] v_11 = acc_copy_0 * subscript_1 - v = v_view_desc.load([offset_0, offset_2, 0]) + v = tl.load(v_view + (indices_0[:, None, None] * v_view_stride_0 + indices_2[None, :, None] * v_view_stride_1 + indices_4[None, None, :] * v_view_stride_2), mask_3[None, :, None], other=0) acc = tl.reshape(tl.dot(tl.reshape(tl.cast(_mask_to_3, tl.float32), [_BLOCK_SIZE_1, _BLOCK_SIZE_3]), tl.reshape(tl.cast(v, tl.float32), [_BLOCK_SIZE_3, 64]), acc=tl.reshape(v_11, [_BLOCK_SIZE_1, 64]), input_precision='tf32', out_dtype=tl.float32), [_BLOCK_SIZE_0, _BLOCK_SIZE_1, 64]) m_i = v_2 subscript_2 = l_i[:, :, None] v_12 = acc / subscript_2 - out_desc.store([offset_0, offset_1, 0], v_12) + tl.store(out + (indices_0[:, None, None] * out_stride_0 + indices_1[None, :, None] * out_stride_1 + indices_4[None, None, :] * out_stride_2), v_12, mask_1[None, :, None]) def attention(q_in: torch.Tensor, k_in: torch.Tensor, v_in: torch.Tensor, *, _launcher=_default_launcher): """ @@ -93,39 +88,37 @@ def attention(q_in: torch.Tensor, k_in: torch.Tensor, v_in: torch.Tensor, *, _la k_view = k_in.reshape([-1, n_dim, head_dim]).transpose(1, 2) out = torch.empty_like(q_view) _BLOCK_SIZE_1 = 16 + _RDIM_SIZE_2 = 64 _BLOCK_SIZE_3 = 16 - _launcher(_helion_attention, (q_in.size(1) * triton.cdiv(m_dim, _BLOCK_SIZE_1),), q_view, k_view, v_view, out, k_view.size(0), k_view.size(2), out.size(0), out.size(1), q_in.size(1), q_view.size(0), q_view.size(1), v_view.size(0), v_view.size(1), k_view.stride(0), k_view.stride(1), k_view.stride(2), out.stride(0), out.stride(1), out.stride(2), q_view.stride(0), q_view.stride(1), q_view.stride(2), v_view.stride(0), v_view.stride(1), v_view.stride(2), m_dim, n_dim, _BLOCK_SIZE_1, 1, _BLOCK_SIZE_3, num_warps=4, num_stages=3) + _launcher(_helion_attention, (q_in.size(1) * triton.cdiv(m_dim, _BLOCK_SIZE_1),), q_view, k_view, v_view, out, q_in.size(1), k_view.stride(0), k_view.stride(1), k_view.stride(2), out.stride(0), out.stride(1), out.stride(2), q_view.stride(0), q_view.stride(1), q_view.stride(2), v_view.stride(0), v_view.stride(1), v_view.stride(2), m_dim, n_dim, _BLOCK_SIZE_1, _RDIM_SIZE_2, 1, _BLOCK_SIZE_3, num_warps=4, num_stages=3) return out.view(q_in.size()) --- assertExpectedJournal(TestTensorDescriptor.test_attention_tensor_descriptor) from __future__ import annotations import torch -import helion import triton import triton.language as tl from torch._inductor.runtime import triton_helpers from torch._inductor.runtime.triton_compat import libdevice from helion.runtime import default_launcher as _default_launcher -helion.runtime.set_triton_allocator() - @triton.jit -def _helion_attention(q_view, k_view, v_view, out, _BLOCK_SIZE_1: tl.constexpr, _BLOCK_SIZE_0: tl.constexpr, _BLOCK_SIZE_3: tl.constexpr): - q_view_desc = tl.make_tensor_descriptor(q_view, [64, 1024, 64], [65536, 64, 1], [_BLOCK_SIZE_0, _BLOCK_SIZE_1, 64]) - k_view_desc = tl.make_tensor_descriptor(k_view, [64, 512, 64], [32768, 64, 1], [_BLOCK_SIZE_0, _BLOCK_SIZE_3, 64]) - v_view_desc = tl.make_tensor_descriptor(v_view, [64, 512, 64], [32768, 64, 1], [_BLOCK_SIZE_0, _BLOCK_SIZE_3, 64]) - out_desc = tl.make_tensor_descriptor(out, [64, 1024, 64], [65536, 64, 1], [_BLOCK_SIZE_0, _BLOCK_SIZE_1, 64]) +def _helion_attention(q_view, k_view, v_view, out, _BLOCK_SIZE_1: tl.constexpr, _RDIM_SIZE_2: tl.constexpr, _BLOCK_SIZE_0: tl.constexpr, _BLOCK_SIZE_3: tl.constexpr): num_blocks_0 = 64 pid_0 = tl.program_id(0) % num_blocks_0 pid_1 = tl.program_id(0) // num_blocks_0 offset_0 = pid_0 + indices_0 = offset_0 + tl.zeros([1], tl.int32) offset_1 = pid_1 * _BLOCK_SIZE_1 + indices_1 = (offset_1 + tl.arange(0, _BLOCK_SIZE_1)).to(tl.int32) + indices_4 = tl.arange(0, _RDIM_SIZE_2).to(tl.int32) m_i = tl.full([_BLOCK_SIZE_0, _BLOCK_SIZE_1], float('-inf'), tl.float32) l_i = tl.full([_BLOCK_SIZE_0, _BLOCK_SIZE_1], 1.0, tl.float32) acc = tl.full([_BLOCK_SIZE_0, _BLOCK_SIZE_1, 64], 0.0, tl.float32) - q = q_view_desc.load([offset_0, offset_1, 0]) + q = tl.load(q_view + (indices_0[:, None, None] * 65536 + indices_1[None, :, None] * 64 + indices_4[None, None, :] * 1), None) for offset_2 in tl.range(0, 512, _BLOCK_SIZE_3): + indices_2 = offset_2 + tl.arange(0, _BLOCK_SIZE_3).to(tl.int32) q_copy = q m_i_copy = m_i l_i_copy = l_i @@ -134,7 +127,7 @@ def _helion_attention(q_view, k_view, v_view, out, _BLOCK_SIZE_1: tl.constexpr, m_i_copy_0 = m_i_copy l_i_copy_0 = l_i_copy acc_copy_0 = acc_copy - k = tl.permute(k_view_desc.load([offset_0, offset_2, 0]), [0, 2, 1]) + k = tl.load(k_view + (indices_0[:, None, None] * 32768 + indices_4[None, :, None] * 1 + indices_2[None, None, :] * 64), None) qk = tl.reshape(tl.dot(tl.reshape(tl.cast(q_copy_0, tl.float16), [_BLOCK_SIZE_1, 64]), tl.reshape(tl.cast(k, tl.float16), [64, _BLOCK_SIZE_3]), input_precision='tf32', out_dtype=tl.float32), [_BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_3]) amax = tl.cast(tl.max(qk, 2), tl.float16) v_0 = 0.18033688 @@ -154,14 +147,14 @@ def _helion_attention(q_view, k_view, v_view, out, _BLOCK_SIZE_1: tl.constexpr, l_i = v_11 + l_ij subscript_1 = v_10[:, :, None] v_13 = acc_copy_0 * subscript_1 - v = v_view_desc.load([offset_0, offset_2, 0]) + v = tl.load(v_view + (indices_0[:, None, None] * 32768 + indices_2[None, :, None] * 64 + indices_4[None, None, :] * 1), None) v_14 = tl.cast(v_8, tl.float16) acc = tl.reshape(tl.dot(tl.reshape(tl.cast(v_14, tl.float16), [_BLOCK_SIZE_1, _BLOCK_SIZE_3]), tl.reshape(tl.cast(v, tl.float16), [_BLOCK_SIZE_3, 64]), acc=tl.reshape(v_13, [_BLOCK_SIZE_1, 64]), input_precision='tf32', out_dtype=tl.float32), [_BLOCK_SIZE_0, _BLOCK_SIZE_1, 64]) m_i = v_3 subscript_2 = l_i[:, :, None] v_15 = acc / subscript_2 v_16 = tl.cast(v_15, tl.float16) - out_desc.store([offset_0, offset_1, 0], v_16) + tl.store(out + (indices_0[:, None, None] * 65536 + indices_1[None, :, None] * 64 + indices_4[None, None, :] * 1), v_16, None) def attention(q_in: torch.Tensor, k_in: torch.Tensor, v_in: torch.Tensor, *, _launcher=_default_launcher): """ @@ -187,6 +180,7 @@ def attention(q_in: torch.Tensor, k_in: torch.Tensor, v_in: torch.Tensor, *, _la k_view = k_in.reshape([-1, n_dim, head_dim]).transpose(1, 2) out = torch.empty_like(q_view) _BLOCK_SIZE_1 = 128 + _RDIM_SIZE_2 = 64 _BLOCK_SIZE_3 = 64 - _launcher(_helion_attention, (64 * triton.cdiv(1024, _BLOCK_SIZE_1),), q_view, k_view, v_view, out, _BLOCK_SIZE_1, 1, _BLOCK_SIZE_3, num_warps=4, num_stages=3) + _launcher(_helion_attention, (64 * triton.cdiv(1024, _BLOCK_SIZE_1),), q_view, k_view, v_view, out, _BLOCK_SIZE_1, _RDIM_SIZE_2, 1, _BLOCK_SIZE_3, num_warps=4, num_stages=3) return out.view(q_in.size())