-
Notifications
You must be signed in to change notification settings - Fork 21.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[te] Fix bugs with shift operators #49271
Changes from all commits
d3d36ec
a7422dc
95c5d28
6230dbb
60f9edd
a7bd0b6
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -179,69 +179,51 @@ class Mod : public BinaryOpNode<Mod> { | |
: BinaryOpNode(lhs, rhs, IRNodeType::kMod) {} | ||
}; | ||
|
||
class And : public BinaryOpNode<And> { | ||
template <typename Op> | ||
class BitwiseOpNode : public BinaryOpNode<Op> { | ||
public: | ||
And(const Expr* lhs, const Expr* rhs) | ||
: BinaryOpNode(lhs, rhs, IRNodeType::kAnd) { | ||
if (!lhs->dtype().is_integral()) { | ||
BitwiseOpNode(const Expr* lhs, const Expr* rhs, IRNodeType type) | ||
: BinaryOpNode<Op>(lhs, rhs, type) {} | ||
|
||
static ExprHandle make(const ExprHandle& lhs, const ExprHandle& rhs) { | ||
if (!lhs.dtype().is_integral()) { | ||
throw unsupported_dtype(); | ||
} | ||
if (lhs->dtype() != rhs->dtype()) { | ||
throw malformed_input("bad dtype in And"); | ||
if (lhs.dtype() != rhs.dtype()) { | ||
throw malformed_input("lhs/rhs dtype mismatch"); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. to double check, we can't do `(uint64_t)42 << (uint8_t)5 ? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah the TE dsl requires operand types to match, and when coming from pytorch we enforce that at the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. thanks! |
||
} | ||
return BinaryOpNode<Op>::make(lhs, rhs); | ||
} | ||
}; | ||
|
||
class Or : public BinaryOpNode<Or> { | ||
class And : public BitwiseOpNode<And> { | ||
public: | ||
And(const Expr* lhs, const Expr* rhs) | ||
: BitwiseOpNode(lhs, rhs, IRNodeType::kAnd) {} | ||
}; | ||
|
||
class Or : public BitwiseOpNode<Or> { | ||
public: | ||
Or(const Expr* lhs, const Expr* rhs) | ||
: BinaryOpNode(lhs, rhs, IRNodeType::kOr) { | ||
if (!lhs->dtype().is_integral()) { | ||
throw unsupported_dtype(); | ||
} | ||
if (lhs->dtype() != rhs->dtype()) { | ||
throw malformed_input("bad dtype in Or"); | ||
} | ||
} | ||
: BitwiseOpNode(lhs, rhs, IRNodeType::kOr) {} | ||
}; | ||
|
||
class Xor : public BinaryOpNode<Xor> { | ||
class Xor : public BitwiseOpNode<Xor> { | ||
public: | ||
Xor(const Expr* lhs, const Expr* rhs) | ||
: BinaryOpNode(lhs, rhs, IRNodeType::kXor) { | ||
if (!lhs->dtype().is_integral()) { | ||
throw unsupported_dtype(); | ||
} | ||
if (lhs->dtype() != rhs->dtype()) { | ||
throw malformed_input("bad dtype in Xor"); | ||
} | ||
} | ||
: BitwiseOpNode(lhs, rhs, IRNodeType::kXor) {} | ||
}; | ||
|
||
class Lshift : public BinaryOpNode<Lshift> { | ||
class Lshift : public BitwiseOpNode<Lshift> { | ||
public: | ||
Lshift(const Expr* lhs, const Expr* rhs) | ||
: BinaryOpNode(lhs, rhs, IRNodeType::kLshift) { | ||
if (lhs->dtype().scalar_type() != ScalarType::Int) { | ||
throw unsupported_dtype(); | ||
} | ||
if (lhs->dtype() != rhs->dtype()) { | ||
throw malformed_input("bad dtype in Lshift"); | ||
} | ||
} | ||
: BitwiseOpNode(lhs, rhs, IRNodeType::kLshift) {} | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Where can I find the definition of whatever operation that maps to There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. native_functions.yaml is a good source - it lists all (or majority?) of the ops that we have in pytorch. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks @ZolotukhinM
In the constructor of There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The In building the TE from a jit::Graph ( There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Oh, also, I don't think we handle non-constant scalars in the TE codegen. I don't remember why but it wouldn't surprise me if there are a ton of corner cases since scalars are coming from the python type system. |
||
}; | ||
|
||
class Rshift : public BinaryOpNode<Rshift> { | ||
class Rshift : public BitwiseOpNode<Rshift> { | ||
public: | ||
Rshift(const Expr* lhs, const Expr* rhs) | ||
: BinaryOpNode(lhs, rhs, IRNodeType::kRshift) { | ||
if (lhs->dtype().scalar_type() != ScalarType::Int) { | ||
throw unsupported_dtype(); | ||
} | ||
if (lhs->dtype() != rhs->dtype()) { | ||
throw malformed_input("bad dtype in Rshift"); | ||
} | ||
} | ||
: BitwiseOpNode(lhs, rhs, IRNodeType::kRshift) {} | ||
}; | ||
|
||
class Max : public BinaryOpNode<Max> { | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
out of curiosity: these throws used to be in a c-tor?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, throwing in an
Expr
ctor is bad news because of the wayKernelScopedObject
andKernelArena
work.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thanks!