Skip to content

Commit

Permalink
[NNC] Fix lowering of aten::pow (#47795)
Browse files Browse the repository at this point in the history
Summary:
NNC lowering of aten::pow assumes that the types of the exponent is either float or int cast to to float, which doesn't work great with double (or half for that matter).

Fixes #47304

Pull Request resolved: #47795

Reviewed By: ZolotukhinM

Differential Revision: D24904201

Pulled By: nickgg

fbshipit-source-id: 43c3ea704399ebb36c33cd222db16c60e5b7ada5
  • Loading branch information
nickgg authored and facebook-github-bot committed Nov 12, 2020
1 parent 149190c commit b1a4170
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 46 deletions.
15 changes: 15 additions & 0 deletions test/test_tensorexpr.py
Original file line number Diff line number Diff line change
Expand Up @@ -1307,6 +1307,21 @@ def bias_gelu(bias, y):
x = warmup_and_run_forward(traced, a, b)
self.assertLastGraphAllFused()

def test_exp_pow(self):
devices = ["cuda", "cpu"] if torch.cuda.is_available() else ["cpu"]

@torch.jit.script
def do_exp(x, y, z):
return ((x * y) * 2) * torch.pow(z, 2)

for device in devices:
x = torch.rand(10, dtype=torch.double, device=device)
y = torch.rand(10, dtype=torch.double, device=device)
z = torch.rand(10, dtype=torch.double, device=device)
traced = torch.jit.trace(do_exp, (x, y, z))
x = warmup_and_run_forward(traced, x, y, z)
self.assertLastGraphAllFused()

def test_transpose(self):
@torch.jit.script
def test(x, y, z):
Expand Down
68 changes: 22 additions & 46 deletions torch/csrc/jit/tensorexpr/kernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1019,54 +1019,30 @@ Tensor* TensorExprKernel::computeValue(const torch::jit::Value* v) {
case aten::pow: {
return computeTwoOperand(
"aten_pow", v, [](const ExprHandle& lhs, const ExprHandle& rhs) {
const FloatImm* floatImm = rhs.AsNode<FloatImm>();
if (floatImm) {
float imm = floatImm->value();
if (imm == 1.0f) {
return lhs;
} else if (imm == 2.0f) { // NOLINT
return lhs * lhs;
} else if (imm == 3.0f) { // NOLINT
return (lhs * lhs) * lhs;
} else if (imm == 4.0f) { // NOLINT
ExprHandle tmp = lhs * lhs;
return tmp * tmp;
} else if (imm == 0.5f) { // NOLINT
return sqrt(lhs);
} else if (imm == 0.0f) {
return ExprHandle(1.0f);
} else if (imm == -0.5f) { // NOLINT
return rsqrt(lhs);
} else if (imm == -1.0f) {
return ExprHandle(1.0f) / lhs;
} else if (imm == -2.0f) { // NOLINT
return ExprHandle(1.0f) / (lhs * lhs);
}
double val = 0;
if (rhs.node()->isConstant()) {
val = immediateAs<double>(IRSimplifier::simplify(rhs.node()));
}

const Cast* floatCast = rhs.AsNode<Cast>();
if (floatCast) {
const IntImm* intImm =
dynamic_cast<const IntImm*>(floatCast->src_value());
if (intImm) {
float imm = static_cast<float>(intImm->value());
if (imm == 1) {
return lhs;
} else if (imm == 2) {
return lhs * lhs;
} else if (imm == 3) {
return (lhs * lhs) * lhs;
} else if (imm == 4) {
ExprHandle tmp = lhs * lhs;
return tmp * tmp;
} else if (imm == 0) {
return ExprHandle(1.0f);
} else if (imm == -1) {
return ExprHandle(1.0f) / lhs;
} else if (imm == -2) {
return ExprHandle(1.0f) / (lhs * lhs);
}
}
if (val == 1.0f) {
return lhs;
} else if (val == 2.0f) { // NOLINT
return lhs * lhs;
} else if (val == 3.0f) { // NOLINT
return (lhs * lhs) * lhs;
} else if (val == 4.0f) { // NOLINT
ExprHandle tmp = lhs * lhs;
return tmp * tmp;
} else if (val == 0.5f) { // NOLINT
return sqrt(lhs);
} else if (val == 0.0f) {
return ExprHandle(1.0f);
} else if (val == -0.5f) { // NOLINT
return rsqrt(lhs);
} else if (val == -1.0f) {
return ExprHandle(1.0f) / lhs;
} else if (val == -2.0f) { // NOLINT
return ExprHandle(1.0f) / (lhs * lhs);
}
return pow(lhs, rhs);
});
Expand Down

0 comments on commit b1a4170

Please sign in to comment.