Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[te] Ban uint8 tensors from fusion groups #49247

Closed
wants to merge 7 commits into from
28 changes: 19 additions & 9 deletions test/test_jit_fuser_te.py
Original file line number Diff line number Diff line change
Expand Up @@ -490,7 +490,8 @@ def apply(fn):
try:
t = torch.jit.trace(fn, (x, y, z))
self.assertEqual(ref, t(x, y, z))
self.assertAllFused(t.graph_for(x, y, z))
if dtype != torch.uint8:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we move the dtypes to self.dtypes and exclude torch.uint8 ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh yeah that's a better idea than this. I had kind of wanted to test a few uint cases just to make sure nothing was broken but no reason to exhaustively go through them here.

self.assertAllFused(t.graph_for(x, y, z))
except Exception as e:
raise RuntimeError(
" ".join(["Failed:", str(dtype), op.__name__, device])
Expand Down Expand Up @@ -528,7 +529,8 @@ def apply(fn):
try:
t = torch.jit.trace(fn, (x, y, z))
self.assertEqual(ref, t(x, y, z))
self.assertAllFused(t.graph_for(x, y, z))
if dtype != torch.uint8:
self.assertAllFused(t.graph_for(x, y, z))
except Exception as e:
raise RuntimeError(
" ".join(["Failed:", str(dtype), op.__name__, device])
Expand Down Expand Up @@ -1275,7 +1277,8 @@ def apply(fn):
try:
t = torch.jit.trace(fn, (x,))
torch.testing.assert_allclose(ref, t(x))
self.assertAllFused(t.graph_for(x))
if dtype != torch.uint8:
self.assertAllFused(t.graph_for(x))
except Exception as e:
raise RuntimeError(
" ".join(["Failed:", str(dtype), op.__name__, device, str(size)])
Expand Down Expand Up @@ -1344,7 +1347,8 @@ def apply(fn):
t = torch.jit.trace(fn, (x, y))
self.assertEqual(ref, t(x, y))
if op not in fp_only or dtype.is_floating_point:
self.assertAllFused(t.graph_for(x, y))
if dtype != torch.uint8:
self.assertAllFused(t.graph_for(x, y))
except Exception as e:
raise RuntimeError(
" ".join(["Failed:", str(dtype), op.__name__, device])
Expand Down Expand Up @@ -1403,7 +1407,8 @@ def apply_with_scalar(fn, scalar):
try:
t = torch.jit.trace(fn, (x))
self.assertEqual(ref, t(x))
self.assertAllFused(t.graph_for(x))
if dtype != torch.uint8:
self.assertAllFused(t.graph_for(x))
except Exception as e:
raise RuntimeError(
" ".join(["Failed:", str(dtype), op.__name__, device])
Expand Down Expand Up @@ -1487,7 +1492,8 @@ def apply_with_scalar(fn, scalar):
try:
t = torch.jit.trace(fn, (x))
self.assertEqual(ref, t(x))
self.assertAllFused(t.graph_for(x))
if dtype != torch.uint8:
self.assertAllFused(t.graph_for(x))
except Exception as e:
raise RuntimeError(
" ".join(["Failed:", str(dtype), op.__name__, device])
Expand Down Expand Up @@ -1529,7 +1535,8 @@ def apply(fn):
try:
t = torch.jit.trace(fn, (x, y, z))
self.assertEqual(ref, t(x, y, z))
self.assertAllFused(t.graph_for(x, y, z))
if dtype != torch.uint8:
self.assertAllFused(t.graph_for(x, y, z))
except Exception as e:
raise RuntimeError(
" ".join(["Failed:", str(dtype), op.__name__, device])
Expand Down Expand Up @@ -1570,7 +1577,8 @@ def apply(fn):
try:
t = torch.jit.trace(fn, (x, y, z))
self.assertEqual(ref, t(x, y, z))
self.assertAllFused(t.graph_for(x, y, z))
if dtype != torch.uint8:
self.assertAllFused(t.graph_for(x, y, z))
except Exception as e:
raise RuntimeError(
" ".join(["Failed:", str(dtype), op.__name__, device])
Expand Down Expand Up @@ -1612,7 +1620,8 @@ def apply(fn):
try:
t = torch.jit.trace(fn, (cond, x, y))
self.assertEqual(ref, t(cond, x, y))
self.assertAllFused(t.graph_for(cond, x, y))
if dtype != torch.uint8:
self.assertAllFused(t.graph_for(cond, x, y))
except Exception as e:
raise RuntimeError(
" ".join(["Failed:", str(dtype), op.__name__, device])
Expand All @@ -1624,6 +1633,7 @@ def fn(x):
return x * x + x

unsupported_dtypes = [
torch.uint8,
torch.bfloat16,
torch.complex32,
torch.complex64,
Expand Down
25 changes: 16 additions & 9 deletions torch/csrc/jit/passes/tensorexpr_fuser.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -739,15 +739,22 @@ class TensorExprFuser {
};
// clang-format on

// Value is only supported if operands are floats.
if (node->isMemberOf(float_only_operator_set)) {
for (const Value* v : node->inputs()) {
if (auto const& tt = v->type()->cast<TensorType>()) {
auto const& st = tt->scalarType();
if (!st || !isFloatingType(*st)) {
return false;
}
} else if (!v->type()->cast<FloatType>()) {
for (const Value* v : node->inputs()) {
if (auto const& tt = v->type()->cast<TensorType>()) {
auto const& st = tt->scalarType();

// All tensors must be typed.
if (!st) {
return false;
}

// Byte tensors introduce too many corner cases in type promotion. Better not to try to handle them.
if (*st == c10::ScalarType::Byte) {
return false;
}

// These operators only support floats, because integer divisors need to raise ZeroDivisionError.
if (node->isMemberOf(float_only_operator_set) && !isFloatingType(*st)) {
return false;
}
}
Expand Down