-
Notifications
You must be signed in to change notification settings - Fork 21.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[inductor] Fix argmin/max with duplicate values #99920
Changes from all commits
cdf8cff
a18bada
8cdafd4
ebe7b16
cef8a3f
ed19f97
4e5c903
90356de
d83317d
7ecfb27
ba7b2a1
0ae1451
902ea5e
1a0025b
54312fc
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -1195,72 +1195,86 @@ def final_reduction(value): | |
return f"{module}.{reduction_type}2({value}, {dim})[{', '.join(sizes)}]" | ||
return f"{module}.{reduction_type}({value}, {dim})[{', '.join(sizes)}]" | ||
|
||
def final_argreduce(buffer, result_var, value, index): | ||
buffer.splice( | ||
f"""\ | ||
_, {result_var}_tmp = triton_helpers.{root_op}_with_index({value}, {index}, {dim}) | ||
{result_var} = {result_var}_tmp[{', '.join(sizes)}] | ||
""" | ||
) | ||
|
||
dim = len(self.range_trees) - 1 | ||
result_var = self.cse.newvar() | ||
result_var.mask_vars = {var for var in masks if var[0] != "r"} | ||
cond = " & ".join(masks) | ||
|
||
if self.persistent_reduction: | ||
cond = " & ".join(masks) | ||
masked_value = self.cse.generate( | ||
self.compute, f"tl.where({cond}, {value}, {default})" | ||
) | ||
result_var = self.cse.generate(self.compute, final_reduction(masked_value)) | ||
if reduction_type in {"argmax", "argmin"}: | ||
accumulator_index = self.cse.generate( | ||
self.compute, | ||
f"tl.broadcast_to({reduction_range_prefix}index, {masked_value}.shape)", | ||
) | ||
result_var = self.cse.newvar() | ||
root_op = {"argmax": "max", "argmin": "min"}[reduction_type] | ||
final_argreduce( | ||
self.compute, result_var, masked_value, accumulator_index | ||
) | ||
else: | ||
result_var = self.cse.generate( | ||
self.compute, final_reduction(masked_value) | ||
) | ||
elif (src_dtype, reduction_type, value) not in self.cse.reduction_cache: | ||
self.cse.reduction_cache[(src_dtype, reduction_type, value)] = result_var | ||
accumulator = f"_{result_var}" | ||
# NOTE: We should be using tl.full here, but this also does type | ||
# promotion e.g. bool to int32, which is sometimes necessary if | ||
# similar promotion happened elsewhere in the pre-reduction | ||
# operation. We should identify any such cases and fix them. | ||
default_value = f" + {default}" if default != 0 else "" | ||
self.body.writeline( | ||
f"{accumulator} = tl.zeros({self.dense_size_str()}, {triton_compute_type(src_dtype)}){default_value}" | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: can you replace this with There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Turns out that doing this uncovers some bugs that I'd rather not dig into in this PR. Essentially some boolean |
||
) | ||
accumulator_index = None | ||
|
||
if reduction_type in {"argmax", "argmin"}: | ||
accumulator_index = f"_{result_var}_index" | ||
long_max = torch.iinfo(torch.int64).max | ||
self.body.writeline( | ||
f"{accumulator_index} = tl.zeros({self.dense_size_str()}, tl.int64)" | ||
f"{accumulator_index} = tl.full({self.dense_size_str()}, {long_max}, tl.int64)" | ||
) | ||
root_op = {"argmax": "max", "argmin": "min"}[reduction_type] | ||
|
||
updated = value | ||
if reduction_type == "argmin": | ||
masks.append(f"({accumulator} > {value})") | ||
elif reduction_type == "argmax": | ||
masks.append(f"({accumulator} < {value})") | ||
elif reduction_type == "min": | ||
updated = f"triton_helpers.minimum({accumulator}, {value})" | ||
elif reduction_type == "max": | ||
updated = f"triton_helpers.maximum({accumulator}, {value})" | ||
elif reduction_type == "sum": | ||
updated = f"{accumulator} + {value}" | ||
elif reduction_type == "prod": | ||
updated = f"{accumulator} * {value}" | ||
elif reduction_type == "xor_sum": | ||
updated = f"{accumulator} ^ {value}" | ||
self.compute.splice( | ||
f"""\ | ||
{accumulator}_next, {accumulator_index}_next = triton_helpers.{root_op}imum_with_index( | ||
{accumulator}, {accumulator_index}, {value}, {reduction_range_prefix}index | ||
) | ||
{accumulator} = tl.where({cond}, {accumulator}_next, {accumulator}) | ||
{accumulator_index} = tl.where({cond}, {accumulator_index}_next, {accumulator_index}) | ||
""" | ||
) | ||
idx_dtype = self.index_dtype | ||
final_argreduce(self.suffix, result_var, accumulator, accumulator_index) | ||
else: | ||
raise NotImplementedError(f"reduction_type {reduction_type}") | ||
|
||
cond = " & ".join(masks) | ||
updated = value | ||
if reduction_type == "min": | ||
updated = f"triton_helpers.minimum({accumulator}, {value})" | ||
elif reduction_type == "max": | ||
updated = f"triton_helpers.maximum({accumulator}, {value})" | ||
elif reduction_type == "sum": | ||
updated = f"{accumulator} + {value}" | ||
elif reduction_type == "prod": | ||
updated = f"{accumulator} * {value}" | ||
elif reduction_type == "xor_sum": | ||
updated = f"{accumulator} ^ {value}" | ||
else: | ||
raise NotImplementedError(f"reduction_type {reduction_type}") | ||
|
||
if accumulator_index: | ||
# argmax or argmin | ||
self.compute.writeline( | ||
f"{accumulator_index} = tl.where({cond}, {reduction_range_prefix}index, {accumulator_index})", | ||
f"{accumulator} = tl.where({cond}, {updated}, {accumulator})" | ||
) | ||
self.compute.writeline( | ||
f"{accumulator} = tl.where({cond}, {updated}, {accumulator})" | ||
) | ||
|
||
if accumulator_index: | ||
# argmax, argmin | ||
idx_dtype = self.index_dtype | ||
self.suffix.writelines( | ||
[ | ||
f"{accumulator_index}_reduce = " | ||
f"tl.{reduction_type}({accumulator}, {dim})[{', '.join(sizes)}].to(tl.int32)", | ||
f"{accumulator_index}_mask = tl.arange(0, {reduction_range_prefix.upper()}BLOCK)" | ||
f"[{', '.join(reduction_sizes)}] == {accumulator_index}_reduce", | ||
f"{result_var} = tl.sum(" | ||
f"tl.where({accumulator_index}_mask, {accumulator_index}, 0), {dim})[{', '.join(sizes)}]", | ||
] | ||
) | ||
else: | ||
self.suffix.writeline(f"{result_var} = {final_reduction(accumulator)}") | ||
else: | ||
var_name = self.cse.reduction_cache[(src_dtype, reduction_type, value)] | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
great that this test is fixed!