Skip to content

Commit

Permalink
[inductor] Fix argmin/max with duplicate values
Browse files Browse the repository at this point in the history
Fixes pytorch#99879

This adds `minimum_with_index` helper functions to compute the minimum
value and index simultaneously, with a preference for the smaller
index which is required to match eager in case of duplicates.

I also remove the mask-and-sum hack with a `tl.reduce` using
the previously mentioned helper. This additionally fixes the indices
being added together in the case of duplicates.

ghstack-source-id: 73dc04c8f3a249b046f440f4f3df7c06d83f183e
Pull Request resolved: pytorch#99920
  • Loading branch information
peterbell10 committed Apr 25, 2023
1 parent 4a6edee commit 4b6d31f
Show file tree
Hide file tree
Showing 3 changed files with 140 additions and 52 deletions.
41 changes: 30 additions & 11 deletions test/inductor/test_torchinductor.py
Original file line number Diff line number Diff line change
Expand Up @@ -5014,17 +5014,36 @@ def fn(x):
aten.argmin(x, 1),
)

self.common(
fn,
[
torch.randn([144, 144]),
],
# Mismatched elements: 1 / 144 (0.7%)
# Greatest absolute difference: 26 at index (71,)
# Greatest relative difference: 0.4126984179019928 at index (71,)
atol=1e-5,
rtol=0.5,
)
self.common(fn, (torch.randn([144, 144]),))

def test_argmax_argmin_with_duplicates(self):
def fn(x):
return (
aten.argmax(x, 0),
aten.argmin(x, 0),
aten.argmax(x, 1),
aten.argmin(x, 1),
)

t1 = torch.randint(8, size=(32, 32))
self.common(fn, (t1,))

t1 = torch.randint(8, size=(1024, 1024))
self.common(fn, (t1,))

def test_argmax_argmin_with_nan(self):
def fn(x):
return (
aten.argmax(x, 0),
aten.argmin(x, 0),
aten.argmax(x, 1),
aten.argmin(x, 1),
)

t1 = torch.randn((1024, 1024))
t1[:, 40] = float('nan')
t1[:, 100] = float('nan')
self.common(fn, (t1,))

def test_conv_backward(self):
def fn(rank4_inps, rank3_inps, rank5_inps):
Expand Down
78 changes: 37 additions & 41 deletions torch/_inductor/codegen/triton.py
Original file line number Diff line number Diff line change
Expand Up @@ -1150,15 +1150,16 @@ def reduction(self, name, dtype, src_dtype, reduction_type, index, value):
reduction_type = "max"

def final_reduction(value):
use_helper = reduction_type in {"max", "min", "prod"}
use_helper = reduction_type in {"argmax", "argmin", "max", "min", "prod"}
module = "triton_helpers" if use_helper else "tl"
return f"{module}.{reduction_type}({value}, {dim})[{', '.join(sizes)}]"

dim = len(self.range_trees) - 1
result_var = self.cse.newvar()
result_var.mask_vars = {var for var in masks if var[0] != "r"}
cond = " & ".join(masks)

if self.persistent_reduction:
cond = " & ".join(masks)
masked_value = self.cse.generate(
self.compute, f"tl.where({cond}, {value}, {default})"
)
Expand All @@ -1170,54 +1171,49 @@ def final_reduction(value):
self.body.writeline(
f"{accumulator} = tl.zeros({self.dense_size_str()}, {triton_compute_type(src_dtype)}){default_value}"
)
accumulator_index = None

if reduction_type in {"argmax", "argmin"}:
accumulator_index = f"_{result_var}_index"
long_max = torch.iinfo(torch.int64).max
self.body.writeline(
f"{accumulator_index} = tl.zeros({self.dense_size_str()}, tl.int64)"
f"{accumulator_index} = tl.full({self.dense_size_str()}, {long_max}, tl.int64)"
)
root_op = {"argmax": "max", "argmin": "min"}[reduction_type]

updated = value
if reduction_type == "argmin":
masks.append(f"({accumulator} > {value})")
elif reduction_type == "argmax":
masks.append(f"({accumulator} < {value})")
elif reduction_type == "min":
updated = f"triton_helpers.minimum({accumulator}, {value})"
elif reduction_type == "max":
updated = f"triton_helpers.maximum({accumulator}, {value})"
elif reduction_type == "sum":
updated = f"{accumulator} + {value}"
elif reduction_type == "prod":
updated = f"{accumulator} * {value}"
else:
raise NotImplementedError(f"reduction_type {reduction_type}")

cond = " & ".join(masks)

if accumulator_index:
# argmax or argmin
self.compute.writeline(
f"{accumulator_index} = tl.where({cond}, {reduction_range_prefix}index, {accumulator_index})",
self.compute.splice(
f"""
{accumulator}_next, {accumulator_index}_next = triton_helpers.{root_op}imum_with_index(
{accumulator}, {accumulator_index}, {value}, {reduction_range_prefix}index
)
{accumulator} = tl.where({cond}, {accumulator}_next, {accumulator})
{accumulator_index} = tl.where({cond}, {accumulator_index}_next, {accumulator_index})
"""
)
self.compute.writeline(
f"{accumulator} = tl.where({cond}, {updated}, {accumulator})"
)

if accumulator_index:
# argmax, argmin
idx_dtype = self.index_dtype
self.suffix.writelines(
[
f"{accumulator_index}_reduce = "
f"tl.{reduction_type}({accumulator}, {dim})[{', '.join(sizes)}].to(tl.int32)",
f"{accumulator_index}_mask = tl.arange(0, {reduction_range_prefix.upper()}BLOCK)"
f"[{', '.join(reduction_sizes)}] == {accumulator_index}_reduce",
f"{result_var} = tl.sum("
f"tl.where({accumulator_index}_mask, {accumulator_index}, 0), {dim})[{', '.join(sizes)}]",
]
self.suffix.splice(
f"""
_, {result_var}_tmp = triton_helpers.{root_op}_with_index(
{accumulator}, {accumulator_index}, {dim}
)
{result_var} = {result_var}_tmp[{', '.join(sizes)}]
"""
)
else:
updated = value
if reduction_type == "min":
updated = f"triton_helpers.minimum({accumulator}, {value})"
elif reduction_type == "max":
updated = f"triton_helpers.maximum({accumulator}, {value})"
elif reduction_type == "sum":
updated = f"{accumulator} + {value}"
elif reduction_type == "prod":
updated = f"{accumulator} * {value}"
else:
raise NotImplementedError(f"reduction_type {reduction_type}")

self.compute.writeline(
f"{accumulator} = tl.where({cond}, {updated}, {accumulator})"
)
self.suffix.writeline(f"{result_var} = {final_reduction(accumulator)}")
else:
var_name = self.cse.reduction_cache[(src_dtype, reduction_type, value)]
Expand Down
73 changes: 73 additions & 0 deletions torch/_inductor/triton_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,3 +43,76 @@ def min(a, dim):
@triton.jit
def max(a, dim):
return tl.reduce(a, dim, maximum)


@triton.jit
def minimum_with_index(a_value, a_index, b_value, b_index):
mask = a_value < b_value
equal = a_value == b_value
if is_floating(a_value):
a_isnan = (a_value != a_value)
b_isnan = (b_value != b_value)
mask |= a_isnan and not b_isnan
# Consider NaNs as equal
equal |= a_isnan and b_isnan
else:
equal = a_value == b_value

# Prefer lowest index if values are equal
mask |= equal & (a_index < b_index)
return tl.where(mask, a_value, b_value), tl.where(mask, a_index, b_index)


@triton.jit
def maximum_with_index(a_value, a_index, b_value, b_index):
mask = a_value > b_value
equal = a_value == b_value
if is_floating(a_value):
a_isnan = (a_value != a_value)
b_isnan = (b_value != b_value)
mask |= a_isnan and not b_isnan
# Consider NaNs as equal
equal |= a_isnan and b_isnan
else:
equal = a_value == b_value

# Prefer lowest index if values are equal
mask |= equal & (a_index < b_index)
return tl.where(mask, a_value, b_value), tl.where(mask, a_index, b_index)


@triton.jit
def min_with_index(value, index, dim):
return tl.reduce((value, index), dim, minimum_with_index)


@triton.jit
def max_with_index(value, index, dim):
return tl.reduce((value, index), dim, maximum_with_index)


@triton.jit
def make_index_tensor_for(input, dim):
n = input.shape[dim]
index = tl.arange(0, n)

if len(input.shape) > 1:
# Broadcast index across the non-reduced axes
expand_dims_index = [None] * len(input.shape)
expand_dims_index[dim] = slice(None)
index = index[expand_dims_index]
index = tl.broadcast_to(index, input.shape)

return index


@triton.jit
def argmax(value, dim):
index = make_index_tensor_for(value, dim)
return max_with_index(value, index, dim)


@triton.jit
def argmin(value, dim):
index = make_index_tensor_for(value, dim)
return min_with_index(value, index, dim)

0 comments on commit 4b6d31f

Please sign in to comment.