Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[inductor] Fix argmin/max with duplicate values #99920

Closed
wants to merge 15 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
63 changes: 52 additions & 11 deletions test/inductor/test_torchinductor.py
Original file line number Diff line number Diff line change
Expand Up @@ -5100,17 +5100,58 @@ def fn(x):
aten.argmin(x, 1),
)

self.common(
fn,
[
torch.randn([144, 144]),
],
# Mismatched elements: 1 / 144 (0.7%)
# Greatest absolute difference: 26 at index (71,)
# Greatest relative difference: 0.4126984179019928 at index (71,)
atol=1e-5,
rtol=0.5,
)
self.common(fn, (torch.randn([144, 144]),))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

great that this test is fixed!


def test_argmax_argmin_with_duplicates(self):
def fn(x):
return (
aten.argmax(x, 0),
aten.argmin(x, 0),
aten.argmax(x, 1),
aten.argmin(x, 1),
)

# Unrolled reduction
t1 = torch.randint(2, size=(6, 6))
self.common(fn, (t1,))

# Persistent reduction
t1 = torch.randint(8, size=(32, 32))
self.common(fn, (t1,))

# Non-persistent reduction
t1 = torch.randint(8, size=(1028, 1028))
self.common(fn, (t1,))

def test_argmax_argmin_with_nan(self):
def fn(x):
return (
aten.argmax(x, 0),
aten.argmin(x, 0),
aten.argmax(x, 1),
aten.argmin(x, 1),
)

if self.device == "cpu":
raise unittest.SkipTest("broken on CPU")

# Unrolled reduction
t1 = torch.randn((6, 6))
t1[:, 1] = float("nan")
t1[:, 3] = float("nan")
self.common(fn, (t1,))

# Persistent reduction
t1 = torch.randn((32, 32))
t1[:, 4] = float("nan")
t1[:, 8] = float("nan")
self.common(fn, (t1,))

# Non-persistent reduction
t1 = torch.randn((1028, 1028))
t1[:, 40] = float("nan")
t1[:, 100] = float("nan")
self.common(fn, (t1,))

def test_conv_backward(self):
def fn(rank4_inps, rank3_inps, rank5_inps):
Expand Down
100 changes: 57 additions & 43 deletions torch/_inductor/codegen/triton.py
Original file line number Diff line number Diff line change
Expand Up @@ -1195,72 +1195,86 @@ def final_reduction(value):
return f"{module}.{reduction_type}2({value}, {dim})[{', '.join(sizes)}]"
return f"{module}.{reduction_type}({value}, {dim})[{', '.join(sizes)}]"

def final_argreduce(buffer, result_var, value, index):
buffer.splice(
f"""\
_, {result_var}_tmp = triton_helpers.{root_op}_with_index({value}, {index}, {dim})
{result_var} = {result_var}_tmp[{', '.join(sizes)}]
"""
)

dim = len(self.range_trees) - 1
result_var = self.cse.newvar()
result_var.mask_vars = {var for var in masks if var[0] != "r"}
cond = " & ".join(masks)

if self.persistent_reduction:
cond = " & ".join(masks)
masked_value = self.cse.generate(
self.compute, f"tl.where({cond}, {value}, {default})"
)
result_var = self.cse.generate(self.compute, final_reduction(masked_value))
if reduction_type in {"argmax", "argmin"}:
accumulator_index = self.cse.generate(
self.compute,
f"tl.broadcast_to({reduction_range_prefix}index, {masked_value}.shape)",
)
result_var = self.cse.newvar()
root_op = {"argmax": "max", "argmin": "min"}[reduction_type]
final_argreduce(
self.compute, result_var, masked_value, accumulator_index
)
else:
result_var = self.cse.generate(
self.compute, final_reduction(masked_value)
)
elif (src_dtype, reduction_type, value) not in self.cse.reduction_cache:
self.cse.reduction_cache[(src_dtype, reduction_type, value)] = result_var
accumulator = f"_{result_var}"
# NOTE: We should be using tl.full here, but this also does type
# promotion e.g. bool to int32, which is sometimes necessary if
# similar promotion happened elsewhere in the pre-reduction
# operation. We should identify any such cases and fix them.
default_value = f" + {default}" if default != 0 else ""
self.body.writeline(
f"{accumulator} = tl.zeros({self.dense_size_str()}, {triton_compute_type(src_dtype)}){default_value}"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: can you replace this with tl.full?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Turns out that doing this uncovers some bugs that I'd rather not dig into in this PR. Essentially some boolean ops actually return int32, but this was masked by tl.zeros() + x promoting the accumulator to integer anyway.

)
accumulator_index = None

if reduction_type in {"argmax", "argmin"}:
accumulator_index = f"_{result_var}_index"
long_max = torch.iinfo(torch.int64).max
self.body.writeline(
f"{accumulator_index} = tl.zeros({self.dense_size_str()}, tl.int64)"
f"{accumulator_index} = tl.full({self.dense_size_str()}, {long_max}, tl.int64)"
)
root_op = {"argmax": "max", "argmin": "min"}[reduction_type]

updated = value
if reduction_type == "argmin":
masks.append(f"({accumulator} > {value})")
elif reduction_type == "argmax":
masks.append(f"({accumulator} < {value})")
elif reduction_type == "min":
updated = f"triton_helpers.minimum({accumulator}, {value})"
elif reduction_type == "max":
updated = f"triton_helpers.maximum({accumulator}, {value})"
elif reduction_type == "sum":
updated = f"{accumulator} + {value}"
elif reduction_type == "prod":
updated = f"{accumulator} * {value}"
elif reduction_type == "xor_sum":
updated = f"{accumulator} ^ {value}"
self.compute.splice(
f"""\
{accumulator}_next, {accumulator_index}_next = triton_helpers.{root_op}imum_with_index(
{accumulator}, {accumulator_index}, {value}, {reduction_range_prefix}index
)
{accumulator} = tl.where({cond}, {accumulator}_next, {accumulator})
{accumulator_index} = tl.where({cond}, {accumulator_index}_next, {accumulator_index})
"""
)
idx_dtype = self.index_dtype
final_argreduce(self.suffix, result_var, accumulator, accumulator_index)
else:
raise NotImplementedError(f"reduction_type {reduction_type}")

cond = " & ".join(masks)
updated = value
if reduction_type == "min":
updated = f"triton_helpers.minimum({accumulator}, {value})"
elif reduction_type == "max":
updated = f"triton_helpers.maximum({accumulator}, {value})"
elif reduction_type == "sum":
updated = f"{accumulator} + {value}"
elif reduction_type == "prod":
updated = f"{accumulator} * {value}"
elif reduction_type == "xor_sum":
updated = f"{accumulator} ^ {value}"
else:
raise NotImplementedError(f"reduction_type {reduction_type}")

if accumulator_index:
# argmax or argmin
self.compute.writeline(
f"{accumulator_index} = tl.where({cond}, {reduction_range_prefix}index, {accumulator_index})",
f"{accumulator} = tl.where({cond}, {updated}, {accumulator})"
)
self.compute.writeline(
f"{accumulator} = tl.where({cond}, {updated}, {accumulator})"
)

if accumulator_index:
# argmax, argmin
idx_dtype = self.index_dtype
self.suffix.writelines(
[
f"{accumulator_index}_reduce = "
f"tl.{reduction_type}({accumulator}, {dim})[{', '.join(sizes)}].to(tl.int32)",
f"{accumulator_index}_mask = tl.arange(0, {reduction_range_prefix.upper()}BLOCK)"
f"[{', '.join(reduction_sizes)}] == {accumulator_index}_reduce",
f"{result_var} = tl.sum("
f"tl.where({accumulator_index}_mask, {accumulator_index}, 0), {dim})[{', '.join(sizes)}]",
]
)
else:
self.suffix.writeline(f"{result_var} = {final_reduction(accumulator)}")
else:
var_name = self.cse.reduction_cache[(src_dtype, reduction_type, value)]
Expand Down
24 changes: 20 additions & 4 deletions torch/_inductor/ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -724,15 +724,31 @@ def combine_fn(a, b):
elif reduction_type == "argmin":

def combine_fn(a, b):
return ops.minimum(a[0], b[0]), ops.where(
ops.lt(b[0], a[0]), b[1], a[1]
a_value, a_index = a
b_value, b_index = b
mask = ops.lt(b_value, a_value)
a_isnan = ops.ne(a_value, a_value)
b_isnan = ops.ne(b_value, b_value)
mask = ops.logical_or(mask, ops.gt(b_isnan, a_isnan))

return (
ops.where(mask, b_value, a_value),
ops.where(mask, b_index, a_index),
)

elif reduction_type == "argmax":

def combine_fn(a, b):
return ops.maximum(a[0], b[0]), ops.where(
ops.gt(b[0], a[0]), b[1], a[1]
a_value, a_index = a
b_value, b_index = b
mask = ops.gt(b_value, a_value)
a_isnan = ops.ne(a_value, a_value)
b_isnan = ops.ne(b_value, b_value)
mask = ops.logical_or(mask, ops.gt(b_isnan, a_isnan))

return (
ops.where(mask, b_value, a_value),
ops.where(mask, b_index, a_index),
)

else:
Expand Down
48 changes: 42 additions & 6 deletions torch/_inductor/triton_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,14 +100,50 @@ def maximum_with_index(a_value, a_index, b_value, b_index):
return tl.where(mask, a_value, b_value), tl.where(mask, a_index, b_index)


@triton.jit
def min_with_index(value, index, dim):
return tl.reduce((value, index), dim, minimum_with_index)
if TRITON_HAS_REDUCE:

@triton.jit
def min_with_index(value, index, dim):
return tl.reduce((value, index), dim, minimum_with_index)

@triton.jit
def max_with_index(value, index, dim):
return tl.reduce((value, index), dim, maximum_with_index)
@triton.jit
def max_with_index(value, index, dim):
return tl.reduce((value, index), dim, maximum_with_index)

else:

@triton.jit
def _argreduce_index(reduction_result, value, index, dim):
reduction_result_keepdim = reduction_result[None, :]
if dim == 0:
pass
elif dim == 1:
reduction_result_keepdim = reduction_result[:, None]
else:
tl.device_assert(False)

equal = value == reduction_result_keepdim
if is_floating(value):
# Treat nan as equal
result_is_nan = reduction_result_keepdim != reduction_result_keepdim
equal |= (value != value) and result_is_nan

invalid_index = 2**62
indices = tl.where(equal, index, invalid_index)
index = tl.min(indices, dim)
return index

@triton.jit
def min_with_index(value, index, dim):
min_values = min2(value, dim)
min_index = _argreduce_index(min_values, value, index, dim)
return min_values, min_index

@triton.jit
def max_with_index(value, index, dim):
max_values = max2(value, dim)
max_index = _argreduce_index(max_values, value, index, dim)
return max_values, max_index


@triton.jit
Expand Down