Skip to content

Commit

Permalink
Update
Browse files Browse the repository at this point in the history
[ghstack-poisoned]
  • Loading branch information
jansel committed May 23, 2024
1 parent 7e87d8e commit 0379412
Show file tree
Hide file tree
Showing 4 changed files with 37 additions and 19 deletions.
5 changes: 4 additions & 1 deletion torch/_inductor/codegen/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ def device_guard(self, device_idx):
# https://github.com/intel/intel-extension-for-pytorch/blob/5dcc9d57e5422cf295e1a1ee97896d6b6a554a85/intel_extension_for_pytorch/_inductor/__init__.py#L9
def register_backend_for_device(
device: str,
device_scheduling: type,
device_scheduling: Any,
device_wrapper_codegen: type,
device_cpp_wrapper_codegen: type = type(None),
):
Expand Down Expand Up @@ -1140,6 +1140,9 @@ def __eq__(self, other) -> bool:
def update_on_args(self, name, args, kwargs):
pass

def __repr__(self):
return f"{self.__class__.__name__}({self.name!r})"


class CppWrapperKernelArgs(KernelArgs):
def wrap_ptr_arg(self, buf, dtype):
Expand Down
20 changes: 17 additions & 3 deletions torch/_inductor/codegen/simd.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,10 +50,11 @@
sympy_subs,
unique,
)
from ..virtualized import V
from ..virtualized import ops, OpsValue, V
from .common import CSEVariable, index_prevent_reordering, Kernel, PythonPrinter
from .multi_kernel import MultiKernel


if TYPE_CHECKING:
pass

Expand Down Expand Up @@ -320,7 +321,7 @@ def __eq__(self, other):
return self.name == other.name


def triton_constant(value):
def constant_repr(value):
if value == float("inf"):
return 'float("inf")'
elif value == float("-inf"):
Expand Down Expand Up @@ -862,8 +863,9 @@ def mask_loads(self, mask):
"""Context manager to add an additional mask to tl.load/store"""
prior = self._load_mask
if prior:
mask = self.cse.generate(self.compute, f"{mask} & {prior}")
mask = ops.logical_and(mask, prior)

mask = OpsValue.unwrap(mask)
self._load_mask = mask
try:
# TODO(jansel): do we need a reshape here?
Expand Down Expand Up @@ -1043,6 +1045,18 @@ def warn_mix_layout(self, kernel_name):
)
log.warning(msg)

def welford_reduce_fallback(self, dtype, value):
sum_ = ops.reduction(dtype, dtype, "sum", value)
self.inside_reduction = False
rnumel = ops.index_expr(self.numels[-1], dtype)
mean = ops.truediv(sum_, rnumel)

self.inside_reduction = True
dx = ops.sub(value, mean)
dx2 = ops.mul(dx, dx)
m2 = ops.reduction(dtype, dtype, "sum", dx2)
return OpsValue.unwrap((mean, m2, rnumel))

def codegen_kernel(self):
raise NotImplementedError

Expand Down
23 changes: 8 additions & 15 deletions torch/_inductor/codegen/triton.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,12 +47,12 @@
TensorArg,
)
from .simd import (
constant_repr,
IndexingOptions,
IterationRangesEntry,
pexpr,
SIMDKernel,
SIMDScheduling,
triton_constant,
)
from .triton_utils import config_of, signature_of, signature_to_meta

Expand Down Expand Up @@ -492,7 +492,7 @@ def to_dtype_bitcast(x, dtype: torch.dtype, src_dtype: torch.dtype):
@staticmethod
def _shaped_constant(value, dtype, shape):
type_ = torch._prims_common.dtype_to_type(dtype)
triton_val = triton_constant(type_(value))
triton_val = constant_repr(type_(value))
triton_type = triton_compute_type(dtype)

if triton_type == "tl.float32":
Expand Down Expand Up @@ -866,7 +866,7 @@ def masked(mask, body, other):
# Take dtype from result to prevent accidental promotion
other = V.kernel.cse.generate(
V.kernel.compute,
f"tl.full({result}.shape, {triton_constant(other)}, {result}.dtype)",
f"tl.full({result}.shape, {constant_repr(other)}, {result}.dtype)",
bounds=ValueRanges.wrap(other),
)
return ops.where(new_mask, result, other)
Expand Down Expand Up @@ -1342,7 +1342,7 @@ def where_cond(tval, fval):

if self.persistent_reduction:
default = ir.Reduction.default_value(reduction_type, src_dtype)
default = self._map_tuple_or_scalar(triton_constant, default)
default = self._map_tuple_or_scalar(constant_repr, default)

def _mask_value(value, default):
return self.cse.generate(self.compute, where_cond(value, default))
Expand All @@ -1367,16 +1367,7 @@ def _mask_value(value, default):
# For persistent reductions, don't bother with
# welford's algorithm since it uses more registers, and
# taking two reductions doesn't increase memory usage.
sum_ = ops.reduction(dtype, dtype, "sum", value)
self.inside_reduction = False
rnumel = ops.index_expr(self.numels[-1], dtype)
mean = ops.truediv(sum_, rnumel)

self.inside_reduction = True
dx = ops.sub(value, mean)
dx2 = ops.mul(dx, dx)
m2 = ops.reduction(dtype, dtype, "sum", dx2)
result_var = (mean, m2, rnumel)
result_var = self.welford_reduce_fallback(dtype, value)
elif reduction_type == "welford_combine":
mean, m2, weight = masked_value
welford = f"triton_helpers.welford({mean}, {m2}, {weight}, {dim})"
Expand All @@ -1394,7 +1385,7 @@ def _mask_value(value, default):
else:
accumulator = f"_{result_var}"
default = ir.Reduction.default_accumulator(reduction_type, src_dtype)
default = self._map_tuple_or_scalar(triton_constant, default)
default = self._map_tuple_or_scalar(constant_repr, default)
if not isinstance(default, tuple):
self.body.writeline(
f"{accumulator} = tl.full({self.dense_size_str()}, {default}, {acc_type})"
Expand Down Expand Up @@ -1501,8 +1492,10 @@ def _mask_value(value, default):
self.cse.reduction_cache[cache_key] = result_var

if isinstance(result_var, tuple):
assert all(isinstance(x, TritonCSEVariable) for x in result_var)
self.outside_loop_vars |= set(result_var)
else:
assert isinstance(result_var, TritonCSEVariable)
self.outside_loop_vars.add(result_var)

return result_var
Expand Down
8 changes: 8 additions & 0 deletions torch/_inductor/virtualized.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,6 +248,14 @@ def __rshfit__(self, n):
def __lshift__(self, n):
return ops.bitwise_left_shift(self, n)

@staticmethod
def unwrap(x):
if isinstance(x, OpsValue):
return x.value
if isinstance(x, (list, tuple)):
return x.__class__(map(OpsValue.unwrap, x))
return x


class OpsWrapper:
"""This wraps any returned IR values into an `OpsValue` instance, so that we
Expand Down

0 comments on commit 0379412

Please sign in to comment.