From ec7551d1b783e284cedddeb9aeabb285e653c480 Mon Sep 17 00:00:00 2001 From: FFFrog Date: Tue, 9 Apr 2024 23:24:09 +0000 Subject: [PATCH] UFMT formatting on test/export (#123520) Partially addresses https://github.com/pytorch/pytorch/issues/123062 Ran lintrunner on: test/export Detail: ```Shell $ lintrunner -a --take UFMT --all-files ok No lint issues. Successfully applied all patches. ``` Co-authored-by: Edward Z. Yang Pull Request resolved: https://github.com/pytorch/pytorch/pull/123520 Approved by: https://github.com/shink, https://github.com/ezyang --- .lintrunner.toml | 9 - test/export/test_db.py | 8 +- test/export/test_export.py | 423 ++++++++++++++++++++------------- test/export/test_pass_infra.py | 17 +- test/export/test_passes.py | 202 +++++++++++----- test/export/test_serialize.py | 124 ++++++---- test/export/test_unflatten.py | 60 +++-- test/export/test_upgrade.py | 71 +++--- test/export/test_verifier.py | 23 +- 9 files changed, 591 insertions(+), 346 deletions(-) diff --git a/.lintrunner.toml b/.lintrunner.toml index 94699c59e4770..946f3411fc982 100644 --- a/.lintrunner.toml +++ b/.lintrunner.toml @@ -1144,15 +1144,6 @@ exclude_patterns = [ 'test/distributed/test_pg_wrapper.py', 'test/distributed/test_store.py', 'test/expect/__init__.py', - 'test/export/test_db.py', - 'test/export/test_export.py', - 'test/export/test_funtionalized_assertions.py', - 'test/export/test_pass_infra.py', - 'test/export/test_passes.py', - 'test/export/test_serialize.py', - 'test/export/test_upgrade.py', - 'test/export/test_verifier.py', - 'test/export/test_unflatten.py', 'test/functorch/attn_ft.py', 'test/functorch/attn_positional.py', 'test/functorch/common_utils.py', diff --git a/test/export/test_db.py b/test/export/test_db.py index f846567c19b55..2abce16a8b0c0 100644 --- a/test/export/test_db.py +++ b/test/export/test_db.py @@ -12,12 +12,13 @@ from torch.export import export from torch.testing._internal.common_utils import ( instantiate_parametrized_tests, + IS_WINDOWS, parametrize, run_tests, TestCase, - IS_WINDOWS ) + @unittest.skipIf(IS_WINDOWS, "Windows not supported for this test") @unittest.skipIf(not torchdynamo.is_dynamo_supported(), "dynamo doesn't support") class ExampleTests(TestCase): @@ -60,7 +61,9 @@ def test_exportdb_supported(self, name: str, case: ExportCase) -> None: def test_exportdb_not_supported(self, name: str, case: ExportCase) -> None: model = case.model # pyre-ignore - with self.assertRaises((torchdynamo.exc.Unsupported, AssertionError, RuntimeError)): + with self.assertRaises( + (torchdynamo.exc.Unsupported, AssertionError, RuntimeError) + ): inputs = normalize_inputs(case.example_inputs) exported_model = export( model, @@ -77,6 +80,7 @@ def test_exportdb_not_supported(self, name: str, case: ExportCase) -> None: for rewrite_case in get_rewrite_cases(case) ] if exportdb_not_supported_rewrite_cases: + @parametrize( "name,rewrite_case", exportdb_not_supported_rewrite_cases, diff --git a/test/export/test_export.py b/test/export/test_export.py index 7167102fd08d3..b07c423b94f29 100644 --- a/test/export/test_export.py +++ b/test/export/test_export.py @@ -3,8 +3,8 @@ import copy import dataclasses import io -import re import logging +import re import unittest import warnings from contextlib import contextmanager @@ -27,24 +27,25 @@ ) from torch._subclasses import FakeTensorMode from torch.export import Dim, dynamic_dim, export, unflatten -from torch.export.graph_signature import InputKind from torch.export._trace import ( _export, _export_to_torch_ir, DEFAULT_EXPORT_DYNAMO_CONFIG, ) +from torch.export.graph_signature import InputKind from torch.fx.experimental.proxy_tensor import make_fx +from torch.fx.experimental.symbolic_shapes import ShapeEnv from torch.testing import FileCheck from torch.testing._internal.common_cuda import PLATFORM_SUPPORTS_FLASH_ATTENTION from torch.testing._internal.common_device_type import onlyCPU, onlyCUDA from torch.testing._internal.common_utils import ( - run_tests, - TestCase as TorchTestCase, + find_library_location, IS_FBCODE, IS_MACOS, IS_SANDCASTLE, IS_WINDOWS, - find_library_location, + run_tests, + TestCase as TorchTestCase, ) from torch.utils._pytree import ( LeafSpec, @@ -55,7 +56,6 @@ treespec_dumps, treespec_loads, ) -from torch.fx.experimental.symbolic_shapes import ShapeEnv try: from torchrec.sparse.jagged_tensor import KeyedJaggedTensor @@ -91,11 +91,13 @@ tags=torch.Tag.pt2_compliant_tag, ) + @torch.library.impl("testlib::returns_tensor_symint", "cpu") @torch.library.impl_abstract("testlib::returns_tensor_symint") def returns_tensor_symint_impl(x): return x, x.shape[0] + @torch.library.impl("testlib::foo", "cpu") @torch._dynamo.disable def foo_impl(x, z): @@ -103,22 +105,27 @@ def foo_impl(x, z): z.add_(5) return x, z, x + z + @torch.library.impl_abstract("testlib::foo") def foo_abstract(x, z): return x, z, x + z + @torch.library.impl("testlib::foo_mutated", "CompositeImplicitAutograd") def foo_mutated(x): a, b, c = torch.ops.testlib.foo(x, x.cos()) return a, a.cos() + @torch.library.impl("testlib::foo_functional", "CompositeImplicitAutograd") def foo_functional(x): a, b, c = torch.ops.testlib.foo(x.cos(), x.cos()) return a.cos() + NON_STRICT_SUFFIX = "_non_strict" + def is_non_strict_test(test_name): return test_name.endswith(NON_STRICT_SUFFIX) @@ -188,6 +195,7 @@ def forward(self, x: torch.Tensor): # Being able to export means shape is preserved as static export(branch_on_shape, inp) + @unittest.skipIf(IS_WINDOWS, "Windows isn't supported for this case") @unittest.skipIf(not torchdynamo.is_dynamo_supported(), "dynamo isn't support") class TestExport(TestCase): @@ -250,8 +258,7 @@ class M(torch.nn.Module): def __init__(self) -> None: super().__init__() self.conv = torch.nn.Conv2d( - in_channels=3, out_channels=32 - , kernel_size=3, padding=1 + in_channels=3, out_channels=32, kernel_size=3, padding=1 ) self.relu = torch.nn.ReLU() self.maxpool = torch.nn.MaxPool2d(kernel_size=3) @@ -347,7 +354,10 @@ def forward(self, x, ys, zs, c): torch.ones(4), ) with self.assertRaisesRegex( - RuntimeError, escape("Expected input at *args[1][0].shape[0] to be equal to 6, but got 5") + RuntimeError, + escape( + "Expected input at *args[1][0].shape[0] to be equal to 6, but got 5" + ), ): ep_ns.module()(*bad_runtime_inp1) @@ -358,7 +368,8 @@ def forward(self, x, ys, zs, c): torch.ones(6), ) with self.assertRaisesRegex( - RuntimeError, escape("Expected input at *args[3].shape[0] to be equal to 4, but got 6") + RuntimeError, + escape("Expected input at *args[3].shape[0] to be equal to 4, but got 6"), ): ep_ns.module()(*bad_runtime_inp2) @@ -431,7 +442,7 @@ def forward(self, x): x = x + x return x - ep1 = export(M1(), (torch.randn(3, 3), )) + ep1 = export(M1(), (torch.randn(3, 3),)) expected_result = [ ("linear_1", "builtin_function_or_method.linear"), ("linear_1", "builtin_function_or_method.linear"), @@ -490,7 +501,6 @@ def bar(self, x): def forward(self, x): return cond(x.shape[0] <= 2, self.subm.forward, self.bar, [x]) - example_inputs = (torch.randn(1, 3, 3, 3),) m = CondBranchClassMethod() m.eval() @@ -503,7 +513,10 @@ def forward(self, x): torch_fn = node.meta.get("torch_fn") print(torch_fn) actual_torch_fns.append(torch_fn) - exp_torch_fns = [("cos_1", "method_descriptor.cos"), ("sin_1", "method_descriptor.sin")] + exp_torch_fns = [ + ("cos_1", "method_descriptor.cos"), + ("sin_1", "method_descriptor.sin"), + ] self.assertEqual(actual_torch_fns, exp_torch_fns) def test_derived_dim_basic(self): @@ -708,7 +721,9 @@ def forward(self, x, y, z): ): ep.module()(torch.randn(6), torch.randn(7), torch.randn(5)) - self.assertEqual(ep.module()(torch.randn(6), torch.randn(7), torch.randn(8)).size()[0], 6) + self.assertEqual( + ep.module()(torch.randn(6), torch.randn(7), torch.randn(8)).size()[0], 6 + ) def test_derived_dim_out_of_order_repeat_derived(self): dimy = torch.export.Dim("dimy", min=5, max=7) @@ -826,7 +841,9 @@ def forward(self, x, y, z): ): ep.module()(torch.randn(6), torch.randn(7), torch.randn(5)) - self.assertEqual(ep.module()(torch.randn(6), torch.randn(7), torch.randn(8)).size()[0], 6) + self.assertEqual( + ep.module()(torch.randn(6), torch.randn(7), torch.randn(8)).size()[0], 6 + ) def test_derived_dim_out_of_order_simplified_repeat_non_derived(self): class Foo(torch.nn.Module): @@ -871,7 +888,8 @@ def test_static_dim_constraints(self): class Foo(torch.nn.Module): def __init__(self): super().__init__() - self.l = torch.nn.Linear(6,4) + self.l = torch.nn.Linear(6, 4) + def forward(self, x, y, z): x0 = self.l(x) + y[1:] return x0, z * 2.0 @@ -887,7 +905,7 @@ def forward(self, x, y, z): ({0: dx, 1: 6}, {0: dy, 1: 4}, {0: dz, 1: 3}), ((dx, None), (dy, 4), (dz, 3)), ((None, 6), (5, None), (None, None)), - ((4, 6), {0: None, 1: 4}, {0: None, 1: 3}) + ((4, 6), {0: None, 1: 4}, {0: None, 1: 3}), ]: ep = export(foo, inputs, dynamic_shapes=dynamic_shapes) self.assertEqual(foo(*inputs), ep.module()(*inputs)) @@ -902,17 +920,17 @@ def forward(self, x, y, z): with self.assertRaisesRegex( ( torch.fx.experimental.symbolic_shapes.ConstraintViolationError, - torch._dynamo.exc.UserError + torch._dynamo.exc.UserError, ), - "Static shape constraint of 5 does not match input size of 4, for .*" + "Static shape constraint of 5 does not match input size of 4, for .*", ): _ = export(foo, inputs, dynamic_shapes=((5, None), None, None)) with self.assertRaisesRegex( ( torch.fx.experimental.symbolic_shapes.ConstraintViolationError, - torch._dynamo.exc.UserError + torch._dynamo.exc.UserError, ), - "Static shape constraint of 9 does not match input size of 6, for .*" + "Static shape constraint of 9 does not match input size of 6, for .*", ): _ = export(foo, inputs, dynamic_shapes=((dx, 9), (dy, 4), (3, 3))) @@ -923,16 +941,11 @@ def forward(self, x): return x * 2 dx = Dim("dx", min=1, max=2) - ep = export( - Foo(), - (torch.randn(2, 2), ), - dynamic_shapes=({0: dx, 1: None}, ) - ) + ep = export(Foo(), (torch.randn(2, 2),), dynamic_shapes=({0: dx, 1: None},)) ep.module()(torch.randn(1, 2)) ep.module()(torch.randn(2, 2)) with self.assertRaisesRegex( - RuntimeError, - "Expected input at .* to be <= 2, but got 3" + RuntimeError, "Expected input at .* to be <= 2, but got 3" ): ep.module()(torch.randn(3, 2)) vr = list(ep.range_constraints.values())[0] @@ -949,7 +962,7 @@ def forward(self, x, y): ep = export( Bar(), (torch.randn(2, 2), torch.randn(3, 2)), - dynamic_shapes=({0: dx, 1: None}, {0: dx+1, 1: None}) + dynamic_shapes=({0: dx, 1: None}, {0: dx + 1, 1: None}), ) ep.module()(torch.randn(1, 2), torch.randn(2, 2)) range_lower_bounds = sorted(vr.lower for vr in ep.range_constraints.values()) @@ -1163,9 +1176,7 @@ def forward(self, arg1, arg2, *args, kw1, kw2, **kwargs): def test_unbacked_slice(self): class M(torch.nn.Module): - def forward( - self, scores, score_thr, topk: torch.Tensor, results=None - ): + def forward(self, scores, score_thr, topk: torch.Tensor, results=None): valid_mask = scores > score_thr scores = scores[valid_mask] valid_idxs = torch.nonzero(valid_mask).to(scores.device) @@ -1433,6 +1444,7 @@ def forward(self, inputs): class Foo(torch.nn.Module): def forward(self, kjt) -> torch.Tensor: return kjt.values() + 0, kjt.offsets() + 0 + foo = Foo() kjt = KeyedJaggedTensor( values=torch.Tensor([1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]), @@ -1450,13 +1462,14 @@ def forward(self, kjt) -> torch.Tensor: ) self.assertEqual( [out.shape for out in efoo.module()(*inputs)], - [out.shape for out in foo(*inputs)] + [out.shape for out in foo(*inputs)], ) # pass dynamic shapes of inputs [distinct, error] class Foo(torch.nn.Module): def forward(self, x, y): return torch.matmul(x, y) + foo = Foo() inputs = (torch.randn(10, 2, 3), torch.randn(10, 3, 4)) batch, M, K1, K2, N = dims("batch", "M", "K1", "K2", "N") @@ -1593,9 +1606,7 @@ class MyDataClass: flat, spec = tree_flatten(dt) self.assertEqual( spec, - TreeSpec( - MyDataClass, [["x", "y"], ["z"]], [LeafSpec(), LeafSpec()] - ), + TreeSpec(MyDataClass, [["x", "y"], ["z"]], [LeafSpec(), LeafSpec()]), ) self.assertEqual(flat, [3, 4]) @@ -1779,8 +1790,16 @@ def forward(self, x): inp_test = ((torch.randn(20, 16, 50, 100), torch.randn(20, 16, 50, 100)),) - self.assertTrue(torch.allclose(ep.module()(*inp_test)[0], ep_rexported.module()(*inp_test)[0])) - self.assertTrue(torch.allclose(ep.module()(*inp_test)[1], ep_rexported.module()(*inp_test)[1])) + self.assertTrue( + torch.allclose( + ep.module()(*inp_test)[0], ep_rexported.module()(*inp_test)[0] + ) + ) + self.assertTrue( + torch.allclose( + ep.module()(*inp_test)[1], ep_rexported.module()(*inp_test)[1] + ) + ) def test_module_with_dict_container_inp_out(self): class MyLinear(torch.nn.Module): @@ -1830,10 +1849,14 @@ def forward(self, x): ) self.assertTrue( - torch.allclose(ep.module()(*inp_test)["a"], ep_rexported.module()(*inp_test)["a"]) + torch.allclose( + ep.module()(*inp_test)["a"], ep_rexported.module()(*inp_test)["a"] + ) ) self.assertTrue( - torch.allclose(ep.module()(*inp_test)["b"], ep_rexported.module()(*inp_test)["b"]) + torch.allclose( + ep.module()(*inp_test)["b"], ep_rexported.module()(*inp_test)["b"] + ) ) def test_args_type_checked(self): @@ -1929,7 +1952,11 @@ def forward(self, x, y): return y + n fn = Module() - error = ValueError if is_non_strict_test(self._testMethodName) else torch._dynamo.exc.TorchRuntimeError + error = ( + ValueError + if is_non_strict_test(self._testMethodName) + else torch._dynamo.exc.TorchRuntimeError + ) with self.assertRaisesRegex( error, "Constraining SymFloat or Symbool is nyi", @@ -2082,7 +2109,7 @@ def forward(self, x, y): @testing.expectedFailureNonStrict # non-strict does not add deferred runtime assertions @testing.expectedFailureSerDerPreDispatch # .item call becomes aten.item in predispatch IR - @testing.expectedFailurePreDispatchRunDecomp # assert name is still referring to item + @testing.expectedFailurePreDispatchRunDecomp # assert name is still referring to item def test_automatic_constrain_size(self): class M(torch.nn.Module): def forward(self, x, y): @@ -2091,7 +2118,10 @@ def forward(self, x, y): ep = export(M(), (torch.tensor(1), torch.ones(4, 5))) - with self.assertRaisesRegex(RuntimeError, r"_local_scalar_dense is outside of inline constraint \[0, 9223372036854775806\]"): + with self.assertRaisesRegex( + RuntimeError, + r"_local_scalar_dense is outside of inline constraint \[0, 9223372036854775806\]", + ): _ = ep.module()(torch.tensor(-1), torch.randn(4, 5)) self.assertTrue( @@ -2139,7 +2169,7 @@ def forward(self, a, b, alpha: int): @testing.expectedFailureNonStrict @testing.expectedFailureSerDerPreDispatch # .item() becomes aten.item in predispatch IR - @testing.expectedFailurePreDispatchRunDecomp # Assert message is still using the old node name, so it shoudl fail + @testing.expectedFailurePreDispatchRunDecomp # Assert message is still using the old node name, so it shoudl fail def test_export_with_inline_constraints(self): class Module(torch.nn.Module): def forward(self, x): @@ -2272,10 +2302,13 @@ def forward(self, x, y): foo, (tensor_inp, 5), dynamic_shapes=dynamic_shapes ) self.assertTrue( - torch.allclose(exported.module()(torch.ones(8, 5), 5), foo(torch.ones(8, 5), 5)) + torch.allclose( + exported.module()(torch.ones(8, 5), 5), foo(torch.ones(8, 5), 5) + ) ) with self.assertRaisesRegex( - RuntimeError, escape("Expected input at *args[1] to be equal to 5, but got 6") + RuntimeError, + escape("Expected input at *args[1] to be equal to 5, but got 6"), ): _ = exported.module()(torch.ones(8, 5), 6) @@ -2283,7 +2316,8 @@ def forward(self, x, y): foo, (tensor_inp, 5.0), dynamic_shapes=dynamic_shapes ) with self.assertRaisesRegex( - RuntimeError, escape("Expected input at *args[1] to be equal to 5.0, but got 6.0") + RuntimeError, + escape("Expected input at *args[1] to be equal to 5.0, but got 6.0"), ): _ = exported.module()(torch.ones(7, 5), 6.0) @@ -2394,7 +2428,9 @@ def forward(self, x): exported.module(), (inp,), dynamic_shapes=({0: dim0_x},) ) self.assertTrue( - torch.allclose(Foo()(torch.ones(7, 5)), reexported.module()(torch.ones(7, 5))) + torch.allclose( + Foo()(torch.ones(7, 5)), reexported.module()(torch.ones(7, 5)) + ) ) # can't retrace with invalid inputs with respect to the original ExportedProgram @@ -2403,7 +2439,8 @@ def forward(self, x): Foo(), (inp,), dynamic_shapes={"x": {0: dim0_x_v2}} ) with self.assertRaisesRegex( - RuntimeError, escape("Expected input at *args[0].shape[0] to be >= 3, but got 2") + RuntimeError, + escape("Expected input at *args[0].shape[0] to be >= 3, but got 2"), ): torch.export.export(exported_v2.module(), (torch.randn(2, 2),)) @@ -2435,13 +2472,17 @@ def false_fn(x): Foo(), (inp,), ) - self.assertTrue(torch.allclose(ep.module()(torch.ones(6, 4)), Foo()(torch.ones(6, 4)))) + self.assertTrue( + torch.allclose(ep.module()(torch.ones(6, 4)), Foo()(torch.ones(6, 4))) + ) def test_cond_buffers(self): class M(torch.nn.Module): def __init__(self): super().__init__() - self.register_parameter("param", torch.nn.Parameter(torch.ones(2, 3), requires_grad=False)) + self.register_parameter( + "param", torch.nn.Parameter(torch.ones(2, 3), requires_grad=False) + ) self.register_buffer("buffer", torch.ones(2, 3) + 1) def true_fn(self, x): @@ -2463,8 +2504,7 @@ def forward(self, x): if not isinstance(gm, torch.fx.GraphModule): continue self.assertEqual( - len([node for node in gm.graph.nodes if node.op == "placeholder"]), - 1 + len([node for node in gm.graph.nodes if node.op == "placeholder"]), 1 ) # map_fn references module outside the module hierarchy @@ -2473,10 +2513,13 @@ def test_map_buffers(self): class M1(torch.nn.Module): def __init__(self): super().__init__() - self.register_parameter("param", torch.nn.Parameter(torch.tensor(5), requires_grad=False)) + self.register_parameter( + "param", torch.nn.Parameter(torch.tensor(5), requires_grad=False) + ) self.register_buffer("buffer", torch.tensor(6) + 1) m1 = M1() + def map_fn(x, y): z = x + y + m1.param + m1.buffer z.add_(4) @@ -2496,8 +2539,7 @@ def forward(self, xs, y): if not isinstance(gm, torch.fx.GraphModule): continue self.assertEqual( - len([node for node in gm.graph.nodes if node.op == "placeholder"]), - 2 + len([node for node in gm.graph.nodes if node.op == "placeholder"]), 2 ) @testing.expectedFailureSerDer # We don't preserve metadata on graph module @@ -2519,14 +2561,15 @@ def forward(self, x): stateful_module = exported.module() self.assertTrue(len(stateful_module.meta["input_shape_constraints"]), 1) - re_exported = export( - stateful_module, (inp,), dynamic_shapes=({0: dim0_x},) - ) + re_exported = export(stateful_module, (inp,), dynamic_shapes=({0: dim0_x},)) self.assertTrue( len(re_exported.graph_module.meta["input_shape_constraints"]) == 1 ) self.assertTrue( - torch.allclose(exported.module()(torch.ones(7, 5)), re_exported.module()(torch.ones(7, 5))) + torch.allclose( + exported.module()(torch.ones(7, 5)), + re_exported.module()(torch.ones(7, 5)), + ) ) re_exported_v2 = export(exported.module(), (inp,)) @@ -2534,7 +2577,10 @@ def forward(self, x): len(re_exported_v2.graph_module.meta["input_shape_constraints"]) == 0 ) self.assertTrue( - torch.allclose(exported.module()(torch.ones(7, 5)), re_exported_v2.module()(torch.ones(7, 5))) + torch.allclose( + exported.module()(torch.ones(7, 5)), + re_exported_v2.module()(torch.ones(7, 5)), + ) ) def test_constrain_as_size_error(self): @@ -2551,7 +2597,9 @@ def forward(self, x): error_msg = r"Could not guard on data-dependent expression" else: error = torch._dynamo.exc.UserError - error_msg = r"Tried to use data-dependent value in the subsequent computation" + error_msg = ( + r"Tried to use data-dependent value in the subsequent computation" + ) with self.assertRaisesRegex(error, error_msg): _ = export(f, (torch.tensor(6),)) @@ -2632,7 +2680,7 @@ class Input: torch._export.utils.register_dataclass_as_pytree_node( Input, - serialized_type_name="test_preserve_shape_dynamism_for_unused_inputs.Input" + serialized_type_name="test_preserve_shape_dynamism_for_unused_inputs.Input", ) class Module(torch.nn.Module): @@ -2684,9 +2732,7 @@ def forward(self, x): return x + x exported_program = export(MyModule(), (torch.rand(2, 3),), {}) - with self.assertRaisesRegex( - ValueError, "Trying to flatten user inputs" - ): + with self.assertRaisesRegex(ValueError, "Trying to flatten user inputs"): exported_program.module()(torch.rand(2, 3), torch.rand(2, 3)) @testing.expectedFailureSerDerPreDispatch # linear shouldn't decompose @@ -2769,9 +2815,9 @@ def forward(self, x): f = Foo() ep = export(f, (torch.tensor([3]),)) - FileCheck().check_count("torch.ops.aten._assert_async.msg", 2, exactly=True).run( - ep.graph_module.code - ) + FileCheck().check_count( + "torch.ops.aten._assert_async.msg", 2, exactly=True + ).run(ep.graph_module.code) def test_non_arg_name_dynamic_shapes_api(self): class Foo(torch.nn.Module): @@ -2845,7 +2891,9 @@ def forward(self, a, b, kw1, kw2): test_inp = (torch.randn(4, 4), torch.randn(7, 4)) test_kwargs = {"kw2": torch.ones(4, 4), "kw1": torch.zeros(9, 4)} # This should work even if the kwarg order are flipped. - self.assertEqual(ep.module()(*test_inp, **test_kwargs), foo(*test_inp, **test_kwargs)) + self.assertEqual( + ep.module()(*test_inp, **test_kwargs), foo(*test_inp, **test_kwargs) + ) def test_non_arg_name_dynamic_shapes_api_with_container_type(self): class Foo(torch.nn.Module): @@ -2907,13 +2955,17 @@ def forward(self, x): Foo(), (inp,), dynamic_shapes=({0: torch.export.Dim("dim", min=3)},), - pre_dispatch=True + pre_dispatch=True, ).module() - with self.assertRaisesRegex(RuntimeError, escape("Expected input at *args[0].shape[0]")): + with self.assertRaisesRegex( + RuntimeError, escape("Expected input at *args[0].shape[0]") + ): gm(torch.randn(2, 2)) - with self.assertRaisesRegex(RuntimeError, escape("Expected input at *args[0].shape[0]")): + with self.assertRaisesRegex( + RuntimeError, escape("Expected input at *args[0].shape[0]") + ): torch.export.export(gm, (torch.randn(2, 2),)) ep = torch.export.export( @@ -3048,7 +3100,7 @@ def forward( training, momentum, eps, - **kwargs + **kwargs, ): return self.op( input, @@ -3059,7 +3111,7 @@ def forward( training, momentum, eps, - **kwargs + **kwargs, ) input = torch.randn(5, 5, 5) @@ -3091,7 +3143,9 @@ def forward( ) ep.run_decompositions(decomp_table=torch._decomp.decomposition_table) self.assertEqual( - ep.module()(input, weight, bias, running_mean, running_var, training, momentum, eps), + ep.module()( + input, weight, bias, running_mean, running_var, training, momentum, eps + ), output, ) @@ -3228,10 +3282,10 @@ def forward(self, x, y): def test_export_input_mutation_bug(self): class M(torch.nn.Module): def forward(self, x): - x[:, :2, :] = x[:,:2,:] + 1 + x[:, :2, :] = x[:, :2, :] + 1 return x - inputs = (torch.ones(4,4,4),) + inputs = (torch.ones(4, 4, 4),) ep = torch.export.export(M(), inputs) m = ep.module() @@ -3244,8 +3298,10 @@ def forward(self, x): ep = torch.export.export(m, inputs) - inputs = (torch.randn(4,4,4),) - self.assertEqual(ep.module()(*copy.deepcopy(inputs)), M()(*copy.deepcopy(inputs))) + inputs = (torch.randn(4, 4, 4),) + self.assertEqual( + ep.module()(*copy.deepcopy(inputs)), M()(*copy.deepcopy(inputs)) + ) def test__scaled_dot_product_flash_attention(self): class Module(torch.nn.Module): @@ -3518,6 +3574,7 @@ class Foo(torch.nn.Module): def __init__(self): super().__init__() self.linear = torch.nn.Linear(4, 4) + def forward(self, x): x = self.linear(x) x *= 2.0 @@ -3525,22 +3582,22 @@ def forward(self, x): ep = export( Foo(), - (torch.randn(4, 4), ), + (torch.randn(4, 4),), ) # check correct lines are in stack trace - trace_mul = [ - node for node in ep.graph.nodes - if node.name == "mul" - ][0].meta.get("stack_trace", "") + trace_mul = [node for node in ep.graph.nodes if node.name == "mul"][0].meta.get( + "stack_trace", "" + ) self.assertTrue( re.search(r"test_export.py.*in forward\n.*x \*= 2.0", trace_mul) ) trace_addmm = [ - node for node in ep.graph.nodes - if node.name in ["addmm", "linear"] + node for node in ep.graph.nodes if node.name in ["addmm", "linear"] ][0].meta.get("stack_trace", "") self.assertTrue( - re.search(r"test_export.py.*in forward\n.*x = self.linear\(x\)", trace_addmm) + re.search( + r"test_export.py.*in forward\n.*x = self.linear\(x\)", trace_addmm + ) ) def test_sym_stack_trace(self): @@ -3554,23 +3611,19 @@ def forward(self, x, y): ep = export( Foo(), (torch.randn(4, 4), torch.tensor(5)), - dynamic_shapes={ - "x": (Dim("dx0"), Dim("dx1")), - "y": None - } + dynamic_shapes={"x": (Dim("dx0"), Dim("dx1")), "y": None}, ) # stack trace for sym call constrain_range trace_constrain_range = [ # different names for serdes/pre-dispatch - node for node in ep.graph.nodes - if node.name in [ - "sym_constrain_range_for_size", - "sym_constrain_range_for_size_default" - ] + node + for node in ep.graph.nodes + if node.name + in ["sym_constrain_range_for_size", "sym_constrain_range_for_size_default"] ][0].meta.get("stack_trace", None) self.assertTrue( re.search( r"torch/__init__.py.*in _constrain_as_size\n.*torch.sym_constrain_range_for_size", - trace_constrain_range + trace_constrain_range, ) ) @@ -3681,7 +3734,7 @@ def forward(self, x, y): {}, dynamic_shapes=None, pre_dispatch=True, - strict=False + strict=False, ) class Model(torch.nn.Module): @@ -3710,30 +3763,38 @@ def true_fn(x, y): {}, dynamic_shapes=None, pre_dispatch=True, - strict=False + strict=False, ) - self.assertExpectedInline(str(exported_program.graph_module.code.strip()), """\ + self.assertExpectedInline( + str(exported_program.graph_module.code.strip()), + """\ def forward(self, b_pred, b_t, x, y): true_graph_0 = self.true_graph_0 false_graph_0 = self.false_graph_0 conditional = torch.ops.higher_order.cond(b_pred, true_graph_0, false_graph_0, [b_t, x, y]); b_pred = true_graph_0 = false_graph_0 = b_t = x = y = None getitem = conditional[0]; conditional = None - return (getitem,)""") # noqa: B950 + return (getitem,)""", + ) # noqa: B950 - self.assertExpectedInline(str(exported_program.graph_module.true_graph_0.code.strip()), """\ + self.assertExpectedInline( + str(exported_program.graph_module.true_graph_0.code.strip()), + """\ def forward(self, arg1_1, arg0_1, arg2_1): submod_3 = self.submod_1 add_1 = torch._higher_order_ops.wrap.wrap_with_set_grad_enabled(True, submod_3, arg1_1, arg0_1, arg2_1); submod_3 = arg1_1 = arg0_1 = arg2_1 = None - return (add_1,)""") + return (add_1,)""", + ) - self.assertExpectedInline(str(exported_program.graph_module.true_graph_0.submod_1.code.strip()), """\ + self.assertExpectedInline( + str(exported_program.graph_module.true_graph_0.submod_1.code.strip()), + """\ def forward(self, arg1_1, arg0_1, arg2_1): sub = torch.ops.aten.sub.Tensor(arg1_1, 1); arg1_1 = None add = torch.ops.aten.add.Tensor(sub, arg0_1); sub = arg0_1 = None add_1 = torch.ops.aten.add.Tensor(add, arg2_1); add = arg2_1 = None - return add_1""") - + return add_1""", + ) def test_non_persistent_buffer(self): class MyModule(torch.nn.Module): @@ -3795,7 +3856,9 @@ def __init__(self): def forward(self, x): return self.foo + x + self.bar + self.baz - fake_mode = torch._subclasses.FakeTensorMode(shape_env=ShapeEnv(tracked_fakes=[])) + fake_mode = torch._subclasses.FakeTensorMode( + shape_env=ShapeEnv(tracked_fakes=[]) + ) with fake_mode: m = MyModule() inp = torch.randn(4, 4) @@ -3811,7 +3874,9 @@ def __init__(self): def forward(self, x): return self.foo + x - fake_mode = torch._subclasses.FakeTensorMode(shape_env=ShapeEnv(tracked_fakes=[])) + fake_mode = torch._subclasses.FakeTensorMode( + shape_env=ShapeEnv(tracked_fakes=[]) + ) m = MyModule() with fake_mode: inp = torch.randn(4, 4) @@ -3828,7 +3893,9 @@ def __init__(self): def forward(self, x): return self.foo + x - fake_mode = torch._subclasses.FakeTensorMode(shape_env=ShapeEnv(tracked_fakes=[])) + fake_mode = torch._subclasses.FakeTensorMode( + shape_env=ShapeEnv(tracked_fakes=[]) + ) with fake_mode: m = MyModule() inp = torch.randn(4, 4) @@ -3912,7 +3979,9 @@ def forward(self, x, z): self.assertTrue(torch.allclose(legit_eager, legit_export)) ep = ep.run_decompositions() - x_new_export, z_new_export, legit_export = ep.module()(*inps_for_export_with_decomp) + x_new_export, z_new_export, legit_export = ep.module()( + *inps_for_export_with_decomp + ) self.assertTrue(torch.allclose(x_new_eager, x_new_export)) self.assertTrue(torch.allclose(z_new_eager, z_new_export)) self.assertTrue(torch.allclose(legit_eager, legit_export)) @@ -3928,23 +3997,28 @@ def forward(self, x): inps = (torch.ones(5),) ep = torch.export.export(M(), inps) - self.assertExpectedInline(str(ep.graph_module.code.strip()), """\ + self.assertExpectedInline( + str(ep.graph_module.code.strip()), + """\ def forward(self, x): cos = torch.ops.aten.cos.default(x) auto_functionalized = torch._higher_order_ops.auto_functionalize.auto_functionalized(torch.ops.testlib.foo.default, x = x, z = cos); x = cos = None getitem_3 = auto_functionalized[3]; auto_functionalized = None cos_1 = torch.ops.aten.cos.default(getitem_3) - return (getitem_3, getitem_3, cos_1)""") + return (getitem_3, getitem_3, cos_1)""", + ) ep = torch.export._trace._export(M(), inps, pre_dispatch=True) - self.assertExpectedInline(str(ep.graph_module.code.strip()), """\ + self.assertExpectedInline( + str(ep.graph_module.code.strip()), + """\ def forward(self, x): cos = torch.ops.aten.cos.default(x) auto_functionalized = torch._higher_order_ops.auto_functionalize.auto_functionalized(torch.ops.testlib.foo.default, x = x, z = cos); x = cos = None getitem_3 = auto_functionalized[3]; auto_functionalized = None cos_1 = torch.ops.aten.cos.default(getitem_3) - return (getitem_3, getitem_3, cos_1)""") - + return (getitem_3, getitem_3, cos_1)""", + ) def test_custom_op_auto_warn_pre_dispatch(self): class M(torch.nn.Module): @@ -3957,20 +4031,26 @@ def forward(self, x): inps = (torch.ones(5),) ep = torch.export.export(M(), inps) - self.assertExpectedInline(str(ep.graph_module.code.strip()), """\ + self.assertExpectedInline( + str(ep.graph_module.code.strip()), + """\ def forward(self, x): cos = torch.ops.aten.cos.default(x) cos_1 = torch.ops.aten.cos.default(x); x = None auto_functionalized = torch._higher_order_ops.auto_functionalize.auto_functionalized(torch.ops.testlib.foo.default, x = cos, z = cos_1); cos = cos_1 = None getitem_3 = auto_functionalized[3]; auto_functionalized = None cos_2 = torch.ops.aten.cos.default(getitem_3); getitem_3 = None - return (cos_2,)""") + return (cos_2,)""", + ) ep = torch.export._trace._export(M(), inps, pre_dispatch=True) - self.assertExpectedInline(str(ep.graph_module.code.strip()), """\ + self.assertExpectedInline( + str(ep.graph_module.code.strip()), + """\ def forward(self, x): foo_functional = torch.ops.testlib.foo_functional.default(x); x = None - return (foo_functional,)""") + return (foo_functional,)""", + ) # original input names aren't retraceable: # compilation will succeed, but names won't match forward() signature. @@ -3979,19 +4059,15 @@ def test_placeholder_naming_collisions(self): # test collisions between nested user inputs class Foo(torch.nn.Module): def forward(self, x, x_foo, x_foo_0): - return x['foo'][0] + x_foo[0] + x_foo_0 + return x["foo"][0] + x_foo[0] + x_foo_0 inputs = ( - {'foo': [torch.randn(4, 4)]}, - (torch.randn(4, 4), ), + {"foo": [torch.randn(4, 4)]}, + (torch.randn(4, 4),), torch.randn(4, 4), ) ep = export(Foo(), inputs) - expected_names = [ - "x_foo_0", - "x_foo_0_1", - "x_foo_0_2" - ] + expected_names = ["x_foo_0", "x_foo_0_1", "x_foo_0_2"] real_names = [spec.arg.name for spec in ep.graph_signature.input_specs] self.assertEqual(expected_names, real_names) @@ -4000,12 +4076,13 @@ class Foo(torch.nn.Module): def __init__(self): super().__init__() self.param = torch.nn.Parameter(torch.randn(4)) - self.register_buffer('alpha', torch.randn(4), persistent=True) - self.register_buffer('beta', torch.randn(4), persistent=False) + self.register_buffer("alpha", torch.randn(4), persistent=True) + self.register_buffer("beta", torch.randn(4), persistent=False) self.gamma = torch.randn(4) + def forward(self, p, b_alpha, b, c_gamma): - p = p['param'] + self.param - b = self.alpha + self.beta + b_alpha + b['beta'] + p = p["param"] + self.param + b = self.alpha + self.beta + b_alpha + b["beta"] c = self.gamma + c_gamma return p, b, c @@ -4017,16 +4094,18 @@ def forward(self, p, b_alpha, b, c_gamma): ) ep = export(Foo(), inputs) expected_names = [ # user inputs should be prioritized, unprefixed - ('p_param_1', InputKind.PARAMETER), - ('b_alpha_1', InputKind.BUFFER), - ('b_beta_1', InputKind.BUFFER), - ('c_gamma_1', InputKind.CONSTANT_TENSOR), - ('p_param', InputKind.USER_INPUT), - ('b_alpha', InputKind.USER_INPUT), - ('b_beta', InputKind.USER_INPUT), - ('c_gamma', InputKind.USER_INPUT) + ("p_param_1", InputKind.PARAMETER), + ("b_alpha_1", InputKind.BUFFER), + ("b_beta_1", InputKind.BUFFER), + ("c_gamma_1", InputKind.CONSTANT_TENSOR), + ("p_param", InputKind.USER_INPUT), + ("b_alpha", InputKind.USER_INPUT), + ("b_beta", InputKind.USER_INPUT), + ("c_gamma", InputKind.USER_INPUT), + ] + real_names = [ + (spec.arg.name, spec.kind) for spec in ep.graph_signature.input_specs ] - real_names = [(spec.arg.name, spec.kind) for spec in ep.graph_signature.input_specs] self.assertEqual(expected_names, real_names) # test collisions between user inputs & call_function nodes @@ -4047,6 +4126,7 @@ def forward(self, mul, add, add_1): real_names_and_ops = [(node.name, node.op) for node in ep.graph.nodes] self.assertEqual(expected_names_and_ops, real_names_and_ops) + @unittest.skipIf(not torchdynamo.is_dynamo_supported(), "dynamo isn't support") class TestOneOffModelExportResult(TestCase): def test_scaled_dot_product_attention_cpu(self): @@ -4063,6 +4143,7 @@ def test_scaled_dot_product_attention_cpu(self): torch/_decomp/decompositions.py along with the kernel changes, so all of the downstream backends are not being affected. """ + class ScaledDotProductAttention(torch.nn.Module): def __init__(self): super().__init__() @@ -4072,23 +4153,24 @@ def forward(self, q, k, v): q, k, v, None, dropout_p=0.0, is_causal=True ) return attn_output + q = torch.randn(1, 1, 8, 8, device="cpu") k = torch.randn(1, 1, 8, 8, device="cpu") v = torch.randn(1, 1, 8, 8, device="cpu") from torch.nn.attention import SDPBackend + with torch.nn.attention.sdpa_kernel([SDPBackend.MATH]): ep = torch.export.export(ScaledDotProductAttention(), (q, k, v)) print(ep.graph) ep.run_decompositions() print(ep.graph) - -# self.assertExpectedInline(ep.graph_module.code.strip(), """\ -# def forward(self, arg0_1, arg1_1, arg2_1): -# _scaled_dot_product_flash_attention_for_cpu = torch.ops.aten._scaled_dot_product_flash_attention_for_cpu.default(arg0_1, arg1_1, arg2_1, 0.0, True); arg0_1 = arg1_1 = arg2_1 = None -# getitem = _scaled_dot_product_flash_attention_for_cpu[0]; _scaled_dot_product_flash_attention_for_cpu = None -# return (getitem,)""") + # self.assertExpectedInline(ep.graph_module.code.strip(), """\ + # def forward(self, arg0_1, arg1_1, arg2_1): + # _scaled_dot_product_flash_attention_for_cpu = torch.ops.aten._scaled_dot_product_flash_attention_for_cpu.default(arg0_1, arg1_1, arg2_1, 0.0, True); arg0_1 = arg1_1 = arg2_1 = None + # getitem = _scaled_dot_product_flash_attention_for_cpu[0]; _scaled_dot_product_flash_attention_for_cpu = None + # return (getitem,)""") @unittest.skipIf( not PLATFORM_SUPPORTS_FLASH_ATTENTION, @@ -4102,6 +4184,7 @@ def test_scaled_dot_product_attention_cuda(self): backend relies on this export result so if this test fails, feel free to change it to the latest export() result. """ + class ScaledDotProductAttention(torch.nn.Module): def __init__(self): super().__init__() @@ -4111,16 +4194,20 @@ def forward(self, q, k, v): q, k, v, None, dropout_p=0.0, is_causal=True ) return attn_output - q = torch.randn(1, 16, 16, 64, dtype = torch.bfloat16, device="cuda") - k = torch.randn(1, 16, 16, 64, dtype = torch.bfloat16, device="cuda") - v = torch.randn(1, 16, 16, 64, dtype = torch.bfloat16, device="cuda") + + q = torch.randn(1, 16, 16, 64, dtype=torch.bfloat16, device="cuda") + k = torch.randn(1, 16, 16, 64, dtype=torch.bfloat16, device="cuda") + v = torch.randn(1, 16, 16, 64, dtype=torch.bfloat16, device="cuda") ep = torch.export.export(ScaledDotProductAttention(), (q, k, v)) - self.assertExpectedInline(ep.graph_module.code.strip(), """\ + self.assertExpectedInline( + ep.graph_module.code.strip(), + """\ def forward(self, q, k, v): _scaled_dot_product_flash_attention = torch.ops.aten._scaled_dot_product_flash_attention.default(q, k, v, 0.0, True, scale = 0.125); q = k = v = None getitem = _scaled_dot_product_flash_attention[0]; _scaled_dot_product_flash_attention = None - return (getitem,)""") + return (getitem,)""", + ) def test_int_list_output(self): class M(torch.nn.Module): @@ -4149,7 +4236,10 @@ def forward(self, x, y): self.assertEqual(res[0], torch.tensor(20)) self.assertEqual(res[1], 5) - with self.assertRaisesRegex(RuntimeError, escape("Expected input at *args[1] to be equal to 5, but got 20")): + with self.assertRaisesRegex( + RuntimeError, + escape("Expected input at *args[1] to be equal to 5, but got 20"), + ): res = ep.module()(torch.tensor(4), 20) class F(torch.nn.Module): @@ -4192,7 +4282,6 @@ def forward( # https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/transformers/attention.cpp#L732 return scaled_dot_product_attention(query, key, value) - cache = torch.randn(1, 128, 16, 128, dtype=torch.float16) query = torch.randn(1, 1, 16, 128, dtype=torch.float16) start_pos = torch.tensor([0]) @@ -4253,6 +4342,7 @@ def forward(self, x): def test_logging_logger(self): logger = logging.getLogger(__name__) + class M(torch.nn.Module): def forward(self, x): logger.log("start") @@ -4367,7 +4457,6 @@ def __init__(self): def forward(self, x): return x + self.param - class Foo(torch.nn.Module): def __init__(self): super().__init__() @@ -4393,7 +4482,9 @@ def forward(self, x): if nn_module_stack_1 is None: self.assertTrue(nn_module_stack_2 is None) else: - for v1, v2 in zip(nn_module_stack_1.values(), nn_module_stack_2.values()): + for v1, v2 in zip( + nn_module_stack_1.values(), nn_module_stack_2.values() + ): self.assertEqual(v1, v2) @@ -4405,9 +4496,9 @@ def setUp(self): elif IS_SANDCASTLE or IS_MACOS: raise unittest.SkipTest("non-portable load_library call used in test") elif IS_WINDOWS: - lib_file_path = find_library_location('torchbind_test.dll') + lib_file_path = find_library_location("torchbind_test.dll") else: - lib_file_path = find_library_location('libtorchbind_test.so') + lib_file_path = find_library_location("libtorchbind_test.so") torch.ops.load_library(str(lib_file_path)) def test_lift_custom_obj(self): @@ -4439,12 +4530,16 @@ def forward(self, x): custom_node.meta["val"] = torch.ones(4, 4) # Copy over an nn_module_stack as they are required. custom_node.meta["nn_module_stack"] = node.meta["nn_module_stack"] - custom_node.meta["torch_fn"] = ("custom_op", "torch.ops._TorchScriptTesting.take_an_instance.default") + custom_node.meta["torch_fn"] = ( + "custom_op", + "torch.ops._TorchScriptTesting.take_an_instance.default", + ) arg0, _ = node.args node.args = (arg0, custom_node) from torch._export.passes.lift_constants_pass import lift_constants_pass - from torch._export.serde.serialize import serialize, deserialize + from torch._export.serde.serialize import deserialize, serialize + constants = lift_constants_pass(ep.graph_module, ep.graph_signature, {}) for k, v in constants.items(): assert k not in ep.constants @@ -4454,8 +4549,9 @@ def forward(self, x): for node in deserialized_ep.graph.nodes: if ( - node.op == "call_function" and - node.target == torch.ops._TorchScriptTesting.take_an_instance.default + node.op == "call_function" + and node.target + == torch.ops._TorchScriptTesting.take_an_instance.default ): arg = node.args[0] self.assertTrue(arg.op == "placeholder") @@ -4467,5 +4563,6 @@ def forward(self, x): ep = torch.export.export(M(), (torch.ones(3),), strict=False) -if __name__ == '__main__': + +if __name__ == "__main__": run_tests() diff --git a/test/export/test_pass_infra.py b/test/export/test_pass_infra.py index 0b1e2c4cc04c8..fecae442256df 100644 --- a/test/export/test_pass_infra.py +++ b/test/export/test_pass_infra.py @@ -8,7 +8,7 @@ from torch._export.pass_base import _ExportPassBaseDeprecatedDoNotUse from torch.export import export from torch.fx.passes.infra.pass_base import PassResult -from torch.testing._internal.common_utils import run_tests, TestCase, IS_WINDOWS +from torch.testing._internal.common_utils import IS_WINDOWS, run_tests, TestCase @unittest.skipIf(not is_dynamo_supported(), "Dynamo not supported") @@ -64,7 +64,9 @@ def false_fn(x, y): x = torch.tensor([2]) y = torch.tensor([5]) mod = M() - _ = export(mod, (torch.tensor(True), x, y))._transform_do_not_use(_ExportPassBaseDeprecatedDoNotUse()) + _ = export(mod, (torch.tensor(True), x, y))._transform_do_not_use( + _ExportPassBaseDeprecatedDoNotUse() + ) def test_node_name_stability(self) -> None: # Tests that graph nodes stay the same for nodes that are not touched @@ -77,12 +79,14 @@ def __init__(self): self.my_parameter = torch.nn.Parameter(torch.tensor(2.0)) # Define two buffers - self.register_buffer('my_buffer1', torch.tensor(3.0)) - self.register_buffer('my_buffer2', torch.tensor(4.0)) + self.register_buffer("my_buffer1", torch.tensor(3.0)) + self.register_buffer("my_buffer2", torch.tensor(4.0)) def forward(self, x1, x2): # Use the parameter, buffers, and both inputs in the forward method - output = (x1 + self.my_parameter) * self.my_buffer1 + x2 * self.my_buffer2 + output = ( + x1 + self.my_parameter + ) * self.my_buffer1 + x2 * self.my_buffer2 # Mutate one of the buffers (e.g., increment it by 1) self.my_buffer2.add_(1.0) @@ -184,5 +188,6 @@ def replace_pass(gm): old_signature = ep_before.graph_signature self.assertNotEqual(sig.user_outputs, old_signature.user_outputs) -if __name__ == '__main__': + +if __name__ == "__main__": run_tests() diff --git a/test/export/test_passes.py b/test/export/test_passes.py index 4b5f909c8a91b..c7240ec0ee2a8 100644 --- a/test/export/test_passes.py +++ b/test/export/test_passes.py @@ -42,20 +42,21 @@ from torch.export._remove_effect_tokens_pass import _remove_effect_tokens from torch.fx.passes.infra.partitioner import Partition from torch.fx.passes.operator_support import OperatorSupport -from torch.library import impl, _scoped_library +from torch.library import _scoped_library, impl from torch.testing import FileCheck from torch.testing._internal.common_utils import ( find_library_location, - run_tests, - TestCase, - skipIfTorchDynamo, IS_FBCODE, IS_MACOS, IS_SANDCASTLE, IS_WINDOWS, + run_tests, + skipIfTorchDynamo, + TestCase, ) from torch.utils import _pytree as pytree + def count_call_function(graph: torch.fx.Graph, target: torch.ops.OpOverload) -> int: count = 0 for node in graph.nodes: @@ -71,9 +72,7 @@ def is_node_supported(self, submodules, node: torch.fx.Node) -> bool: class _AtenAddOperatorSupport(OperatorSupport): def is_node_supported(self, submodules, node: torch.fx.Node) -> bool: - return node.op == "call_function" and node.target in { - torch.ops.aten.add.Tensor - } + return node.op == "call_function" and node.target in {torch.ops.aten.add.Tensor} def _to_partition_names(partitions: List[Partition]) -> List[Set[str]]: @@ -87,6 +86,7 @@ def _get_output_names(gm: torch.fx.GraphModule) -> List[str]: # args = args[0] return [str(arg) for arg in args] + def _set_grad_enabled_tests(): from torch.export._trace import _export @@ -133,12 +133,21 @@ def _get_predispatch_module(mod, args, ambient_grad_enabled=True): return _export(mod, args, pre_dispatch=True).module() return { - "ctx_manager" : (_get_predispatch_module(SetGradCtxManager(), (x,)), (x,)), - "ctx_manager_under_no_grad" : (_get_predispatch_module(SetGradCtxManager(), (x,), False), (x,)), - "ctx_manager_multi_dep" : (_get_predispatch_module(SetGradCtxManagerMultiDep(), (x,)), (x,)), - "ctx_manager_multi_dep_no_grad" : (_get_predispatch_module(SetGradCtxManagerMultiDep(), (x,), False), (x,)), - "op" : (_get_predispatch_module(SetGradOp(), (x,)), (x,)), - "op_under_no_grad" : (_get_predispatch_module(SetGradOp(), (x,), False), (x,)) + "ctx_manager": (_get_predispatch_module(SetGradCtxManager(), (x,)), (x,)), + "ctx_manager_under_no_grad": ( + _get_predispatch_module(SetGradCtxManager(), (x,), False), + (x,), + ), + "ctx_manager_multi_dep": ( + _get_predispatch_module(SetGradCtxManagerMultiDep(), (x,)), + (x,), + ), + "ctx_manager_multi_dep_no_grad": ( + _get_predispatch_module(SetGradCtxManagerMultiDep(), (x,), False), + (x,), + ), + "op": (_get_predispatch_module(SetGradOp(), (x,)), (x,)), + "op_under_no_grad": (_get_predispatch_module(SetGradOp(), (x,), False), (x,)), } @@ -170,13 +179,17 @@ def _get_predispatch_module(mod, args): def _insert_dilimiter_nodes(gm: torch.fx.GraphModule, step: int = 1): insert_locs = [] - for i, node in enumerate(nodes_filter(gm.graph.nodes, lambda n: n.op == "call_function")): + for i, node in enumerate( + nodes_filter(gm.graph.nodes, lambda n: n.op == "call_function") + ): if i % step == 0: insert_locs.append(node) for i, node in enumerate(insert_locs): with gm.graph.inserting_before(node): - gm.graph.call_function(torch._C._set_grad_enabled, (True if i % 2 == 0 else False,), {}) + gm.graph.call_function( + torch._C._set_grad_enabled, (True if i % 2 == 0 else False,), {} + ) return gm x = torch.randn(2, 2) @@ -185,10 +198,10 @@ def _insert_dilimiter_nodes(gm: torch.fx.GraphModule, step: int = 1): multi_dep = _get_predispatch_module(MultiDep(), (x, x.sin())) multi_dep1 = _get_predispatch_module(MultiDep(), (x, x.sin())) return { - 'simple_step1': (_insert_dilimiter_nodes(simple1, 1), (x,)), - 'simple_step2': (_insert_dilimiter_nodes(simple, 2), (x,)), - 'multi_dep_step2': (_insert_dilimiter_nodes(multi_dep, 2), (x, x.sin())), - 'multi_dep_step3': (_insert_dilimiter_nodes(multi_dep1, 3), (x, x.sin())), + "simple_step1": (_insert_dilimiter_nodes(simple1, 1), (x,)), + "simple_step2": (_insert_dilimiter_nodes(simple, 2), (x,)), + "multi_dep_step2": (_insert_dilimiter_nodes(multi_dep, 2), (x, x.sin())), + "multi_dep_step3": (_insert_dilimiter_nodes(multi_dep1, 3), (x, x.sin())), } @@ -207,9 +220,9 @@ def setUp(self): elif IS_MACOS: raise unittest.SkipTest("non-portable load_library call used in test") else: - lib_file_path = find_library_location('libtorchbind_test.so') + lib_file_path = find_library_location("libtorchbind_test.so") if IS_WINDOWS: - lib_file_path = find_library_location('torchbind_test.dll') + lib_file_path = find_library_location("torchbind_test.dll") torch.ops.load_library(str(lib_file_path)) def tearDown(self): @@ -230,10 +243,15 @@ def forward(self, x): dim1_x = torch.export.Dim("dim1_x", min=2, max=6) ep = torch.export.export(M(), (x,), dynamic_shapes={"x": {1: dim1_x}}) - with self.assertRaisesRegex(RuntimeError, escape("Expected input at *args[0].shape[1] to be <= 6, but got 7")): + with self.assertRaisesRegex( + RuntimeError, + escape("Expected input at *args[0].shape[1] to be <= 6, but got 7"), + ): ep.module()(torch.zeros(2, 7, 3)) - self.assertEqual(ep.module()(torch.ones(2, 4, 3)), M().forward(torch.ones(2, 4, 3))) + self.assertEqual( + ep.module()(torch.ones(2, 4, 3)), M().forward(torch.ones(2, 4, 3)) + ) def test_runtime_assert_multiple_dims(self) -> None: class M(torch.nn.Module): @@ -253,10 +271,16 @@ def forward(self, x, y): M(), (x, y), dynamic_shapes={"x": {0: dim0_x, 1: dim1_x}, "y": {0: dim0_y}} ) - with self.assertRaisesRegex(RuntimeError, escape("Expected input at *args[0].shape[1] to be <= 6, but got 7")): + with self.assertRaisesRegex( + RuntimeError, + escape("Expected input at *args[0].shape[1] to be <= 6, but got 7"), + ): ep.module()(torch.zeros(4, 7, 3), torch.ones(5, 5, 5)) - with self.assertRaisesRegex(RuntimeError, escape("Expected input at *args[1].shape[0] to be >= 3, but got 2")): + with self.assertRaisesRegex( + RuntimeError, + escape("Expected input at *args[1].shape[0] to be >= 3, but got 2"), + ): ep.module()(torch.zeros(4, 2, 3), torch.ones(2, 5, 5)) def test_runtime_assert_some_dims_not_specified(self) -> None: @@ -277,12 +301,16 @@ def forward(self, x, y): M(), (x, y), dynamic_shapes={"x": {0: dim0_x, 1: dim1_x}, "y": None} ) - with self.assertRaisesRegex(RuntimeError, escape("Expected input at *args[0].shape[1] to be <= 6, but got 7")): + with self.assertRaisesRegex( + RuntimeError, + escape("Expected input at *args[0].shape[1] to be <= 6, but got 7"), + ): ep.module()(torch.zeros(4, 7, 3), torch.ones(5, 5, 5)) # y is specialized to 5 with self.assertRaisesRegex( - RuntimeError, escape("Expected input at *args[1].shape[0] to be equal to 5, but got 2") + RuntimeError, + escape("Expected input at *args[1].shape[0] to be equal to 5, but got 2"), ): ep.module()(torch.zeros(4, 2, 3), torch.ones(2, 5, 5)) @@ -304,14 +332,17 @@ def forward(self, x, y): y = torch.zeros(5, 5, 5) dim1_y = torch.export.Dim("dim1_y", min=3, max=6) - ep = torch.export.export(M(), (x, y), dynamic_shapes={"x": None, "y": {1: dim1_y}}) + ep = torch.export.export( + M(), (x, y), dynamic_shapes={"x": None, "y": {1: dim1_y}} + ) with self.assertRaisesRegex(RuntimeError, escape("shape[1] to be equal to 2")): ep.module()(torch.zeros(4, 7, 3), torch.ones(5, 5, 5)) # y is specialized to 5 with self.assertRaisesRegex( - RuntimeError, escape("Expected input at *args[1].shape[0] to be equal to 5, but got 2") + RuntimeError, + escape("Expected input at *args[1].shape[0] to be equal to 5, but got 2"), ): ep.module()(torch.zeros(4, 2, 3), torch.ones(2, 5, 5)) @@ -348,10 +379,14 @@ def forward(self, x): x = torch.zeros(4, 2, 3) foo = Module() - ep = export(foo, (x,))._transform_do_not_use(ReplaceViewOpsWithViewCopyOpsPass()) + ep = export(foo, (x,))._transform_do_not_use( + ReplaceViewOpsWithViewCopyOpsPass() + ) # After this pass, there shouldn't be any view nodes in the graph self.assertTrue(count_call_function(ep.graph, torch.ops.aten.view.default) == 0) - self.assertTrue(count_call_function(ep.graph, torch.ops.aten.view_copy.default) > 0) + self.assertTrue( + count_call_function(ep.graph, torch.ops.aten.view_copy.default) > 0 + ) def test_views_op_having_view_copy(self) -> None: schemas = torch._C._dispatch_get_registrations_for_dispatch_key("") @@ -414,7 +449,10 @@ def forward(self, x): mod = M() ep = export(mod, (x,)) - with self.assertRaisesRegex(RuntimeError, r"_local_scalar_dense is outside of inline constraint \[2, 5\]."): + with self.assertRaisesRegex( + RuntimeError, + r"_local_scalar_dense is outside of inline constraint \[2, 5\].", + ): ep.module()(torch.tensor([6])) new_inp = torch.tensor([5]) @@ -445,12 +483,14 @@ def forward(self, x): self.assertEqual(num_scalar_tensor, 2) with self.assertRaisesRegex( - RuntimeError, r"nonzero.shape\[0\] is outside of inline constraint \[3, 5\]." + RuntimeError, + r"nonzero.shape\[0\] is outside of inline constraint \[3, 5\].", ): ep.module()(torch.tensor([1, 1, 0, 0, 0])) with self.assertRaisesRegex( - RuntimeError, r"nonzero.shape\[0\] is outside of inline constraint \[3, 5\]." + RuntimeError, + r"nonzero.shape\[0\] is outside of inline constraint \[3, 5\].", ): ep.module()(torch.ones(6)) @@ -482,8 +522,9 @@ def false_fn(x, y): mod = M() ep = export(mod, (torch.tensor(True), x, y)) - - with self.assertRaisesRegex(RuntimeError, "is outside of inline constraint \\[2, 5\\]."): + with self.assertRaisesRegex( + RuntimeError, "is outside of inline constraint \\[2, 5\\]." + ): ep.module()(torch.tensor(False), torch.tensor([6]), torch.tensor([6])) def test_functionalize_inline_constraints(self) -> None: @@ -538,7 +579,9 @@ def forward(self, x): def test_predispatceh_set_grad(self): mod, args = self.SET_GRAD_ENABLED_TESTS["op"] - self.assertExpectedInline(mod.code.strip("\n"), """\ + self.assertExpectedInline( + mod.code.strip("\n"), + """\ def forward(self, arg_0): x, = fx_pytree.tree_flatten_spec(([arg_0], {}), self._in_spec) add = torch.ops.aten.add.Tensor(x, 1); x = None @@ -548,9 +591,12 @@ def forward(self, arg_0): add_1 = torch._higher_order_ops.wrap.wrap_with_set_grad_enabled(False, submod_4, sum_1); submod_4 = sum_1 = None sub = torch.ops.aten.sub.Tensor(add_1, 1) return pytree.tree_unflatten((add_1, sub), self._out_spec) - """) + """, + ) mod, args = self.SET_GRAD_ENABLED_TESTS["op_under_no_grad"] - self.assertExpectedInline(mod.code.strip("\n"), """\ + self.assertExpectedInline( + mod.code.strip("\n"), + """\ def forward(self, arg_0): x, = fx_pytree.tree_flatten_spec(([arg_0], {}), self._in_spec) add = torch.ops.aten.add.Tensor(x, 1); x = None @@ -560,10 +606,13 @@ def forward(self, arg_0): add_1 = torch._higher_order_ops.wrap.wrap_with_set_grad_enabled(False, submod_4, sum_1); submod_4 = sum_1 = None sub = torch.ops.aten.sub.Tensor(add_1, 1) return pytree.tree_unflatten((add_1, sub), self._out_spec) - """) + """, + ) mod, args = self.SET_GRAD_ENABLED_TESTS["ctx_manager"] - self.assertExpectedInline(mod.code.strip("\n"), """\ + self.assertExpectedInline( + mod.code.strip("\n"), + """\ def forward(self, arg_0): x, = fx_pytree.tree_flatten_spec(([arg_0], {}), self._in_spec) add = torch.ops.aten.add.Tensor(x, 1); x = None @@ -573,9 +622,12 @@ def forward(self, arg_0): add_1 = torch._higher_order_ops.wrap.wrap_with_set_grad_enabled(False, submod_3, sum_1); submod_3 = sum_1 = None sub = torch.ops.aten.sub.Tensor(add_1, 1) return pytree.tree_unflatten((add_1, sub), self._out_spec) - """) + """, + ) mod, args = self.SET_GRAD_ENABLED_TESTS["ctx_manager_under_no_grad"] - self.assertExpectedInline(mod.code.strip("\n"), """\ + self.assertExpectedInline( + mod.code.strip("\n"), + """\ def forward(self, arg_0): x, = fx_pytree.tree_flatten_spec(([arg_0], {}), self._in_spec) add = torch.ops.aten.add.Tensor(x, 1); x = None @@ -585,9 +637,12 @@ def forward(self, arg_0): submod_6 = self.submod_3 sub = torch._higher_order_ops.wrap.wrap_with_set_grad_enabled(True, submod_6, add_1); submod_6 = None return pytree.tree_unflatten((add_1, sub), self._out_spec) - """) + """, + ) mod, args = self.SET_GRAD_ENABLED_TESTS["ctx_manager_multi_dep"] - self.assertExpectedInline(mod.code.strip("\n"), """\ + self.assertExpectedInline( + mod.code.strip("\n"), + """\ def forward(self, arg_0): x, = fx_pytree.tree_flatten_spec(([arg_0], {}), self._in_spec) add = torch.ops.aten.add.Tensor(x, 1); x = None @@ -602,9 +657,12 @@ def forward(self, arg_0): sub = torch.ops.aten.sub.Tensor(add_1, 1) sub_1 = torch.ops.aten.sub.Tensor(add_2, 1) return pytree.tree_unflatten((add_1, add_2, sub, sub_1), self._out_spec) - """) # noqa: B950 + """, # noqa: B950 + ) mod, args = self.SET_GRAD_ENABLED_TESTS["ctx_manager_multi_dep_no_grad"] - self.assertExpectedInline(mod.code.strip("\n"), """\ + self.assertExpectedInline( + mod.code.strip("\n"), + """\ def forward(self, arg_0): x, = fx_pytree.tree_flatten_spec(([arg_0], {}), self._in_spec) add = torch.ops.aten.add.Tensor(x, 1); x = None @@ -619,13 +677,16 @@ def forward(self, arg_0): sub = wrap_with_set_grad_enabled_1[0] sub_1 = wrap_with_set_grad_enabled_1[1]; wrap_with_set_grad_enabled_1 = None return pytree.tree_unflatten((add_1, add_2, sub, sub_1), self._out_spec) - """) # noqa: B950 + """, # noqa: B950 + ) def test_sequential_split(self): for gm, args in self.SEQUENTIAL_SPLIT_INLINE_TESTS.values(): set_grad_counts = nodes_count(gm.graph.nodes, _is_set_grad_enabled_node) new_gm = sequential_split(gm, _is_set_grad_enabled_node) - new_set_grad_counts = nodes_count(new_gm.graph.nodes, _is_set_grad_enabled_sub_mod) + new_set_grad_counts = nodes_count( + new_gm.graph.nodes, _is_set_grad_enabled_sub_mod + ) self.assertEqual(set_grad_counts, new_set_grad_counts) self.assertEqual(gm(*args), new_gm(*args)) @@ -634,7 +695,9 @@ def test_sequential_split_graph(self): new_gm = sequential_split(gm, _is_set_grad_enabled_node) self.assertEqual(gm(*args), new_gm(*args)) - self.assertExpectedInline(new_gm.code.strip("\n"), """\ + self.assertExpectedInline( + new_gm.code.strip("\n"), + """\ def forward(self, arg_0, arg_1): x1, x2, = fx_pytree.tree_flatten_spec(([arg_0, arg_1], {}), self._in_spec) submod_1 = self.submod_1(x1, x2); x1 = x2 = None @@ -650,41 +713,53 @@ def forward(self, arg_0, arg_1): getitem_6 = submod_4[0] getitem_7 = submod_4[1]; submod_4 = None return pytree.tree_unflatten((getitem_4, getitem_5, getitem_6, getitem_7), self._out_spec) - """) - self.assertExpectedInline(new_gm.submod_1.code.strip("\n"), """\ + """, + ) + self.assertExpectedInline( + new_gm.submod_1.code.strip("\n"), + """\ def forward(self, x1, x2): _set_grad_enabled = torch._C._set_grad_enabled(True) add = torch.ops.aten.add.Tensor(x1, 1); x1 = None add_1 = torch.ops.aten.add.Tensor(x2, 1); x2 = None return (add, add_1) - """) - self.assertExpectedInline(new_gm.submod_2.code.strip("\n"), """\ + """, + ) + self.assertExpectedInline( + new_gm.submod_2.code.strip("\n"), + """\ def forward(self, add, add_1): _set_grad_enabled_1 = torch._C._set_grad_enabled(False) sin = torch.ops.aten.sin.default(add); add = None cos = torch.ops.aten.cos.default(add_1); add_1 = None return (sin, cos) - """) - self.assertExpectedInline(new_gm.submod_3.code.strip("\n"), """\ + """, + ) + self.assertExpectedInline( + new_gm.submod_3.code.strip("\n"), + """\ def forward(self, sin, cos): _set_grad_enabled_2 = torch._C._set_grad_enabled(True) add_2 = torch.ops.aten.add.Tensor(sin, 1); sin = None add_3 = torch.ops.aten.add.Tensor(cos, 1); cos = None return (add_2, add_3) - """) + """, + ) def test_inline_(self): for gm, args in self.SEQUENTIAL_SPLIT_INLINE_TESTS.values(): before_str = gm.print_readable(print_output=False) new_gm = sequential_split(gm, _is_set_grad_enabled_node) - nodes_map(new_gm.graph.nodes, lambda node: node_inline_(node) if node.op == "call_module" else node) + nodes_map( + new_gm.graph.nodes, + lambda node: node_inline_(node) if node.op == "call_module" else node, + ) after_inline_str = new_gm.print_readable(print_output=False) self.assertEqual(before_str, after_inline_str) self.assertEqual(gm(*args), new_gm(*args)) def test_remove_auto_functionalized_pass(self) -> None: with _scoped_library("DO_NOT_USE_TEST_ONLY", "DEF") as lib: - lib.define("custom_mutator(Tensor x, Tensor(a!) y) -> Tensor") @impl(lib, "custom_mutator", "Meta") @@ -694,7 +769,6 @@ def custom_mutator_meta( ) -> torch.Tensor: return torch.empty_like(x) - @impl(lib, "custom_mutator", "CompositeExplicitAutograd") def custom_mutator( x: torch.Tensor, @@ -725,8 +799,9 @@ def forward(self, x): def test_remove_auto_functionalized_pass_tuple(self) -> None: with _scoped_library("DO_NOT_USE_TEST_ONLY", "DEF") as lib: - - lib.define("custom_mutator_tuple(Tensor x, Tensor(a!) y) -> (Tensor, Tensor)") + lib.define( + "custom_mutator_tuple(Tensor x, Tensor(a!) y) -> (Tensor, Tensor)" + ) @impl(lib, "custom_mutator_tuple", "Meta") def custom_mutator_tuple_meta( @@ -735,7 +810,6 @@ def custom_mutator_tuple_meta( ): return (torch.empty_like(x), torch.empty_like(x)) - @impl(lib, "custom_mutator_tuple", "CompositeExplicitAutograd") def custom_mutator_tuple( x: torch.Tensor, @@ -773,5 +847,5 @@ def forward(self, x): self.assertEqual(out_specs[2].arg.name, "getitem_1") # tuple return 2 -if __name__ == '__main__': +if __name__ == "__main__": run_tests() diff --git a/test/export/test_serialize.py b/test/export/test_serialize.py index df5d227bec38d..8645eff8f6598 100644 --- a/test/export/test_serialize.py +++ b/test/export/test_serialize.py @@ -13,6 +13,7 @@ import torch import torch._dynamo as torchdynamo +import torch.export._trace import torch.utils._pytree as pytree from torch._export.db.case import ExportCase, normalize_inputs, SupportLevel from torch._export.db.examples import all_examples @@ -27,7 +28,6 @@ from torch._higher_order_ops.torchbind import enable_torchbind_tracing from torch._subclasses.fake_tensor import FakeTensor, FakeTensorMode from torch.export import Dim, export, load, save -import torch.export._trace from torch.fx.experimental.symbolic_shapes import is_concrete_int from torch.testing._internal.common_utils import ( find_library_location, @@ -75,6 +75,7 @@ def forward(self, x): inp = (torch.ones(10),) with torch.no_grad(): from torch.export._trace import _export + ep = _export(Foo(), inp, pre_dispatch=True) buffer = io.BytesIO() @@ -90,6 +91,7 @@ def forward(self, x): def test_export_example_inputs_preserved(self): class MyModule(torch.nn.Module): """A test module with that has multiple args and uses kwargs""" + def __init__(self): super().__init__() self.p = torch.nn.Parameter(torch.ones(2, 3)) @@ -280,6 +282,7 @@ def test_kwargs_default(self) -> None: Tests that the kwargs default values are serialized even if they are not specified """ + class Foo(torch.nn.Module): def forward(self, x: torch.Tensor) -> torch.Tensor: values = torch.randn(3, 2) @@ -312,9 +315,10 @@ def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: g = c.graph_module.graph self.assertLess( g.nodes[0].inputs[0].arg.as_tensor.name, - g.nodes[1].inputs[0].arg.as_tensor.name + g.nodes[1].inputs[0].arg.as_tensor.name, ) + @unittest.skipIf(IS_WINDOWS, "Windows not supported for this test") @unittest.skipIf(not torchdynamo.is_dynamo_supported(), "dynamo doesn't support") class TestDeserialize(TestCase): @@ -326,9 +330,9 @@ def setUp(self): elif IS_MACOS: raise unittest.SkipTest("non-portable load_library call used in test") else: - lib_file_path = find_library_location('libtorchbind_test.so') + lib_file_path = find_library_location("libtorchbind_test.so") if IS_WINDOWS: - lib_file_path = find_library_location('torchbind_test.dll') + lib_file_path = find_library_location("torchbind_test.dll") torch.ops.load_library(str(lib_file_path)) def _check_graph_nodes(self, gm1, gm2, _check_meta=True): @@ -358,10 +362,14 @@ def _check_graph_nodes(self, gm1, gm2, _check_meta=True): else: self.assertEqual(str(s1), str(s2)) self.assertEqual(val1.dtype, val2.dtype) - elif isinstance(val1, (list, tuple)) and isinstance(val2, (list, tuple)): + elif isinstance(val1, (list, tuple)) and isinstance( + val2, (list, tuple) + ): # Or both are fake tensors lists with one element and with the # same shape/dtype - for v1, v2 in zip(pytree.tree_leaves(val1), pytree.tree_leaves(val2)): + for v1, v2 in zip( + pytree.tree_leaves(val1), pytree.tree_leaves(val2) + ): if isinstance(v1, FakeTensor): self.assertEqual(v1.shape, v2.shape) self.assertEqual(v1.dtype, v2.dtype) @@ -388,10 +396,7 @@ def _check_graph_nodes(self, gm1, gm2, _check_meta=True): map_graph2 = getattr(gm2, node2.args[0].target) self._check_graph_nodes(map_graph1, map_graph2, False) - if ( - _check_meta and - node1.op not in ("get_attr", "placeholder", "output") - ): + if _check_meta and node1.op not in ("get_attr", "placeholder", "output"): # Check "nn_module_stack" metadata self.assertEqual( node1.meta.get("nn_module_stack", None), @@ -403,8 +408,17 @@ def _check_graph_nodes(self, gm1, gm2, _check_meta=True): node2.meta.get("source_fn_stack", None), ) - def check_graph(self, fn, inputs, dynamic_shapes=None, _check_meta=True, use_pre_dispatch=True, strict=True) -> None: + def check_graph( + self, + fn, + inputs, + dynamic_shapes=None, + _check_meta=True, + use_pre_dispatch=True, + strict=True, + ) -> None: """Export a graph, serialize it, deserialize it, and compare the results.""" + def _check_graph(pre_dispatch): if pre_dispatch: ep = torch.export._trace._export( @@ -413,7 +427,7 @@ def _check_graph(pre_dispatch): {}, dynamic_shapes=dynamic_shapes, pre_dispatch=True, - strict=strict + strict=strict, ) else: ep = torch.export.export( @@ -421,12 +435,14 @@ def _check_graph(pre_dispatch): copy.deepcopy(inputs), {}, dynamic_shapes=dynamic_shapes, - strict=strict + strict=strict, ) ep.graph.eliminate_dead_code() serialized_artifact = serialize(ep, opset_version={"aten": 0}) - deserialized_ep = deserialize(serialized_artifact, expected_opset_version={"aten": 0}) + deserialized_ep = deserialize( + serialized_artifact, expected_opset_version={"aten": 0} + ) deserialized_ep.graph.eliminate_dead_code() orig_outputs = ep.module()(*copy.deepcopy(inputs)) @@ -444,7 +460,9 @@ def _check_graph(pre_dispatch): self.assertTrue(torch.allclose(orig, loaded)) else: self.assertEqual(orig, loaded) - self._check_graph_nodes(ep.graph_module, deserialized_ep.graph_module, _check_meta) + self._check_graph_nodes( + ep.graph_module, deserialized_ep.graph_module, _check_meta + ) if use_pre_dispatch: _check_graph(pre_dispatch=True) @@ -518,6 +536,7 @@ def test_multi_return(self) -> None: """ Test multiple return from a single node (ex. layer_norm has 2 outputs) """ + class MyModule(torch.nn.Module): def __init__(self): super().__init__() @@ -625,6 +644,7 @@ def forward(self, x): def test_cond(self): from functorch.experimental.control_flow import cond + inputs = torch.ones(4, 3), torch.zeros(4, 3) class M(torch.nn.Module): @@ -634,6 +654,7 @@ def t(x, y): def f(x, y): return x - y + return cond(x[0][0] > 4, t, f, [x, y]) self.check_graph(M(), inputs) @@ -655,20 +676,32 @@ def forward(self, xs, y): def test_tensor_tensor_list(self): try: from torch.library import Library + lib = Library("_export", "FRAGMENT") # noqa: TOR901 lib.define( "_test_tensor_tensor_list_output(Tensor x, Tensor y) -> (Tensor, Tensor[])", - tags=torch.Tag.pt2_compliant_tag) + tags=torch.Tag.pt2_compliant_tag, + ) def _test_tensor_tensor_list_output(x, y): return y, [x] - lib.impl("_test_tensor_tensor_list_output", _test_tensor_tensor_list_output, "CPU") - lib.impl("_test_tensor_tensor_list_output", _test_tensor_tensor_list_output, "Meta") + lib.impl( + "_test_tensor_tensor_list_output", + _test_tensor_tensor_list_output, + "CPU", + ) + lib.impl( + "_test_tensor_tensor_list_output", + _test_tensor_tensor_list_output, + "Meta", + ) class M(torch.nn.Module): def forward(self, x, y): - a, b = torch.ops._export._test_tensor_tensor_list_output.default(x, y) + a, b = torch.ops._export._test_tensor_tensor_list_output.default( + x, y + ) return a + b[0] self.check_graph(M(), (torch.rand(3, 2), torch.rand(3, 2))) @@ -697,7 +730,7 @@ def forward(self, x): ret = torch.sym_ite(b, x.shape[0], x.shape[1]) return ret - dynamic_shapes = {'x': {0: Dim("dim0"), 1: Dim("dim1")}} + dynamic_shapes = {"x": {0: Dim("dim0"), 1: Dim("dim1")}} self.check_graph(Foo(), (torch.ones(4, 5),), dynamic_shapes=dynamic_shapes) @parametrize( @@ -808,6 +841,7 @@ def forward(self, x): instantiate_parametrized_tests(TestDeserialize) + @unittest.skipIf(not torchdynamo.is_dynamo_supported(), "dynamo doesn't support") class TestSchemaVersioning(TestCase): def test_error(self): @@ -820,12 +854,14 @@ def forward(self, x): serialized_program = ExportedProgramSerializer().serialize(ep) serialized_program.exported_program.schema_version.major = -1 - with self.assertRaisesRegex(SerializeError, r"Serialized schema version .* does not match our current"): + with self.assertRaisesRegex( + SerializeError, r"Serialized schema version .* does not match our current" + ): ExportedProgramDeserializer().deserialize( serialized_program.exported_program, serialized_program.state_dict, serialized_program.constants, - serialized_program.example_inputs + serialized_program.example_inputs, ) @@ -850,19 +886,18 @@ def test_model_op_namespace_version_missing_from_deserializer_do_not_raises(self compiler_opset_version = {"aten": 3} model_opset_version = {"aten": 3, "custom": 4} deserializer = ExportedProgramDeserializer(compiler_opset_version) - with self.assertLogs(level='WARN') as log: + with self.assertLogs(level="WARN") as log: deserializer._validate_model_opset_version(model_opset_version) - self.assertIn("Compiler doesn't have a version table for op namespace", log.output[0]) + self.assertIn( + "Compiler doesn't have a version table for op namespace", log.output[0] + ) + # We didn't set up kwargs input yet -unittest.expectedFailure( - TestDeserialize.test_exportdb_supported_case_fn_with_kwargs -) +unittest.expectedFailure(TestDeserialize.test_exportdb_supported_case_fn_with_kwargs) # Failed to produce a graph during tracing. Tracing through 'f' must produce a single graph. -unittest.expectedFailure( - TestDeserialize.test_exportdb_supported_case_scalar_output -) +unittest.expectedFailure(TestDeserialize.test_exportdb_supported_case_scalar_output) @unittest.skipIf(not torchdynamo.is_dynamo_supported(), "dynamo doesn't support") @@ -959,10 +994,12 @@ def forward(self, x): f.seek(0) # Modify the version - with zipfile.ZipFile(f, 'a') as zipf: - zipf.writestr('version', "-1.1") + with zipfile.ZipFile(f, "a") as zipf: + zipf.writestr("version", "-1.1") - with self.assertRaisesRegex(RuntimeError, r"Serialized version .* does not match our current"): + with self.assertRaisesRegex( + RuntimeError, r"Serialized version .* does not match our current" + ): f.seek(0) load(f) @@ -985,6 +1022,7 @@ def forward(self, x): inp = (torch.tensor(1),) self.assertTrue(torch.allclose(ep.module()(*inp), loaded_ep.module()(*inp))) + @unittest.skipIf(not torchdynamo.is_dynamo_supported(), "dynamo doesn't support") class TestSerializeCustomClass(TestCase): def setUp(self): @@ -995,9 +1033,9 @@ def setUp(self): elif IS_MACOS: raise unittest.SkipTest("non-portable load_library call used in test") else: - lib_file_path = find_library_location('libtorchbind_test.so') + lib_file_path = find_library_location("libtorchbind_test.so") if IS_WINDOWS: - lib_file_path = find_library_location('torchbind_test.dll') + lib_file_path = find_library_location("torchbind_test.dll") torch.ops.load_library(str(lib_file_path)) def test_custom_class(self): @@ -1021,7 +1059,10 @@ def forward(self, x): (custom_obj,), ) custom_node.meta["val"] = torch.ones(4, 4) - custom_node.meta["torch_fn"] = ("take_an_instance", "take_an_instance") + custom_node.meta["torch_fn"] = ( + "take_an_instance", + "take_an_instance", + ) arg0, _ = node.args node.args = (arg0, custom_node) @@ -1035,8 +1076,9 @@ def forward(self, x): for node in deserialized_ep.graph.nodes: if ( - node.op == "call_function" and - node.target == torch.ops._TorchScriptTesting.take_an_instance.default + node.op == "call_function" + and node.target + == torch.ops._TorchScriptTesting.take_an_instance.default ): arg = node.args[0] self.assertTrue(isinstance(arg, torch._C.ScriptObject)) @@ -1048,7 +1090,9 @@ def test_custom_class_containing_fake_tensor(self): class Foo(torch.nn.Module): def __init__(self): super().__init__() - self.custom_obj = torch.classes._TorchScriptTesting._ContainsTensor(torch.rand(2, 3)) + self.custom_obj = torch.classes._TorchScriptTesting._ContainsTensor( + torch.rand(2, 3) + ) def forward(self, x): return x + self.custom_obj.get() @@ -1065,5 +1109,5 @@ def forward(self, x): self.assertTrue(isinstance(ep.constants["custom_obj"].get(), FakeTensor)) -if __name__ == '__main__': +if __name__ == "__main__": run_tests() diff --git a/test/export/test_unflatten.py b/test/export/test_unflatten.py index 30d0a79d16bd5..c9c60d8cc89dd 100644 --- a/test/export/test_unflatten.py +++ b/test/export/test_unflatten.py @@ -4,42 +4,41 @@ import unittest from contextlib import contextmanager from dataclasses import dataclass -from typing import List, Any from re import escape +from typing import Any, List import torch import torch._dynamo as torchdynamo from functorch.experimental.control_flow import cond, map from torch import Tensor +from torch._export.utils import ( + get_buffer, + get_param, + is_buffer, + is_param, + register_dataclass_as_pytree_node, +) +from torch._higher_order_ops.torchbind import enable_torchbind_tracing from torch.export import ( Constraint, Dim, dynamic_dim, export, - unflatten, FlatArgsAdapter, + unflatten, ) -from torch._higher_order_ops.torchbind import enable_torchbind_tracing from torch.export._trace import DEFAULT_EXPORT_DYNAMO_CONFIG -from torch._export.utils import ( - get_buffer, - get_param, - is_buffer, - is_param, - register_dataclass_as_pytree_node, -) -from torch.export import Constraint, Dim, export from torch.fx.experimental.proxy_tensor import make_fx from torch.testing import FileCheck from torch.testing._internal.common_utils import ( - run_tests, - TestCase, + find_library_location, IS_FBCODE, IS_MACOS, IS_SANDCASTLE, IS_WINDOWS, - find_library_location, + run_tests, skipIfTorchDynamo, + TestCase, ) from torch.utils._pytree import ( LeafSpec, @@ -252,7 +251,7 @@ def forward(self, x, y): inps, {}, preserve_module_call_signature=("foo.nested",), - strict=strict + strict=strict, ) unflattened = unflatten(export_module) self.compare_outputs(export_module.module(), unflattened, inps) @@ -305,7 +304,9 @@ def forward(self, x): export_module = torch.export.export(Mod(), (torch.randn((2, 3)),)) unflattened = unflatten(export_module) - self.compare_outputs(export_module.module(), unflattened, (torch.randn((2, 3)),)) + self.compare_outputs( + export_module.module(), unflattened, (torch.randn((2, 3)),) + ) def test_unflatten_wrong_input(self): class Mod(torch.nn.Module): @@ -327,11 +328,17 @@ def forward(self, x): return a export_module = torch.export.export(Mod(), (torch.randn((2, 3)),)) - with self.assertRaisesRegex(RuntimeError, escape("Expected input at *args[0].shape[0] to be equal to 2, but got 6")): + with self.assertRaisesRegex( + RuntimeError, + escape("Expected input at *args[0].shape[0] to be equal to 2, but got 6"), + ): export_module.module()(torch.randn(6, 6)) unflattened = unflatten(export_module) - with self.assertRaisesRegex(RuntimeError, escape("Expected input at *args[0].shape[0] to be equal to 2, but got 6")): + with self.assertRaisesRegex( + RuntimeError, + escape("Expected input at *args[0].shape[0] to be equal to 2, but got 6"), + ): unflattened(torch.randn(6, 6)) def test_unflatten_with_inplace_compile(self): @@ -524,7 +531,9 @@ def forward(self, x): for sub_node in transpose_module.graph.nodes: if sub_node.op == "placeholder" or sub_node.op == "get_attr": call_module_input_order.append(sub_node.op) - self.assertEqual(call_module_input_order, ["placeholder", "get_attr", "get_attr"]) + self.assertEqual( + call_module_input_order, ["placeholder", "get_attr", "get_attr"] + ) def test_unflatten_constant_tensor(self): class SubMod(torch.nn.Module): @@ -546,7 +555,9 @@ def forward(self, x): export_module = torch.export.export(Mod(), (torch.randn((2, 3)),)) unflattened = unflatten(export_module) - self.compare_outputs(export_module.module(), unflattened, (torch.randn((2, 3)),)) + self.compare_outputs( + export_module.module(), unflattened, (torch.randn((2, 3)),) + ) @skipIfTorchDynamo("custom objects not supported in dynamo yet") def test_unflatten_constant_obj(self): @@ -580,10 +591,14 @@ def forward(self, x): return x + self.submod(x) with enable_torchbind_tracing(): - export_module = torch.export.export(Mod(), (torch.randn((2, 3)),), strict=False) + export_module = torch.export.export( + Mod(), (torch.randn((2, 3)),), strict=False + ) unflattened = unflatten(export_module) - self.compare_outputs(export_module.module(), unflattened, (torch.randn((2, 3)),)) + self.compare_outputs( + export_module.module(), unflattened, (torch.randn((2, 3)),) + ) def test_nested_leaf_non_strict(self): class Leaf(torch.nn.Module): @@ -656,5 +671,6 @@ def forward(self, x): fqn_list, ) + if __name__ == "__main__": run_tests() diff --git a/test/export/test_upgrade.py b/test/export/test_upgrade.py index 104ae92dca44b..3913370f9e46f 100644 --- a/test/export/test_upgrade.py +++ b/test/export/test_upgrade.py @@ -4,14 +4,10 @@ import torch import torch._dynamo as torchdynamo -from torch.export import export from torch._export.serde.serialize import GraphModuleOpUpgrader from torch._export.serde.upgrade import get_target_version, get_upgraders -from torch.testing._internal.common_utils import ( - run_tests, - TestCase, - IS_WINDOWS, -) +from torch.export import export +from torch.testing._internal.common_utils import IS_WINDOWS, run_tests, TestCase TEST_UPGRADERS = { "aten::div__Scalar_mode_0_3": ( @@ -32,8 +28,7 @@ def gelu_0_9(self: Tensor) -> Tensor: } TEST_UPGRADERS_ENTRY_MAP = { - "div__Scalar_mode_0_3": - """ + "div__Scalar_mode_0_3": """ from typing import Any, Optional def div__Scalar_mode_0_3(self: torch.Tensor, other: Any, *, rounding_mode: Optional[str]=None) -> torch.Tensor: return self.divide_(other, rounding_mode=rounding_mode)""" @@ -44,7 +39,7 @@ def div__Scalar_mode_0_3(self: torch.Tensor, other: Any, *, rounding_mode: Opti torch._C._UpgraderEntry( 4, "div__Scalar_mode_0_3", - "aten::div_.Scalar_mode(Tensor(a!) self, Scalar other, *, str? rounding_mode) -> Tensor(a!)" + "aten::div_.Scalar_mode(Tensor(a!) self, Scalar other, *, str? rounding_mode) -> Tensor(a!)", ) ] } @@ -52,27 +47,43 @@ def div__Scalar_mode_0_3(self: torch.Tensor, other: Any, *, rounding_mode: Opti def count_op(graph, target_str): return len( - [n for n in graph.nodes if isinstance(n.target, torch._ops.OpOverload) and n.target.name() == target_str]) + [ + n + for n in graph.nodes + if isinstance(n.target, torch._ops.OpOverload) + and n.target.name() == target_str + ] + ) @unittest.skipIf(not torchdynamo.is_dynamo_supported(), "dynamo doesn't support") class TestUpgrade(TestCase): def test_get_upgraders(self): - with patch.object(torch._C, "_get_upgraders_entry_map", return_value=TEST_UPGRADERS_ENTRY_MAP), \ - patch.object(torch._C, "_get_operator_version_map", return_value=TEST_OP_VERSION_MAP): + with patch.object( + torch._C, "_get_upgraders_entry_map", return_value=TEST_UPGRADERS_ENTRY_MAP + ), patch.object( + torch._C, "_get_operator_version_map", return_value=TEST_OP_VERSION_MAP + ): op_upgraders = get_upgraders() - self.assertEqual(op_upgraders, { - "div__Scalar_mode_0_3": ( - "aten::div_.Scalar_mode(Tensor(a!) self, Scalar other, *, str? rounding_mode) -> Tensor(a!)", - """ + self.assertEqual( + op_upgraders, + { + "div__Scalar_mode_0_3": ( + "aten::div_.Scalar_mode(Tensor(a!) self, Scalar other, *, str? rounding_mode) -> Tensor(a!)", + """ from typing import Any, Optional def div__Scalar_mode_0_3(self: torch.Tensor, other: Any, *, rounding_mode: Optional[str]=None) -> torch.Tensor: return self.divide_(other, rounding_mode=rounding_mode)""", - )}) + ) + }, + ) def test_get_upgraders_missing_from_entry_map_raises(self): - with patch.object(torch._C, "_get_upgraders_entry_map", return_value={}), \ - patch.object(torch._C, "_get_operator_version_map", return_value=TEST_OP_VERSION_MAP): + with patch.object( + torch._C, "_get_upgraders_entry_map", return_value={} + ), patch.object( + torch._C, "_get_operator_version_map", return_value=TEST_OP_VERSION_MAP + ): with self.assertRaises(RuntimeError): get_upgraders() @@ -93,21 +104,25 @@ def test_get_target_version_invalid_format_throws_exception(self): def test_creates_upgrader_pass(self): compiler_opset_version = {"aten": 4} model_opset_version = {"aten": 3} - upgrader = GraphModuleOpUpgrader(compiler_opset_version, model_opset_version, TEST_UPGRADERS) + upgrader = GraphModuleOpUpgrader( + compiler_opset_version, model_opset_version, TEST_UPGRADERS + ) self.assertEqual(len(upgrader.upgrader_passes), 1) def test_div_upgrader_replaces_op_with_old_version(self): class Foo(torch.nn.Module): def forward(self, a: torch.Tensor, b): - return torch.ops.aten.div.Scalar_mode(a, b, rounding_mode='trunc') + return torch.ops.aten.div.Scalar_mode(a, b, rounding_mode="trunc") fn = Foo() - inputs = (torch.ones([2, 3]) * 4, 2.) + inputs = (torch.ones([2, 3]) * 4, 2.0) ep = export(fn, inputs, []) compiler_opset_version = {"aten": 4} model_opset_version = {"aten": 3} - upgrader = GraphModuleOpUpgrader(compiler_opset_version, model_opset_version, TEST_UPGRADERS) + upgrader = GraphModuleOpUpgrader( + compiler_opset_version, model_opset_version, TEST_UPGRADERS + ) upgraded = ep._transform_do_not_use(*upgrader.upgrader_passes) upgraded.graph_module.print_readable() @@ -120,15 +135,17 @@ def forward(self, a: torch.Tensor, b): def test_div_upgrader_pass_return_new_op_after_retrace(self): class Foo(torch.nn.Module): def forward(self, a: torch.Tensor, b): - return torch.ops.aten.div.Scalar_mode(a, b, rounding_mode='trunc') + return torch.ops.aten.div.Scalar_mode(a, b, rounding_mode="trunc") fn = Foo() - inputs = (torch.ones([2, 3]) * 4, 2.) + inputs = (torch.ones([2, 3]) * 4, 2.0) ep = export(fn, inputs) compiler_opset_version = {"aten": 4} model_opset_version = {"aten": 3} - upgrader = GraphModuleOpUpgrader(compiler_opset_version, model_opset_version, TEST_UPGRADERS) + upgrader = GraphModuleOpUpgrader( + compiler_opset_version, model_opset_version, TEST_UPGRADERS + ) count = count_op(ep.graph, "aten::div.Scalar_mode") self.assertEqual(count, 1) @@ -146,5 +163,5 @@ def forward(self, a: torch.Tensor, b): self.assertEqual(decomposed_op_count, 1) -if __name__ == '__main__': +if __name__ == "__main__": run_tests() diff --git a/test/export/test_verifier.py b/test/export/test_verifier.py index 89a69de7c9ea5..c85e90f1b435f 100644 --- a/test/export/test_verifier.py +++ b/test/export/test_verifier.py @@ -5,11 +5,12 @@ from functorch.experimental import control_flow from torch import Tensor from torch._dynamo.eval_frame import is_dynamo_supported -from torch.export import export from torch._export.verifier import SpecViolationError, Verifier +from torch.export import export from torch.export.exported_program import InputKind, InputSpec, TensorArgument -from torch.testing._internal.common_utils import run_tests, TestCase, IS_WINDOWS +from torch.testing._internal.common_utils import IS_WINDOWS, run_tests, TestCase + @unittest.skipIf(not is_dynamo_supported(), "dynamo isn't supported") class TestVerifier(TestCase): @@ -66,9 +67,7 @@ def true_fn(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: def false_fn(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: return x - y - return control_flow.cond( - x.shape[0] > 2, true_fn, false_fn, [x, y] - ) + return control_flow.cond(x.shape[0] > 2, true_fn, false_fn, [x, y]) f = Foo() @@ -87,9 +86,7 @@ def true_fn(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: def false_fn(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: return x - y - return control_flow.cond( - x.shape[0] > 2, true_fn, false_fn, [x, y] - ) + return control_flow.cond(x.shape[0] > 2, true_fn, false_fn, [x, y]) f = Foo() @@ -118,7 +115,9 @@ def test_ep_verifier_invalid_param(self) -> None: class M(torch.nn.Module): def __init__(self) -> None: super().__init__() - self.register_parameter(name="a", param=torch.nn.Parameter(torch.randn(100))) + self.register_parameter( + name="a", param=torch.nn.Parameter(torch.randn(100)) + ) def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: return x + y + self.a @@ -127,9 +126,7 @@ def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: # Parameter doesn't exist in the state dict ep.graph_signature.input_specs[0] = InputSpec( - kind=InputKind.PARAMETER, - arg=TensorArgument(name="p_a"), - target="bad_param" + kind=InputKind.PARAMETER, arg=TensorArgument(name="p_a"), target="bad_param" ) with self.assertRaisesRegex(SpecViolationError, "not in the state dict"): ep._validate() @@ -220,5 +217,5 @@ def forward(self, x1, x2): ep._validate() -if __name__ == '__main__': +if __name__ == "__main__": run_tests()