From b3d0850b418d0476dbe6228f5986614a398ef1f8 Mon Sep 17 00:00:00 2001 From: Hu Niu Date: Sun, 14 Apr 2024 16:24:28 +0800 Subject: [PATCH] fix bug --- test/test_functionalization.py | 18 ++++++------------ 1 file changed, 6 insertions(+), 12 deletions(-) diff --git a/test/test_functionalization.py b/test/test_functionalization.py index 0fe7f5f6fa3b..3c1f3c823f28 100644 --- a/test/test_functionalization.py +++ b/test/test_functionalization.py @@ -551,8 +551,7 @@ def f(x): def forward(self, arg0_1): - _fused_moving_avg_obs_fq_helper_functional = torch.ops.aten._fused_moving_avg_obs_fq_helper_functional.default( - arg0_1, arg0_1, arg0_1, arg0_1, arg0_1, arg0_1, arg0_1, 1.0, 0, 1, 0) + _fused_moving_avg_obs_fq_helper_functional = torch.ops.aten._fused_moving_avg_obs_fq_helper_functional.default(arg0_1, arg0_1, arg0_1, arg0_1, arg0_1, arg0_1, arg0_1, 1.0, 0, 1, 0) getitem = _fused_moving_avg_obs_fq_helper_functional[0] getitem_1 = _fused_moving_avg_obs_fq_helper_functional[1] getitem_2 = _fused_moving_avg_obs_fq_helper_functional[2] @@ -1690,8 +1689,7 @@ def forward(self, arg0_1): view_copy_2 = torch.ops.aten.view_copy.default(add, [4, 4]); add = None as_strided_copy_1 = torch.ops.aten.as_strided_copy.default(view_copy_2, [3, 3], [3, 1]) view_copy_3 = torch.ops.aten.view_copy.default(add_1, [3, 3]); add_1 = None - as_strided_scatter = torch.ops.aten.as_strided_scatter.default( - view_copy_2, view_copy_3, [3, 3], [3, 1]); view_copy_2 = view_copy_3 = None + as_strided_scatter = torch.ops.aten.as_strided_scatter.default(view_copy_2, view_copy_3, [3, 3], [3, 1]); view_copy_2 = view_copy_3 = None view_copy_4 = torch.ops.aten.view_copy.default(as_strided_scatter, [8, 2]); as_strided_scatter = None view_copy_5 = torch.ops.aten.view_copy.default(view_copy_4, [4, 4]) as_strided_copy_2 = torch.ops.aten.as_strided_copy.default(view_copy_5, [3, 3], [3, 1]); view_copy_5 = None @@ -1946,8 +1944,7 @@ def forward(self, arg0_1, arg1_1, arg2_1): repeat_1 = torch.ops.aten.repeat.default(arg2_1, [20]) view_copy = torch.ops.aten.view_copy.default(arg0_1, [1, 2000, 35, 45]); arg0_1 = None empty = torch.ops.aten.empty.memory_format([0], dtype = torch.uint8, layout = torch.strided, device = device(type='cpu')) - _native_batch_norm_legit_functional = torch.ops.aten._native_batch_norm_legit_functional.default( - view_copy, None, None, repeat, repeat_1, True, 0.1, 1e-05); view_copy = repeat = repeat_1 = None + _native_batch_norm_legit_functional = torch.ops.aten._native_batch_norm_legit_functional.default(view_copy, None, None, repeat, repeat_1, True, 0.1, 1e-05); view_copy = repeat = repeat_1 = None getitem = _native_batch_norm_legit_functional[0] getitem_1 = _native_batch_norm_legit_functional[1] getitem_2 = _native_batch_norm_legit_functional[2] @@ -1993,8 +1990,7 @@ def forward(self, arg0_1, arg1_1, arg2_1): repeat_1 = torch.ops.aten.repeat.default(arg2_1, [20]) view = torch.ops.aten.view.default(arg0_1, [1, 2000, 35, 45]); arg0_1 = None empty = torch.ops.aten.empty.memory_format([0], dtype = torch.uint8, layout = torch.strided, device = device(type='cpu')) - _native_batch_norm_legit_functional = torch.ops.aten._native_batch_norm_legit_functional.default( - view, None, None, repeat, repeat_1, True, 0.1, 1e-05); view = repeat = repeat_1 = None + _native_batch_norm_legit_functional = torch.ops.aten._native_batch_norm_legit_functional.default(view, None, None, repeat, repeat_1, True, 0.1, 1e-05); view = repeat = repeat_1 = None getitem = _native_batch_norm_legit_functional[0] getitem_1 = _native_batch_norm_legit_functional[1] getitem_2 = _native_batch_norm_legit_functional[2] @@ -2057,8 +2053,7 @@ def f(x, running_mean, running_var): def forward(self, arg0_1, arg1_1, arg2_1): empty = torch.ops.aten.empty.memory_format([0], dtype = torch.uint8, layout = torch.strided, device = device(type='cpu')) - _native_batch_norm_legit_functional = torch.ops.aten._native_batch_norm_legit_functional.default( - arg0_1, None, None, arg1_1, arg2_1, True, 0.1, 1e-05); arg0_1 = None + _native_batch_norm_legit_functional = torch.ops.aten._native_batch_norm_legit_functional.default(arg0_1, None, None, arg1_1, arg2_1, True, 0.1, 1e-05); arg0_1 = None getitem = _native_batch_norm_legit_functional[0] getitem_1 = _native_batch_norm_legit_functional[1] getitem_2 = _native_batch_norm_legit_functional[2] @@ -2086,8 +2081,7 @@ def forward(self, arg0_1, arg1_1, arg2_1): def forward(self, arg0_1, arg1_1, arg2_1): empty = torch.ops.aten.empty.memory_format([0], dtype = torch.uint8, layout = torch.strided, device = device(type='cpu')) - _native_batch_norm_legit_functional = torch.ops.aten._native_batch_norm_legit_functional.default( - arg0_1, None, None, arg1_1, arg2_1, True, 0.1, 1e-05); arg0_1 = None + _native_batch_norm_legit_functional = torch.ops.aten._native_batch_norm_legit_functional.default(arg0_1, None, None, arg1_1, arg2_1, True, 0.1, 1e-05); arg0_1 = None getitem = _native_batch_norm_legit_functional[0] getitem_1 = _native_batch_norm_legit_functional[1] getitem_2 = _native_batch_norm_legit_functional[2]