From 9d2f5a278414aeaa6f3277c5b15aee4938601fa6 Mon Sep 17 00:00:00 2001 From: Animesh Jain Date: Wed, 16 Nov 2022 08:51:30 +0000 Subject: [PATCH 1/2] [dynamo] Support if cond on NNModuleVariable (#89095) Pull Request resolved: https://github.com/pytorch/pytorch/pull/89095 Approved by: https://github.com/yanboliang, https://github.com/mlazos --- test/dynamo/test_misc.py | 28 ++++++++++++++++++++++++++++ torch/_dynamo/symbolic_convert.py | 5 +++++ 2 files changed, 33 insertions(+) diff --git a/test/dynamo/test_misc.py b/test/dynamo/test_misc.py index e27f7bc5198d..8f79f2476aee 100644 --- a/test/dynamo/test_misc.py +++ b/test/dynamo/test_misc.py @@ -2885,6 +2885,34 @@ def func(x, y): self.assertTrue(same(ref, res)) self.assertTrue(same(x, x1)) + def test_if_cond_nn_mod(self): + class MockModule(torch.nn.Module): + def __init__(self, output_relu=True): + super(MockModule, self).__init__() + self.relu = torch.nn.ReLU() if output_relu else None + + def forward(self, x): + x = torch.sin(x) + if self.relu: + x = self.relu(x) + return x + + model = MockModule() + opt_model = torch._dynamo.optimize("eager", nopython=True)(model) + + x = torch.rand(4) + ref = model(x) + res = opt_model(x) + self.assertTrue(same(ref, res)) + + model = MockModule(output_relu=False) + opt_model = torch._dynamo.optimize("eager", nopython=True)(model) + + x = torch.rand(4) + ref = model(x) + res = opt_model(x) + self.assertTrue(same(ref, res)) + class CustomFunc(torch.autograd.Function): @staticmethod diff --git a/torch/_dynamo/symbolic_convert.py b/torch/_dynamo/symbolic_convert.py index d5c05f76efb0..d2bc5332719c 100644 --- a/torch/_dynamo/symbolic_convert.py +++ b/torch/_dynamo/symbolic_convert.py @@ -252,6 +252,11 @@ def inner(self: "InstructionTranslatorBase", inst: Instruction): + if_next + if_jump ) + elif isinstance(value, NNModuleVariable): + # Equivant of "self.nn_module is not None" + if truth_fn(value): + push and self.push(value) + self.jump(inst) elif not isinstance(value, TensorVariable) and value.has_unpack_var_sequence( self ): From 9d28775c1d28ab7c1dd93479a58bdafb9b626341 Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Wed, 16 Nov 2022 09:45:49 +0000 Subject: [PATCH 2/2] Revert "Rewrite assert statement with torch._assert under config (#88246)" This reverts commit 62ba15e10e875ce088dff26e872605ee70c8c04a. Reverted https://github.com/pytorch/pytorch/pull/88246 on behalf of https://github.com/DanilBaibak due to breaking internal builds --- test/dynamo/test_repros.py | 92 ------------------------------ torch/_dynamo/config.py | 3 - torch/_dynamo/symbolic_convert.py | 94 ------------------------------- 3 files changed, 189 deletions(-) diff --git a/test/dynamo/test_repros.py b/test/dynamo/test_repros.py index e30a1275ed13..503231b4cb12 100644 --- a/test/dynamo/test_repros.py +++ b/test/dynamo/test_repros.py @@ -1938,98 +1938,6 @@ def fn(x): self.assertEqual(cnt.frame_count, 1) self.assertEqual(cnt.op_count, 1) - @patch.object(torch._dynamo.config, "rewrite_assert_with_torch_assert", True) - def test_rewrite_assert_with_msg(self): - def f(x): - b = x.sin() - assert x[0] == 3, "First dim need to be 3" - return x.cos() + b - - args = (torch.Tensor([3, 4, 5]),) - cnt = torch._dynamo.testing.CompileCounter() - - opt_f = torch._dynamo.optimize(cnt, nopython=True)(f) - self.assertTrue(same(f(*args), opt_f(*args))) - self.assertEqual(cnt.op_count, 6) - self.assertEqual(cnt.frame_count, 1) - - exported, _ = torch._dynamo.export(f, torch.Tensor([3, 4, 5])) - self.assertTrue(same(exported(*args), f(*args))) - - with self.assertRaisesRegex(AssertionError, ""): - exported, _ = torch._dynamo.export(f, torch.Tensor([4, 4, 5])) - - @patch.object(torch._dynamo.config, "rewrite_assert_with_torch_assert", True) - def test_not_rewrite_assert_for_other_errors(self): - def f(x): - b = x.sin() - if not x.sum() <= 3: - raise ValueError("input sum needs to be 3") - return x.cos() + b - - args = (torch.Tensor([3, 4, 5]),) - opt_fn = torch._dynamo.optimize("eager")(f) - with self.assertRaisesRegex(ValueError, "input sum needs to be 3"): - opt_fn(*args) - - # TODO (tmanlaibaatar) handle data-dependent fstring in assert statement. - @patch.object(torch._dynamo.config, "rewrite_assert_with_torch_assert", True) - def test_rewrite_assert_with_fstring_msg(self): - def f(x): - b = x.sin() - assert x[0] == 3, f"First dim need to be {x[0]}" - return x.cos() + b - - args = (torch.Tensor([3, 4, 5]),) - with self.assertRaisesRegex(torch._dynamo.exc.Unsupported, "generic_jump"): - exported, _ = torch._dynamo.export(f, torch.Tensor([3, 4, 5])) - - @patch.object(torch._dynamo.config, "rewrite_assert_with_torch_assert", True) - def test_rewrite_assert_without_msg(self): - def f(x): - b = x.sin() - assert x[0] == 3 - return x.cos() + b - - args = (torch.Tensor([3, 4, 5]),) - exported, _ = torch._dynamo.export(f, torch.Tensor([3, 4, 5])) - self.assertTrue(same(exported(*args), f(*args))) - - with self.assertRaisesRegex(AssertionError, ""): - exported, _ = torch._dynamo.export(f, torch.Tensor([4, 4, 5])) - - @patch.object(torch._dynamo.config, "rewrite_assert_with_torch_assert", True) - def test_rewrite_assert_noop(self): - def f(x): - b = x.sin() - assert True - assert x.dtype == torch.float32 - return x.cos() + b - - args = (torch.Tensor([3, 4, 5]),) - exported, _ = torch._dynamo.export(f, torch.Tensor([3, 4, 5])) - self.assertTrue(same(exported(*args), f(*args))) - - cnt = torch._dynamo.testing.CompileCounter() - opt_f = torch._dynamo.optimize(cnt, nopython=True)(f) - self.assertTrue(same(f(*args), opt_f(*args))) - # torch._assert shouldn't be in the graph - self.assertEqual(cnt.op_count, 3) - self.assertEqual(cnt.frame_count, 1) - - exported, _ = torch._dynamo.export(f, torch.Tensor([4, 4, 5])) - self.assertTrue(same(exported(*args), f(*args))) - - @patch.object(torch._dynamo.config, "rewrite_assert_with_torch_assert", False) - def test_not_rewrite_assert(self): - def f(x): - b = x.sin() - assert x[0] == 3 - return x.cos() + b - - with self.assertRaisesRegex(torch._dynamo.exc.Unsupported, "generic_jump"): - torch._dynamo.export(f, torch.Tensor([3, 4, 5])) - if __name__ == "__main__": from torch._dynamo.test_case import run_tests diff --git a/torch/_dynamo/config.py b/torch/_dynamo/config.py index 39a1a6433419..12088383e741 100644 --- a/torch/_dynamo/config.py +++ b/torch/_dynamo/config.py @@ -87,9 +87,6 @@ # if an exception is encountered replay_record_enabled = False -# Rewrite assert statement in python with torch._assert -rewrite_assert_with_torch_assert = True - # Show a warning on every graph break print_graph_breaks = False diff --git a/torch/_dynamo/symbolic_convert.py b/torch/_dynamo/symbolic_convert.py index d2bc5332719c..e64804cb68b2 100644 --- a/torch/_dynamo/symbolic_convert.py +++ b/torch/_dynamo/symbolic_convert.py @@ -53,7 +53,6 @@ fake_tensors_available, graph_break_dup_warning_checker, istype, - proxy_args_kwargs, ) from .variables.base import MutableLocal, typestr, VariableTracker from .variables.builder import VariableBuilder, wrap_fx_proxy @@ -122,103 +121,10 @@ def impl(self: "InstructionTranslatorBase", inst: Instruction): return impl -def _detect_and_normalize_assert_statement( - self: "InstructionTranslatorBase", truth_fn: typing.Callable, push: bool -): - # Detect if this jump instruction is assert and normalize the assert - # by pushing dummy error message when nothing is given. - # - # Python 3.9 assertion is in following format: - # 18 POP_JUMP_IF_TRUE 28 - # 20 LOAD_ASSERTION_ERROR - # 22 LOAD_CONST 3 ('Assert message') -> optional instruction - # 24 CALL_FUNCTION 1 -> optional instruction - # 26 RAISE_VARARGS - # - # Python 3.8 assertion is in following format: - # 18 POP_JUMP_IF_TRUE 28 - # 20 LOAD_GLOBAL 0 (Assertion type) - # 22 LOAD_CONST 3 ('Assert message') -> optional instruction - # 24 CALL_FUNCTION 1 -> optional instruction - # 26 RAISE_VARARGS 1 - - if (truth_fn is not operator.truth) or push: - return False - - current_instruction_pointer = self.instruction_pointer - inst = self.instructions[current_instruction_pointer] - # Detect LOAD_ASSERTION_ERROR or LOAD_GLOBAL 0 - if sys.version_info < (3, 9): - if inst.opname != "LOAD_GLOBAL" or inst.argval != "AssertionError": - return False - else: - if inst.opname != "LOAD_ASSERTION_ERROR": - return False - - current_instruction_pointer += 1 - - if current_instruction_pointer >= len(self.instructions): - return False - - inst = self.instructions[current_instruction_pointer] - has_error_msg = False - # DETECT RAISE_VARARGS or LOAD CONST - if inst.opname == "LOAD_CONST": - if not isinstance(inst.argval, str): - return False - self.LOAD_CONST(inst) - has_error_msg = True - - # if it is LOAD_CONSTANT, it must be followed by CALL_FUNCTION - current_instruction_pointer += 1 - if current_instruction_pointer >= len(self.instructions): - return False - inst = self.instructions[current_instruction_pointer] - if inst.opname != "CALL_FUNCTION": - return False - - # CALL_FUNCTION should be followed by RAISE_VARARGS - current_instruction_pointer += 1 - if current_instruction_pointer >= len(self.instructions): - return False - inst = self.instructions[current_instruction_pointer] - - if inst.opname != "RAISE_VARARGS": - return False - - if not has_error_msg: - # Push dummy value instead of error message - self.push(ConstantVariable("assertion error")) - - return True - - def generic_jump(truth_fn: typing.Callable, push: bool): def inner(self: "InstructionTranslatorBase", inst: Instruction): value: VariableTracker = self.pop() self.output.guards.update(value.guards) - if ( - config.rewrite_assert_with_torch_assert - and _detect_and_normalize_assert_statement(self, truth_fn, push) - ): - error_msg: VariableTracker = self.pop() - self.output.guards.update(error_msg.guards) - # Skip over things like `assert True` - if value.is_python_constant() and bool(value.as_python_constant()): - self.jump(inst) - return - - # Manually insert torch._assert instead of python assert and jump over - # assert related instructions as we don't need them anymore. - self.output.create_proxy( - "call_function", - torch._assert, - *proxy_args_kwargs((value, error_msg), {}), - current_tx=self, - ) - self.jump(inst) - return - if value.is_python_constant(): if truth_fn(value.as_python_constant()): push and self.push(value)