Skip to content

Commit

Permalink
Revert tl.reduce usage (#100521)
Browse files Browse the repository at this point in the history
Test Plan: sandcastle

Reviewed By: bertmaher

Differential Revision: D45513572

fbshipit-source-id: a03df851503f72313dfb50238e7d6db9239bf42e
  • Loading branch information
ezyang committed May 3, 2023
1 parent 287f74c commit db4572d
Show file tree
Hide file tree
Showing 7 changed files with 122 additions and 221 deletions.
71 changes: 11 additions & 60 deletions test/inductor/test_torchinductor.py
Original file line number Diff line number Diff line change
Expand Up @@ -824,14 +824,6 @@ def fn(a, b):
for dtype in dtypes:
self.common(fn, (torch.randn(8, 8).to(dtype), torch.randn(8, 8).to(dtype)))

def test_min_max_reduction_nan(self):
def fn(a):
return (torch.max(a), torch.min(a))

t1 = torch.randn(32)
t1[16] = float("nan")
self.common(fn, (t1,))

def test_fmin_fmax(self):
def fn(a, b):
return (
Expand Down Expand Up @@ -5100,58 +5092,17 @@ def fn(x):
aten.argmin(x, 1),
)

self.common(fn, (torch.randn([144, 144]),))

def test_argmax_argmin_with_duplicates(self):
def fn(x):
return (
aten.argmax(x, 0),
aten.argmin(x, 0),
aten.argmax(x, 1),
aten.argmin(x, 1),
)

# Unrolled reduction
t1 = torch.randint(2, size=(6, 6))
self.common(fn, (t1,))

# Persistent reduction
t1 = torch.randint(8, size=(32, 32))
self.common(fn, (t1,))

# Non-persistent reduction
t1 = torch.randint(8, size=(1028, 1028))
self.common(fn, (t1,))

def test_argmax_argmin_with_nan(self):
def fn(x):
return (
aten.argmax(x, 0),
aten.argmin(x, 0),
aten.argmax(x, 1),
aten.argmin(x, 1),
)

if self.device == "cpu":
raise unittest.SkipTest("broken on CPU")

# Unrolled reduction
t1 = torch.randn((6, 6))
t1[:, 1] = float("nan")
t1[:, 3] = float("nan")
self.common(fn, (t1,))

# Persistent reduction
t1 = torch.randn((32, 32))
t1[:, 4] = float("nan")
t1[:, 8] = float("nan")
self.common(fn, (t1,))

# Non-persistent reduction
t1 = torch.randn((1028, 1028))
t1[:, 40] = float("nan")
t1[:, 100] = float("nan")
self.common(fn, (t1,))
self.common(
fn,
[
torch.randn([144, 144]),
],
# Mismatched elements: 1 / 144 (0.7%)
# Greatest absolute difference: 26 at index (71,)
# Greatest relative difference: 0.4126984179019928 at index (71,)
atol=1e-5,
rtol=0.5,
)

def test_conv_backward(self):
def fn(rank4_inps, rank3_inps, rank5_inps):
Expand Down
24 changes: 19 additions & 5 deletions torch/_inductor/codegen/cpp.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,9 +118,7 @@ def reduction_combine(reduction_type, var, next_value):
return f"{var} ^= {next_value}"
if reduction_type == "any":
return f"{var} = {var} || {next_value}"
if reduction_type in ("min", "max"):
return f"{var} = {reduction_type}_propagate_nan({var}, {next_value})"
raise AssertionError(reduction_type)
return f"{var} = std::{reduction_type}({var}, {next_value})"


def reduction_combine_vec(reduction_type, var, next_value):
Expand Down Expand Up @@ -563,6 +561,14 @@ def minimum(a, b):
def maximum(a, b):
return f"at::vec::maximum({a}, {b})"

@staticmethod
def int_minimum(a, b):
return f"at::vec::minimum({a}, {b})"

@staticmethod
def int_maximum(a, b):
return f"at::vec::maximum({a}, {b})"

@staticmethod
def square(a):
return f"{a} * {a}"
Expand Down Expand Up @@ -835,11 +841,19 @@ def relu(x):

@staticmethod
def minimum(a, b):
return f"min_propagate_nan({a}, {b})"
return f"({b} != {b}) ? {b} : std::min({a}, {b})"

@staticmethod
def maximum(a, b):
return f"max_propagate_nan({a}, {b})"
return f"({b} != {b}) ? {b} : std::max({a}, {b})"

@staticmethod
def int_minimum(a, b):
return f"std::min({a}, {b})"

@staticmethod
def int_maximum(a, b):
return f"std::max({a}, {b})"

@staticmethod
def where(a, b, c):
Expand Down
17 changes: 0 additions & 17 deletions torch/_inductor/codegen/cpp_prefix.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
#include <limits>
#include <omp.h>

#include <ATen/NumericUtils.h>
#include <ATen/core/PhiloxRNGEngine.h>
#if defined(CPU_CAPABILITY_AVX512) || defined(CPU_CAPABILITY_AVX2)
#include <ATen/cpu/vec/functional.h>
Expand All @@ -23,22 +22,6 @@ template <typename T> inline T mod(T a, T b) { return a % b; }
template <> inline float mod(float a, float b) { return std::fmod(a, b); }
template <> inline double mod(double a, double b) { return std::fmod(a, b); }

template <typename scalar_t>
inline scalar_t max_propagate_nan(scalar_t a, scalar_t b) {
if (at::_isnan(a)) {
return a;
}
return a > b ? a : b;
}

template <typename scalar_t>
inline scalar_t min_propagate_nan(scalar_t a, scalar_t b) {
if (at::_isnan(a)) {
return a;
}
return a < b ? a : b;
}

constexpr float uint32_to_uniform_float(uint32_t value) {
// maximum value such that `MAX_INT * scale < 1.0` (with float rounding)
constexpr float scale = 4.6566127342e-10;
Expand Down
160 changes: 66 additions & 94 deletions torch/_inductor/codegen/triton.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
import torch

import torch._logging
from torch._prims_common import is_integer_dtype
from ..._dynamo import config as dynamo_config
from ..._dynamo.utils import counters
from .. import config, ir, scheduler
Expand Down Expand Up @@ -111,13 +110,6 @@ def triton_compute_type(dtype):
return f"tl.{triton_type_name}"


def triton_acc_type(dtype):
if is_integer_dtype(dtype) and dtype.is_signed:
nbits = 64 if dtype == torch.int64 else 32
return f"tl.int{nbits}"
return triton_compute_type(dtype)


def triton_constant(value):
if value == float("inf"):
return 'float("inf")'
Expand Down Expand Up @@ -216,11 +208,19 @@ def relu(x):

@staticmethod
def minimum(a, b):
return f"triton_helpers.minimum({a}, {b})"
return f"tl.where({a} != {a}, {a}, tl.where({a} < {b}, {a}, {b}))"

@staticmethod
def maximum(a, b):
return f"triton_helpers.maximum({a}, {b})"
return f"tl.where({a} != {a}, {a}, tl.where({a} > {b}, {a}, {b}))"

@staticmethod
def int_minimum(a, b):
return f"tl.where({a} < {b}, {a}, {b})"

@staticmethod
def int_maximum(a, b):
return f"tl.where({a} > {b}, {a}, {b})"

@staticmethod
def where(a, b, c):
Expand Down Expand Up @@ -389,11 +389,11 @@ def libdevice_log(x):

@staticmethod
def isinf(x):
return f"tl.math.isinf({x}).to(tl.int1)"
return f"tl.math.isinf({x})"

@staticmethod
def isnan(x):
return f"tl.math.isnan({x}).to(tl.int1)"
return f"tl.math.isnan({x})"

@staticmethod
def round(x):
Expand Down Expand Up @@ -966,18 +966,20 @@ def indexing(

expand_str = None

if isinstance(index, sympy.Integer):
expand_str = f"{copy_shape}.shape" if copy_shape else self.dense_size_str()
index_str = f"tl.full({expand_str}, {index_str}, tl.int32)"
return index_str, set(), "None", expand_str

if need_dense and not have_dense:
expand_str = f"{copy_shape}.shape" if copy_shape else self.dense_size_str()
index_str = f"tl.broadcast_to({index_str}, {expand_str})"
mask_vars = dense_mask_vars
if (need_dense and not have_dense) or isinstance(index, sympy.Integer):
if copy_shape:
index_str = f"{index_str} + tl.zeros({copy_shape}.shape, tl.int32)"
expand_str = f"{copy_shape}.shape"
else:
index_str = f"{index_str} + tl.zeros({self.dense_size_str()}, tl.int32)"
expand_str = self.dense_size_str()
if isinstance(index, sympy.Integer):
return index_str, set(), "None", expand_str
else:
mask_vars = dense_mask_vars
elif not have_loop_vars and copy_shape:
index_str = f"tl.broadcast_to({index_str}, {copy_shape}.shape)"
mask_vars = dense_mask_vars
index_str = f"{index_str} + tl.zeros({copy_shape}.shape, tl.int32)"

if override_mask:
mask_vars = {override_mask}
Expand Down Expand Up @@ -1195,102 +1197,72 @@ def reduction(self, name, dtype, src_dtype, reduction_type, index, value):
reduction_type = "max"

def final_reduction(value):
use_helper = reduction_type in {"argmax", "argmin", "max", "min", "prod"}
module = "triton_helpers" if use_helper else "tl"
module = "triton_helpers" if reduction_type in ("prod",) else "tl"
return f"{module}.{reduction_type}({value}, {dim})[{', '.join(sizes)}]"

def final_argreduce(buffer, result_var, value, index):
buffer.splice(
f"""\
_, {result_var}_tmp = triton_helpers.{root_op}_with_index({value}, {index}, {dim})
{result_var} = {result_var}_tmp[{', '.join(sizes)}]
"""
)

dim = len(self.range_trees) - 1
result_var = self.cse.newvar()
result_var.mask_vars = {var for var in masks if var[0] != "r"}
cond = " & ".join(masks)

if self.persistent_reduction:
cond = " & ".join(masks)
masked_value = self.cse.generate(
self.compute, f"tl.where({cond}, {value}, {default})"
)
if reduction_type in {"argmax", "argmin"}:
accumulator_index = self.cse.generate(
self.compute,
f"tl.broadcast_to({reduction_range_prefix}index, {masked_value}.shape)",
)
result_var = self.cse.newvar()
root_op = {"argmax": "max", "argmin": "min"}[reduction_type]
final_argreduce(
self.compute, result_var, masked_value, accumulator_index
)
else:
result_var = self.cse.generate(
self.compute, final_reduction(masked_value)
)
result_var = self.cse.generate(self.compute, final_reduction(masked_value))
elif (src_dtype, reduction_type, value) not in self.cse.reduction_cache:
self.cse.reduction_cache[(src_dtype, reduction_type, value)] = result_var
accumulator = f"_{result_var}"
default_value = f" + {default}" if default != 0 else ""
self.body.writeline(
f"{accumulator} = tl.full({self.dense_size_str()}, {default}, {triton_acc_type(src_dtype)})"
f"{accumulator} = tl.zeros({self.dense_size_str()}, {triton_compute_type(src_dtype)}){default_value}"
)

accumulator_index = None
if reduction_type in {"argmax", "argmin"}:
accumulator_index = f"_{result_var}_index"
long_max = torch.iinfo(torch.int64).max
self.body.writeline(
f"{accumulator_index} = tl.full({self.dense_size_str()}, {long_max}, tl.int64)"
f"{accumulator_index} = tl.zeros({self.dense_size_str()}, tl.int64)"
)
root_op = {"argmax": "max", "argmin": "min"}[reduction_type]

self.compute.splice(
f"""\
{accumulator}_next, {accumulator_index}_next = triton_helpers.{root_op}imum_with_index(
{accumulator}, {accumulator_index}, {value}, {reduction_range_prefix}index
)
{accumulator} = tl.where({cond}, {accumulator}_next, {accumulator})
{accumulator_index} = tl.where({cond}, {accumulator_index}_next, {accumulator_index})
"""
)
idx_dtype = self.index_dtype
final_argreduce(self.suffix, result_var, accumulator, accumulator_index)
updated = value
if reduction_type in {"min", "argmin"}:
masks.append(f"({accumulator} > {value})")
elif reduction_type in {"max", "argmax"}:
masks.append(f"({accumulator} < {value})")
elif reduction_type == "sum":
updated = f"{accumulator} + {value}"
elif reduction_type == "prod":
updated = f"{accumulator} * {value}"
elif reduction_type == "xor_sum":
updated = f"{accumulator} ^ {value}"
else:
updated = value
if reduction_type == "min":
updated = f"triton_helpers.minimum({accumulator}, {value})"
elif reduction_type == "max":
updated = f"triton_helpers.maximum({accumulator}, {value})"
elif reduction_type == "sum":
updated = f"{accumulator} + {value}"
elif reduction_type == "prod":
updated = f"{accumulator} * {value}"
elif reduction_type == "xor_sum":
updated = f"{accumulator} ^ {value}"
else:
raise NotImplementedError(f"reduction_type {reduction_type}")
raise NotImplementedError(f"reduction_type {reduction_type}")

cond = " & ".join(masks)

if accumulator_index:
# argmax or argmin
self.compute.writeline(
f"{accumulator} = tl.where({cond}, {updated}, {accumulator})"
f"{accumulator_index} = tl.where({cond}, {reduction_range_prefix}index, {accumulator_index})",
)
self.compute.writeline(
f"{accumulator} = tl.where({cond}, {updated}, {accumulator})"
)

if src_dtype == torch.bool:
# This is only really used for aten.any. It changes the
# final reduction of a non-persistent reduction from
# tmp5 = triton_helpers.max(_tmp5, 1)[:, None]
# to
# tmp5 = triton_helpers.max(_tmp5.to(tl.int8), 1)[:, None].to(tl.int1)
# which is needed because tl.reduce doesn't support tl.int1
accumulator = f"{accumulator}.to(tl.int8)"
result_type = triton_compute_type(dtype)
self.suffix.writeline(
f"{result_var} = {final_reduction(accumulator)}.to({result_type})"
)
else:
self.suffix.writeline(
f"{result_var} = {final_reduction(accumulator)}"
)
if accumulator_index:
# argmax, argmin
idx_dtype = self.index_dtype
self.suffix.writelines(
[
f"{accumulator_index}_reduce = "
f"tl.{reduction_type}({accumulator}, {dim})[{', '.join(sizes)}].to(tl.int32)",
f"{accumulator_index}_mask = tl.arange(0, {reduction_range_prefix.upper()}BLOCK)"
f"[{', '.join(reduction_sizes)}] == {accumulator_index}_reduce",
f"{result_var} = tl.sum("
f"tl.where({accumulator_index}_mask, {accumulator_index}, 0), {dim})[{', '.join(sizes)}]",
]
)
else:
self.suffix.writeline(f"{result_var} = {final_reduction(accumulator)}")
else:
var_name = self.cse.reduction_cache[(src_dtype, reduction_type, value)]
self.suffix.writeline(f"{result_var} = {var_name}")
Expand Down

0 comments on commit db4572d

Please sign in to comment.