Skip to content

Commit

Permalink
Update
Browse files Browse the repository at this point in the history
[ghstack-poisoned]
  • Loading branch information
jansel committed Jun 19, 2024
1 parent c5139c9 commit 62caf66
Showing 1 changed file with 75 additions and 0 deletions.
75 changes: 75 additions & 0 deletions torch/_inductor/codegen/halide.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
from .. import config, ir
from ..codecache import HalideCodeCache
from ..metrics import is_metric_table_enabled, log_kernel_metadata
from ..ops_handler import MockHandler

from ..runtime.hints import HalideInputSpec, HalideMeta, ReductionHint
from ..utils import (
Expand Down Expand Up @@ -1193,6 +1194,78 @@ def welford_combine_impl(self, mean, m2, weight):
self.body.writeline(f"{unpacked[-1]} = {result_var}[{i}]")
return tuple(unpacked)

def scan(
self,
dtypes: Tuple[torch.dtype, ...],
combine_fn: Callable[
[Tuple[CSEVariable, ...], Tuple[CSEVariable, ...]], Tuple[CSEVariable, ...]
],
values_orig: Tuple[CSEVariable, ...],
) -> Tuple[CSEVariable, ...]:
assert self.inside_reduction
assert len(dtypes) == len(values_orig)
values: List[HalideCSEVariable] = []
all_used_dims = set()
for value in values_orig:
assert isinstance(value, HalideCSEVariable) and value.used_dims is not None
if value.used_dims and value.used_dims[-1] == "rindex":
values.append(value)
else:
values.append(self.genfunc(f"{value}", [*value.used_dims, "rindex"]))
all_used_dims.update(value.used_dims)
used_dims = [
tree.name for tree in self.range_trees if tree.name in all_used_dims
]
result_var = self.newfunc(used_dims)
assert result_var.used_dims and result_var.used_dims[-1] == "rindex"
prefix = result_var.used_dims[:-1]
initial = [
f"hl.cast({halide_acc_type(dtype)}, {value})"
for dtype, value in zip(dtypes, values)
]

length = self.kexpr(self.rename_indexing(self.range_trees[-1].numel))
scan_dom = self.genfunc(f"hl.RDom([hl.Range(1, {length})])", [])
scan = f"{scan_dom}.x"

if len(values) == 1:

def maybe_tuple(x):
return x[0]

read_left = [result_var.index_str([*prefix, f"{scan} - 1"])]
read_right = [result_var.index_str([*prefix, scan])]
else:

def maybe_tuple(x):
return f"hl.Tuple([{', '.join(x)}])"

read_left = [
result_var.index_str([*prefix, f"{scan} - 1"]) + f"[{i}]"
for i in range(len(values))
]
read_right = [
result_var.index_str([*prefix, scan]) + f"[{i}]"
for i in range(len(values))
]

self.body.writeline(f"{result_var} = {maybe_tuple(initial)}")

# Disable CSE for update fn
with V.set_ops_handler(HalideOverrides(MockHandler())):
combine_str = combine_fn(read_left, read_right) # type: ignore[arg-type]
self.body.writeline(
f"{result_var.index_str([*prefix, scan])} = {maybe_tuple(combine_str)}"
)

if len(values) == 1:
return (result_var,)

unpack_vars = [self.newfunc(used_dims) for _ in values]
for i, v in enumerate(unpack_vars):
self.body.writeline(f"{v} = {result_var}[{i}]")
return tuple(unpack_vars)

def genfunc(
self, line, used_dims, *, bounds=ValueRanges.unknown()
) -> HalideCSEVariable:
Expand Down Expand Up @@ -1493,6 +1566,8 @@ def get_backend_features(cls, device: torch.device):
BackendFeature.PREFER_STORE_LOOP_ORDER,
]
)
if config.halide.scan_kernels:
result[BackendFeature.SCAN] = None
return result

def define_kernel(self, src_code, node_schedule, kernel):
Expand Down

0 comments on commit 62caf66

Please sign in to comment.