Skip to content

Commit

Permalink
[TensorExpr] Fix lowering for aten::div. (#48329)
Browse files Browse the repository at this point in the history
Summary: Pull Request resolved: #48329

Test Plan: Imported from OSS

Reviewed By: eellison

Differential Revision: D25130750

Pulled By: ZolotukhinM

fbshipit-source-id: 7c6345adcaec5f92cd6ce78b01f6a7d5923c0004
  • Loading branch information
Mikhail Zolotukhin authored and facebook-github-bot committed Nov 21, 2020
1 parent 5e1faa1 commit b967119
Show file tree
Hide file tree
Showing 2 changed files with 49 additions and 8 deletions.
55 changes: 48 additions & 7 deletions test/test_jit_fuser_te.py
Original file line number Diff line number Diff line change
Expand Up @@ -1308,6 +1308,7 @@ def apply(fn):
torch.max,
lambda x, y: torch.lerp(x, y, 0.5),
torch.atan2,
torch.div,

# FIXME: comparison ops yield different results when fused
# torch.eq,
Expand All @@ -1316,12 +1317,13 @@ def apply(fn):
# torch.gt,
# torch.lt,

# TODO: test operators exercising division too
# FIXME: fails on CPU backend with int8
# torch.fmod,
# torch.remainder,

# FIXME: segfaults on CPU backend
# operator.__rshift__,
# operator.__lshift__,
# torch.div,
]
devices = self.devices
for dtype, op, device in product(dtypes, binary_ops, devices):
Expand Down Expand Up @@ -1358,7 +1360,7 @@ def apply_with_scalar(fn, scalar):
torch.float16,
torch.float32,
torch.float64,
# torch.bool intentionally not included
torch.bool
]
binary_ops = [
operator.__and__,
Expand All @@ -1375,12 +1377,51 @@ def apply_with_scalar(fn, scalar):
# torch.lt,
# torch.gt,

# FIXME: fails with integer dtype and scalar={3,0}
# torch.div,

# FIXME: segfaults on CPU backend
# operator.__rshift__,
# operator.__lshift__,
]
devices = self.devices
# Maybe we should split this into separate tests to speed it up by
# only using scalar values relevant to particular ops
scalars = [1.5, 3, 0, -2.0, -1]
for dtype, op, device, scalar in product(dtypes, binary_ops, devices, scalars):
try:
x = self.data_for(dtype, device)
fn = apply_with_scalar(op, scalar)
ref = fn(x)
except Exception:
# If eager mode doesn't support a dtype/op/device combo,
# neither does the fuser. Catch everything to avoid needing to
# guess what errors might be thrown by eager.
continue
try:
t = torch.jit.trace(fn, (x))
self.assertEqual(ref, t(x))
self.assertAllFused(t.graph_for(x))
except Exception as e:
raise RuntimeError(
" ".join(["Failed:", str(dtype), op.__name__, device])
)

def test_binary_div_ops(self):
def apply_with_scalar(fn, scalar):
return lambda x: fn(x, scalar)

dtypes = [
torch.int8,
torch.uint8,
torch.int16,
torch.int32,
torch.int64,
# FIXME: breaks in IR eval
# torch.float16,
torch.float32,
torch.float64,
torch.bool
]
binary_ops = [
torch.div,

# FIXME: wrong results with int8 on cpu
# torch.remainder,
Expand All @@ -1389,7 +1430,7 @@ def apply_with_scalar(fn, scalar):
devices = self.devices
# Maybe we should split this into separate tests to speed it up by
# only using scalar values relevant to particular ops
scalars = [1.5, 3, 0, -2.0, -1]
scalars = [1.5, 3, -2.0, -1] # skip 0
for dtype, op, device, scalar in product(dtypes, binary_ops, devices, scalars):
try:
x = self.data_for(dtype, device)
Expand Down
2 changes: 1 addition & 1 deletion torch/csrc/jit/tensorexpr/kernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -780,7 +780,7 @@ Tensor* TensorExprKernel::computeValue(const torch::jit::Value* v) {
case aten::div: {
return computeTwoOperand(
"aten_div", v, [](const ExprHandle& lhs, const ExprHandle& rhs) {
return boolToInteger(lhs) / boolToInteger(rhs);
return promoteIntegerToFloat(lhs) / promoteIntegerToFloat(rhs);
});
} break;

Expand Down

0 comments on commit b967119

Please sign in to comment.