Skip to content

Commit

Permalink
[ONNX] Add logical_and, logical_or, logical_xor torch op support in p…
Browse files Browse the repository at this point in the history
…ytorch exporter (#50570)

Fixes #{}
Add logical_and, logical_or, logical_xor torch op support in pytorch exporter.
  • Loading branch information
hwangdeyu committed Jan 20, 2021
1 parent 566406c commit 8f9ee34
Show file tree
Hide file tree
Showing 2 changed files with 81 additions and 0 deletions.
66 changes: 66 additions & 0 deletions test/onnx/test_pytorch_onnx_onnxruntime.py
Original file line number Diff line number Diff line change
Expand Up @@ -3209,6 +3209,72 @@ def _test_compare_ops(self, model, num_inputs):
self.run_test(model, x_float)
self.run_test(model, x_int)

@skipIfUnsupportedMinOpsetVersion(9)
def test_logical_and(self):
class AndModel(torch.nn.Module):
def forward(self, x, y):
return torch.logical_and(x, y)

x = torch.randint(0, 2, (5, 5), dtype=torch.bool)
y = torch.randint(0, 2, (5, 5), dtype=torch.bool)
self.run_test(AndModel(), input=(x, y))

x = torch.randint(10, (5, 5), dtype=torch.int32)
y = torch.randint(10, (5, 5), dtype=torch.int32)
self.run_test(AndModel(), input=(x, y))

x = torch.randint(10, (5, 5), dtype=torch.double)
y = torch.randint(10, (5, 5), dtype=torch.double)
self.run_test(AndModel(), input=(x, y))

x = torch.randint(10, (2, 3, 5), dtype=torch.float32)
y = torch.randint(10, (2, 3, 5), dtype=torch.long)
self.run_test(AndModel(), input=(x, y))

@skipIfUnsupportedMinOpsetVersion(9)
def test_logical_or(self):
class OrModel(torch.nn.Module):
def forward(self, x, y):
return torch.logical_or(x, y)

x = torch.randint(0, 2, (5, 5), dtype=torch.bool)
y = torch.randint(0, 2, (5, 5), dtype=torch.bool)
self.run_test(OrModel(), input=(x, y))

x = torch.randint(10, (5, 5), dtype=torch.int32)
y = torch.randint(10, (5, 5), dtype=torch.int32)
self.run_test(OrModel(), input=(x, y))

x = torch.randint(10, (5, 5), dtype=torch.double)
y = torch.randint(10, (5, 5), dtype=torch.double)
self.run_test(OrModel(), input=(x, y))

x = torch.randint(10, (2, 3, 5), dtype=torch.float32)
y = torch.randint(10, (2, 3, 5), dtype=torch.long)
self.run_test(OrModel(), input=(x, y))

@skipIfUnsupportedMinOpsetVersion(9)
def test_logical_xor(self):
class XorModel(torch.nn.Module):
def forward(self, x, y):
return torch.logical_xor(x, y)

x = torch.randint(0, 2, (5, 5), dtype=torch.bool)
y = torch.randint(0, 2, (5, 5), dtype=torch.bool)
self.run_test(XorModel(), input=(x, y))

x = torch.randint(10, (5, 5), dtype=torch.int32)
y = torch.randint(10, (5, 5), dtype=torch.int32)
self.run_test(XorModel(), input=(x, y))

x = torch.randint(10, (5, 5), dtype=torch.double)
y = torch.randint(10, (5, 5), dtype=torch.double)
self.run_test(XorModel(), input=(x, y))

x = torch.randint(10, (2, 3, 5), dtype=torch.float32)
y = torch.randint(10, (2, 3, 5), dtype=torch.long)
self.run_test(XorModel(), input=(x, y))

def test_gt(self):
class GreaterModel(torch.nn.Module):
def forward(self, input, other):
Expand Down
15 changes: 15 additions & 0 deletions torch/onnx/symbolic_opset9.py
Original file line number Diff line number Diff line change
Expand Up @@ -1105,6 +1105,21 @@ def __or_(g, input, other):
return g.op('Or', input, other)


@wrap_logical_op_with_cast_to_and_from('Bool')
def logical_and(g, input, other):
return g.op('And', input, other)


@wrap_logical_op_with_cast_to_and_from('Bool')
def logical_or(g, input, other):
return g.op('Or', input, other)


@wrap_logical_op_with_cast_to_and_from('Bool')
def logical_xor(g, input, other):
return g.op('Xor', input, other)


def __rshift_(g, self, other):
# make sure to cast other to self's type
# (when self is long, make sure that other is not float)
Expand Down

0 comments on commit 8f9ee34

Please sign in to comment.