Skip to content

Commit

Permalink
torch.Assert: make it torch.jit.script'able (#47399)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #47399

Currently torch.Assert is not scriptable, which makes it not very useful for production code. According to jamesr66a , moving this to c++ op land will help with scriptability. This PR implements the change.

Note: with the current code the Assert is scriptable but the Assert is a no-op after being scripted. Would love suggestions on how to address that (can be in future PR).

Test Plan:
```
python test/test_utils.py TestAssert.test_assert_scriptable
python test/test_utils.py TestAssert.test_assert_true
python test/test_fx.py TestFX.test_symbolic_trace_assert
```

Imported from OSS

Reviewed By: eellison

Differential Revision: D24740727

fbshipit-source-id: c7888e769c921408a3020ca8332f4dae33f2bc0e
  • Loading branch information
vkuzo authored and facebook-github-bot committed Nov 13, 2020
1 parent a8ca042 commit b787e74
Show file tree
Hide file tree
Showing 5 changed files with 52 additions and 13 deletions.
10 changes: 7 additions & 3 deletions test/test_fx.py
Original file line number Diff line number Diff line change
Expand Up @@ -746,20 +746,24 @@ def test_construct_root_dict(self):
self.assertEqual(out, ref_out)

def test_symbolic_trace_assert(self):
message = "assert_foobar"

class AssertsTensorShape(torch.nn.Module):
def forward(self, x):
torch._assert(x.shape[1] > 4, message)
torch._assert(x.shape[1] > 4, "assert_foobar")
return x

m = AssertsTensorShape()
# verify traceability
traced = symbolic_trace(m)
# verify assertion on traced model works correctly at runtime
traced(torch.rand(4, 5))
with self.assertRaisesRegex(AssertionError, message):
with self.assertRaisesRegex(AssertionError, "assert_foobar"):
traced(torch.rand(4, 3))
# verify the symbolically traced module is scriptable
ms = torch.jit.script(m)
with self.assertRaisesRegex(torch.jit.Error, "assert_foobar"):
ms(torch.rand(4, 3))


def test_copy_no_remap(self):
traced = symbolic_trace(SimpleTest())
Expand Down
20 changes: 20 additions & 0 deletions test/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -637,9 +637,29 @@ def test_import_hipify(self):
class TestAssert(TestCase):
def test_assert_true(self):
# verify assertions work as expected
# bool argument
torch._assert(True, "foo")
with self.assertRaisesRegex(AssertionError, "bar"):
torch._assert(False, "bar")
# tensor argument
torch._assert(torch.tensor([True], dtype=torch.bool), "foo")
with self.assertRaisesRegex(AssertionError, "bar"):
torch._assert(torch.tensor([False], dtype=torch.bool), "bar")

def test_assert_scriptable(self):
class M(torch.nn.Module):
def forward(self, x):
torch._assert(x.sum() > 0, "foo")
return x

m = M()
# scriptable
ms = torch.jit.script(m)
# data can be passed without errors
x = torch.randn(4, 4).fill_(1.0)
ms(x)
with self.assertRaisesRegex(torch.jit.Error, "foo"):
ms(torch.tensor([False], dtype=torch.bool))


if __name__ == '__main__':
Expand Down
24 changes: 14 additions & 10 deletions torch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -544,6 +544,20 @@ def manager_path():
del ComplexFloatStorageBase
del QUInt4x2StorageBase

################################################################################
# Define _assert
################################################################################

# needs to be before the submodule imports to avoid circular dependencies
def _assert(condition, message):
r"""A wrapper around Python's assert which is symbolically traceable.
"""
from .overrides import has_torch_function, handle_torch_function

if type(condition) is not torch.Tensor and has_torch_function((condition,)):
return handle_torch_function(_assert, (condition,), condition, message)
assert condition, message

################################################################################
# Import most common subpackages
################################################################################
Expand Down Expand Up @@ -618,13 +632,3 @@ def compiled_with_cxx11_abi():
# class usage. We add these lines here to preserve backward compatbility.
quantized_lstm = torch.ops.aten.quantized_lstm
quantized_gru = torch.ops.aten.quantized_gru


def _assert(condition, message):
r"""A wrapper around Python's assert which is symbolically traceable.
"""
from .overrides import has_torch_function, handle_torch_function

if type(condition) is not torch.Tensor and has_torch_function((condition,)):
return handle_torch_function(_assert, (condition,), condition, message)
assert condition, message
10 changes: 10 additions & 0 deletions torch/csrc/jit/frontend/builtin_functions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,15 @@ def _assert_int_or_pair(vals: List[int], name: str, message: str):
def list_with_default(out_size: List[int], defaults: List[int]):
assert len(defaults) > len(out_size)
return out_size
def _assert(condition : bool, message : str):
assert condition, message
)SCRIPT";

// an additional overload for Tensor variant of _assert
const auto aten_ops_additional =
R"SCRIPT(
def _assert(condition : Tensor, message : str):
assert bool(condition), message
)SCRIPT";

// Implementations of historic symbol behaviors are defined here
Expand Down Expand Up @@ -215,6 +224,7 @@ struct BuiltinFunctionRegistry {
}

loadSource(aten_ops, "aten");
loadSource(aten_ops_additional, "aten");

// Loads functions implementing historic behavior, see note [Versioned
// Symbols]
Expand Down
1 change: 1 addition & 0 deletions torch/jit/_builtins.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@
(math.degrees, "aten::degrees"),
(math.radians, "aten::radians"),
(math.ldexp, "aten::ldexp"),
(torch._assert, "aten::_assert"),
(torch.autograd.grad, "aten::grad"),
(torch.autograd.backward, "aten::backward"),
(torch._C._infer_size, "aten::_infer_size"),
Expand Down

0 comments on commit b787e74

Please sign in to comment.