Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[te] Fix bugs with shift operators #49271

Closed
wants to merge 6 commits into from
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
13 changes: 3 additions & 10 deletions test/test_jit_fuser_te.py
Original file line number Diff line number Diff line change
Expand Up @@ -477,7 +477,9 @@ def apply(fn):
binary_ops = [
operator.__and__,
operator.__or__,
operator.__xor__
operator.__xor__,
operator.__lshift__,
operator.__rshift__,
]
devices = self.devices
for dtype, op, device in product(self.int_dtypes, binary_ops, devices):
Expand Down Expand Up @@ -1292,11 +1294,6 @@ def apply(fn):
torch.lt,
torch.fmod,
torch.remainder,

# FIXME: segfaults on CPU backend
# operator.__rshift__,
# operator.__lshift__,

lambda x, y: y.type_as(x),
]
fp_only = [
Expand Down Expand Up @@ -1343,10 +1340,6 @@ def apply_with_scalar(fn, scalar):
torch.ge,
torch.lt,
torch.gt,

# FIXME: segfaults on CPU backend
# operator.__rshift__,
# operator.__lshift__,
]
devices = self.devices
# Maybe we should split this into separate tests to speed it up by
Expand Down
11 changes: 11 additions & 0 deletions torch/csrc/jit/passes/tensorexpr_fuser.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -737,6 +737,12 @@ class TensorExprFuser {
"aten::remainder.Scalar(Tensor self, Scalar other) -> Tensor",
"aten::remainder.Tensor(Tensor self, Tensor other) -> Tensor",
};
static const OperatorSet int_only_operator_set{
"aten::__lshift__.Scalar(Tensor self, Scalar other) -> Tensor",
"aten::__lshift__.Tensor(Tensor self, Tensor other) -> Tensor",
"aten::__rshift__.Scalar(Tensor self, Scalar other) -> Tensor",
"aten::__rshift__.Tensor(Tensor self, Tensor other) -> Tensor",
};
// clang-format on

for (const Value* v : node->inputs()) {
Expand All @@ -759,6 +765,11 @@ class TensorExprFuser {
if (node->isMemberOf(float_only_operator_set) && !isFloatingType(*st)) {
return false;
}

// These operators have complicated casting rules for floats.
if (node->isMemberOf(int_only_operator_set) && isFloatingType(*st)) {
return false;
}
}
}

Expand Down
66 changes: 24 additions & 42 deletions torch/csrc/jit/tensorexpr/ir.h
Original file line number Diff line number Diff line change
Expand Up @@ -150,69 +150,51 @@ class Mod : public BinaryOpNode<Mod> {
: BinaryOpNode(lhs, rhs, IRNodeType::kMod) {}
};

class And : public BinaryOpNode<And> {
template <typename Op>
class BitwiseOpNode : public BinaryOpNode<Op> {
public:
And(const Expr* lhs, const Expr* rhs)
: BinaryOpNode(lhs, rhs, IRNodeType::kAnd) {
if (!lhs->dtype().is_integral()) {
BitwiseOpNode(const Expr* lhs, const Expr* rhs, IRNodeType type)
: BinaryOpNode<Op>(lhs, rhs, type) {}

static ExprHandle make(const ExprHandle& lhs, const ExprHandle& rhs) {
if (!lhs.dtype().is_integral()) {
throw unsupported_dtype();
}
if (lhs->dtype() != rhs->dtype()) {
throw malformed_input("bad dtype in And");
if (lhs.dtype() != rhs.dtype()) {
throw malformed_input("lhs/rhs dtype mismatch");
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

out of curiosity: these throws used to be in a c-tor?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, throwing in an Expr ctor is bad news because of the way KernelScopedObject and KernelArena work.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks!

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

to double check, we can't do `(uint64_t)42 << (uint8_t)5 ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah the TE dsl requires operand types to match, and when coming from pytorch we enforce that at the kernel translation layer.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks!

}
return BinaryOpNode<Op>::make(lhs, rhs);
}
};

class Or : public BinaryOpNode<Or> {
class And : public BitwiseOpNode<And> {
public:
And(const Expr* lhs, const Expr* rhs)
: BitwiseOpNode(lhs, rhs, IRNodeType::kAnd) {}
};

class Or : public BitwiseOpNode<Or> {
public:
Or(const Expr* lhs, const Expr* rhs)
: BinaryOpNode(lhs, rhs, IRNodeType::kOr) {
if (!lhs->dtype().is_integral()) {
throw unsupported_dtype();
}
if (lhs->dtype() != rhs->dtype()) {
throw malformed_input("bad dtype in Or");
}
}
: BitwiseOpNode(lhs, rhs, IRNodeType::kOr) {}
};

class Xor : public BinaryOpNode<Xor> {
class Xor : public BitwiseOpNode<Xor> {
public:
Xor(const Expr* lhs, const Expr* rhs)
: BinaryOpNode(lhs, rhs, IRNodeType::kXor) {
if (!lhs->dtype().is_integral()) {
throw unsupported_dtype();
}
if (lhs->dtype() != rhs->dtype()) {
throw malformed_input("bad dtype in Xor");
}
}
: BitwiseOpNode(lhs, rhs, IRNodeType::kXor) {}
};

class Lshift : public BinaryOpNode<Lshift> {
class Lshift : public BitwiseOpNode<Lshift> {
public:
Lshift(const Expr* lhs, const Expr* rhs)
: BinaryOpNode(lhs, rhs, IRNodeType::kLshift) {
if (lhs->dtype().scalar_type() != ScalarType::Int) {
throw unsupported_dtype();
}
if (lhs->dtype() != rhs->dtype()) {
throw malformed_input("bad dtype in Lshift");
}
}
: BitwiseOpNode(lhs, rhs, IRNodeType::kLshift) {}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Where can I find the definition of whatever operation that maps to kLshift? Intuitively I thought shift operators would take a tensor as 1st operand and an integer as 2nd operand (e.g., tensorB = tensorA << 1). Or does it really support sth like tensorB = tensorA << tensorC?

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

native_functions.yaml is a good source - it lists all (or majority?) of the ops that we have in pytorch.
Another useful list is what ops we accept from the fuser: https://github.com/pytorch/pytorch/blob/master/torch/csrc/jit/passes/tensorexpr_fuser.cpp#L137-L140
And one more useful place is what we know how to lower to TE:
https://github.com/pytorch/pytorch/blob/master/torch/csrc/jit/tensorexpr/kernel.cpp#L879

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @ZolotukhinM lshift is defined in native_functions.yaml. And it supports the 2nd operand to be of tensor type or scalar type.

- func: __ilshift__.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!)
use_c10_dispatcher: full
variants: method
dispatch:
CPU, CUDA: __ilshift__
- func: __ilshift__.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!)
use_c10_dispatcher: full
variants: method
dispatch:
CPU, CUDA: __ilshift__

In the constructor of kLshift, we assume both operands to be tensor types. What happens if the 2nd operand is a scalar? Is it promoted to tensor types somewhere?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The Lshift constructor actually just takes Exprs, which can either be a scalar or a tensor element. For the scalar case we'd generate something like lhs(i) << scalar, for the tensor case we'd have lhs(i) << rhs(i). (where i is standing in for whatever indexing expression we have for the tensor).

In building the TE from a jit::Graph (kernel.cpp) we have a lookup function called tensorOrConstant that will lookup a jit::Value in a map and return the appropriate Expr, whether it's a tensor or a constant.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, also, I don't think we handle non-constant scalars in the TE codegen. I don't remember why but it wouldn't surprise me if there are a ton of corner cases since scalars are coming from the python type system.

};

class Rshift : public BinaryOpNode<Rshift> {
class Rshift : public BitwiseOpNode<Rshift> {
public:
Rshift(const Expr* lhs, const Expr* rhs)
: BinaryOpNode(lhs, rhs, IRNodeType::kRshift) {
if (lhs->dtype().scalar_type() != ScalarType::Int) {
throw unsupported_dtype();
}
if (lhs->dtype() != rhs->dtype()) {
throw malformed_input("bad dtype in Rshift");
}
}
: BitwiseOpNode(lhs, rhs, IRNodeType::kRshift) {}
};

class Max : public BinaryOpNode<Max> {
Expand Down