Skip to content

Commit

Permalink
[te] Ban uint8 tensors from fusion groups
Browse files Browse the repository at this point in the history
Pull Request resolved: #49247

uint8's expose all kind of corner cases in type promotion.  As an example, consider:
```
>>> torch.tensor([1], dtype=torch.uint8).lt(-1)
tensor([True])
>>> torch.tensor([1], dtype=torch.uint8).lt(torch.tensor(-1))
tensor([True])
>>> torch.tensor([1], dtype=torch.uint8).lt(torch.tensor([-1]))
tensor([False])
```
the difference is how promotions involving scalars (or 0-dim tensors, which are treated like scalars) are prioritized compared to tensor dtypes.
Per @eellison, the order is something like:
1. Tensor FP types
2. Scalar FP types
3. Tensor Int types
4. Scalar Int types

The logic for this is here: https://github.com/pytorch/pytorch/blob/c73e97033a3aef97a5685588ea014d54a5cc11cc/aten/src/ATen/native/TypeProperties.cpp#L93

AFAICT the effects are mainly visible for the unsigned byte type (the only unsigned type, besides bool) since the others degrade more or less gracefully.

It's hard to re-use this logic as is in TensorIterator/TypeProperties, and it's complicated enough that it's not worth re-implementing in TE unless there's evidence that it matters for real models.
ghstack-source-id: 118430556

Differential Revision: [D25489035](https://our.internmc.facebook.com/intern/diff/D25489035/)
  • Loading branch information
bertmaher committed Dec 11, 2020
1 parent 6413a96 commit b53d79d
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 118 deletions.
137 changes: 28 additions & 109 deletions test/test_jit_fuser_te.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,19 @@ def setUp(self):
torch._C._jit_set_texpr_fuser_enabled(True)

self.devices = ['cpu'] if not torch.cuda.is_available() else ['cpu', 'cuda']
self.int_dtypes = [
torch.int8,
torch.int16,
torch.int32,
torch.int64,
torch.bool,
]
self.fp_dtypes = [
torch.float16,
torch.float32,
torch.float64,
]
self.dtypes = self.int_dtypes + self.fp_dtypes

def tearDown(self):
torch._C._jit_set_profiling_executor(self.old_profiling_executor)
Expand Down Expand Up @@ -461,21 +474,13 @@ def test_bitwise_ops(self):
def apply(fn):
return lambda x, y, z: fn(fn(x, y), z)

dtypes = [
torch.int8,
torch.uint8,
torch.int16,
torch.int32,
torch.int64,
torch.bool,
]
binary_ops = [
operator.__and__,
operator.__or__,
operator.__xor__
]
devices = self.devices
for dtype, op, device in product(dtypes, binary_ops, devices):
for dtype, op, device in product(self.int_dtypes, binary_ops, devices):
try:
x = self.data_for(dtype, device)
y = self.data_for(dtype, device)
Expand All @@ -500,20 +505,12 @@ def test_minmax_int_ops(self):
def apply(fn):
return lambda x, y, z: fn(fn(x, y), z)

dtypes = [
torch.int8,
torch.uint8,
torch.int16,
torch.int32,
torch.int64,
torch.bool,
]
binary_ops = [
torch.min,
torch.max
]
devices = self.devices
for dtype, op, device in product(dtypes, binary_ops, devices):
for dtype, op, device in product(self.int_dtypes, binary_ops, devices):
try:
x = self.data_for(dtype, device)
y = self.data_for(dtype, device)
Expand Down Expand Up @@ -1215,17 +1212,6 @@ def test_unary_ops(self):
def apply(fn):
return lambda x: fn(x)

dtypes = [
torch.int8,
torch.uint8,
torch.int16,
torch.int32,
torch.int64,
torch.float16,
torch.float32,
torch.float64,
torch.bool,
]
unary_ops = [
torch.lgamma,
torch.sigmoid,
Expand Down Expand Up @@ -1262,7 +1248,7 @@ def apply(fn):
lambda x: torch.clamp(x, -10, 10),
]
sizes = [(1,), (2,), (4, 4)]
for dtype, op, device, size in product(dtypes, unary_ops, self.devices, sizes):
for dtype, op, device, size in product(self.dtypes, unary_ops, self.devices, sizes):
try:
x = self.data_for(dtype, device, size=size)
fn = apply(op)
Expand All @@ -1286,18 +1272,7 @@ def test_binary_ops(self):
def apply(fn):
return lambda x, y: fn(x, y)

dtypes = [
# FIXME: Fails in IR Eval: torch.int8 and_ cpu
torch.int8,
torch.uint8,
torch.int16,
torch.int32,
torch.int64,
torch.float16,
torch.float32,
torch.float64,
torch.bool,
]
# FIXME: Fails in IR Eval: torch.int8 and_ cpu
binary_ops = [
operator.__and__,
operator.__or__,
Expand Down Expand Up @@ -1329,7 +1304,7 @@ def apply(fn):
torch.remainder,
]
devices = self.devices
for dtype, op, device in product(dtypes, binary_ops, devices):
for dtype, op, device in product(self.dtypes, binary_ops, devices):
try:
x = self.data_for(dtype, device)
y = self.data_for(dtype, device)
Expand All @@ -1355,18 +1330,7 @@ def test_binary_tensor_scalar_ops(self):
def apply_with_scalar(fn, scalar):
return lambda x: fn(x, scalar)

dtypes = [
torch.int8,
torch.uint8,
torch.int16,
torch.int32,
# FIXME: Fails in IR Eval: torch.int64 and_ cpu
torch.int64,
torch.float16,
torch.float32,
torch.float64,
torch.bool
]
# FIXME: Fails in IR Eval: torch.int64 and_ cpu
binary_ops = [
operator.__and__,
operator.__or__,
Expand All @@ -1376,11 +1340,9 @@ def apply_with_scalar(fn, scalar):
torch.mul,
torch.eq,
torch.ne,

# FIXME: fails with dtype=uint8, scalar=-1
# torch.ge,
# torch.lt,
# torch.gt,
torch.ge,
torch.lt,
torch.gt,

# FIXME: segfaults on CPU backend
# operator.__rshift__,
Expand All @@ -1390,7 +1352,7 @@ def apply_with_scalar(fn, scalar):
# Maybe we should split this into separate tests to speed it up by
# only using scalar values relevant to particular ops
scalars = [1.5, 3, 0, -2.0, -1]
for dtype, op, device, scalar in product(dtypes, binary_ops, devices, scalars):
for dtype, op, device, scalar in product(self.dtypes, binary_ops, devices, scalars):
try:
x = self.data_for(dtype, device)
fn = apply_with_scalar(op, scalar)
Expand All @@ -1413,17 +1375,6 @@ def test_binary_div_ops(self):
def apply_with_scalar(fn, scalar):
return lambda x: fn(x, scalar)

dtypes = [
torch.int8,
torch.uint8,
torch.int16,
torch.int32,
torch.int64,
torch.float16,
torch.float32,
torch.float64,
torch.bool
]
binary_ops = [
torch.div,
torch.remainder,
Expand All @@ -1433,7 +1384,7 @@ def apply_with_scalar(fn, scalar):
# Maybe we should split this into separate tests to speed it up by
# only using scalar values relevant to particular ops
scalars = [1.5, 3, -2.0, -1] # skip 0
for dtype, op, device, scalar in product(dtypes, binary_ops, devices, scalars):
for dtype, op, device, scalar in product(self.dtypes, binary_ops, devices, scalars):
try:
x = self.data_for(dtype, device)
fn = apply_with_scalar(op, scalar)
Expand Down Expand Up @@ -1498,23 +1449,12 @@ def test_ternary_ops(self):
def apply(fn):
return lambda x, y, z: fn(x, y, z)

dtypes = [
torch.int8,
torch.uint8,
torch.int16,
torch.int32,
torch.int64,
torch.float16,
torch.float32,
torch.float64,
torch.bool,
]
ternary_ops = [
torch.lerp,
torch.addcmul,
]
devices = self.devices
for dtype, op, device in product(dtypes, ternary_ops, devices):
for dtype, op, device in product(self.dtypes, ternary_ops, devices):
try:
x = self.data_for(dtype, device)
y = self.data_for(dtype, device)
Expand All @@ -1540,22 +1480,11 @@ def test_list_ops(self):
def apply(fn):
return lambda x, y, z: fn([x * x, y * y, z * z])

dtypes = [
torch.int8,
torch.uint8,
torch.int16,
torch.int32,
torch.int64,
torch.float16,
torch.float32,
torch.float64,
torch.bool,
]
devices = self.devices
list_ops = [
torch.cat,
]
for dtype, op, device in product(dtypes, list_ops, devices):
for dtype, op, device in product(self.dtypes, list_ops, devices):
try:
x = self.data_for(dtype, device, size=[5, 4, 1, 7])
y = self.data_for(dtype, device, size=[5, 4, 1, 7])
Expand All @@ -1580,24 +1509,13 @@ def test_where_ops(self):
def apply(fn):
return lambda cond, x, y: fn(cond, x, y)

dtypes = [
torch.int8,
torch.uint8,
torch.int16,
torch.int32,
torch.int64,
torch.float16,
torch.float32,
torch.float64,
torch.bool,
]
ops = [
torch.where,
lambda cond, x, y: torch.where(cond, x, 3.1415),
lambda cond, x, y: torch.where(cond, 42, y),
]
devices = self.devices
for dtype, op, device in product(dtypes, ops, devices):
for dtype, op, device in product(self.dtypes, ops, devices):
try:
cond = self.data_for(torch.bool, device)
x = self.data_for(dtype, device)
Expand All @@ -1624,6 +1542,7 @@ def fn(x):
return x * x + x

unsupported_dtypes = [
torch.uint8,
torch.bfloat16,
torch.complex32,
torch.complex64,
Expand Down
27 changes: 18 additions & 9 deletions torch/csrc/jit/passes/tensorexpr_fuser.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -739,15 +739,24 @@ class TensorExprFuser {
};
// clang-format on

// Value is only supported if operands are floats.
if (node->isMemberOf(float_only_operator_set)) {
for (const Value* v : node->inputs()) {
if (auto const& tt = v->type()->cast<TensorType>()) {
auto const& st = tt->scalarType();
if (!st || !isFloatingType(*st)) {
return false;
}
} else if (!v->type()->cast<FloatType>()) {
for (const Value* v : node->inputs()) {
if (auto const& tt = v->type()->cast<TensorType>()) {
auto const& st = tt->scalarType();

// All tensors must be typed.
if (!st) {
return false;
}

// Byte tensors introduce too many corner cases in type promotion.
// Better not to try to handle them.
if (*st == c10::ScalarType::Byte) {
return false;
}

// These operators only support floats, because integer divisors need to
// raise ZeroDivisionError.
if (node->isMemberOf(float_only_operator_set) && !isFloatingType(*st)) {
return false;
}
}
Expand Down

0 comments on commit b53d79d

Please sign in to comment.