Skip to content

Commit

Permalink
[te] Fix bugs with shift operators (#49271)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #49271

Two things:

1. These throw exceptions in their constructor, which causes a segfault (*), so
   move the exceptions to ::make.
2. They technically support FP types but the rules are complicated so let's not
   bother.

(*) The reason for the segfault: all Exprs including these inherit from
KernelScopedObject, whose constructor adds the object to a list for destruction
at the end of the containing KernelArena's lifetime.  But if the derived-class
constructor throws, the object is deleted even though it's still in the
KernelArena's list.  So when the KernelArena is itself deleted, it double-frees
the pointer and dies.  I've also fixed And, Or, and Xor in this diff.
ghstack-source-id: 118594998

Test Plan: `buck test //caffe2/test:jit`

Differential Revision: D25512052

fbshipit-source-id: f3ca16f208c427cd3d740e8971302d8d504240fb
  • Loading branch information
bertmaher authored and facebook-github-bot committed Dec 15, 2020
1 parent 39a10fb commit 67da5af
Show file tree
Hide file tree
Showing 4 changed files with 50 additions and 54 deletions.
13 changes: 3 additions & 10 deletions test/test_jit_fuser_te.py
Original file line number Diff line number Diff line change
Expand Up @@ -477,7 +477,9 @@ def apply(fn):
binary_ops = [
operator.__and__,
operator.__or__,
operator.__xor__
operator.__xor__,
operator.__lshift__,
operator.__rshift__,
]
devices = self.devices
for dtype, op, device in product(self.int_dtypes, binary_ops, devices):
Expand Down Expand Up @@ -1292,11 +1294,6 @@ def apply(fn):
torch.lt,
torch.fmod,
torch.remainder,

# FIXME: segfaults on CPU backend
# operator.__rshift__,
# operator.__lshift__,

lambda x, y: y.type_as(x),
]
fp_only = [
Expand Down Expand Up @@ -1343,10 +1340,6 @@ def apply_with_scalar(fn, scalar):
torch.ge,
torch.lt,
torch.gt,

# FIXME: segfaults on CPU backend
# operator.__rshift__,
# operator.__lshift__,
]
devices = self.devices
# Maybe we should split this into separate tests to speed it up by
Expand Down
15 changes: 15 additions & 0 deletions torch/csrc/jit/passes/tensorexpr_fuser.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -737,6 +737,12 @@ class TensorExprFuser {
"aten::remainder.Scalar(Tensor self, Scalar other) -> Tensor",
"aten::remainder.Tensor(Tensor self, Tensor other) -> Tensor",
};
static const OperatorSet int_only_operator_set{
"aten::__lshift__.Scalar(Tensor self, Scalar other) -> Tensor",
"aten::__lshift__.Tensor(Tensor self, Tensor other) -> Tensor",
"aten::__rshift__.Scalar(Tensor self, Scalar other) -> Tensor",
"aten::__rshift__.Tensor(Tensor self, Tensor other) -> Tensor",
};
// clang-format on

for (const Value* v : node->inputs()) {
Expand All @@ -759,11 +765,20 @@ class TensorExprFuser {
if (node->isMemberOf(float_only_operator_set) && !isFloatingType(*st)) {
return false;
}

// These operators have complicated casting rules for floats.
if (node->isMemberOf(int_only_operator_set) && isFloatingType(*st)) {
return false;
}
} else if (node->isMemberOf(float_only_operator_set)) {
// Check scalar operands of float-only ops.
if (!v->type()->cast<FloatType>()) {
return false;
}
} else if (node->isMemberOf(int_only_operator_set)) {
if (!v->type()->cast<IntType>()) {
return false;
}
}
}

Expand Down
10 changes: 8 additions & 2 deletions torch/csrc/jit/tensorexpr/eval.h
Original file line number Diff line number Diff line change
Expand Up @@ -422,8 +422,14 @@ class SimpleIREvaluator : public CodeGen, public IRVisitor {

if (expr_type == IRNodeType::kLshift || expr_type == IRNodeType::kRshift) {
switch (lhs_v.dtype().scalar_type()) {
case ScalarType::Int:
value_ = shift_binary_op<int>(lhs_v, rhs_v, expr_type);
#define TYPE_CASE(Type, Name) \
case ScalarType::Name: \
value_ = shift_binary_op<Type>(lhs_v, rhs_v, expr_type); \
break;
AT_FORALL_INT_TYPES(TYPE_CASE);
#undef TYPE_CASE
case ScalarType::Bool:
value_ = shift_binary_op<unsigned char>(lhs_v, rhs_v, expr_type);
break;
default:
throw unsupported_dtype();
Expand Down
66 changes: 24 additions & 42 deletions torch/csrc/jit/tensorexpr/ir.h
Original file line number Diff line number Diff line change
Expand Up @@ -179,69 +179,51 @@ class Mod : public BinaryOpNode<Mod> {
: BinaryOpNode(lhs, rhs, IRNodeType::kMod) {}
};

class And : public BinaryOpNode<And> {
template <typename Op>
class BitwiseOpNode : public BinaryOpNode<Op> {
public:
And(const Expr* lhs, const Expr* rhs)
: BinaryOpNode(lhs, rhs, IRNodeType::kAnd) {
if (!lhs->dtype().is_integral()) {
BitwiseOpNode(const Expr* lhs, const Expr* rhs, IRNodeType type)
: BinaryOpNode<Op>(lhs, rhs, type) {}

static ExprHandle make(const ExprHandle& lhs, const ExprHandle& rhs) {
if (!lhs.dtype().is_integral()) {
throw unsupported_dtype();
}
if (lhs->dtype() != rhs->dtype()) {
throw malformed_input("bad dtype in And");
if (lhs.dtype() != rhs.dtype()) {
throw malformed_input("lhs/rhs dtype mismatch");
}
return BinaryOpNode<Op>::make(lhs, rhs);
}
};

class Or : public BinaryOpNode<Or> {
class And : public BitwiseOpNode<And> {
public:
And(const Expr* lhs, const Expr* rhs)
: BitwiseOpNode(lhs, rhs, IRNodeType::kAnd) {}
};

class Or : public BitwiseOpNode<Or> {
public:
Or(const Expr* lhs, const Expr* rhs)
: BinaryOpNode(lhs, rhs, IRNodeType::kOr) {
if (!lhs->dtype().is_integral()) {
throw unsupported_dtype();
}
if (lhs->dtype() != rhs->dtype()) {
throw malformed_input("bad dtype in Or");
}
}
: BitwiseOpNode(lhs, rhs, IRNodeType::kOr) {}
};

class Xor : public BinaryOpNode<Xor> {
class Xor : public BitwiseOpNode<Xor> {
public:
Xor(const Expr* lhs, const Expr* rhs)
: BinaryOpNode(lhs, rhs, IRNodeType::kXor) {
if (!lhs->dtype().is_integral()) {
throw unsupported_dtype();
}
if (lhs->dtype() != rhs->dtype()) {
throw malformed_input("bad dtype in Xor");
}
}
: BitwiseOpNode(lhs, rhs, IRNodeType::kXor) {}
};

class Lshift : public BinaryOpNode<Lshift> {
class Lshift : public BitwiseOpNode<Lshift> {
public:
Lshift(const Expr* lhs, const Expr* rhs)
: BinaryOpNode(lhs, rhs, IRNodeType::kLshift) {
if (lhs->dtype().scalar_type() != ScalarType::Int) {
throw unsupported_dtype();
}
if (lhs->dtype() != rhs->dtype()) {
throw malformed_input("bad dtype in Lshift");
}
}
: BitwiseOpNode(lhs, rhs, IRNodeType::kLshift) {}
};

class Rshift : public BinaryOpNode<Rshift> {
class Rshift : public BitwiseOpNode<Rshift> {
public:
Rshift(const Expr* lhs, const Expr* rhs)
: BinaryOpNode(lhs, rhs, IRNodeType::kRshift) {
if (lhs->dtype().scalar_type() != ScalarType::Int) {
throw unsupported_dtype();
}
if (lhs->dtype() != rhs->dtype()) {
throw malformed_input("bad dtype in Rshift");
}
}
: BitwiseOpNode(lhs, rhs, IRNodeType::kRshift) {}
};

class Max : public BinaryOpNode<Max> {
Expand Down

0 comments on commit 67da5af

Please sign in to comment.