diff --git a/helion/_compiler/aten_lowering.py b/helion/_compiler/aten_lowering.py index eff835c10..12f622b9e 100644 --- a/helion/_compiler/aten_lowering.py +++ b/helion/_compiler/aten_lowering.py @@ -221,6 +221,39 @@ def codegen_view(ctx: LoweringContext, node: Node) -> object: return expr_from_string(f"tl.reshape({{tensor}}, {shape_str})", tensor=tensor) +view_dtype_lowering = register_lowering( + torch.ops.aten.view.dtype, + masked_value_fn=passthrough_masked_value, +) + + +@view_dtype_lowering.register_codegen("triton") +def codegen_view_dtype(ctx: LoweringContext, node: Node) -> object: + """Generate tl.cast with bitcast=True for dtype reinterpretation.""" + tensor = map_arg(node.args[0], lambda arg: _env_arg(ctx, arg)) + assert isinstance(tensor, ast.AST) + target_dtype = node.args[1] + assert isinstance(target_dtype, torch.dtype) + return expr_from_string( + f"tl.cast({{tensor}}, {triton_type(target_dtype)}, bitcast=True)", + tensor=tensor, + ) + + +alias_lowering = register_lowering( + torch.ops.aten.alias.default, + masked_value_fn=passthrough_masked_value, +) + + +@alias_lowering.register_codegen("triton") +def codegen_alias(ctx: LoweringContext, node: Node) -> object: + """Alias is a no-op view, just pass through the input tensor.""" + tensor = map_arg(node.args[0], lambda arg: _env_arg(ctx, arg)) + assert isinstance(tensor, ast.AST) + return tensor + + permute_lowering = register_lowering( torch.ops.aten.permute.default, masked_value_fn=passthrough_masked_value, diff --git a/test/test_views.expected b/test/test_views.expected index f1ccf0ec8..251761dca 100644 --- a/test/test_views.expected +++ b/test/test_views.expected @@ -1,53 +1,6 @@ This file is automatically generated by assertExpectedJournal calls in test_views.py. Update expected outputs by running tests with the EXPECTTEST_ACCEPT=1 environment variable set. ---- assertExpectedJournal(TestViews.test_specialize_reshape) -from __future__ import annotations - -import torch -import triton -import triton.language as tl -from helion.runtime import default_launcher as _default_launcher - -@triton.jit -def _helion_fn(reshaped, out, _BLOCK_SIZE_2: tl.constexpr): - # src[test_views.py:N]: for tile in hl.tile(reshaped.size()): - num_blocks_0 = 2 - num_blocks_1 = 3 - pid_0 = tl.program_id(0) % num_blocks_0 - pid_1 = tl.program_id(0) // num_blocks_0 % num_blocks_1 - pid_2 = tl.program_id(0) // (num_blocks_0 * num_blocks_1) - offset_0 = pid_0 - indices_0 = offset_0 + tl.zeros([1], tl.int32) - offset_1 = pid_1 - indices_1 = offset_1 + tl.zeros([1], tl.int32) - offset_2 = pid_2 * _BLOCK_SIZE_2 - indices_2 = (offset_2 + tl.arange(0, _BLOCK_SIZE_2)).to(tl.int32) - # src[test_views.py:N]: out[tile] = reshaped[tile] + 1 - load = tl.load(reshaped + (indices_0[:, None, None] * 96 + indices_1[None, :, None] * 32 + indices_2[None, None, :] * 1), None) - v_0 = 1.0 - v_1 = load + v_0 - tl.store(out + (indices_0[:, None, None] * 96 + indices_1[None, :, None] * 32 + indices_2[None, None, :] * 1), v_1, None) - -def fn(x: torch.Tensor, chunk_size: int, *, _launcher=_default_launcher): - # src[test_views.py:N]: batch, seqlen = x.shape - batch, seqlen = x.shape - # src[test_views.py:N]: chunk_size = hl.specialize(chunk_size) - chunk_size = 32 - # src[test_views.py:N]: nchunks = (seqlen + chunk_size - 1) // chunk_size - nchunks = (seqlen + chunk_size - 1) // chunk_size - # src[test_views.py:N]: reshaped = x.reshape(batch, nchunks, chunk_size) - reshaped = x.reshape(batch, nchunks, chunk_size) - # src[test_views.py:N]: out = torch.empty_like(reshaped) - out = torch.empty_like(reshaped) - # src[test_views.py:N]: for tile in hl.tile(reshaped.size()): - _BLOCK_SIZE_2 = 32 - # src[test_views.py:N]: for tile in hl.tile(reshaped.size()): - # src[test_views.py:N]: out[tile] = reshaped[tile] + 1 - _launcher(_helion_fn, (2 * 3 * triton.cdiv(32, _BLOCK_SIZE_2),), reshaped, out, _BLOCK_SIZE_2, num_warps=4, num_stages=1) - # src[test_views.py:N]: return out.reshape(batch, seqlen) - return out.reshape(batch, seqlen) - --- assertExpectedJournal(TestViews.test_reshape_sum) from __future__ import annotations @@ -198,6 +151,53 @@ def softmax(x: torch.Tensor, *, _launcher=_default_launcher): # src[test_views.py:N]: return out return out +--- assertExpectedJournal(TestViews.test_specialize_reshape) +from __future__ import annotations + +import torch +import triton +import triton.language as tl +from helion.runtime import default_launcher as _default_launcher + +@triton.jit +def _helion_fn(reshaped, out, _BLOCK_SIZE_2: tl.constexpr): + # src[test_views.py:N]: for tile in hl.tile(reshaped.size()): + num_blocks_0 = 2 + num_blocks_1 = 3 + pid_0 = tl.program_id(0) % num_blocks_0 + pid_1 = tl.program_id(0) // num_blocks_0 % num_blocks_1 + pid_2 = tl.program_id(0) // (num_blocks_0 * num_blocks_1) + offset_0 = pid_0 + indices_0 = offset_0 + tl.zeros([1], tl.int32) + offset_1 = pid_1 + indices_1 = offset_1 + tl.zeros([1], tl.int32) + offset_2 = pid_2 * _BLOCK_SIZE_2 + indices_2 = (offset_2 + tl.arange(0, _BLOCK_SIZE_2)).to(tl.int32) + # src[test_views.py:N]: out[tile] = reshaped[tile] + 1 + load = tl.load(reshaped + (indices_0[:, None, None] * 96 + indices_1[None, :, None] * 32 + indices_2[None, None, :] * 1), None) + v_0 = 1.0 + v_1 = load + v_0 + tl.store(out + (indices_0[:, None, None] * 96 + indices_1[None, :, None] * 32 + indices_2[None, None, :] * 1), v_1, None) + +def fn(x: torch.Tensor, chunk_size: int, *, _launcher=_default_launcher): + # src[test_views.py:N]: batch, seqlen = x.shape + batch, seqlen = x.shape + # src[test_views.py:N]: chunk_size = hl.specialize(chunk_size) + chunk_size = 32 + # src[test_views.py:N]: nchunks = (seqlen + chunk_size - 1) // chunk_size + nchunks = (seqlen + chunk_size - 1) // chunk_size + # src[test_views.py:N]: reshaped = x.reshape(batch, nchunks, chunk_size) + reshaped = x.reshape(batch, nchunks, chunk_size) + # src[test_views.py:N]: out = torch.empty_like(reshaped) + out = torch.empty_like(reshaped) + # src[test_views.py:N]: for tile in hl.tile(reshaped.size()): + _BLOCK_SIZE_2 = 32 + # src[test_views.py:N]: for tile in hl.tile(reshaped.size()): + # src[test_views.py:N]: out[tile] = reshaped[tile] + 1 + _launcher(_helion_fn, (2 * 3 * triton.cdiv(32, _BLOCK_SIZE_2),), reshaped, out, _BLOCK_SIZE_2, num_warps=4, num_stages=1) + # src[test_views.py:N]: return out.reshape(batch, seqlen) + return out.reshape(batch, seqlen) + --- assertExpectedJournal(TestViews.test_squeeze) from __future__ import annotations @@ -433,3 +433,44 @@ def foo(x: torch.Tensor, *, _launcher=_default_launcher): _launcher(_helion_foo, (triton.cdiv(1024, _BLOCK_SIZE_0),), x, out, _BLOCK_SIZE_0, _BLOCK_SIZE_0 // 2, _SHAPE_DIM, num_warps=4, num_stages=1) # src[test_views.py:N]: return out return out + +--- assertExpectedJournal(TestViews.test_view_dtype_reinterpret) +from __future__ import annotations + +import torch +import triton +import triton.language as tl +from helion.runtime import default_launcher as _default_launcher + +@triton.jit +def _helion_view_dtype_kernel(x, out, _BLOCK_SIZE_0: tl.constexpr): + # src[test_views.py:N]: for tile in hl.tile(n): + pid_0 = tl.program_id(0) + offset_0 = pid_0 * _BLOCK_SIZE_0 + indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32) + # src[test_views.py:N]: val = x[tile] + val = tl.load(x + indices_0 * 1, None) + # src[test_views.py:N]: val_as_int = val.view(dtype=torch.int16) + view = tl.cast(val, tl.int16, bitcast=True) + # src[test_views.py:N]: val_as_int = val_as_int + 1 + v_0 = tl.full([], 1, tl.int16) + v_1 = view + v_0 + # src[test_views.py:N]: val_back = val_as_int.view(dtype=torch.bfloat16) + view_1 = tl.cast(v_1, tl.bfloat16, bitcast=True) + # src[test_views.py:N]: out[tile] = val_back + tl.store(out + indices_0 * 1, view_1, None) + +def view_dtype_kernel(x: torch.Tensor, *, _launcher=_default_launcher): + # src[test_views.py:N]: n = x.size(0) + n = x.size(0) + # src[test_views.py:N]: out = torch.empty_like(x) + out = torch.empty_like(x) + # src[test_views.py:N]: for tile in hl.tile(n): + _BLOCK_SIZE_0 = 32 + # src[test_views.py:N]: for tile in hl.tile(n): + # src[test_views.py:N]: val = x[tile] + # src[test_views.py:N]: # View bf16 as int16, add 1 to raw bits, view back as bf16 + # src[test_views.py:N-N]: ... + _launcher(_helion_view_dtype_kernel, (triton.cdiv(1024, _BLOCK_SIZE_0),), x, out, _BLOCK_SIZE_0, num_warps=4, num_stages=1) + # src[test_views.py:N]: return out + return out diff --git a/test/test_views.py b/test/test_views.py index d91c7a6da..9b797c54b 100644 --- a/test/test_views.py +++ b/test/test_views.py @@ -466,6 +466,30 @@ def capture_graph(graph): ) assert "aten.cat" in self._graph and "aten.stack" not in self._graph + def test_view_dtype_reinterpret(self): + """Test viewing a tensor with a different dtype (bitcast/reinterpret).""" + + @helion.kernel(static_shapes=True) + def view_dtype_kernel(x: torch.Tensor) -> torch.Tensor: + # x is bfloat16, view as int16 to access raw bits + n = x.size(0) + out = torch.empty_like(x) + for tile in hl.tile(n): + val = x[tile] + # View bf16 as int16, add 1 to raw bits, view back as bf16 + val_as_int = val.view(dtype=torch.int16) + val_as_int = val_as_int + 1 + val_back = val_as_int.view(dtype=torch.bfloat16) + out[tile] = val_back + return out + + x = torch.randn(1024, dtype=torch.bfloat16, device=DEVICE) + code, result = code_and_output(view_dtype_kernel, (x,)) + # Verify that the operation is a bitcast (add 1 to raw bits) + expected = (x.view(dtype=torch.int16) + 1).view(dtype=torch.bfloat16) + torch.testing.assert_close(result, expected) + self.assertExpectedJournal(code) + if __name__ == "__main__": unittest.main()