Skip to content

Commit

Permalink
[inductor] Fix argmin/max with duplicate values
Browse files Browse the repository at this point in the history
Fixes pytorch#99879

This adds `minimum_with_index` helper functions to compute the minimum
value and index simultaneously, with a preference for the smaller
index which is required to match eager in case of duplicates.

I also remove the mask-and-sum hack with a `tl.reduce` using
the previously mentioned helper. This additionally fixes the indices
being added together in the case of duplicates.

ghstack-source-id: 47407b98133251392fbe329fa4b741eebc4000e2
Pull Request resolved: pytorch#99920
  • Loading branch information
peterbell10 committed Apr 25, 2023
1 parent dba70f2 commit 0743311
Show file tree
Hide file tree
Showing 4 changed files with 160 additions and 57 deletions.
63 changes: 52 additions & 11 deletions test/inductor/test_torchinductor.py
Original file line number Diff line number Diff line change
Expand Up @@ -5014,17 +5014,58 @@ def fn(x):
aten.argmin(x, 1),
)

self.common(
fn,
[
torch.randn([144, 144]),
],
# Mismatched elements: 1 / 144 (0.7%)
# Greatest absolute difference: 26 at index (71,)
# Greatest relative difference: 0.4126984179019928 at index (71,)
atol=1e-5,
rtol=0.5,
)
self.common(fn, (torch.randn([144, 144]),))

def test_argmax_argmin_with_duplicates(self):
def fn(x):
return (
aten.argmax(x, 0),
aten.argmin(x, 0),
aten.argmax(x, 1),
aten.argmin(x, 1),
)

# Unrolled reduction
t1 = torch.randint(2, size=(6, 6))
self.common(fn, (t1,))

# Persistent reduction
t1 = torch.randint(8, size=(32, 32))
self.common(fn, (t1,))

# Non-persistent reduction
t1 = torch.randint(8, size=(1024, 1024))
self.common(fn, (t1,))

def test_argmax_argmin_with_nan(self):
def fn(x):
return (
aten.argmax(x, 0),
aten.argmin(x, 0),
aten.argmax(x, 1),
aten.argmin(x, 1),
)

# Unrolled reduction
t1 = torch.randn((6, 6))
t1[:, 1] = float("nan")
t1[:, 3] = float("nan")
self.common(fn, (t1,))

if self.device == "cpu":
raise unittest.SkipTest("broken on CPU")

# Persistent reduction
t1 = torch.randn((32, 32))
t1[:, 4] = float("nan")
t1[:, 8] = float("nan")
self.common(fn, (t1,))

# Non-persistent reduction
t1 = torch.randn((1024, 1024))
t1[:, 40] = float("nan")
t1[:, 100] = float("nan")
self.common(fn, (t1,))

def test_conv_backward(self):
def fn(rank4_inps, rank3_inps, rank5_inps):
Expand Down
97 changes: 53 additions & 44 deletions torch/_inductor/codegen/triton.py
Original file line number Diff line number Diff line change
Expand Up @@ -1150,74 +1150,83 @@ def reduction(self, name, dtype, src_dtype, reduction_type, index, value):
reduction_type = "max"

def final_reduction(value):
use_helper = reduction_type in {"max", "min", "prod"}
use_helper = reduction_type in {"argmax", "argmin", "max", "min", "prod"}
module = "triton_helpers" if use_helper else "tl"
return f"{module}.{reduction_type}({value}, {dim})[{', '.join(sizes)}]"

def final_argreduce(buffer, result_var, value, index):
buffer.splice(
f"""\
_, {result_var}_tmp = triton_helpers.{root_op}_with_index({value}, {index}, {dim})
{result_var} = {result_var}_tmp[{', '.join(sizes)}]
"""
)

dim = len(self.range_trees) - 1
result_var = self.cse.newvar()
result_var.mask_vars = {var for var in masks if var[0] != "r"}
cond = " & ".join(masks)

if self.persistent_reduction:
cond = " & ".join(masks)
masked_value = self.cse.generate(
self.compute, f"tl.where({cond}, {value}, {default})"
)
result_var = self.cse.generate(self.compute, final_reduction(masked_value))
if reduction_type in {"argmax", "argmin"}:
accumulator_index = self.cse.generate(
self.compute,
f"tl.broadcast_to({reduction_range_prefix}index, {masked_value}.shape)",
)
result_var = self.cse.newvar()
root_op = {"argmax": "max", "argmin": "min"}[reduction_type]
final_argreduce(
self.compute, result_var, masked_value, accumulator_index
)
else:
result_var = self.cse.generate(
self.compute, final_reduction(masked_value)
)
elif (src_dtype, reduction_type, value) not in self.cse.reduction_cache:
self.cse.reduction_cache[(src_dtype, reduction_type, value)] = result_var
accumulator = f"_{result_var}"
default_value = f" + {default}" if default != 0 else ""
self.body.writeline(
f"{accumulator} = tl.zeros({self.dense_size_str()}, {triton_compute_type(src_dtype)}){default_value}"
f"{accumulator} = tl.full({self.dense_size_str()}, {default}, {triton_compute_type(src_dtype)})"
)
accumulator_index = None

if reduction_type in {"argmax", "argmin"}:
accumulator_index = f"_{result_var}_index"
long_max = torch.iinfo(torch.int64).max
self.body.writeline(
f"{accumulator_index} = tl.zeros({self.dense_size_str()}, tl.int64)"
f"{accumulator_index} = tl.full({self.dense_size_str()}, {long_max}, tl.int64)"
)
root_op = {"argmax": "max", "argmin": "min"}[reduction_type]

updated = value
if reduction_type == "argmin":
masks.append(f"({accumulator} > {value})")
elif reduction_type == "argmax":
masks.append(f"({accumulator} < {value})")
elif reduction_type == "min":
updated = f"triton_helpers.minimum({accumulator}, {value})"
elif reduction_type == "max":
updated = f"triton_helpers.maximum({accumulator}, {value})"
elif reduction_type == "sum":
updated = f"{accumulator} + {value}"
elif reduction_type == "prod":
updated = f"{accumulator} * {value}"
self.compute.splice(
f"""\
{accumulator}_next, {accumulator_index}_next = triton_helpers.{root_op}imum_with_index(
{accumulator}, {accumulator_index}, {value}, {reduction_range_prefix}index
)
{accumulator} = tl.where({cond}, {accumulator}_next, {accumulator})
{accumulator_index} = tl.where({cond}, {accumulator_index}_next, {accumulator_index})
"""
)
idx_dtype = self.index_dtype
final_argreduce(self.suffix, result_var, accumulator, accumulator_index)
else:
raise NotImplementedError(f"reduction_type {reduction_type}")

cond = " & ".join(masks)
updated = value
if reduction_type == "min":
updated = f"triton_helpers.minimum({accumulator}, {value})"
elif reduction_type == "max":
updated = f"triton_helpers.maximum({accumulator}, {value})"
elif reduction_type == "sum":
updated = f"{accumulator} + {value}"
elif reduction_type == "prod":
updated = f"{accumulator} * {value}"
else:
raise NotImplementedError(f"reduction_type {reduction_type}")

if accumulator_index:
# argmax or argmin
self.compute.writeline(
f"{accumulator_index} = tl.where({cond}, {reduction_range_prefix}index, {accumulator_index})",
f"{accumulator} = tl.where({cond}, {updated}, {accumulator})"
)
self.compute.writeline(
f"{accumulator} = tl.where({cond}, {updated}, {accumulator})"
)

if accumulator_index:
# argmax, argmin
idx_dtype = self.index_dtype
self.suffix.writelines(
[
f"{accumulator_index}_reduce = "
f"tl.{reduction_type}({accumulator}, {dim})[{', '.join(sizes)}].to(tl.int32)",
f"{accumulator_index}_mask = tl.arange(0, {reduction_range_prefix.upper()}BLOCK)"
f"[{', '.join(reduction_sizes)}] == {accumulator_index}_reduce",
f"{result_var} = tl.sum("
f"tl.where({accumulator_index}_mask, {accumulator_index}, 0), {dim})[{', '.join(sizes)}]",
]
)
else:
self.suffix.writeline(f"{result_var} = {final_reduction(accumulator)}")
else:
var_name = self.cse.reduction_cache[(src_dtype, reduction_type, value)]
Expand Down
11 changes: 9 additions & 2 deletions torch/_inductor/ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -707,8 +707,15 @@ def combine_fn(a, b):
elif reduction_type == "argmax":

def combine_fn(a, b):
return ops.maximum(a[0], b[0]), ops.where(
ops.gt(b[0], a[0]), b[1], a[1]
a_value, a_index = a
b_value, b_index = b
mask = ops.gt(a_value, b_value)
isnan = ops.neq(a_value, a_value)
mask = ops.logical_or(mask, isnan)

return (
ops.where(mask, a_value, b_value),
ops.where(mask, a_index, b_index),
)

else:
Expand Down
46 changes: 46 additions & 0 deletions torch/_inductor/triton_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,3 +43,49 @@ def min(a, dim):
@triton.jit
def max(a, dim):
return tl.reduce(a, dim, maximum)


@triton.jit
def minimum_with_index(a_value, a_index, b_value, b_index):
mask = a_value < b_value
equal = a_value == b_value
if is_floating(a_value):
a_isnan = a_value != a_value
b_isnan = b_value != b_value
mask |= a_isnan and not b_isnan
# Consider NaNs as equal
equal |= a_isnan and b_isnan
else:
equal = a_value == b_value

# Prefer lowest index if values are equal
mask |= equal & (a_index < b_index)
return tl.where(mask, a_value, b_value), tl.where(mask, a_index, b_index)


@triton.jit
def maximum_with_index(a_value, a_index, b_value, b_index):
mask = a_value > b_value
equal = a_value == b_value
if is_floating(a_value):
a_isnan = a_value != a_value
b_isnan = b_value != b_value
mask |= a_isnan and not b_isnan
# Consider NaNs as equal
equal |= a_isnan and b_isnan
else:
equal = a_value == b_value

# Prefer lowest index if values are equal
mask |= equal & (a_index < b_index)
return tl.where(mask, a_value, b_value), tl.where(mask, a_index, b_index)


@triton.jit
def min_with_index(value, index, dim):
return tl.reduce((value, index), dim, minimum_with_index)


@triton.jit
def max_with_index(value, index, dim):
return tl.reduce((value, index), dim, maximum_with_index)

0 comments on commit 0743311

Please sign in to comment.