Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 33 additions & 0 deletions helion/_compiler/aten_lowering.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,6 +221,39 @@ def codegen_view(ctx: LoweringContext, node: Node) -> object:
return expr_from_string(f"tl.reshape({{tensor}}, {shape_str})", tensor=tensor)


view_dtype_lowering = register_lowering(
torch.ops.aten.view.dtype,
masked_value_fn=passthrough_masked_value,
)


@view_dtype_lowering.register_codegen("triton")
def codegen_view_dtype(ctx: LoweringContext, node: Node) -> object:
"""Generate tl.cast with bitcast=True for dtype reinterpretation."""
tensor = map_arg(node.args[0], lambda arg: _env_arg(ctx, arg))
assert isinstance(tensor, ast.AST)
target_dtype = node.args[1]
assert isinstance(target_dtype, torch.dtype)
return expr_from_string(
f"tl.cast({{tensor}}, {triton_type(target_dtype)}, bitcast=True)",
tensor=tensor,
)


alias_lowering = register_lowering(
torch.ops.aten.alias.default,
masked_value_fn=passthrough_masked_value,
)


@alias_lowering.register_codegen("triton")
def codegen_alias(ctx: LoweringContext, node: Node) -> object:
"""Alias is a no-op view, just pass through the input tensor."""
tensor = map_arg(node.args[0], lambda arg: _env_arg(ctx, arg))
assert isinstance(tensor, ast.AST)
return tensor


permute_lowering = register_lowering(
torch.ops.aten.permute.default,
masked_value_fn=passthrough_masked_value,
Expand Down
135 changes: 88 additions & 47 deletions test/test_views.expected
Original file line number Diff line number Diff line change
@@ -1,53 +1,6 @@
This file is automatically generated by assertExpectedJournal calls in test_views.py.
Update expected outputs by running tests with the EXPECTTEST_ACCEPT=1 environment variable set.

--- assertExpectedJournal(TestViews.test_specialize_reshape)
from __future__ import annotations

import torch
import triton
import triton.language as tl
from helion.runtime import default_launcher as _default_launcher

@triton.jit
def _helion_fn(reshaped, out, _BLOCK_SIZE_2: tl.constexpr):
# src[test_views.py:N]: for tile in hl.tile(reshaped.size()):
num_blocks_0 = 2
num_blocks_1 = 3
pid_0 = tl.program_id(0) % num_blocks_0
pid_1 = tl.program_id(0) // num_blocks_0 % num_blocks_1
pid_2 = tl.program_id(0) // (num_blocks_0 * num_blocks_1)
offset_0 = pid_0
indices_0 = offset_0 + tl.zeros([1], tl.int32)
offset_1 = pid_1
indices_1 = offset_1 + tl.zeros([1], tl.int32)
offset_2 = pid_2 * _BLOCK_SIZE_2
indices_2 = (offset_2 + tl.arange(0, _BLOCK_SIZE_2)).to(tl.int32)
# src[test_views.py:N]: out[tile] = reshaped[tile] + 1
load = tl.load(reshaped + (indices_0[:, None, None] * 96 + indices_1[None, :, None] * 32 + indices_2[None, None, :] * 1), None)
v_0 = 1.0
v_1 = load + v_0
tl.store(out + (indices_0[:, None, None] * 96 + indices_1[None, :, None] * 32 + indices_2[None, None, :] * 1), v_1, None)

def fn(x: torch.Tensor, chunk_size: int, *, _launcher=_default_launcher):
# src[test_views.py:N]: batch, seqlen = x.shape
batch, seqlen = x.shape
# src[test_views.py:N]: chunk_size = hl.specialize(chunk_size)
chunk_size = 32
# src[test_views.py:N]: nchunks = (seqlen + chunk_size - 1) // chunk_size
nchunks = (seqlen + chunk_size - 1) // chunk_size
# src[test_views.py:N]: reshaped = x.reshape(batch, nchunks, chunk_size)
reshaped = x.reshape(batch, nchunks, chunk_size)
# src[test_views.py:N]: out = torch.empty_like(reshaped)
out = torch.empty_like(reshaped)
# src[test_views.py:N]: for tile in hl.tile(reshaped.size()):
_BLOCK_SIZE_2 = 32
# src[test_views.py:N]: for tile in hl.tile(reshaped.size()):
# src[test_views.py:N]: out[tile] = reshaped[tile] + 1
_launcher(_helion_fn, (2 * 3 * triton.cdiv(32, _BLOCK_SIZE_2),), reshaped, out, _BLOCK_SIZE_2, num_warps=4, num_stages=1)
# src[test_views.py:N]: return out.reshape(batch, seqlen)
return out.reshape(batch, seqlen)

--- assertExpectedJournal(TestViews.test_reshape_sum)
from __future__ import annotations

Expand Down Expand Up @@ -198,6 +151,53 @@ def softmax(x: torch.Tensor, *, _launcher=_default_launcher):
# src[test_views.py:N]: return out
return out

--- assertExpectedJournal(TestViews.test_specialize_reshape)
from __future__ import annotations

import torch
import triton
import triton.language as tl
from helion.runtime import default_launcher as _default_launcher

@triton.jit
def _helion_fn(reshaped, out, _BLOCK_SIZE_2: tl.constexpr):
# src[test_views.py:N]: for tile in hl.tile(reshaped.size()):
num_blocks_0 = 2
num_blocks_1 = 3
pid_0 = tl.program_id(0) % num_blocks_0
pid_1 = tl.program_id(0) // num_blocks_0 % num_blocks_1
pid_2 = tl.program_id(0) // (num_blocks_0 * num_blocks_1)
offset_0 = pid_0
indices_0 = offset_0 + tl.zeros([1], tl.int32)
offset_1 = pid_1
indices_1 = offset_1 + tl.zeros([1], tl.int32)
offset_2 = pid_2 * _BLOCK_SIZE_2
indices_2 = (offset_2 + tl.arange(0, _BLOCK_SIZE_2)).to(tl.int32)
# src[test_views.py:N]: out[tile] = reshaped[tile] + 1
load = tl.load(reshaped + (indices_0[:, None, None] * 96 + indices_1[None, :, None] * 32 + indices_2[None, None, :] * 1), None)
v_0 = 1.0
v_1 = load + v_0
tl.store(out + (indices_0[:, None, None] * 96 + indices_1[None, :, None] * 32 + indices_2[None, None, :] * 1), v_1, None)

def fn(x: torch.Tensor, chunk_size: int, *, _launcher=_default_launcher):
# src[test_views.py:N]: batch, seqlen = x.shape
batch, seqlen = x.shape
# src[test_views.py:N]: chunk_size = hl.specialize(chunk_size)
chunk_size = 32
# src[test_views.py:N]: nchunks = (seqlen + chunk_size - 1) // chunk_size
nchunks = (seqlen + chunk_size - 1) // chunk_size
# src[test_views.py:N]: reshaped = x.reshape(batch, nchunks, chunk_size)
reshaped = x.reshape(batch, nchunks, chunk_size)
# src[test_views.py:N]: out = torch.empty_like(reshaped)
out = torch.empty_like(reshaped)
# src[test_views.py:N]: for tile in hl.tile(reshaped.size()):
_BLOCK_SIZE_2 = 32
# src[test_views.py:N]: for tile in hl.tile(reshaped.size()):
# src[test_views.py:N]: out[tile] = reshaped[tile] + 1
_launcher(_helion_fn, (2 * 3 * triton.cdiv(32, _BLOCK_SIZE_2),), reshaped, out, _BLOCK_SIZE_2, num_warps=4, num_stages=1)
# src[test_views.py:N]: return out.reshape(batch, seqlen)
return out.reshape(batch, seqlen)

--- assertExpectedJournal(TestViews.test_squeeze)
from __future__ import annotations

Expand Down Expand Up @@ -433,3 +433,44 @@ def foo(x: torch.Tensor, *, _launcher=_default_launcher):
_launcher(_helion_foo, (triton.cdiv(1024, _BLOCK_SIZE_0),), x, out, _BLOCK_SIZE_0, _BLOCK_SIZE_0 // 2, _SHAPE_DIM, num_warps=4, num_stages=1)
# src[test_views.py:N]: return out
return out

--- assertExpectedJournal(TestViews.test_view_dtype_reinterpret)
from __future__ import annotations

import torch
import triton
import triton.language as tl
from helion.runtime import default_launcher as _default_launcher

@triton.jit
def _helion_view_dtype_kernel(x, out, _BLOCK_SIZE_0: tl.constexpr):
# src[test_views.py:N]: for tile in hl.tile(n):
pid_0 = tl.program_id(0)
offset_0 = pid_0 * _BLOCK_SIZE_0
indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32)
# src[test_views.py:N]: val = x[tile]
val = tl.load(x + indices_0 * 1, None)
# src[test_views.py:N]: val_as_int = val.view(dtype=torch.int16)
view = tl.cast(val, tl.int16, bitcast=True)
# src[test_views.py:N]: val_as_int = val_as_int + 1
v_0 = tl.full([], 1, tl.int16)
v_1 = view + v_0
# src[test_views.py:N]: val_back = val_as_int.view(dtype=torch.bfloat16)
view_1 = tl.cast(v_1, tl.bfloat16, bitcast=True)
# src[test_views.py:N]: out[tile] = val_back
tl.store(out + indices_0 * 1, view_1, None)

def view_dtype_kernel(x: torch.Tensor, *, _launcher=_default_launcher):
# src[test_views.py:N]: n = x.size(0)
n = x.size(0)
# src[test_views.py:N]: out = torch.empty_like(x)
out = torch.empty_like(x)
# src[test_views.py:N]: for tile in hl.tile(n):
_BLOCK_SIZE_0 = 32
# src[test_views.py:N]: for tile in hl.tile(n):
# src[test_views.py:N]: val = x[tile]
# src[test_views.py:N]: # View bf16 as int16, add 1 to raw bits, view back as bf16
# src[test_views.py:N-N]: ...
_launcher(_helion_view_dtype_kernel, (triton.cdiv(1024, _BLOCK_SIZE_0),), x, out, _BLOCK_SIZE_0, num_warps=4, num_stages=1)
# src[test_views.py:N]: return out
return out
24 changes: 24 additions & 0 deletions test/test_views.py
Original file line number Diff line number Diff line change
Expand Up @@ -466,6 +466,30 @@ def capture_graph(graph):
)
assert "aten.cat" in self._graph and "aten.stack" not in self._graph

def test_view_dtype_reinterpret(self):
"""Test viewing a tensor with a different dtype (bitcast/reinterpret)."""

@helion.kernel(static_shapes=True)
def view_dtype_kernel(x: torch.Tensor) -> torch.Tensor:
# x is bfloat16, view as int16 to access raw bits
n = x.size(0)
out = torch.empty_like(x)
for tile in hl.tile(n):
val = x[tile]
# View bf16 as int16, add 1 to raw bits, view back as bf16
val_as_int = val.view(dtype=torch.int16)
val_as_int = val_as_int + 1
val_back = val_as_int.view(dtype=torch.bfloat16)
out[tile] = val_back
return out

x = torch.randn(1024, dtype=torch.bfloat16, device=DEVICE)
code, result = code_and_output(view_dtype_kernel, (x,))
# Verify that the operation is a bitcast (add 1 to raw bits)
expected = (x.view(dtype=torch.int16) + 1).view(dtype=torch.bfloat16)
torch.testing.assert_close(result, expected)
self.assertExpectedJournal(code)


if __name__ == "__main__":
unittest.main()
Loading