Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion aten/src/ATen/core/interned_strings.h
Original file line number Diff line number Diff line change
Expand Up @@ -74,8 +74,10 @@ namespace c10 {
_(prim, fork) \
_(prim, RaiseException) \
_(aten, append) \
_(aten, __not__) \
_(aten, format) \
_(aten, __not__) \
_(aten, __is__) \
_(aten, __isnot__) \
FORALL_ATEN_BASE_SYMBOLS(_) \
_(onnx, Add) \
_(onnx, Concat) \
Expand Down
23 changes: 23 additions & 0 deletions aten/src/ATen/core/ivalue.h
Original file line number Diff line number Diff line change
Expand Up @@ -409,6 +409,9 @@ struct CAFFE2_API IValue final {
template<typename T>
optional<T> toOptional();

// this is a shallow comparison of two IValues to test the object identity
bool isSameIdentity(IValue& rhs);

CAFFE2_API friend std::ostream& operator<<(
std::ostream& out,
const IValue& v);
Expand Down Expand Up @@ -623,4 +626,24 @@ inline optional<T> IValue::toOptional() {
return this->to<T>();
}

inline bool IValue::isSameIdentity(IValue& rhs) {
// We choose to not use memcmp for payload check due to potenntial random padding characters on union type

// Semantics:
// 1. None is None, False is False, and True is True are all true
// 2. If it is a reference type (i.e. is_intrusive_ptr), then is is True when the pointed-to object is the same.
// 3. False for all other comparisons.
if (this->isNone() && rhs.isNone()) {
return true;
} else if (this->isBool() && rhs.isBool()) {
// for bool type, do equality check
return this->toBool() == rhs.toBool();
} else {
// for objects holding in IValue, do shallow compare on pointer address to testify the identity
return this->is_intrusive_ptr && rhs.is_intrusive_ptr
&& this->payload.as_intrusive_ptr == rhs.payload.as_intrusive_ptr;
}
}


} // namespace c10
35 changes: 35 additions & 0 deletions test/test_jit.py
Original file line number Diff line number Diff line change
Expand Up @@ -4434,6 +4434,41 @@ def test_tensor_number_math(self):
def test_tensor_number_math_cuda(self):
self._test_tensor_number_math(device='cuda')

def test_not(self):
# test not operator in python
# TODO: add more tests when bool conversions ready
def test_not_op(a):
return not bool(a > 1)

self.checkScript(test_not_op, (torch.tensor(2), ), optimize=True)

def test_is_isnot(self):
# test is and is not operator in python
template = dedent('''
def func():
# type: () -> bool
return {lhs} {op} {rhs}
''')

def test(op, args):
code = template.format(lhs=args[0], rhs=args[1], op=op)
scope = {}
exec(code, globals(), scope)
cu = torch.jit.CompilationUnit(code)
self.assertEqual(
cu.func(),
scope['func'](),
"Failed with op: {}, lhs: {}, rhs: {}"
.format(op, args[0], args[1])
)

ops = ['is', 'is not']
type_literals = [True, False, None, [1, 1]]

# do literals product to try any types combinations
for op, lhs, rhs in product(ops, type_literals, type_literals):
test(op, [lhs, rhs])

def test_python_call(self):
def pyfunc(a):
return a * 3.0
Expand Down
24 changes: 22 additions & 2 deletions torch/csrc/jit/register_prim_ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -900,10 +900,30 @@ Operator( \
};
}),
Operator(
"aten::__not__(int self) -> int",
"aten::__not__(bool self) -> bool",
[](const Node* node) {
return [=](Stack& stack) {
push(stack, !pop(stack).toInt());
push(stack, !pop(stack).toBool());
return 0;
};
}),
Operator(
"aten::__is__(t1 self, t2 obj) -> bool",
[](const Node* node) {
return [=](Stack& stack) {
IValue self, obj;
pop(stack, self, obj);
push(stack, self.isSameIdentity(obj));
return 0;
};
}),
Operator(
"aten::__isnot__(t1 self, t2 obj) -> bool",
[](const Node* node) {
return [=](Stack& stack) {
IValue self, obj;
pop(stack, self, obj);
push(stack, !self.isSameIdentity(obj));
return 0;
};
}),
Expand Down
6 changes: 6 additions & 0 deletions torch/csrc/jit/script/compiler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1485,6 +1485,10 @@ struct to_ir {
return aten::__and__;
case TK_OR:
return aten::__or__;
case TK_IS:
return aten::__is__;
case TK_ISNOT:
return aten::__isnot__;
case TK_NOT:
return aten::__not__;
default:
Expand Down Expand Up @@ -1665,6 +1669,8 @@ struct to_ir {
switch (tree->kind()) {
case '@':
case TK_POW:
case TK_IS:
case TK_ISNOT:
case TK_NOT:
case TK_NE:
case TK_EQ:
Expand Down
40 changes: 21 additions & 19 deletions torch/csrc/jit/script/lexer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,29 +11,31 @@ namespace jit {
namespace script {

static const std::unordered_map<int, int> binary_prec = {
{TK_IF, 1},
{TK_AND, 2},
{TK_OR, 2},
{TK_IF, 1},
{TK_AND, 2},
{TK_OR, 2},
// reserve a level for unary not
{'<', 4},
{'>', 4},
{TK_EQ, 4},
{TK_LE, 4},
{TK_GE, 4},
{TK_NE, 4},
{'+', 5},
{'-', 5},
{'*', 6},
{'/', 6},
{'%', 6},
{'@', 6},
{TK_POW, 7},
{'<', 4},
{'>', 4},
{TK_IS, 4},
{TK_ISNOT, 4},
{TK_EQ, 4},
{TK_LE, 4},
{TK_GE, 4},
{TK_NE, 4},
{'+', 5},
{'-', 5},
{'*', 6},
{'/', 6},
{'%', 6},
{'@', 6},
{TK_POW, 7},
};

static const std::unordered_map<int, int> unary_prec = {
{TK_NOT, 3},
{'-', 8},
{'*', 8},
{TK_NOT, 3},
{'-', 8},
{'*', 8},
};

bool SharedParserData::isUnary(int kind, int* prec) {
Expand Down
2 changes: 2 additions & 0 deletions torch/csrc/jit/script/lexer.h
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,8 @@ namespace script {
_(TK_WHILE, "while", "while") \
_(TK_EXPR_STMT, "expression statement", "") \
_(TK_RETURN, "return", "return") \
_(TK_IS, "is", "is") \
_(TK_ISNOT, "is not", "is not") \

This comment was marked as off-topic.

_(TK_NE, "ne", "!=") \
_(TK_EQ, "eq", "==") \
_(TK_LE, "le", "<=") \
Expand Down
7 changes: 7 additions & 0 deletions torch/csrc/jit/script/tree_views.h
Original file line number Diff line number Diff line change
Expand Up @@ -42,10 +42,13 @@ namespace script {
// | Le TK_LE
// | Ge TK_GE
// | Ne TK_NE
// | Is TK_IS
// | IsNot TK_ISNOT
// | Add '+'
// | Sub '-'
// | Mul '*'
// | Div '/'
// | Mod '%'
// | MatMult '@'
// | Pow TK_POW
// | UnaryOp(Expr expr)
Expand Down Expand Up @@ -226,6 +229,8 @@ struct Expr : public TreeView {
case TK_OR:
case '<':
case '>':
case TK_IS:
case TK_ISNOT:
case TK_EQ:
case TK_LE:
case TK_GE:
Expand Down Expand Up @@ -522,6 +527,8 @@ struct BinOp : public Expr {
case TK_OR:
case '<':
case '>':
case TK_IS:
case TK_ISNOT:
case TK_EQ:
case TK_LE:
case TK_GE:
Expand Down
2 changes: 2 additions & 0 deletions torch/jit/frontend.py
Original file line number Diff line number Diff line change
Expand Up @@ -361,6 +361,8 @@ class ExprBuilder(Builder):
ast.Lt: '<',
ast.GtE: '>=',
ast.Gt: '>',
ast.Is: 'is',
ast.IsNot: 'is not',
}

@staticmethod
Expand Down