From 0379412204b61e28c2d9f51a42c504672e9af64a Mon Sep 17 00:00:00 2001 From: Jason Ansel Date: Wed, 22 May 2024 19:02:49 -0700 Subject: [PATCH] Update [ghstack-poisoned] --- torch/_inductor/codegen/common.py | 5 ++++- torch/_inductor/codegen/simd.py | 20 +++++++++++++++++--- torch/_inductor/codegen/triton.py | 23 ++++++++--------------- torch/_inductor/virtualized.py | 8 ++++++++ 4 files changed, 37 insertions(+), 19 deletions(-) diff --git a/torch/_inductor/codegen/common.py b/torch/_inductor/codegen/common.py index 20d5f200a393b..6ae7a503de05d 100644 --- a/torch/_inductor/codegen/common.py +++ b/torch/_inductor/codegen/common.py @@ -120,7 +120,7 @@ def device_guard(self, device_idx): # https://github.com/intel/intel-extension-for-pytorch/blob/5dcc9d57e5422cf295e1a1ee97896d6b6a554a85/intel_extension_for_pytorch/_inductor/__init__.py#L9 def register_backend_for_device( device: str, - device_scheduling: type, + device_scheduling: Any, device_wrapper_codegen: type, device_cpp_wrapper_codegen: type = type(None), ): @@ -1140,6 +1140,9 @@ def __eq__(self, other) -> bool: def update_on_args(self, name, args, kwargs): pass + def __repr__(self): + return f"{self.__class__.__name__}({self.name!r})" + class CppWrapperKernelArgs(KernelArgs): def wrap_ptr_arg(self, buf, dtype): diff --git a/torch/_inductor/codegen/simd.py b/torch/_inductor/codegen/simd.py index 1d9ebce334f24..8010ee0e80857 100644 --- a/torch/_inductor/codegen/simd.py +++ b/torch/_inductor/codegen/simd.py @@ -50,10 +50,11 @@ sympy_subs, unique, ) -from ..virtualized import V +from ..virtualized import ops, OpsValue, V from .common import CSEVariable, index_prevent_reordering, Kernel, PythonPrinter from .multi_kernel import MultiKernel + if TYPE_CHECKING: pass @@ -320,7 +321,7 @@ def __eq__(self, other): return self.name == other.name -def triton_constant(value): +def constant_repr(value): if value == float("inf"): return 'float("inf")' elif value == float("-inf"): @@ -862,8 +863,9 @@ def mask_loads(self, mask): """Context manager to add an additional mask to tl.load/store""" prior = self._load_mask if prior: - mask = self.cse.generate(self.compute, f"{mask} & {prior}") + mask = ops.logical_and(mask, prior) + mask = OpsValue.unwrap(mask) self._load_mask = mask try: # TODO(jansel): do we need a reshape here? @@ -1043,6 +1045,18 @@ def warn_mix_layout(self, kernel_name): ) log.warning(msg) + def welford_reduce_fallback(self, dtype, value): + sum_ = ops.reduction(dtype, dtype, "sum", value) + self.inside_reduction = False + rnumel = ops.index_expr(self.numels[-1], dtype) + mean = ops.truediv(sum_, rnumel) + + self.inside_reduction = True + dx = ops.sub(value, mean) + dx2 = ops.mul(dx, dx) + m2 = ops.reduction(dtype, dtype, "sum", dx2) + return OpsValue.unwrap((mean, m2, rnumel)) + def codegen_kernel(self): raise NotImplementedError diff --git a/torch/_inductor/codegen/triton.py b/torch/_inductor/codegen/triton.py index c4b85472a6eb4..62177b6fbe094 100644 --- a/torch/_inductor/codegen/triton.py +++ b/torch/_inductor/codegen/triton.py @@ -47,12 +47,12 @@ TensorArg, ) from .simd import ( + constant_repr, IndexingOptions, IterationRangesEntry, pexpr, SIMDKernel, SIMDScheduling, - triton_constant, ) from .triton_utils import config_of, signature_of, signature_to_meta @@ -492,7 +492,7 @@ def to_dtype_bitcast(x, dtype: torch.dtype, src_dtype: torch.dtype): @staticmethod def _shaped_constant(value, dtype, shape): type_ = torch._prims_common.dtype_to_type(dtype) - triton_val = triton_constant(type_(value)) + triton_val = constant_repr(type_(value)) triton_type = triton_compute_type(dtype) if triton_type == "tl.float32": @@ -866,7 +866,7 @@ def masked(mask, body, other): # Take dtype from result to prevent accidental promotion other = V.kernel.cse.generate( V.kernel.compute, - f"tl.full({result}.shape, {triton_constant(other)}, {result}.dtype)", + f"tl.full({result}.shape, {constant_repr(other)}, {result}.dtype)", bounds=ValueRanges.wrap(other), ) return ops.where(new_mask, result, other) @@ -1342,7 +1342,7 @@ def where_cond(tval, fval): if self.persistent_reduction: default = ir.Reduction.default_value(reduction_type, src_dtype) - default = self._map_tuple_or_scalar(triton_constant, default) + default = self._map_tuple_or_scalar(constant_repr, default) def _mask_value(value, default): return self.cse.generate(self.compute, where_cond(value, default)) @@ -1367,16 +1367,7 @@ def _mask_value(value, default): # For persistent reductions, don't bother with # welford's algorithm since it uses more registers, and # taking two reductions doesn't increase memory usage. - sum_ = ops.reduction(dtype, dtype, "sum", value) - self.inside_reduction = False - rnumel = ops.index_expr(self.numels[-1], dtype) - mean = ops.truediv(sum_, rnumel) - - self.inside_reduction = True - dx = ops.sub(value, mean) - dx2 = ops.mul(dx, dx) - m2 = ops.reduction(dtype, dtype, "sum", dx2) - result_var = (mean, m2, rnumel) + result_var = self.welford_reduce_fallback(dtype, value) elif reduction_type == "welford_combine": mean, m2, weight = masked_value welford = f"triton_helpers.welford({mean}, {m2}, {weight}, {dim})" @@ -1394,7 +1385,7 @@ def _mask_value(value, default): else: accumulator = f"_{result_var}" default = ir.Reduction.default_accumulator(reduction_type, src_dtype) - default = self._map_tuple_or_scalar(triton_constant, default) + default = self._map_tuple_or_scalar(constant_repr, default) if not isinstance(default, tuple): self.body.writeline( f"{accumulator} = tl.full({self.dense_size_str()}, {default}, {acc_type})" @@ -1501,8 +1492,10 @@ def _mask_value(value, default): self.cse.reduction_cache[cache_key] = result_var if isinstance(result_var, tuple): + assert all(isinstance(x, TritonCSEVariable) for x in result_var) self.outside_loop_vars |= set(result_var) else: + assert isinstance(result_var, TritonCSEVariable) self.outside_loop_vars.add(result_var) return result_var diff --git a/torch/_inductor/virtualized.py b/torch/_inductor/virtualized.py index 07c6ea8190a61..4362b3fecfe87 100644 --- a/torch/_inductor/virtualized.py +++ b/torch/_inductor/virtualized.py @@ -248,6 +248,14 @@ def __rshfit__(self, n): def __lshift__(self, n): return ops.bitwise_left_shift(self, n) + @staticmethod + def unwrap(x): + if isinstance(x, OpsValue): + return x.value + if isinstance(x, (list, tuple)): + return x.__class__(map(OpsValue.unwrap, x)) + return x + class OpsWrapper: """This wraps any returned IR values into an `OpsValue` instance, so that we