Skip to content

Commit

Permalink
Update
Browse files Browse the repository at this point in the history
[ghstack-poisoned]
  • Loading branch information
jansel committed Jun 19, 2024
2 parents 2779d26 + 4fd2646 commit 695e91b
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 79 deletions.
18 changes: 18 additions & 0 deletions benchmarks/dynamo/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -3123,6 +3123,12 @@ def parse_args(args=None):
parser.add_argument(
"--freezing", action="store_true", help="turn on freezing", default=False
)
parser.add_argument(
"--inductor-config",
"-c",
action="append",
help="key=value in torch._inductor.config",
)
parser.add_argument(
"--ci", action="store_true", help="Flag to tell that its a CI run"
)
Expand Down Expand Up @@ -4025,6 +4031,18 @@ def run(runner, args, original_dir=None):
inductor_config.triton.divisible_by_16 = not args.disable_divisible_by_16
if args.inference:
inductor_config.freezing = args.freezing
if args.inductor_config:
for config in args.inductor_config:
key, value = config.split("=")
typ = type(inductor_config.__getattr__(key))
if issubclass(typ, bool):
assert value in ("0", "1", "True", "False")
value = value in ("1", "True")
elif issubclass(typ, (str, int, float)):
value = typ(value)
else:
raise NotImplementedError(typ)
inductor_config.__setattr__(key, value)

runner.setup_amp()

Expand Down
8 changes: 5 additions & 3 deletions torch/_inductor/codecache.py
Original file line number Diff line number Diff line change
Expand Up @@ -3014,7 +3014,6 @@ def _codegen_glue(cls, meta, headerfile):
),
buffers=buffers,
buffer_names=", ".join(buffer_names),
cuda_device=meta.cuda_device,
)
return glue_code

Expand Down Expand Up @@ -3204,8 +3203,11 @@ def build_standalone_runtime(cls):
afile = str(dirpath / "standalone_halide_runtime.a")
sofile = str(dirpath / libname)
if not os.path.exists(donefile):
import filelock
import halide as hl # type: ignore[import-untyped]
try:
import filelock
import halide as hl # type: ignore[import-untyped]
except ImportError as e:
raise RuntimeError("requires halide/filelock") from e

with filelock.FileLock(lockfile, LOCK_TIMEOUT):
if not os.path.exists(donefile):
Expand Down
77 changes: 1 addition & 76 deletions torch/_inductor/codegen/halide.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@
from .. import config, ir
from ..codecache import HalideCodeCache
from ..metrics import is_metric_table_enabled, log_kernel_metadata
from ..ops_handler import MockHandler

from ..runtime.hints import HalideInputSpec, HalideMeta, ReductionHint
from ..utils import (
Expand Down Expand Up @@ -1023,7 +1022,6 @@ def load(self, name: str, index: sympy.Expr):
line = f"{var}[{index_str},]" # trailing comma workaround for https://github.com/halide/Halide/issues/8299
dtype = V.graph.get_dtype(name)
if dtype in (torch.float16, torch.bfloat16):
dtype = torch.float32
line = f"hl.cast(hl.Float(32), {line})"

if self._load_mask:
Expand Down Expand Up @@ -1093,6 +1091,7 @@ def reduction(
"""Codegen a reduction operation"""
assert self.inside_reduction
assert not self._load_mask
assert isinstance(value, HalideCSEVariable) and value.used_dims is not None

cache_key = (src_dtype, reduction_type, value)
if cache_key in self.cse.reduction_cache:
Expand Down Expand Up @@ -1194,78 +1193,6 @@ def welford_combine_impl(self, mean, m2, weight):
self.body.writeline(f"{unpacked[-1]} = {result_var}[{i}]")
return tuple(unpacked)

def scan(
self,
dtypes: Tuple[torch.dtype, ...],
combine_fn: Callable[
[Tuple[CSEVariable, ...], Tuple[CSEVariable, ...]], Tuple[CSEVariable, ...]
],
values_orig: Tuple[CSEVariable, ...],
) -> Tuple[CSEVariable, ...]:
assert self.inside_reduction
assert len(dtypes) == len(values_orig)
values: List[HalideCSEVariable] = []
all_used_dims = set()
for value in values_orig:
assert isinstance(value, HalideCSEVariable) and value.used_dims is not None
if value.used_dims and value.used_dims[-1] == "rindex":
values.append(value)
else:
values.append(self.genfunc(f"{value}", [*value.used_dims, "rindex"]))
all_used_dims.update(value.used_dims)
used_dims = [
tree.name for tree in self.range_trees if tree.name in all_used_dims
]
result_var = self.newfunc(used_dims)
assert result_var.used_dims and result_var.used_dims[-1] == "rindex"
prefix = result_var.used_dims[:-1]
initial = [
f"hl.cast({halide_acc_type(dtype)}, {value})"
for dtype, value in zip(dtypes, values)
]

length = self.kexpr(self.rename_indexing(self.range_trees[-1].numel))
scan_dom = self.genfunc(f"hl.RDom([hl.Range(1, {length})])", [])
scan = f"{scan_dom}.x"

if len(values) == 1:

def maybe_tuple(x):
return x[0]

read_left = [result_var.index_str([*prefix, f"{scan} - 1"])]
read_right = [result_var.index_str([*prefix, scan])]
else:

def maybe_tuple(x):
return f"hl.Tuple([{', '.join(x)}])"

read_left = [
result_var.index_str([*prefix, f"{scan} - 1"]) + f"[{i}]"
for i in range(len(values))
]
read_right = [
result_var.index_str([*prefix, scan]) + f"[{i}]"
for i in range(len(values))
]

self.body.writeline(f"{result_var} = {maybe_tuple(initial)}")

# Disable CSE for update fn
with V.set_ops_handler(HalideOverrides(MockHandler())):
combine_str = combine_fn(read_left, read_right) # type: ignore[arg-type]
self.body.writeline(
f"{result_var.index_str([*prefix, scan])} = {maybe_tuple(combine_str)}"
)

if len(values) == 1:
return (result_var,)

unpack_vars = [self.newfunc(used_dims) for _ in values]
for i, v in enumerate(unpack_vars):
self.body.writeline(f"{v} = {result_var}[{i}]")
return tuple(unpack_vars)

def genfunc(
self, line, used_dims, *, bounds=ValueRanges.unknown()
) -> HalideCSEVariable:
Expand Down Expand Up @@ -1566,8 +1493,6 @@ def get_backend_features(cls, device: torch.device):
BackendFeature.PREFER_STORE_LOOP_ORDER,
]
)
if config.halide.scan_kernels:
result[BackendFeature.SCAN] = None
return result

def define_kernel(self, src_code, node_schedule, kernel):
Expand Down

0 comments on commit 695e91b

Please sign in to comment.