Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[te] Ban uint8 tensors from fusion groups #49247

Closed
wants to merge 7 commits into from
138 changes: 28 additions & 110 deletions test/test_jit_fuser_te.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,19 @@ def setUp(self):
torch._C._jit_set_texpr_fuser_enabled(True)

self.devices = ['cpu'] if not torch.cuda.is_available() else ['cpu', 'cuda']
self.int_dtypes = [
torch.int8,
torch.int16,
torch.int32,
torch.int64,
torch.bool,
]
self.fp_dtypes = [
torch.float16,
torch.float32,
torch.float64,
]
self.dtypes = self.int_dtypes + self.fp_dtypes

def tearDown(self):
torch._C._jit_set_profiling_executor(self.old_profiling_executor)
Expand Down Expand Up @@ -461,21 +474,13 @@ def test_bitwise_ops(self):
def apply(fn):
return lambda x, y, z: fn(fn(x, y), z)

dtypes = [
torch.int8,
torch.uint8,
torch.int16,
torch.int32,
torch.int64,
torch.bool,
]
binary_ops = [
operator.__and__,
operator.__or__,
operator.__xor__
]
devices = self.devices
for dtype, op, device in product(dtypes, binary_ops, devices):
for dtype, op, device in product(self.int_dtypes, binary_ops, devices):
try:
x = self.data_for(dtype, device)
y = self.data_for(dtype, device)
Expand All @@ -500,20 +505,12 @@ def test_minmax_int_ops(self):
def apply(fn):
return lambda x, y, z: fn(fn(x, y), z)

dtypes = [
torch.int8,
torch.uint8,
torch.int16,
torch.int32,
torch.int64,
torch.bool,
]
binary_ops = [
torch.min,
torch.max
]
devices = self.devices
for dtype, op, device in product(dtypes, binary_ops, devices):
for dtype, op, device in product(self.int_dtypes, binary_ops, devices):
try:
x = self.data_for(dtype, device)
y = self.data_for(dtype, device)
Expand Down Expand Up @@ -1215,17 +1212,6 @@ def test_unary_ops(self):
def apply(fn):
return lambda x: fn(x)

dtypes = [
torch.int8,
torch.uint8,
torch.int16,
torch.int32,
torch.int64,
torch.float16,
torch.float32,
torch.float64,
torch.bool,
]
unary_ops = [
torch.lgamma,
torch.sigmoid,
Expand Down Expand Up @@ -1262,7 +1248,7 @@ def apply(fn):
lambda x: torch.clamp(x, -10, 10),
]
sizes = [(1,), (2,), (4, 4)]
for dtype, op, device, size in product(dtypes, unary_ops, self.devices, sizes):
for dtype, op, device, size in product(self.dtypes, unary_ops, self.devices, sizes):
try:
x = self.data_for(dtype, device, size=size)
fn = apply(op)
Expand All @@ -1286,18 +1272,7 @@ def test_binary_ops(self):
def apply(fn):
return lambda x, y: fn(x, y)

dtypes = [
# FIXME: Fails in IR Eval: torch.int8 and_ cpu
torch.int8,
torch.uint8,
torch.int16,
torch.int32,
torch.int64,
torch.float16,
torch.float32,
torch.float64,
torch.bool,
]
# FIXME: Fails in IR Eval: torch.int8 and_ cpu
binary_ops = [
operator.__and__,
operator.__or__,
Expand Down Expand Up @@ -1329,7 +1304,7 @@ def apply(fn):
torch.remainder,
]
devices = self.devices
for dtype, op, device in product(dtypes, binary_ops, devices):
for dtype, op, device in product(self.dtypes, binary_ops, devices):
try:
x = self.data_for(dtype, device)
y = self.data_for(dtype, device)
Expand All @@ -1355,18 +1330,7 @@ def test_binary_tensor_scalar_ops(self):
def apply_with_scalar(fn, scalar):
return lambda x: fn(x, scalar)

dtypes = [
torch.int8,
torch.uint8,
torch.int16,
torch.int32,
# FIXME: Fails in IR Eval: torch.int64 and_ cpu
torch.int64,
torch.float16,
torch.float32,
torch.float64,
torch.bool
]
# FIXME: Fails in IR Eval: torch.int64 and_ cpu
binary_ops = [
operator.__and__,
operator.__or__,
Expand All @@ -1376,11 +1340,9 @@ def apply_with_scalar(fn, scalar):
torch.mul,
torch.eq,
torch.ne,

# FIXME: fails with dtype=uint8, scalar=-1
# torch.ge,
# torch.lt,
# torch.gt,
torch.ge,
torch.lt,
torch.gt,

# FIXME: segfaults on CPU backend
# operator.__rshift__,
Expand All @@ -1390,7 +1352,7 @@ def apply_with_scalar(fn, scalar):
# Maybe we should split this into separate tests to speed it up by
# only using scalar values relevant to particular ops
scalars = [1.5, 3, 0, -2.0, -1]
for dtype, op, device, scalar in product(dtypes, binary_ops, devices, scalars):
for dtype, op, device, scalar in product(self.dtypes, binary_ops, devices, scalars):
try:
x = self.data_for(dtype, device)
fn = apply_with_scalar(op, scalar)
Expand All @@ -1413,17 +1375,6 @@ def test_binary_div_ops(self):
def apply_with_scalar(fn, scalar):
return lambda x: fn(x, scalar)

dtypes = [
torch.int8,
torch.uint8,
torch.int16,
torch.int32,
torch.int64,
torch.float16,
torch.float32,
torch.float64,
torch.bool
]
binary_ops = [
torch.div,
torch.remainder,
Expand All @@ -1433,7 +1384,7 @@ def apply_with_scalar(fn, scalar):
# Maybe we should split this into separate tests to speed it up by
# only using scalar values relevant to particular ops
scalars = [1.5, 3, -2.0, -1] # skip 0
for dtype, op, device, scalar in product(dtypes, binary_ops, devices, scalars):
for dtype, op, device, scalar in product(self.dtypes, binary_ops, devices, scalars):
try:
x = self.data_for(dtype, device)
fn = apply_with_scalar(op, scalar)
Expand All @@ -1457,7 +1408,6 @@ def apply_with_scalar(fn, scalar):

dtypes = [
torch.int8,
torch.uint8,
torch.int16,
torch.int32,
torch.int64,
Expand Down Expand Up @@ -1498,23 +1448,12 @@ def test_ternary_ops(self):
def apply(fn):
return lambda x, y, z: fn(x, y, z)

dtypes = [
torch.int8,
torch.uint8,
torch.int16,
torch.int32,
torch.int64,
torch.float16,
torch.float32,
torch.float64,
torch.bool,
]
ternary_ops = [
torch.lerp,
torch.addcmul,
]
devices = self.devices
for dtype, op, device in product(dtypes, ternary_ops, devices):
for dtype, op, device in product(self.dtypes, ternary_ops, devices):
try:
x = self.data_for(dtype, device)
y = self.data_for(dtype, device)
Expand All @@ -1540,22 +1479,11 @@ def test_list_ops(self):
def apply(fn):
return lambda x, y, z: fn([x * x, y * y, z * z])

dtypes = [
torch.int8,
torch.uint8,
torch.int16,
torch.int32,
torch.int64,
torch.float16,
torch.float32,
torch.float64,
torch.bool,
]
devices = self.devices
list_ops = [
torch.cat,
]
for dtype, op, device in product(dtypes, list_ops, devices):
for dtype, op, device in product(self.dtypes, list_ops, devices):
try:
x = self.data_for(dtype, device, size=[5, 4, 1, 7])
y = self.data_for(dtype, device, size=[5, 4, 1, 7])
Expand All @@ -1580,24 +1508,13 @@ def test_where_ops(self):
def apply(fn):
return lambda cond, x, y: fn(cond, x, y)

dtypes = [
torch.int8,
torch.uint8,
torch.int16,
torch.int32,
torch.int64,
torch.float16,
torch.float32,
torch.float64,
torch.bool,
]
ops = [
torch.where,
lambda cond, x, y: torch.where(cond, x, 3.1415),
lambda cond, x, y: torch.where(cond, 42, y),
]
devices = self.devices
for dtype, op, device in product(dtypes, ops, devices):
for dtype, op, device in product(self.dtypes, ops, devices):
try:
cond = self.data_for(torch.bool, device)
x = self.data_for(dtype, device)
Expand All @@ -1624,6 +1541,7 @@ def fn(x):
return x * x + x

unsupported_dtypes = [
torch.uint8,
torch.bfloat16,
torch.complex32,
torch.complex64,
Expand Down
4 changes: 2 additions & 2 deletions test/test_tensorexpr.py
Original file line number Diff line number Diff line change
Expand Up @@ -420,11 +420,11 @@ def easy(x, y):
traced = torch.jit.trace(
easy,
(torch.randint(TENSOR_LEN, (TENSOR_LEN,), dtype=torch.int8),
torch.randint(TENSOR_LEN, (TENSOR_LEN,), dtype=torch.uint8)),
torch.randint(TENSOR_LEN, (TENSOR_LEN,), dtype=torch.int8)),
)

a = torch.randint(TENSOR_LEN, (TENSOR_LEN,), dtype=torch.int8)
b = torch.randint(TENSOR_LEN, (TENSOR_LEN,), dtype=torch.uint8)
b = torch.randint(TENSOR_LEN, (TENSOR_LEN,), dtype=torch.int8)
x = warmup_and_run_forward(traced, a, b)
self.assertLastGraphAllFused()
np.testing.assert_allclose((a.numpy() + b.numpy()) * b.numpy(), x.numpy())
Expand Down
32 changes: 23 additions & 9 deletions torch/csrc/jit/passes/tensorexpr_fuser.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -739,15 +739,29 @@ class TensorExprFuser {
};
// clang-format on

// Value is only supported if operands are floats.
if (node->isMemberOf(float_only_operator_set)) {
for (const Value* v : node->inputs()) {
if (auto const& tt = v->type()->cast<TensorType>()) {
auto const& st = tt->scalarType();
if (!st || !isFloatingType(*st)) {
return false;
}
} else if (!v->type()->cast<FloatType>()) {
for (const Value* v : node->inputs()) {
if (auto const& tt = v->type()->cast<TensorType>()) {
auto const& st = tt->scalarType();

// All tensors must be typed.
if (!st) {
return false;
}

// Byte tensors introduce too many corner cases in type promotion.
// Better not to try to handle them.
if (*st == c10::ScalarType::Byte) {
return false;
}

// These operators only support floats, because integer divisors need to
// raise ZeroDivisionError.
if (node->isMemberOf(float_only_operator_set) && !isFloatingType(*st)) {
return false;
}
} else if (node->isMemberOf(float_only_operator_set)) {
// Check scalar operands of float-only ops.
if (!v->type()->cast<FloatType>()) {
return false;
}
}
Expand Down