From 8e78044003422f5f8f193eb32155bda75572ae83 Mon Sep 17 00:00:00 2001 From: Hu Niu Date: Fri, 12 Apr 2024 17:36:17 +0800 Subject: [PATCH 1/4] 1. Enable EFMT on test/test_functionalization.py --- .lintrunner.toml | 1 - test/test_functionalization.py | 845 ++++++++++++++++++++++++--------- 2 files changed, 617 insertions(+), 229 deletions(-) diff --git a/.lintrunner.toml b/.lintrunner.toml index e3349ac71de38..7ec16f83a8372 100644 --- a/.lintrunner.toml +++ b/.lintrunner.toml @@ -1320,7 +1320,6 @@ exclude_patterns = [ 'test/test_function_schema.py', 'test/test_functional_autograd_benchmark.py', 'test/test_functional_optim.py', - 'test/test_functionalization.py', 'test/test_functionalization_of_rng_ops.py', 'test/test_futures.py', 'test/test_fx.py', diff --git a/test/test_functionalization.py b/test/test_functionalization.py index ac2443823f2ff..e79ab0910ef12 100644 --- a/test/test_functionalization.py +++ b/test/test_functionalization.py @@ -1,31 +1,46 @@ # Owner(s): ["module: codegen"] -import torch +import unittest from contextlib import nullcontext -from torch.testing._internal.common_utils import ( - TestCase, run_tests, skipIfTorchDynamo, TEST_WITH_TORCHDYNAMO, IS_WINDOWS, - xfail_inherited_tests + +import torch +from torch._dispatch.python import ( + enable_crossref_functionalize, + enable_python_dispatcher, +) +from torch._subclasses.functional_tensor import ( + dispatch_functionalize, + FunctionalTensor, + FunctionalTensorMode, ) -from torch._subclasses.functional_tensor import FunctionalTensor, FunctionalTensorMode, dispatch_functionalize -from torch.testing._internal.logging_tensor import LoggingTensor, capture_logs -from torch.utils._pytree import tree_map_only from torch.fx.experimental.proxy_tensor import make_fx from torch.fx.passes.reinplace import reinplace -from torch._dispatch.python import enable_crossref_functionalize, enable_python_dispatcher from torch.multiprocessing.reductions import StorageWeakRef +from torch.testing._internal.common_utils import ( + IS_WINDOWS, + run_tests, + skipIfTorchDynamo, + TEST_WITH_TORCHDYNAMO, + TestCase, + xfail_inherited_tests, +) +from torch.testing._internal.logging_tensor import capture_logs, LoggingTensor from torch.utils import _pytree as pytree +from torch.utils._pytree import tree_map_only -import unittest def are_aliased(x, y): x_storage = StorageWeakRef(x.storage()) y_storage = StorageWeakRef(y.storage()) return x_storage == y_storage + # We can unify testing and use functionalize() here instead # if/when functorch moves into core. # This is basically a crappy version of `functionalize()`. -def _functionalize(f, *, reapply_views: bool, crossref: bool, skip_input_mutations: bool = False): +def _functionalize( + f, *, reapply_views: bool, crossref: bool, skip_input_mutations: bool = False +): def to_fun(t: torch.Tensor): func_t = torch._to_functional_tensor(t) func_t.requires_grad = t.requires_grad @@ -54,34 +69,47 @@ def wrapped(*inputs): if inpt_new.shape == inpt.shape: inpt.copy_(inpt_new) tree_map_only(torch.Tensor, torch._sync, out) - out_unwrapped = tree_map_only(torch.Tensor, torch._from_functional_tensor, out) + out_unwrapped = tree_map_only( + torch.Tensor, torch._from_functional_tensor, out + ) return out_unwrapped return wrapped -@unittest.skipIf(TEST_WITH_TORCHDYNAMO, "https://github.com/pytorch/pytorch/issues/81457") + +@unittest.skipIf( + TEST_WITH_TORCHDYNAMO, "https://github.com/pytorch/pytorch/issues/81457" +) class TestFunctionalization(TestCase): crossref = False def get_logs(self, func, *inpts, reapply_views=False, run_reinplace=False): inpts_clone = tree_map_only(torch.Tensor, torch.clone, inpts) - traced_f = make_fx(_functionalize(func, reapply_views=reapply_views, crossref=self.crossref))(*inpts) + traced_f = make_fx( + _functionalize(func, reapply_views=reapply_views, crossref=self.crossref) + )(*inpts) if run_reinplace: traced_f = reinplace(traced_f, *inpts_clone) return traced_f.code - def assert_functionalization(self, func, *inpts, reapply_views=False, mutated_input_metadata=False): + def assert_functionalization( + self, func, *inpts, reapply_views=False, mutated_input_metadata=False + ): clones1 = tree_map_only(torch.Tensor, torch.clone, inpts) clones2 = tree_map_only(torch.Tensor, torch.clone, inpts) clones3 = tree_map_only(torch.Tensor, torch.clone, inpts) # Compare outputs (and mutated inputs), with and without functionalization. out_ref = func(*inpts) - out_functional = _functionalize(func, reapply_views=reapply_views, crossref=self.crossref)(*clones1) + out_functional = _functionalize( + func, reapply_views=reapply_views, crossref=self.crossref + )(*clones1) # The reinplacing pass is only valid to run with reapply_views=True. - functional_func = make_fx(_functionalize(func, reapply_views=True, crossref=self.crossref))(*clones2) + functional_func = make_fx( + _functionalize(func, reapply_views=True, crossref=self.crossref) + )(*clones2) reinplace_func = reinplace(functional_func, *clones2) # NOTE: for now, need to pass in fresh inputs here, because make_fx @@ -95,22 +123,38 @@ def assert_functionalization(self, func, *inpts, reapply_views=False, mutated_in flat_inpts = pytree.tree_leaves(inpts) flat_clones1 = pytree.tree_leaves(clones1) flat_clones3 = pytree.tree_leaves(clones3) - for inpt, input_clone, input_clone3 in zip(flat_inpts, flat_clones1, flat_clones3): - self.assertEqual(inpt, input_clone) # input mutations should still occur + for inpt, input_clone, input_clone3 in zip( + flat_inpts, flat_clones1, flat_clones3 + ): + self.assertEqual( + inpt, input_clone + ) # input mutations should still occur self.assertEqual(inpt, input_clone3) # Handle tests with multi-tensor outputs if isinstance(out_ref, tuple): - out_refs, out_functionals, out_reinplaces = list(out_ref), list(out_functional), list(out_reinplace) + out_refs, out_functionals, out_reinplaces = ( + list(out_ref), + list(out_functional), + list(out_reinplace), + ) else: - out_refs, out_functionals, out_reinplaces = [out_ref], [out_functional], [out_reinplace] + out_refs, out_functionals, out_reinplaces = ( + [out_ref], + [out_functional], + [out_reinplace], + ) - for out_ref_, out_functional_, out_reinplace_ in zip(out_refs, out_functionals, out_reinplaces): + for out_ref_, out_functional_, out_reinplace_ in zip( + out_refs, out_functionals, out_reinplaces + ): self.assertEqual(out_ref_, out_functional_) self.assertEqual(out_ref_, out_reinplace_) def test_save_for_backwards_segfault(self): - inp = torch._to_functional_tensor(LoggingTensor(torch.randn(2, 2))).requires_grad_(True) + inp = torch._to_functional_tensor( + LoggingTensor(torch.randn(2, 2)) + ).requires_grad_(True) inp.exp() def test_multiple_views_of_same_base(self): @@ -123,6 +167,7 @@ def f(x): # z should have been updated too. z2 = z + 1 return z2 + self.assert_functionalization(f, torch.ones(4)) def test_freeze(self): @@ -143,7 +188,9 @@ def f(x): y.copy_(x) return y - r = _functionalize(f, reapply_views=True, crossref=self.crossref)(torch.ones(2, 2)) + r = _functionalize(f, reapply_views=True, crossref=self.crossref)( + torch.ones(2, 2) + ) self.assertEqual(r.stride(), (5, 1)) def test_set_(self): @@ -155,7 +202,7 @@ def f(x): # We should probaby get the crossref test to work, # but fixing it for Storage() objects is annoying. r = _functionalize(f, reapply_views=True, crossref=False)(torch.ones(2)) - self.assertEqual(str(r.device), 'cpu') + self.assertEqual(str(r.device), "cpu") def test_advanced_indexing(self): def f(): @@ -178,8 +225,11 @@ def f(input): def g(x): loss = f(x).sum() - from torch._functorch.aot_autograd import setup_stacktrace_preservation_hooks import torch.fx.traceback as fx_traceback + from torch._functorch.aot_autograd import ( + setup_stacktrace_preservation_hooks, + ) + setup_stacktrace_preservation_hooks([loss.grad_fn]) with fx_traceback.preserve_node_meta(): loss.backward() @@ -187,7 +237,9 @@ def g(x): with torch.autograd.detect_anomaly(check_nan=False): logs = self.get_logs(g, torch.ones(16, 64, 128, 128, requires_grad=True)) - self.assertExpectedInline(logs, """\ + self.assertExpectedInline( + logs, + """\ @@ -217,7 +269,8 @@ def forward(self, arg0_1): view_copy_11 = torch.ops.aten.view_copy.default(view_copy_8, [16, 64, 128, 128]); view_copy_8 = None detach_copy_1 = torch.ops.aten.detach_copy.default(view_copy_11); view_copy_11 = None return detach_copy_1 - """) # noqa: B950 + """, + ) # noqa: B950 def test_simple(self): def f(x): @@ -227,9 +280,12 @@ def f(x): y.add_(tmp) z = x * x return y + self.assert_functionalization(f, torch.ones(4, 2)) logs = self.get_logs(f, torch.ones(4, 2)) - self.assertExpectedInline(logs, """\ + self.assertExpectedInline( + logs, + """\ @@ -242,10 +298,15 @@ def forward(self, arg0_1): mul = torch.ops.aten.mul.Tensor(view_copy_1, view_copy_1) copy_ = torch.ops.aten.copy_.default(arg0_1, view_copy_1); arg0_1 = view_copy_1 = None return view_copy_2 - """) + """, + ) - reinplaced_logs = self.get_logs(f, torch.ones(4, 2), reapply_views=True, run_reinplace=True) - self.assertExpectedInline(reinplaced_logs, """\ + reinplaced_logs = self.get_logs( + f, torch.ones(4, 2), reapply_views=True, run_reinplace=True + ) + self.assertExpectedInline( + reinplaced_logs, + """\ @@ -258,7 +319,8 @@ def forward(self, arg0_1): mul = torch.ops.aten.mul.Tensor(view_1, view_1) copy_ = torch.ops.aten.copy_.default(arg0_1, view_1); arg0_1 = view_1 = None return view_2 - """) + """, + ) def test_simple_out(self): def f(x): @@ -269,9 +331,12 @@ def f(x): torch.add(y, tmp, out=z) w = z * z return w + self.assert_functionalization(f, torch.ones(4, 2)) logs = self.get_logs(f, torch.ones(4, 2)) - self.assertExpectedInline(logs, """\ + self.assertExpectedInline( + logs, + """\ @@ -282,10 +347,15 @@ def forward(self, arg0_1): add = torch.ops.aten.add.Tensor(view_copy, ones); view_copy = ones = None mul = torch.ops.aten.mul.Tensor(add, add); add = None return mul - """) + """, + ) - reinplaced_logs = self.get_logs(f, torch.ones(4, 2), reapply_views=True, run_reinplace=True) - self.assertExpectedInline(reinplaced_logs, """\ + reinplaced_logs = self.get_logs( + f, torch.ones(4, 2), reapply_views=True, run_reinplace=True + ) + self.assertExpectedInline( + reinplaced_logs, + """\ @@ -296,7 +366,8 @@ def forward(self, arg0_1): add = torch.ops.aten.add.Tensor(view, ones); view = ones = None mul = torch.ops.aten.mul.Tensor(add, add); add = None return mul - """) + """, + ) def test_multi_out(self): def f(x): @@ -306,9 +377,12 @@ def f(x): out_max = torch.empty(4) torch.aminmax(x, dim=0, out=(out_max, out_min)) return out_max + self.assert_functionalization(f, torch.arange(8, dtype=torch.float32)) logs = self.get_logs(f, torch.arange(8, dtype=torch.float32)) - self.assertExpectedInline(logs, """\ + self.assertExpectedInline( + logs, + """\ @@ -319,10 +393,18 @@ def forward(self, arg0_1): getitem = aminmax[0] getitem_1 = aminmax[1]; aminmax = None return getitem - """) + """, + ) - reinplaced_logs = self.get_logs(f, torch.arange(8, dtype=torch.float32), reapply_views=True, run_reinplace=True) - self.assertExpectedInline(reinplaced_logs, """\ + reinplaced_logs = self.get_logs( + f, + torch.arange(8, dtype=torch.float32), + reapply_views=True, + run_reinplace=True, + ) + self.assertExpectedInline( + reinplaced_logs, + """\ @@ -333,7 +415,8 @@ def forward(self, arg0_1): getitem = aminmax[0] getitem_1 = aminmax[1]; aminmax = None return getitem - """) + """, + ) def test_tensor_ctr(self): def f(x): @@ -346,7 +429,9 @@ def f(x): self.assert_functionalization(f, inpt) logs = self.get_logs(f, inpt) - self.assertExpectedInline(logs, """\ + self.assertExpectedInline( + logs, + """\ @@ -358,10 +443,13 @@ def forward(self, arg0_1): view_copy_1 = torch.ops.aten.view_copy.default(add, [3]); add = None view_copy_2 = torch.ops.aten.view_copy.default(view_copy_1, [-1]) return view_copy_1 - """) + """, + ) reinplaced_logs = self.get_logs(f, inpt, reapply_views=True, run_reinplace=True) - self.assertExpectedInline(reinplaced_logs, """\ + self.assertExpectedInline( + reinplaced_logs, + """\ @@ -373,7 +461,8 @@ def forward(self, arg0_1): view_1 = torch.ops.aten.view.default(view, [3]); view = None view_2 = torch.ops.aten.view.default(view_1, [-1]) return view_1 - """) + """, + ) def test_advanced_indexing_correct_strides(self): def f(a): @@ -383,6 +472,7 @@ def f(a): c = torch.ones_like(b, dtype=torch.bool) d = b.masked_fill_(c, 0) return d + self.assert_functionalization(f, torch.ones(2, 2), reapply_views=True) def test_tensor_list_mixed_functional_nonfunctional(self): @@ -393,8 +483,11 @@ def f(x): functional_tensor = torch.ones(2, dtype=torch.long) out = x[functional_tensor, nonfunctional_tensor] return out + out = f(torch.ones(2, 2)) - out_functional = _functionalize(f, reapply_views=True, crossref=self.crossref)(torch.ones(2, 2)) + out_functional = _functionalize(f, reapply_views=True, crossref=self.crossref)( + torch.ones(2, 2) + ) self.assertEqual(out, out_functional) def test_inplace_on_non_view(self): @@ -405,9 +498,12 @@ def f(x): y = x.view(4, 2) x.add_(tmp) return y + self.assert_functionalization(f, torch.ones(4, 2)) logs = self.get_logs(f, torch.ones(4, 2)) - self.assertExpectedInline(logs, """\ + self.assertExpectedInline( + logs, + """\ @@ -418,10 +514,15 @@ def forward(self, arg0_1): copy_ = torch.ops.aten.copy_.default(arg0_1, add); arg0_1 = None view_copy_1 = torch.ops.aten.view_copy.default(add, [4, 2]); add = None return view_copy_1 - """) + """, + ) - reinplaced_logs = self.get_logs(f, torch.ones(4, 2), reapply_views=True, run_reinplace=True) - self.assertExpectedInline(reinplaced_logs, """\ + reinplaced_logs = self.get_logs( + f, torch.ones(4, 2), reapply_views=True, run_reinplace=True + ) + self.assertExpectedInline( + reinplaced_logs, + """\ @@ -432,16 +533,21 @@ def forward(self, arg0_1): copy_ = torch.ops.aten.copy_.default(arg0_1, add); arg0_1 = None view_1 = torch.ops.aten.view.default(add, [4, 2]); add = None return view_1 - """) + """, + ) # Some ops that are mutable are neither inplace nor out= ops. # They also need special handling. def test_mutable_op_not_inplace_or_other(self): def f(x): - return torch._fused_moving_avg_obs_fq_helper(x, x, x, x, x, x, x, 1.0, 0, 1, 0) + return torch._fused_moving_avg_obs_fq_helper( + x, x, x, x, x, x, x, 1.0, 0, 1, 0 + ) logs = self.get_logs(f, torch.ones(1)) - self.assertExpectedInline(logs, """\ + self.assertExpectedInline( + logs, + """\ @@ -455,16 +561,20 @@ def forward(self, arg0_1): getitem_5 = _fused_moving_avg_obs_fq_helper_functional[5]; _fused_moving_avg_obs_fq_helper_functional = None copy_ = torch.ops.aten.copy_.default(arg0_1, getitem_5); arg0_1 = getitem_5 = None return (getitem, getitem_1) - """) # noqa: B950 + """, + ) # noqa: B950 def test_as_strided(self): def f(x): y = x.as_strided((2,), (2,), 1) y.add_(1) return x + self.assert_functionalization(f, torch.ones(9)) logs = self.get_logs(f, torch.ones(9)) - self.assertExpectedInline(logs, """\ + self.assertExpectedInline( + logs, + """\ @@ -475,11 +585,16 @@ def forward(self, arg0_1): as_strided_copy_1 = torch.ops.aten.as_strided_copy.default(as_strided_scatter, [2], [2], 1) copy_ = torch.ops.aten.copy_.default(arg0_1, as_strided_scatter); arg0_1 = None return as_strided_scatter - """) + """, + ) # NB: even with reapply_views=True, we expect to see scatter op - reinplaced_logs = self.get_logs(f, torch.ones(2, 2), reapply_views=True, run_reinplace=False) - self.assertExpectedInline(reinplaced_logs, """\ + reinplaced_logs = self.get_logs( + f, torch.ones(2, 2), reapply_views=True, run_reinplace=False + ) + self.assertExpectedInline( + reinplaced_logs, + """\ @@ -490,32 +605,40 @@ def forward(self, arg0_1): as_strided_1 = torch.ops.aten.as_strided.default(as_strided_scatter, [2], [2], 1) copy_ = torch.ops.aten.copy_.default(arg0_1, as_strided_scatter); arg0_1 = None return as_strided_scatter - """) + """, + ) def test_tensor_list_composite(self): def f(x): # Test an op with TensorList input y = torch.block_diag(x, x) return y + self.assert_functionalization(f, torch.ones(2, 2)) logs = self.get_logs(f, torch.ones(2, 2)) - self.assertExpectedInline(logs, """\ + self.assertExpectedInline( + logs, + """\ def forward(self, arg0_1): block_diag = torch.ops.aten.block_diag.default([arg0_1, arg0_1]); arg0_1 = None return block_diag - """) + """, + ) def test_cat(self): def f(x): out = torch.empty(0) torch.cat((x,), out=out) return out + self.assert_functionalization(f, torch.ones(2, 2)) logs = self.get_logs(f, torch.ones(2, 2)) - self.assertExpectedInline(logs, """\ + self.assertExpectedInline( + logs, + """\ @@ -523,10 +646,15 @@ def forward(self, arg0_1): empty = torch.ops.aten.empty.memory_format([0], device = device(type='cpu'), pin_memory = False) cat = torch.ops.aten.cat.default([arg0_1]); arg0_1 = None return cat - """) + """, + ) - reinplaced_logs = self.get_logs(f, torch.ones(2, 2), reapply_views=True, run_reinplace=True) - self.assertExpectedInline(reinplaced_logs, """\ + reinplaced_logs = self.get_logs( + f, torch.ones(2, 2), reapply_views=True, run_reinplace=True + ) + self.assertExpectedInline( + reinplaced_logs, + """\ @@ -534,8 +662,8 @@ def forward(self, arg0_1): empty = torch.ops.aten.empty.memory_format([0], device = device(type='cpu'), pin_memory = False) cat = torch.ops.aten.cat.default([arg0_1]); arg0_1 = None return cat - """) - + """, + ) def test_diagonal(self): def f(x): @@ -545,9 +673,12 @@ def f(x): y.add_(tmp) z = x * x return z + self.assert_functionalization(f, torch.ones(2, 2)) logs = self.get_logs(f, torch.ones(2, 2)) - self.assertExpectedInline(logs, """\ + self.assertExpectedInline( + logs, + """\ @@ -560,10 +691,15 @@ def forward(self, arg0_1): diagonal_copy_1 = torch.ops.aten.diagonal_copy.default(diagonal_scatter); diagonal_scatter = None mul = torch.ops.aten.mul.Tensor(arg0_1, arg0_1); arg0_1 = None return mul - """) + """, + ) - reinplaced_logs = self.get_logs(f, torch.ones(2, 2), reapply_views=True, run_reinplace=True) - self.assertExpectedInline(reinplaced_logs, """\ + reinplaced_logs = self.get_logs( + f, torch.ones(2, 2), reapply_views=True, run_reinplace=True + ) + self.assertExpectedInline( + reinplaced_logs, + """\ @@ -575,7 +711,8 @@ def forward(self, arg0_1): diagonal_1 = torch.ops.aten.diagonal.default(clone); clone = None mul = torch.ops.aten.mul.Tensor(arg0_1, arg0_1); arg0_1 = None return mul - """) + """, + ) def test_diagonal_mutated_input(self): def f(x): @@ -584,10 +721,13 @@ def f(x): y = x.diagonal() y.add_(tmp) return x + x = torch.ones(2, 2) self.assert_functionalization(f, x) logs = self.get_logs(f, torch.ones(2, 2)) - self.assertExpectedInline(logs, """\ + self.assertExpectedInline( + logs, + """\ @@ -599,11 +739,16 @@ def forward(self, arg0_1): diagonal_copy_1 = torch.ops.aten.diagonal_copy.default(diagonal_scatter) copy_ = torch.ops.aten.copy_.default(arg0_1, diagonal_scatter); arg0_1 = None return diagonal_scatter - """) + """, + ) # NB: even with reapply_views=True, we expect to see scatter op - reinplaced_logs = self.get_logs(f, torch.ones(2, 2), reapply_views=True, run_reinplace=False) - self.assertExpectedInline(reinplaced_logs, """\ + reinplaced_logs = self.get_logs( + f, torch.ones(2, 2), reapply_views=True, run_reinplace=False + ) + self.assertExpectedInline( + reinplaced_logs, + """\ @@ -615,7 +760,8 @@ def forward(self, arg0_1): diagonal_1 = torch.ops.aten.diagonal.default(diagonal_scatter) copy_ = torch.ops.aten.copy_.default(arg0_1, diagonal_scatter); arg0_1 = None return diagonal_scatter - """) + """, + ) def test_channels_last_contiguous(self): def f(x): @@ -624,13 +770,17 @@ def f(x): y = x.diagonal() y.add_(tmp) return x + x = torch.randn(4, 8, 8, 3).permute(0, 3, 1, 2) self.assert_functionalization(f, x) logs = self.get_logs(f, x).strip() # There should be no clone in the graph - self.assertExpectedInline(logs, """\ + self.assertExpectedInline( + logs, + """\ def forward(self, arg0_1): - return arg0_1""") + return arg0_1""", + ) def test_split(self): def f(x): @@ -641,9 +791,12 @@ def f(x): y3.add_(tmp) z = x * x return y3 + self.assert_functionalization(f, torch.ones(4, 2)) logs = self.get_logs(f, torch.ones(4, 2)) - self.assertExpectedInline(logs, """\ + self.assertExpectedInline( + logs, + """\ @@ -666,11 +819,16 @@ def forward(self, arg0_1): mul = torch.ops.aten.mul.Tensor(slice_scatter, slice_scatter) copy_ = torch.ops.aten.copy_.default(arg0_1, slice_scatter); arg0_1 = slice_scatter = None return diagonal_copy_1 - """) # noqa: B950 + """, + ) # noqa: B950 # NB: even with reapply_views=True, we expect to see scatter op - reinplaced_logs = self.get_logs(f, torch.ones(4, 2), reapply_views=True, run_reinplace=False) - self.assertExpectedInline(reinplaced_logs, """\ + reinplaced_logs = self.get_logs( + f, torch.ones(4, 2), reapply_views=True, run_reinplace=False + ) + self.assertExpectedInline( + reinplaced_logs, + """\ @@ -693,7 +851,8 @@ def forward(self, arg0_1): mul = torch.ops.aten.mul.Tensor(slice_scatter, slice_scatter) copy_ = torch.ops.aten.copy_.default(arg0_1, slice_scatter); arg0_1 = slice_scatter = None return diagonal_1 - """) # noqa: B950 + """, + ) # noqa: B950 def test_split_with_sizes(self): def f(x): @@ -704,9 +863,12 @@ def f(x): y3.add_(tmp) z = x * x return y3 + self.assert_functionalization(f, torch.ones(4, 2)) logs = self.get_logs(f, torch.ones(4, 2)) - self.assertExpectedInline(logs, """\ + self.assertExpectedInline( + logs, + """\ @@ -729,11 +891,16 @@ def forward(self, arg0_1): mul = torch.ops.aten.mul.Tensor(slice_scatter, slice_scatter) copy_ = torch.ops.aten.copy_.default(arg0_1, slice_scatter); arg0_1 = slice_scatter = None return diagonal_copy_1 - """) # noqa: B950 + """, + ) # noqa: B950 # NB: even with reapply_views=True, we expect to see scatter op - reinplaced_logs = self.get_logs(f, torch.ones(4, 2), reapply_views=True, run_reinplace=False) - self.assertExpectedInline(reinplaced_logs, """\ + reinplaced_logs = self.get_logs( + f, torch.ones(4, 2), reapply_views=True, run_reinplace=False + ) + self.assertExpectedInline( + reinplaced_logs, + """\ @@ -756,7 +923,8 @@ def forward(self, arg0_1): mul = torch.ops.aten.mul.Tensor(slice_scatter, slice_scatter) copy_ = torch.ops.aten.copy_.default(arg0_1, slice_scatter); arg0_1 = slice_scatter = None return diagonal_1 - """) # noqa: B950 + """, + ) # noqa: B950 def test_slice(self): def f(x): @@ -765,9 +933,12 @@ def f(x): y = x[0:2] y.add_(tmp) return x + self.assert_functionalization(f, torch.ones(4, 2), mutated_input_metadata=True) logs = self.get_logs(f, torch.ones(4, 2)) - self.assertExpectedInline(logs, """\ + self.assertExpectedInline( + logs, + """\ @@ -783,11 +954,16 @@ def forward(self, arg0_1): slice_copy_1 = torch.ops.aten.slice_copy.Tensor(transpose_copy_3, 0, 0, 2); transpose_copy_3 = None transpose_copy_4 = torch.ops.aten.transpose_copy.int(transpose_copy_2, 1, 0); transpose_copy_2 = None return transpose_copy_4 - """) # noqa: B950 + """, + ) # noqa: B950 # NB: even with reapply_views=True, we expect to see scatter op - reinplaced_logs = self.get_logs(f, torch.ones(4, 2), reapply_views=True, run_reinplace=False) - self.assertExpectedInline(reinplaced_logs, """\ + reinplaced_logs = self.get_logs( + f, torch.ones(4, 2), reapply_views=True, run_reinplace=False + ) + self.assertExpectedInline( + reinplaced_logs, + """\ @@ -803,7 +979,8 @@ def forward(self, arg0_1): slice_2 = torch.ops.aten.slice.Tensor(transpose_3, 0, 0, 2); transpose_3 = None transpose_4 = torch.ops.aten.transpose.int(transpose_2, 1, 0); transpose_2 = None return transpose_4 - """) # noqa: B950 + """, + ) # noqa: B950 def test_view_inplace(self): def f(x): @@ -813,9 +990,12 @@ def f(x): y = x[0] y.add_(tmp) return x + self.assert_functionalization(f, torch.ones(4, 2), mutated_input_metadata=True) logs = self.get_logs(f, torch.ones(4, 2)) - self.assertExpectedInline(logs, """\ + self.assertExpectedInline( + logs, + """\ @@ -831,11 +1011,16 @@ def forward(self, arg0_1): select_copy_1 = torch.ops.aten.select_copy.int(transpose_copy_3, 0, 0); transpose_copy_3 = None transpose_copy_4 = torch.ops.aten.transpose_copy.int(transpose_copy_2, 1, 0); transpose_copy_2 = None return transpose_copy_4 - """) # noqa: B950 + """, + ) # noqa: B950 # NB: even with reapply_views=True, we expect to see scatter op - reinplaced_logs = self.get_logs(f, torch.ones(4, 2), reapply_views=True, run_reinplace=False) - self.assertExpectedInline(reinplaced_logs, """\ + reinplaced_logs = self.get_logs( + f, torch.ones(4, 2), reapply_views=True, run_reinplace=False + ) + self.assertExpectedInline( + reinplaced_logs, + """\ @@ -851,7 +1036,8 @@ def forward(self, arg0_1): select_1 = torch.ops.aten.select.int(transpose_3, 0, 0); transpose_3 = None transpose_4 = torch.ops.aten.transpose.int(transpose_2, 1, 0); transpose_2 = None return transpose_4 - """) # noqa: B950 + """, + ) # noqa: B950 def test_unbind(self): def f(x): @@ -861,9 +1047,12 @@ def f(x): y, _ = x.unbind(0) y.add_(tmp) return x + self.assert_functionalization(f, torch.ones(4, 2), mutated_input_metadata=True) logs = self.get_logs(f, torch.ones(4, 2)) - self.assertExpectedInline(logs, """\ + self.assertExpectedInline( + logs, + """\ @@ -883,11 +1072,16 @@ def forward(self, arg0_1): getitem_3 = unbind_copy_1[1]; unbind_copy_1 = None transpose_copy_4 = torch.ops.aten.transpose_copy.int(transpose_copy_2, 1, 0); transpose_copy_2 = None return transpose_copy_4 - """) # noqa: B950 + """, + ) # noqa: B950 # NB: even with reapply_views=True, we expect to see scatter op - reinplaced_logs = self.get_logs(f, torch.ones(4, 2), reapply_views=True, run_reinplace=False) - self.assertExpectedInline(reinplaced_logs, """\ + reinplaced_logs = self.get_logs( + f, torch.ones(4, 2), reapply_views=True, run_reinplace=False + ) + self.assertExpectedInline( + reinplaced_logs, + """\ @@ -907,7 +1101,8 @@ def forward(self, arg0_1): getitem_3 = unbind_1[1]; unbind_1 = None transpose_4 = torch.ops.aten.transpose.int(transpose_2, 1, 0); transpose_2 = None return transpose_4 - """) # noqa: B950 + """, + ) # noqa: B950 def test_optional_tensor_list(self): def f(x): @@ -918,9 +1113,12 @@ def f(x): values = torch.arange(4, dtype=y.dtype) y.index_put_((indices,), values, accumulate=False) return y + self.assert_functionalization(f, torch.ones(4, 2)) logs = self.get_logs(f, torch.ones(4, 2)) - self.assertExpectedInline(logs, """\ + self.assertExpectedInline( + logs, + """\ @@ -933,7 +1131,8 @@ def forward(self, arg0_1): view_copy_2 = torch.ops.aten.view_copy.default(view_copy_1, [8]) copy_ = torch.ops.aten.copy_.default(arg0_1, view_copy_1); arg0_1 = view_copy_1 = None return view_copy_2 - """) # noqa: B950 + """, + ) # noqa: B950 def test_scalars(self): def f(x): @@ -944,9 +1143,12 @@ def f(x): z = 2 * y z.div_(1) return z + self.assert_functionalization(f, torch.ones(4, 2)) logs = self.get_logs(f, torch.ones(4, 2)) - self.assertExpectedInline(logs, """\ + self.assertExpectedInline( + logs, + """\ @@ -960,7 +1162,8 @@ def forward(self, arg0_1): div = torch.ops.aten.div.Tensor(mul, 1); mul = None copy_ = torch.ops.aten.copy_.default(arg0_1, view_copy_1); arg0_1 = view_copy_1 = None return div - """) + """, + ) @skipIfTorchDynamo("Test does not work with TorchDynamo") def test_metadata_change(self): @@ -970,9 +1173,12 @@ def f(x): y = x.clone() out = y.ge_(0) return out + self.assert_functionalization(f, torch.ones(4, 2)) logs = self.get_logs(f, torch.ones(4, 2)) - self.assertExpectedInline(logs, """\ + self.assertExpectedInline( + logs, + """\ @@ -981,10 +1187,15 @@ def forward(self, arg0_1): ge = torch.ops.aten.ge.Scalar(clone, 0); clone = None _to_copy = torch.ops.aten._to_copy.default(ge, dtype = torch.float32, layout = torch.strided); ge = None return _to_copy - """) + """, + ) - reinplaced_logs = self.get_logs(f, torch.ones(2, 2), reapply_views=True, run_reinplace=True) - self.assertExpectedInline(reinplaced_logs, """\ + reinplaced_logs = self.get_logs( + f, torch.ones(2, 2), reapply_views=True, run_reinplace=True + ) + self.assertExpectedInline( + reinplaced_logs, + """\ @@ -993,7 +1204,8 @@ def forward(self, arg0_1): ge = torch.ops.aten.ge.Scalar(clone, 0); clone = None _to_copy = torch.ops.aten._to_copy.default(ge, dtype = torch.float32, layout = torch.strided); ge = None return _to_copy - """) # noqa: B950 + """, + ) # noqa: B950 @skipIfTorchDynamo("Test does not work with TorchDynamo") def test_metadata_change_out_op(self): @@ -1002,7 +1214,9 @@ def f(t, y): return torch.add(t, y, out=out_1) inpt1, inpt2 = torch.tensor([1]), torch.tensor([1]) - inpt1_func, inpt2_func = torch._to_functional_tensor(inpt1), torch._to_functional_tensor(inpt2) + inpt1_func, inpt2_func = torch._to_functional_tensor( + inpt1 + ), torch._to_functional_tensor(inpt2) out_ref = f(inpt1, inpt2) torch._enable_functionalization(reapply_views=True) @@ -1012,22 +1226,25 @@ def f(t, y): torch._disable_functionalization() self.assertEqual(out_ref, torch._from_functional_tensor(out_functional)) - def test_only_one_view(self): def f(x): # This tests that we don't have any unnecessary views in the trace. # If the input wasn't mutated, we don't need to regenerate it, # so there should be a total of 1 op in the output trace. return x.view(4, 2) + logs = self.get_logs(f, torch.ones(4, 2)) - self.assertExpectedInline(logs, """\ + self.assertExpectedInline( + logs, + """\ def forward(self, arg0_1): view_copy = torch.ops.aten.view_copy.default(arg0_1, [4, 2]); arg0_1 = None return view_copy - """) + """, + ) def test_everything(self): def f(x): @@ -1043,9 +1260,12 @@ def f(x): z2.add_(tmp) z4 = z0[0] + z2.reshape(4) return z2 + self.assert_functionalization(f, torch.ones(4, 2)) logs = self.get_logs(f, torch.ones(4, 2)) - self.assertExpectedInline(logs, """\ + self.assertExpectedInline( + logs, + """\ @@ -1096,10 +1316,15 @@ def forward(self, arg0_1): view_copy_13 = torch.ops.aten.view_copy.default(getitem_4, [4]); getitem_4 = None add_2 = torch.ops.aten.add.Tensor(select_copy_1, view_copy_13); select_copy_1 = view_copy_13 = None return getitem_2 - """) # noqa: B950 + """, + ) # noqa: B950 - reinplaced_logs = self.get_logs(f, torch.ones(4, 2), reapply_views=True, run_reinplace=True) - self.assertExpectedInline(reinplaced_logs, """\ + reinplaced_logs = self.get_logs( + f, torch.ones(4, 2), reapply_views=True, run_reinplace=True + ) + self.assertExpectedInline( + reinplaced_logs, + """\ @@ -1141,7 +1366,8 @@ def forward(self, arg0_1): select_1 = torch.ops.aten.select.int(view_9, 0, 0); view_9 = None add_2 = torch.ops.aten.add.Tensor(select_1, _unsafe_view); select_1 = _unsafe_view = None return getitem_2 - """) + """, + ) def test_reapply_views_simple(self): def f(x): @@ -1150,9 +1376,12 @@ def f(x): y.add_(tmp) z = x * x return y + self.assert_functionalization(f, torch.ones(4, 2), reapply_views=True) logs = self.get_logs(f, torch.ones(4, 2), reapply_views=True) - self.assertExpectedInline(logs, """\ + self.assertExpectedInline( + logs, + """\ @@ -1165,7 +1394,8 @@ def forward(self, arg0_1): mul = torch.ops.aten.mul.Tensor(view_1, view_1) copy_ = torch.ops.aten.copy_.default(arg0_1, view_1); arg0_1 = view_1 = None return view_2 - """) + """, + ) def test_aliases_maintained_after_pass_when_reapplying_views(self): def f(x): @@ -1203,7 +1433,9 @@ def f(x): # to() is a composite op that noops when the dtype/shape match, so nothing gets logged. # self.assert_functionalization(f, torch.ones(2)) logs = self.get_logs(f, torch.ones(2)) - self.assertExpectedInline(logs, """\ + self.assertExpectedInline( + logs, + """\ @@ -1217,10 +1449,15 @@ def forward(self, arg0_1): diagonal_scatter_1 = torch.ops.aten.diagonal_scatter.default(diagonal_scatter, add); diagonal_scatter = add = None diagonal_copy_2 = torch.ops.aten.diagonal_copy.default(diagonal_scatter_1); diagonal_scatter_1 = None return diagonal_copy_2 - """) + """, + ) - reinplaced_logs = self.get_logs(f, torch.ones(2), reapply_views=True, run_reinplace=True) - self.assertExpectedInline(reinplaced_logs, """\ + reinplaced_logs = self.get_logs( + f, torch.ones(2), reapply_views=True, run_reinplace=True + ) + self.assertExpectedInline( + reinplaced_logs, + """\ @@ -1232,12 +1469,15 @@ def forward(self, arg0_1): add = torch.ops.aten.add_.Tensor(diagonal_1, arg0_1); diagonal_1 = arg0_1 = None diagonal_2 = torch.ops.aten.diagonal.default(zeros); zeros = None return diagonal_2 - """) + """, + ) # Test 2: copy_() with same dtype, different shape self.assert_functionalization(f, torch.ones(1)) logs = self.get_logs(f, torch.ones(1)) - self.assertExpectedInline(logs, """\ + self.assertExpectedInline( + logs, + """\ @@ -1251,10 +1491,15 @@ def forward(self, arg0_1): diagonal_scatter_1 = torch.ops.aten.diagonal_scatter.default(diagonal_scatter, add); diagonal_scatter = add = None diagonal_copy_2 = torch.ops.aten.diagonal_copy.default(diagonal_scatter_1); diagonal_scatter_1 = None return diagonal_copy_2 - """) + """, + ) - reinplaced_logs = self.get_logs(f, torch.ones(1), reapply_views=True, run_reinplace=True) - self.assertExpectedInline(reinplaced_logs, """\ + reinplaced_logs = self.get_logs( + f, torch.ones(1), reapply_views=True, run_reinplace=True + ) + self.assertExpectedInline( + reinplaced_logs, + """\ @@ -1266,12 +1511,15 @@ def forward(self, arg0_1): add = torch.ops.aten.add_.Tensor(diagonal_1, arg0_1); diagonal_1 = arg0_1 = None diagonal_2 = torch.ops.aten.diagonal.default(zeros); zeros = None return diagonal_2 - """) + """, + ) # Test 3: copy_() with different dtype, same shape self.assert_functionalization(f, torch.ones(2, dtype=torch.long)) logs = self.get_logs(f, torch.ones(2, dtype=torch.long)) - self.assertExpectedInline(logs, """\ + self.assertExpectedInline( + logs, + """\ @@ -1285,10 +1533,15 @@ def forward(self, arg0_1): diagonal_scatter_1 = torch.ops.aten.diagonal_scatter.default(diagonal_scatter, add); diagonal_scatter = add = None diagonal_copy_2 = torch.ops.aten.diagonal_copy.default(diagonal_scatter_1); diagonal_scatter_1 = None return diagonal_copy_2 - """) # noqa: B950 + """, + ) # noqa: B950 - reinplaced_logs = self.get_logs(f, torch.ones(2, dtype=torch.long), reapply_views=True, run_reinplace=True) - self.assertExpectedInline(reinplaced_logs, """\ + reinplaced_logs = self.get_logs( + f, torch.ones(2, dtype=torch.long), reapply_views=True, run_reinplace=True + ) + self.assertExpectedInline( + reinplaced_logs, + """\ @@ -1300,12 +1553,15 @@ def forward(self, arg0_1): add = torch.ops.aten.add_.Tensor(diagonal_1, arg0_1); diagonal_1 = arg0_1 = None diagonal_2 = torch.ops.aten.diagonal.default(zeros); zeros = None return diagonal_2 - """) # noqa: B950 + """, + ) # noqa: B950 # Test 4: copy_() with different dtype, different shape self.assert_functionalization(f, torch.ones(1, dtype=torch.long)) logs = self.get_logs(f, torch.ones(1, dtype=torch.long)) - self.assertExpectedInline(logs, """\ + self.assertExpectedInline( + logs, + """\ @@ -1319,10 +1575,15 @@ def forward(self, arg0_1): diagonal_scatter_1 = torch.ops.aten.diagonal_scatter.default(diagonal_scatter, add); diagonal_scatter = add = None diagonal_copy_2 = torch.ops.aten.diagonal_copy.default(diagonal_scatter_1); diagonal_scatter_1 = None return diagonal_copy_2 - """) # noqa: B950 + """, + ) # noqa: B950 - reinplaced_logs = self.get_logs(f, torch.ones(1, dtype=torch.long), reapply_views=True, run_reinplace=True) - self.assertExpectedInline(reinplaced_logs, """\ + reinplaced_logs = self.get_logs( + f, torch.ones(1, dtype=torch.long), reapply_views=True, run_reinplace=True + ) + self.assertExpectedInline( + reinplaced_logs, + """\ @@ -1334,7 +1595,8 @@ def forward(self, arg0_1): add = torch.ops.aten.add_.Tensor(diagonal_1, arg0_1); diagonal_1 = arg0_1 = None diagonal_2 = torch.ops.aten.diagonal.default(zeros); zeros = None return diagonal_2 - """) # noqa: B950 + """, + ) # noqa: B950 def test_expand_symint(self): # Once some existing SymInt bugs are ironed out, we should update @@ -1344,14 +1606,17 @@ def f(x): self.assert_functionalization(f, torch.ones(2, 2)) logs = self.get_logs(f, torch.ones(2, 2)) - self.assertExpectedInline(logs, """\ + self.assertExpectedInline( + logs, + """\ def forward(self, arg0_1): expand_copy = torch.ops.aten.expand_copy.default(arg0_1, [2, 2]); arg0_1 = None return expand_copy - """) + """, + ) def test_fill_(self): def f(x): @@ -1362,7 +1627,9 @@ def f(x): self.assert_functionalization(f, torch.ones(2, 2)) logs = self.get_logs(f, torch.ones(2, 2)) - self.assertExpectedInline(logs, """\ + self.assertExpectedInline( + logs, + """\ @@ -1373,10 +1640,15 @@ def forward(self, arg0_1): diagonal_scatter = torch.ops.aten.diagonal_scatter.default(add, fill); add = fill = None diagonal_copy_1 = torch.ops.aten.diagonal_copy.default(diagonal_scatter) return diagonal_scatter - """) + """, + ) - reinplaced_logs = self.get_logs(f, torch.ones(2, 2), reapply_views=True, run_reinplace=True) - self.assertExpectedInline(reinplaced_logs, """\ + reinplaced_logs = self.get_logs( + f, torch.ones(2, 2), reapply_views=True, run_reinplace=True + ) + self.assertExpectedInline( + reinplaced_logs, + """\ @@ -1386,7 +1658,8 @@ def forward(self, arg0_1): fill = torch.ops.aten.fill_.Scalar(diagonal, 0); diagonal = None diagonal_1 = torch.ops.aten.diagonal.default(add) return add - """) + """, + ) def test_resize_smaller(self): def f(w): @@ -1401,7 +1674,9 @@ def f(w): self.assert_functionalization(f, torch.ones(8, 2)) logs = self.get_logs(f, torch.ones(8, 2)) - self.assertExpectedInline(logs, """\ + self.assertExpectedInline( + logs, + """\ @@ -1424,10 +1699,15 @@ def forward(self, arg0_1): as_strided_copy_3 = torch.ops.aten.as_strided_copy.default(view_copy_7, [3, 3], [3, 1]); view_copy_7 = None add_2 = torch.ops.aten.add.Tensor(as_strided_copy_3, 1); as_strided_copy_3 = None return add_2 - """) # noqa: B950 + """, + ) # noqa: B950 - reinplaced_logs = self.get_logs(f, torch.ones(8, 2), reapply_views=True, run_reinplace=True) - self.assertExpectedInline(reinplaced_logs, """\ + reinplaced_logs = self.get_logs( + f, torch.ones(8, 2), reapply_views=True, run_reinplace=True + ) + self.assertExpectedInline( + reinplaced_logs, + """\ @@ -1449,7 +1729,8 @@ def forward(self, arg0_1): as_strided_3 = torch.ops.aten.as_strided.default(view_7, [3, 3], [3, 1]); view_7 = None add_2 = torch.ops.aten.add_.Tensor(as_strided_3, 1) return as_strided_3 - """) + """, + ) def test_resize_same_size_diff_rank(self): def f(x): @@ -1478,7 +1759,9 @@ def f(x): self.assert_functionalization(f, torch.ones(8, 2)) logs = self.get_logs(f, torch.ones(8, 2)) - self.assertExpectedInline(logs, """\ + self.assertExpectedInline( + logs, + """\ @@ -1491,10 +1774,15 @@ def forward(self, arg0_1): view_copy_2 = torch.ops.aten.view_copy.default(view_copy_1, [25]) add_1 = torch.ops.aten.add.Tensor(view_copy_1, 1) return (view_copy_1, add_1) - """) + """, + ) - reinplaced_logs = self.get_logs(f, torch.ones(8, 2), reapply_views=True, run_reinplace=True) - self.assertExpectedInline(reinplaced_logs, """\ + reinplaced_logs = self.get_logs( + f, torch.ones(8, 2), reapply_views=True, run_reinplace=True + ) + self.assertExpectedInline( + reinplaced_logs, + """\ @@ -1507,7 +1795,8 @@ def forward(self, arg0_1): view_2 = torch.ops.aten.view.default(view_1, [25]) add_1 = torch.ops.aten.add.Tensor(view_1, 1) return (view_1, add_1) - """) + """, + ) def test_resize_larger_invalid(self): def f(x): @@ -1524,8 +1813,9 @@ def f(x): return y, out with self.assertRaisesRegex( - RuntimeError, - r'Attempted to resize a view tensor to a larger size. This is not allowed in the functionalization pass'): + RuntimeError, + r"Attempted to resize a view tensor to a larger size. This is not allowed in the functionalization pass", + ): self.assert_functionalization(f, torch.ones(8, 2)) def test_nested_functions_propagate_updates(self): @@ -1558,9 +1848,12 @@ def f(x, y): # Make sure that functionalization ran the "+" kernel # with a functional + non-functional tensor, and wrapped the output appropriately. - self.assertExpectedInline('\n'.join(logs), """\ + self.assertExpectedInline( + "\n".join(logs), + """\ $2: f32[4] = torch._ops.aten.add.Tensor($0, $1) -$3: f32[4] = torch._ops.aten.add.Tensor($2, 1)""") +$3: f32[4] = torch._ops.aten.add.Tensor($2, 1)""", + ) def test_mixed_wrappers_invalid(self): x1_not_functional = torch.ones(4) @@ -1577,9 +1870,12 @@ def f(x): tmp = torch.zeros(10) tmp[5].fill_(1) return tmp + self.assert_functionalization(f, torch.ones(2)) logs = self.get_logs(f, torch.ones(2)) - self.assertExpectedInline(logs, """\ + self.assertExpectedInline( + logs, + """\ @@ -1590,10 +1886,15 @@ def forward(self, arg0_1): select_scatter = torch.ops.aten.select_scatter.default(zeros, fill, 0, 5); zeros = fill = None select_copy_1 = torch.ops.aten.select_copy.int(select_scatter, 0, 5) return select_scatter - """) # noqa: B950 + """, + ) # noqa: B950 - reinplaced_logs = self.get_logs(f, torch.ones(2), reapply_views=True, run_reinplace=True) - self.assertExpectedInline(reinplaced_logs, """\ + reinplaced_logs = self.get_logs( + f, torch.ones(2), reapply_views=True, run_reinplace=True + ) + self.assertExpectedInline( + reinplaced_logs, + """\ @@ -1603,23 +1904,39 @@ def forward(self, arg0_1): fill = torch.ops.aten.fill_.Scalar(select, 1); select = None select_1 = torch.ops.aten.select.int(zeros, 0, 5) return zeros - """) - + """, + ) def test_instance_norm(self): size = 100 def f(x, running_mean, running_var): with enable_python_dispatcher(): - return torch.instance_norm(x, None, None, running_mean, running_var, - use_input_stats=True, momentum=0.1, eps=1e-5, cudnn_enabled=False) - self.assert_functionalization(f, torch.randn(20, size, 35, 45), torch.zeros(size), torch.ones(size)) + return torch.instance_norm( + x, + None, + None, + running_mean, + running_var, + use_input_stats=True, + momentum=0.1, + eps=1e-5, + cudnn_enabled=False, + ) + + self.assert_functionalization( + f, torch.randn(20, size, 35, 45), torch.zeros(size), torch.ones(size) + ) # On Windows, for instance_norm, the alias_copy's are reordered to come right before they need to be used # whereas on other platforms, the alias_copy's are before the view_copy's. # e.g., the alias_copy after the getitem_4 assignment would be moved to be right before the copy assignment. if not IS_WINDOWS: - logs = self.get_logs(f, torch.randn(20, size, 35, 45), torch.zeros(size), torch.ones(size)) - self.assertExpectedInline(logs, """\ + logs = self.get_logs( + f, torch.randn(20, size, 35, 45), torch.zeros(size), torch.ones(size) + ) + self.assertExpectedInline( + logs, + """\ @@ -1652,13 +1969,20 @@ def forward(self, arg0_1, arg1_1, arg2_1): copy_ = torch.ops.aten.copy_.default(arg1_1, alias_copy_1); arg1_1 = alias_copy_1 = None copy__1 = torch.ops.aten.copy_.default(arg2_1, alias_copy_4); arg2_1 = alias_copy_4 = None return view_copy_5 - """) # noqa: B950 + """, + ) # noqa: B950 reinplaced_logs = self.get_logs( - f, torch.randn(20, size, 35, 45), torch.zeros(size), torch.ones(size), - reapply_views=True, run_reinplace=True + f, + torch.randn(20, size, 35, 45), + torch.zeros(size), + torch.ones(size), + reapply_views=True, + run_reinplace=True, ) - self.assertExpectedInline(reinplaced_logs, """\ + self.assertExpectedInline( + reinplaced_logs, + """\ @@ -1691,8 +2015,8 @@ def forward(self, arg0_1, arg1_1, arg2_1): copy_ = torch.ops.aten.copy_.default(arg1_1, alias_1); arg1_1 = alias_1 = None copy__1 = torch.ops.aten.copy_.default(arg2_1, alias_4); arg2_1 = alias_4 = None return view_5 - """) # noqa: B950 - + """, + ) # noqa: B950 def test_mutation_overlapping_mem(self): def fn(x): @@ -1702,19 +2026,29 @@ def fn(x): t3 = t2.abs_() return t3 - with self.assertRaisesRegex(RuntimeError, r'encountered a tensor being mutated that has internal overlap'): + with self.assertRaisesRegex( + RuntimeError, + r"encountered a tensor being mutated that has internal overlap", + ): x = torch.ones(1, 5) out = _functionalize(fn, reapply_views=True, crossref=False)(x) - def test_batch_norm(self): def f(x, running_mean, running_var): with enable_python_dispatcher(): - return torch.batch_norm(x, None, None, running_mean, running_var, True, 0.1, 1e-5, False) + return torch.batch_norm( + x, None, None, running_mean, running_var, True, 0.1, 1e-5, False + ) - self.assert_functionalization(f, torch.randn(20, 100, 35, 45), torch.zeros(100), torch.ones(100)) - logs = self.get_logs(f, torch.randn(20, 100, 35, 45), torch.zeros(100), torch.ones(100)) - self.assertExpectedInline(logs, """\ + self.assert_functionalization( + f, torch.randn(20, 100, 35, 45), torch.zeros(100), torch.ones(100) + ) + logs = self.get_logs( + f, torch.randn(20, 100, 35, 45), torch.zeros(100), torch.ones(100) + ) + self.assertExpectedInline( + logs, + """\ @@ -1729,12 +2063,20 @@ def forward(self, arg0_1, arg1_1, arg2_1): copy_ = torch.ops.aten.copy_.default(arg1_1, getitem_3); arg1_1 = getitem_3 = None copy__1 = torch.ops.aten.copy_.default(arg2_1, getitem_4); arg2_1 = getitem_4 = None return getitem - """) # noqa: B950 + """, + ) # noqa: B950 reinplaced_logs = self.get_logs( - f, torch.randn(20, 100, 35, 45), torch.zeros(100), torch.ones(100), reapply_views=True, run_reinplace=True + f, + torch.randn(20, 100, 35, 45), + torch.zeros(100), + torch.ones(100), + reapply_views=True, + run_reinplace=True, ) - self.assertExpectedInline(reinplaced_logs, """\ + self.assertExpectedInline( + reinplaced_logs, + """\ @@ -1749,7 +2091,8 @@ def forward(self, arg0_1, arg1_1, arg2_1): copy_ = torch.ops.aten.copy_.default(arg1_1, getitem_3); arg1_1 = getitem_3 = None copy__1 = torch.ops.aten.copy_.default(arg2_1, getitem_4); arg2_1 = getitem_4 = None return getitem - """) # noqa: B950 + """, + ) # noqa: B950 # This tests our python shims around C++ Functionalization: FunctionalTensor and FunctionalTensorMode def test_python_functionalization(self): @@ -1768,7 +2111,9 @@ def f_functionalized(x): # our FunctionalTensor will inherit the same keyset. # We don't have an easy way of directly mutating a tensor's keyset from python, # so globally disabling functionalization here is easier. - maybe_disable = torch._C._ExcludeDispatchKeyGuard(torch._C.DispatchKeySet(torch._C.DispatchKey.Functionalize)) + maybe_disable = torch._C._ExcludeDispatchKeyGuard( + torch._C.DispatchKeySet(torch._C.DispatchKey.Functionalize) + ) with maybe_disable, FunctionalTensorMode(): x_wrapped = FunctionalTensor.to_functional(x) out_wrapped = f(x_wrapped) @@ -1781,14 +2126,17 @@ def f_functionalized(x): fx_g = make_fx(f_functionalized)(x) # NB: view_1 below is expected (though unused) due to view replay. AOTAutograd runs a # DCE pass that will remove nodes like this later on. - self.assertExpectedInline(fx_g.code.strip(), """\ + self.assertExpectedInline( + fx_g.code.strip(), + """\ def forward(self, x_1): view = torch.ops.aten.view.default(x_1, [-1]) mul = torch.ops.aten.mul.Tensor(x_1, 2); x_1 = None view_1 = torch.ops.aten.view.default(mul, [-1]) view_2 = torch.ops.aten.view.default(mul, [-1]); mul = None add = torch.ops.aten.add.Tensor(view_2, 1); view_2 = None - return add""") + return add""", + ) def test_python_functionalization_zero_tensor(self): def f(x): @@ -1796,14 +2144,21 @@ def f(x): out = x + y out.mul_(2) return out + x = torch.randn(4) out_ref = f(x) out_test = dispatch_functionalize(f)(x) - out_test_cpp = _functionalize(f, reapply_views=True, crossref=False, skip_input_mutations=True)(x) + out_test_cpp = _functionalize( + f, reapply_views=True, crossref=False, skip_input_mutations=True + )(x) self.assertEqual(out_ref, out_test) self.assertEqual(out_ref, out_test_cpp) fx_g = make_fx(dispatch_functionalize(f))(x) - fx_g_cpp = make_fx(_functionalize(f, reapply_views=True, crossref=False, skip_input_mutations=True))(x) + fx_g_cpp = make_fx( + _functionalize( + f, reapply_views=True, crossref=False, skip_input_mutations=True + ) + )(x) self.assertEqual(fx_g_cpp.code.strip(), fx_g.code.strip()) def test_python_functionalization_is_conj(self): @@ -1834,7 +2189,6 @@ def f(x): self.assertEqual(out_ref[0], out_test_cpp[0]) self.assertEqual(out_ref[1], out_test_cpp[1]) - def test_python_functionalization_conj(self): def f(x): y = x.clone().conj() @@ -1844,12 +2198,20 @@ def f(x): x = torch.randn(4, dtype=torch.complex64) out_ref = f(x) out_test = dispatch_functionalize(f)(x) - out_test_cpp = _functionalize(f, reapply_views=True, crossref=False, skip_input_mutations=True)(x) + out_test_cpp = _functionalize( + f, reapply_views=True, crossref=False, skip_input_mutations=True + )(x) self.assertEqual(out_ref, out_test) self.assertEqual(out_test, out_test_cpp) fx_g = make_fx(dispatch_functionalize(f))(x) - fx_g_cpp = make_fx(_functionalize(f, reapply_views=True, crossref=False, skip_input_mutations=True))(x) - self.assertExpectedInline(fx_g.code.strip(), """\ + fx_g_cpp = make_fx( + _functionalize( + f, reapply_views=True, crossref=False, skip_input_mutations=True + ) + )(x) + self.assertExpectedInline( + fx_g.code.strip(), + """\ def forward(self, arg0_1): clone = torch.ops.aten.clone.default(arg0_1); arg0_1 = None _conj = torch.ops.aten._conj.default(clone); clone = None @@ -1861,7 +2223,8 @@ def forward(self, arg0_1): _conj_2 = torch.ops.aten._conj.default(_conj_1); _conj_1 = None clone_3 = torch.ops.aten.clone.default(_conj_2); _conj_2 = None view_as_real = torch.ops.aten.view_as_real.default(clone_3); clone_3 = None - return view_as_real""") + return view_as_real""", + ) self.assertEqual(fx_g_cpp.code.strip(), fx_g.code.strip()) def test_python_functionalization_neg(self): @@ -1873,23 +2236,34 @@ def f(x): x = torch.randn(4) out_ref = f(x) out_test = dispatch_functionalize(f)(x) - out_test_cpp = _functionalize(f, reapply_views=True, crossref=False, skip_input_mutations=True)(x) + out_test_cpp = _functionalize( + f, reapply_views=True, crossref=False, skip_input_mutations=True + )(x) self.assertEqual(out_ref, out_test) self.assertEqual(out_ref, out_test_cpp) fx_g = make_fx(dispatch_functionalize(f))(x) - fx_g_cpp = make_fx(_functionalize(f, reapply_views=True, crossref=False, skip_input_mutations=True))(x) - self.assertExpectedInline(fx_g.code.strip(), """\ + fx_g_cpp = make_fx( + _functionalize( + f, reapply_views=True, crossref=False, skip_input_mutations=True + ) + )(x) + self.assertExpectedInline( + fx_g.code.strip(), + """\ def forward(self, arg0_1): _neg_view = torch.ops.aten._neg_view.default(arg0_1); arg0_1 = None clone = torch.ops.aten.clone.default(_neg_view); _neg_view = None add = torch.ops.aten.add.Tensor(clone, 1); clone = None - return add""") + return add""", + ) self.assertEqual(fx_g_cpp.code.strip(), fx_g.code.strip()) def test_python_functionalization_lift_fresh_storage(self): unlifted = torch.tensor([0.0]) - maybe_disable = torch._C._ExcludeDispatchKeyGuard(torch._C.DispatchKeySet(torch._C.DispatchKey.Functionalize)) + maybe_disable = torch._C._ExcludeDispatchKeyGuard( + torch._C.DispatchKeySet(torch._C.DispatchKey.Functionalize) + ) with maybe_disable, FunctionalTensorMode(): lifted = torch.ops.aten.lift_fresh.default(unlifted) @@ -1903,36 +2277,51 @@ def f(x): x = torch.randn(4) out_ref = f(x) out_test = dispatch_functionalize(f)(x) - out_test_cpp = _functionalize(f, reapply_views=True, crossref=False, skip_input_mutations=True)(x) + out_test_cpp = _functionalize( + f, reapply_views=True, crossref=False, skip_input_mutations=True + )(x) self.assertEqual(out_ref, out_test) self.assertEqual(out_ref, out_test_cpp) fx_g = make_fx(dispatch_functionalize(f))(x) - fx_g_cpp = make_fx(_functionalize(f, reapply_views=True, crossref=False, skip_input_mutations=True))(x) - self.assertExpectedInline(fx_g.code.strip(), """\ + fx_g_cpp = make_fx( + _functionalize( + f, reapply_views=True, crossref=False, skip_input_mutations=True + ) + )(x) + self.assertExpectedInline( + fx_g.code.strip(), + """\ def forward(self, arg0_1): _tensor_constant0 = self._tensor_constant0 lift_fresh_copy = torch.ops.aten.lift_fresh_copy.default(_tensor_constant0); _tensor_constant0 = None add = torch.ops.aten.add.Tensor(lift_fresh_copy, arg0_1); lift_fresh_copy = arg0_1 = None - return add""") + return add""", + ) self.assertEqual(fx_g_cpp.code.strip(), fx_g.code.strip()) -@xfail_inherited_tests([ - "test_as_strided", - "test_copy_", - "test_diagonal", - "test_diagonal_mutated_input", - "test_everything", - "test_fill_", - "test_slice", - "test_split", - "test_split_with_sizes", - "test_unbind", - "test_view_clone_view_inplace", - "test_view_inplace", -]) -@unittest.skipIf(TEST_WITH_TORCHDYNAMO, "dynamo-ing code with proxy + fake doesnt work well") + +@xfail_inherited_tests( + [ + "test_as_strided", + "test_copy_", + "test_diagonal", + "test_diagonal_mutated_input", + "test_everything", + "test_fill_", + "test_slice", + "test_split", + "test_split_with_sizes", + "test_unbind", + "test_view_clone_view_inplace", + "test_view_inplace", + ] +) +@unittest.skipIf( + TEST_WITH_TORCHDYNAMO, "dynamo-ing code with proxy + fake doesnt work well" +) class TestCrossRefFunctionalization(TestFunctionalization): crossref = True -if __name__ == '__main__': + +if __name__ == "__main__": run_tests() From 4c0877874e12ba3e2d12e9a1a607bf5ebf7ff362 Mon Sep 17 00:00:00 2001 From: Hu Niu Date: Sun, 14 Apr 2024 15:56:14 +0800 Subject: [PATCH 2/4] modify some lint issues --- test/test_functionalization.py | 19 ++++++++++++------- 1 file changed, 12 insertions(+), 7 deletions(-) diff --git a/test/test_functionalization.py b/test/test_functionalization.py index e79ab0910ef12..0fe7f5f6fa3bc 100644 --- a/test/test_functionalization.py +++ b/test/test_functionalization.py @@ -81,7 +81,6 @@ def wrapped(*inputs): TEST_WITH_TORCHDYNAMO, "https://github.com/pytorch/pytorch/issues/81457" ) class TestFunctionalization(TestCase): - crossref = False def get_logs(self, func, *inpts, reapply_views=False, run_reinplace=False): @@ -552,7 +551,8 @@ def f(x): def forward(self, arg0_1): - _fused_moving_avg_obs_fq_helper_functional = torch.ops.aten._fused_moving_avg_obs_fq_helper_functional.default(arg0_1, arg0_1, arg0_1, arg0_1, arg0_1, arg0_1, arg0_1, 1.0, 0, 1, 0) + _fused_moving_avg_obs_fq_helper_functional = torch.ops.aten._fused_moving_avg_obs_fq_helper_functional.default( + arg0_1, arg0_1, arg0_1, arg0_1, arg0_1, arg0_1, arg0_1, 1.0, 0, 1, 0) getitem = _fused_moving_avg_obs_fq_helper_functional[0] getitem_1 = _fused_moving_avg_obs_fq_helper_functional[1] getitem_2 = _fused_moving_avg_obs_fq_helper_functional[2] @@ -1690,7 +1690,8 @@ def forward(self, arg0_1): view_copy_2 = torch.ops.aten.view_copy.default(add, [4, 4]); add = None as_strided_copy_1 = torch.ops.aten.as_strided_copy.default(view_copy_2, [3, 3], [3, 1]) view_copy_3 = torch.ops.aten.view_copy.default(add_1, [3, 3]); add_1 = None - as_strided_scatter = torch.ops.aten.as_strided_scatter.default(view_copy_2, view_copy_3, [3, 3], [3, 1]); view_copy_2 = view_copy_3 = None + as_strided_scatter = torch.ops.aten.as_strided_scatter.default( + view_copy_2, view_copy_3, [3, 3], [3, 1]); view_copy_2 = view_copy_3 = None view_copy_4 = torch.ops.aten.view_copy.default(as_strided_scatter, [8, 2]); as_strided_scatter = None view_copy_5 = torch.ops.aten.view_copy.default(view_copy_4, [4, 4]) as_strided_copy_2 = torch.ops.aten.as_strided_copy.default(view_copy_5, [3, 3], [3, 1]); view_copy_5 = None @@ -1945,7 +1946,8 @@ def forward(self, arg0_1, arg1_1, arg2_1): repeat_1 = torch.ops.aten.repeat.default(arg2_1, [20]) view_copy = torch.ops.aten.view_copy.default(arg0_1, [1, 2000, 35, 45]); arg0_1 = None empty = torch.ops.aten.empty.memory_format([0], dtype = torch.uint8, layout = torch.strided, device = device(type='cpu')) - _native_batch_norm_legit_functional = torch.ops.aten._native_batch_norm_legit_functional.default(view_copy, None, None, repeat, repeat_1, True, 0.1, 1e-05); view_copy = repeat = repeat_1 = None + _native_batch_norm_legit_functional = torch.ops.aten._native_batch_norm_legit_functional.default( + view_copy, None, None, repeat, repeat_1, True, 0.1, 1e-05); view_copy = repeat = repeat_1 = None getitem = _native_batch_norm_legit_functional[0] getitem_1 = _native_batch_norm_legit_functional[1] getitem_2 = _native_batch_norm_legit_functional[2] @@ -1991,7 +1993,8 @@ def forward(self, arg0_1, arg1_1, arg2_1): repeat_1 = torch.ops.aten.repeat.default(arg2_1, [20]) view = torch.ops.aten.view.default(arg0_1, [1, 2000, 35, 45]); arg0_1 = None empty = torch.ops.aten.empty.memory_format([0], dtype = torch.uint8, layout = torch.strided, device = device(type='cpu')) - _native_batch_norm_legit_functional = torch.ops.aten._native_batch_norm_legit_functional.default(view, None, None, repeat, repeat_1, True, 0.1, 1e-05); view = repeat = repeat_1 = None + _native_batch_norm_legit_functional = torch.ops.aten._native_batch_norm_legit_functional.default( + view, None, None, repeat, repeat_1, True, 0.1, 1e-05); view = repeat = repeat_1 = None getitem = _native_batch_norm_legit_functional[0] getitem_1 = _native_batch_norm_legit_functional[1] getitem_2 = _native_batch_norm_legit_functional[2] @@ -2054,7 +2057,8 @@ def f(x, running_mean, running_var): def forward(self, arg0_1, arg1_1, arg2_1): empty = torch.ops.aten.empty.memory_format([0], dtype = torch.uint8, layout = torch.strided, device = device(type='cpu')) - _native_batch_norm_legit_functional = torch.ops.aten._native_batch_norm_legit_functional.default(arg0_1, None, None, arg1_1, arg2_1, True, 0.1, 1e-05); arg0_1 = None + _native_batch_norm_legit_functional = torch.ops.aten._native_batch_norm_legit_functional.default( + arg0_1, None, None, arg1_1, arg2_1, True, 0.1, 1e-05); arg0_1 = None getitem = _native_batch_norm_legit_functional[0] getitem_1 = _native_batch_norm_legit_functional[1] getitem_2 = _native_batch_norm_legit_functional[2] @@ -2082,7 +2086,8 @@ def forward(self, arg0_1, arg1_1, arg2_1): def forward(self, arg0_1, arg1_1, arg2_1): empty = torch.ops.aten.empty.memory_format([0], dtype = torch.uint8, layout = torch.strided, device = device(type='cpu')) - _native_batch_norm_legit_functional = torch.ops.aten._native_batch_norm_legit_functional.default(arg0_1, None, None, arg1_1, arg2_1, True, 0.1, 1e-05); arg0_1 = None + _native_batch_norm_legit_functional = torch.ops.aten._native_batch_norm_legit_functional.default( + arg0_1, None, None, arg1_1, arg2_1, True, 0.1, 1e-05); arg0_1 = None getitem = _native_batch_norm_legit_functional[0] getitem_1 = _native_batch_norm_legit_functional[1] getitem_2 = _native_batch_norm_legit_functional[2] From b3d0850b418d0476dbe6228f5986614a398ef1f8 Mon Sep 17 00:00:00 2001 From: Hu Niu Date: Sun, 14 Apr 2024 16:24:28 +0800 Subject: [PATCH 3/4] fix bug --- test/test_functionalization.py | 18 ++++++------------ 1 file changed, 6 insertions(+), 12 deletions(-) diff --git a/test/test_functionalization.py b/test/test_functionalization.py index 0fe7f5f6fa3bc..3c1f3c823f284 100644 --- a/test/test_functionalization.py +++ b/test/test_functionalization.py @@ -551,8 +551,7 @@ def f(x): def forward(self, arg0_1): - _fused_moving_avg_obs_fq_helper_functional = torch.ops.aten._fused_moving_avg_obs_fq_helper_functional.default( - arg0_1, arg0_1, arg0_1, arg0_1, arg0_1, arg0_1, arg0_1, 1.0, 0, 1, 0) + _fused_moving_avg_obs_fq_helper_functional = torch.ops.aten._fused_moving_avg_obs_fq_helper_functional.default(arg0_1, arg0_1, arg0_1, arg0_1, arg0_1, arg0_1, arg0_1, 1.0, 0, 1, 0) getitem = _fused_moving_avg_obs_fq_helper_functional[0] getitem_1 = _fused_moving_avg_obs_fq_helper_functional[1] getitem_2 = _fused_moving_avg_obs_fq_helper_functional[2] @@ -1690,8 +1689,7 @@ def forward(self, arg0_1): view_copy_2 = torch.ops.aten.view_copy.default(add, [4, 4]); add = None as_strided_copy_1 = torch.ops.aten.as_strided_copy.default(view_copy_2, [3, 3], [3, 1]) view_copy_3 = torch.ops.aten.view_copy.default(add_1, [3, 3]); add_1 = None - as_strided_scatter = torch.ops.aten.as_strided_scatter.default( - view_copy_2, view_copy_3, [3, 3], [3, 1]); view_copy_2 = view_copy_3 = None + as_strided_scatter = torch.ops.aten.as_strided_scatter.default(view_copy_2, view_copy_3, [3, 3], [3, 1]); view_copy_2 = view_copy_3 = None view_copy_4 = torch.ops.aten.view_copy.default(as_strided_scatter, [8, 2]); as_strided_scatter = None view_copy_5 = torch.ops.aten.view_copy.default(view_copy_4, [4, 4]) as_strided_copy_2 = torch.ops.aten.as_strided_copy.default(view_copy_5, [3, 3], [3, 1]); view_copy_5 = None @@ -1946,8 +1944,7 @@ def forward(self, arg0_1, arg1_1, arg2_1): repeat_1 = torch.ops.aten.repeat.default(arg2_1, [20]) view_copy = torch.ops.aten.view_copy.default(arg0_1, [1, 2000, 35, 45]); arg0_1 = None empty = torch.ops.aten.empty.memory_format([0], dtype = torch.uint8, layout = torch.strided, device = device(type='cpu')) - _native_batch_norm_legit_functional = torch.ops.aten._native_batch_norm_legit_functional.default( - view_copy, None, None, repeat, repeat_1, True, 0.1, 1e-05); view_copy = repeat = repeat_1 = None + _native_batch_norm_legit_functional = torch.ops.aten._native_batch_norm_legit_functional.default(view_copy, None, None, repeat, repeat_1, True, 0.1, 1e-05); view_copy = repeat = repeat_1 = None getitem = _native_batch_norm_legit_functional[0] getitem_1 = _native_batch_norm_legit_functional[1] getitem_2 = _native_batch_norm_legit_functional[2] @@ -1993,8 +1990,7 @@ def forward(self, arg0_1, arg1_1, arg2_1): repeat_1 = torch.ops.aten.repeat.default(arg2_1, [20]) view = torch.ops.aten.view.default(arg0_1, [1, 2000, 35, 45]); arg0_1 = None empty = torch.ops.aten.empty.memory_format([0], dtype = torch.uint8, layout = torch.strided, device = device(type='cpu')) - _native_batch_norm_legit_functional = torch.ops.aten._native_batch_norm_legit_functional.default( - view, None, None, repeat, repeat_1, True, 0.1, 1e-05); view = repeat = repeat_1 = None + _native_batch_norm_legit_functional = torch.ops.aten._native_batch_norm_legit_functional.default(view, None, None, repeat, repeat_1, True, 0.1, 1e-05); view = repeat = repeat_1 = None getitem = _native_batch_norm_legit_functional[0] getitem_1 = _native_batch_norm_legit_functional[1] getitem_2 = _native_batch_norm_legit_functional[2] @@ -2057,8 +2053,7 @@ def f(x, running_mean, running_var): def forward(self, arg0_1, arg1_1, arg2_1): empty = torch.ops.aten.empty.memory_format([0], dtype = torch.uint8, layout = torch.strided, device = device(type='cpu')) - _native_batch_norm_legit_functional = torch.ops.aten._native_batch_norm_legit_functional.default( - arg0_1, None, None, arg1_1, arg2_1, True, 0.1, 1e-05); arg0_1 = None + _native_batch_norm_legit_functional = torch.ops.aten._native_batch_norm_legit_functional.default(arg0_1, None, None, arg1_1, arg2_1, True, 0.1, 1e-05); arg0_1 = None getitem = _native_batch_norm_legit_functional[0] getitem_1 = _native_batch_norm_legit_functional[1] getitem_2 = _native_batch_norm_legit_functional[2] @@ -2086,8 +2081,7 @@ def forward(self, arg0_1, arg1_1, arg2_1): def forward(self, arg0_1, arg1_1, arg2_1): empty = torch.ops.aten.empty.memory_format([0], dtype = torch.uint8, layout = torch.strided, device = device(type='cpu')) - _native_batch_norm_legit_functional = torch.ops.aten._native_batch_norm_legit_functional.default( - arg0_1, None, None, arg1_1, arg2_1, True, 0.1, 1e-05); arg0_1 = None + _native_batch_norm_legit_functional = torch.ops.aten._native_batch_norm_legit_functional.default(arg0_1, None, None, arg1_1, arg2_1, True, 0.1, 1e-05); arg0_1 = None getitem = _native_batch_norm_legit_functional[0] getitem_1 = _native_batch_norm_legit_functional[1] getitem_2 = _native_batch_norm_legit_functional[2] From d02183a698ba02c029f38b9c09151931c35a39ed Mon Sep 17 00:00:00 2001 From: hun Date: Fri, 26 Apr 2024 01:49:17 +0800 Subject: [PATCH 4/4] fix FLAKE issues --- test/test_functionalization.py | 24 ++++++++++++------------ 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/test/test_functionalization.py b/test/test_functionalization.py index 3c1f3c823f284..978b58b492c03 100644 --- a/test/test_functionalization.py +++ b/test/test_functionalization.py @@ -560,8 +560,8 @@ def forward(self, arg0_1): getitem_5 = _fused_moving_avg_obs_fq_helper_functional[5]; _fused_moving_avg_obs_fq_helper_functional = None copy_ = torch.ops.aten.copy_.default(arg0_1, getitem_5); arg0_1 = getitem_5 = None return (getitem, getitem_1) - """, - ) # noqa: B950 + """, # noqa: B950 + ) def test_as_strided(self): def f(x): @@ -1698,8 +1698,8 @@ def forward(self, arg0_1): as_strided_copy_3 = torch.ops.aten.as_strided_copy.default(view_copy_7, [3, 3], [3, 1]); view_copy_7 = None add_2 = torch.ops.aten.add.Tensor(as_strided_copy_3, 1); as_strided_copy_3 = None return add_2 - """, - ) # noqa: B950 + """, # noqa: B950 + ) reinplaced_logs = self.get_logs( f, torch.ones(8, 2), reapply_views=True, run_reinplace=True @@ -1968,8 +1968,8 @@ def forward(self, arg0_1, arg1_1, arg2_1): copy_ = torch.ops.aten.copy_.default(arg1_1, alias_copy_1); arg1_1 = alias_copy_1 = None copy__1 = torch.ops.aten.copy_.default(arg2_1, alias_copy_4); arg2_1 = alias_copy_4 = None return view_copy_5 - """, - ) # noqa: B950 + """, # noqa: B950 + ) reinplaced_logs = self.get_logs( f, @@ -2014,8 +2014,8 @@ def forward(self, arg0_1, arg1_1, arg2_1): copy_ = torch.ops.aten.copy_.default(arg1_1, alias_1); arg1_1 = alias_1 = None copy__1 = torch.ops.aten.copy_.default(arg2_1, alias_4); arg2_1 = alias_4 = None return view_5 - """, - ) # noqa: B950 + """, # noqa: B950 + ) def test_mutation_overlapping_mem(self): def fn(x): @@ -2062,8 +2062,8 @@ def forward(self, arg0_1, arg1_1, arg2_1): copy_ = torch.ops.aten.copy_.default(arg1_1, getitem_3); arg1_1 = getitem_3 = None copy__1 = torch.ops.aten.copy_.default(arg2_1, getitem_4); arg2_1 = getitem_4 = None return getitem - """, - ) # noqa: B950 + """, # noqa: B950 + ) reinplaced_logs = self.get_logs( f, @@ -2090,8 +2090,8 @@ def forward(self, arg0_1, arg1_1, arg2_1): copy_ = torch.ops.aten.copy_.default(arg1_1, getitem_3); arg1_1 = getitem_3 = None copy__1 = torch.ops.aten.copy_.default(arg2_1, getitem_4); arg2_1 = getitem_4 = None return getitem - """, - ) # noqa: B950 + """, # noqa: B950 + ) # This tests our python shims around C++ Functionalization: FunctionalTensor and FunctionalTensorMode def test_python_functionalization(self):