New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[te] Fix clamp with uint8 args #49143
Conversation
Riddle me this, batman: how could `torch.clamp(torch.tensor([0], dtype=torch.uint8), -10, 10)` equal `10`? The answer: the min/max args are first cast to the dtype of the input, giving min=246 and max 10. Then you have to apply Min and Max in the right order: `Min(Max(in, min), max)`. Differ in any way and you're doomed. Hooray. This PR makes TE match eager mode for this operator, plus fixes a major facepalm in the llvm min/max codegen where we were always generating signed comparisons. Differential Revision: [D25456366](https://our.internmc.facebook.com/intern/diff/D25456366/) [ghstack-poisoned]
💊 CI failures summary and remediationsAs of commit 6e611f2 (more details on the Dr. CI page): ✅ None of the CI failures appear to be your fault 💚
🚧 3 ongoing upstream failures:These were probably caused by upstream breakages that are not fixed yet:
🚧 2 fixed upstream failures:These were probably caused by upstream breakages that were already fixed.
Please rebase on the
|
Riddle me this, batman: how could `torch.clamp(torch.tensor([0], dtype=torch.uint8), -10, 10)` equal `10`? The answer: the min/max args are first cast to the dtype of the input, giving min=246 and max 10. Then you have to apply Min and Max in the right order: `Min(Max(in, min), max)`. Differ in any way and you're doomed. Hooray. This PR makes TE match eager mode for this operator, plus fixes a major facepalm in the llvm min/max codegen where we were always generating signed comparisons. Differential Revision: [D25456366](https://our.internmc.facebook.com/intern/diff/D25456366/) [ghstack-poisoned]
Pull Request resolved: #49143 Riddle me this, batman: how could `torch.clamp(torch.tensor([0], dtype=torch.uint8), -10, 10)` equal `10`? The answer: the min/max args are first cast to the dtype of the input, giving min=246 and max 10. Then you have to apply Min and Max in the right order: `Min(Max(in, min), max)`. Differ in any way and you're doomed. Hooray. This PR makes TE match eager mode for this operator, plus fixes a major facepalm in the llvm min/max codegen where we were always generating signed comparisons. ghstack-source-id: 118276737 Differential Revision: [D25456366](https://our.internmc.facebook.com/intern/diff/D25456366/)
@@ -101,7 +101,8 @@ class TORCH_API TensorExprKernel { | |||
const torch::jit::Value* v, | |||
const std::function< | |||
ExprHandle(const ExprHandle&, const ExprHandle&, const ExprHandle&)>& | |||
innerExpr); | |||
innerExpr, | |||
bool promote_inputs = true); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Minor: it doesn't feel that "promote_inputs" is true/false is two flavors of this function. My first impression is that "promote_inputs = true" is the public behavior, and "promote_inputs = false" is an internal behavior. I would prefer to remove this argument from the public function, and refactor the "= false" flavor to a "_interna"/"_impl" function.
@@ -9,6 +9,10 @@ namespace torch { | |||
namespace jit { | |||
namespace tensorexpr { | |||
|
|||
static bool is_c10_type(const ScalarType& type) { | |||
return type < ScalarType::Undefined; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Minor: this is fine. But it feels a bit dangerous to rely on enum value ordering. I've seen in other systems when a list of enums is imported, it also uses the same macros to define a list of type traits on those enums. Otherwise, the enum value might keep changing, which makes serialization compatibility a bit complicated. It is up to you.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Right, so the main reason I want to use ordering here is that we define the first N values of te::ScalarType in lockstep with c10::ScalarType (and we depend on that all over the place). Maybe something less brittle would be better. But I think I'd rather put the effort towards making te::ScalarType just go away in favor of c10::ScalarType, if possible. :)
@@ -38,6 +42,13 @@ bool is_floating_point(const ScalarType& type) { | |||
return false; | |||
} | |||
|
|||
bool is_signed(const ScalarType& type) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The generality of is_signed
confused me a bit - given that it's only used on paths handling integrals anyway, I don't know if it's warranted. If we made it handle integrals only, it'd side-step the is_c10_type
trick as well.
@@ -720,7 +720,8 @@ void LLVMCodeGenImpl::visit(const Max* v) { | |||
auto rhs = this->value_; | |||
|
|||
if (v->dtype().is_integral()) { | |||
auto icmp = irb_.CreateICmpSGT(lhs, rhs); | |||
auto icmp = v->dtype().is_signed() ? irb_.CreateICmpSGT(lhs, rhs) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You could use CreateICmp
and llvm_comparison_predicate
here and avoid the ternary - I had to do a similar fix for comparisons.
@@ -688,7 +689,9 @@ Tensor* TensorExprKernel::computeThreeOperand( | |||
tensorOrConstant(n->inputs()[2], indices), | |||
}; | |||
|
|||
promoteInputs(inputs); | |||
if (promote_inputs) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should we still be demoting output with this flag?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, if I understand correctly the output type should still be unchanged.
Related: #49178 |
Riddle me this, batman: how could `torch.clamp(torch.tensor([0], dtype=torch.uint8), -10, 10)` equal `10`? The answer: the min/max args are first cast to the dtype of the input, giving min=246 and max 10. Then you have to apply Min and Max in the right order: `Min(Max(in, min), max)`. Differ in any way and you're doomed. Hooray. This PR makes TE match eager mode for this operator, plus fixes a major facepalm in the llvm min/max codegen where we were always generating signed comparisons. Differential Revision: [D25456366](https://our.internmc.facebook.com/intern/diff/D25456366/) [ghstack-poisoned]
Riddle me this, batman: how could `torch.clamp(torch.tensor([0], dtype=torch.uint8), -10, 10)` equal `10`? The answer: the min/max args are first cast to the dtype of the input, giving min=246 and max 10. Then you have to apply Min and Max in the right order: `Min(Max(in, min), max)`. Differ in any way and you're doomed. Hooray. This PR makes TE match eager mode for this operator, plus fixes a major facepalm in the llvm min/max codegen where we were always generating signed comparisons. Differential Revision: [D25456366](https://our.internmc.facebook.com/intern/diff/D25456366/) [ghstack-poisoned]
This pull request has been merged in ae88d25. |
Stack from ghstack:
Riddle me this, batman: how could
torch.clamp(torch.tensor([0], dtype=torch.uint8), -10, 10)
equal10
? The answer: the min/max args are first cast to the dtype of the input, giving min=246 and max 10. Then you have to apply Min and Max in the right order:Min(Max(in, min), max)
. Differ in any way and you're doomed. Hooray.This PR makes TE match eager mode for this operator, plus fixes a major facepalm in the llvm min/max codegen where we were always generating signed comparisons.
Differential Revision: D25456366