Skip to content

Commit

Permalink
[pytorch][tensorexpr] Promote integer arguments to sin/cos/tan to flo…
Browse files Browse the repository at this point in the history
…at (#46776)

Summary:
Pull Request resolved: #46776

Following numpy and (now) eager mode

Fixes #46458

Test Plan: test_jit_fuser_te

Reviewed By: navahgar

Differential Revision: D24509884

fbshipit-source-id: c063030fc609ba4aefcd9abd25b50f082fef1548
  • Loading branch information
bertmaher authored and facebook-github-bot committed Oct 24, 2020
1 parent 343260a commit c4892c8
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 7 deletions.
2 changes: 1 addition & 1 deletion test/test_jit_fuser_te.py
Original file line number Diff line number Diff line change
Expand Up @@ -1274,7 +1274,7 @@ def apply(fn):
torch.erf,
torch.erfc,
torch.cos,
# torch.sin, Note: Reference https://github.com/pytorch/pytorch/issues/46458
torch.sin,
torch.tan,
torch.acos,
torch.asin,
Expand Down
25 changes: 19 additions & 6 deletions torch/csrc/jit/tensorexpr/kernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -411,6 +411,16 @@ ExprHandle TensorExprKernel::constant(const torch::jit::Value* v) {
return scalars_.at(v->unique());
}

ExprHandle promoteIntegerToFloat(const ExprHandle& e) {
auto scalarType = static_cast<c10::ScalarType>(e.dtype().scalar_type());
if (!c10::isIntegralType(scalarType)) {
return e;
}
auto defaultType = static_cast<tensorexpr::ScalarType>(
c10::typeMetaToScalarType(c10::get_default_dtype()));
return Cast::make(Dtype(defaultType, e.dtype().lanes()), e);
}

void TensorExprKernel::promoteInputs(std::vector<ExprHandle>& inputs) {
if (inputs.empty()) {
return;
Expand Down Expand Up @@ -965,18 +975,21 @@ Tensor* TensorExprKernel::computeValue(const torch::jit::Value* v) {
} break;

case aten::cos: {
return computeOneOperand(
"aten_cos", v, [](const ExprHandle& a) { return cos(a); });
return computeOneOperand("aten_cos", v, [](const ExprHandle& a) {
return cos(promoteIntegerToFloat(a));
});
} break;

case aten::sin: {
return computeOneOperand(
"aten_sin", v, [](const ExprHandle& a) { return sin(a); });
return computeOneOperand("aten_sin", v, [](const ExprHandle& a) {
return sin(promoteIntegerToFloat(a));
});
} break;

case aten::tan: {
return computeOneOperand(
"aten_tan", v, [](const ExprHandle& a) { return tan(a); });
return computeOneOperand("aten_tan", v, [](const ExprHandle& a) {
return tan(promoteIntegerToFloat(a));
});
} break;

case aten::type_as: {
Expand Down

0 comments on commit c4892c8

Please sign in to comment.