Skip to content

Commit

Permalink
fix bug
Browse files Browse the repository at this point in the history
  • Loading branch information
huniu20 committed Apr 14, 2024
1 parent 4c08778 commit b3d0850
Showing 1 changed file with 6 additions and 12 deletions.
18 changes: 6 additions & 12 deletions test/test_functionalization.py
Original file line number Diff line number Diff line change
Expand Up @@ -551,8 +551,7 @@ def f(x):
def forward(self, arg0_1):
_fused_moving_avg_obs_fq_helper_functional = torch.ops.aten._fused_moving_avg_obs_fq_helper_functional.default(
arg0_1, arg0_1, arg0_1, arg0_1, arg0_1, arg0_1, arg0_1, 1.0, 0, 1, 0)
_fused_moving_avg_obs_fq_helper_functional = torch.ops.aten._fused_moving_avg_obs_fq_helper_functional.default(arg0_1, arg0_1, arg0_1, arg0_1, arg0_1, arg0_1, arg0_1, 1.0, 0, 1, 0)
getitem = _fused_moving_avg_obs_fq_helper_functional[0]
getitem_1 = _fused_moving_avg_obs_fq_helper_functional[1]
getitem_2 = _fused_moving_avg_obs_fq_helper_functional[2]
Expand Down Expand Up @@ -1690,8 +1689,7 @@ def forward(self, arg0_1):
view_copy_2 = torch.ops.aten.view_copy.default(add, [4, 4]); add = None
as_strided_copy_1 = torch.ops.aten.as_strided_copy.default(view_copy_2, [3, 3], [3, 1])
view_copy_3 = torch.ops.aten.view_copy.default(add_1, [3, 3]); add_1 = None
as_strided_scatter = torch.ops.aten.as_strided_scatter.default(
view_copy_2, view_copy_3, [3, 3], [3, 1]); view_copy_2 = view_copy_3 = None
as_strided_scatter = torch.ops.aten.as_strided_scatter.default(view_copy_2, view_copy_3, [3, 3], [3, 1]); view_copy_2 = view_copy_3 = None
view_copy_4 = torch.ops.aten.view_copy.default(as_strided_scatter, [8, 2]); as_strided_scatter = None
view_copy_5 = torch.ops.aten.view_copy.default(view_copy_4, [4, 4])
as_strided_copy_2 = torch.ops.aten.as_strided_copy.default(view_copy_5, [3, 3], [3, 1]); view_copy_5 = None
Expand Down Expand Up @@ -1946,8 +1944,7 @@ def forward(self, arg0_1, arg1_1, arg2_1):
repeat_1 = torch.ops.aten.repeat.default(arg2_1, [20])
view_copy = torch.ops.aten.view_copy.default(arg0_1, [1, 2000, 35, 45]); arg0_1 = None
empty = torch.ops.aten.empty.memory_format([0], dtype = torch.uint8, layout = torch.strided, device = device(type='cpu'))
_native_batch_norm_legit_functional = torch.ops.aten._native_batch_norm_legit_functional.default(
view_copy, None, None, repeat, repeat_1, True, 0.1, 1e-05); view_copy = repeat = repeat_1 = None
_native_batch_norm_legit_functional = torch.ops.aten._native_batch_norm_legit_functional.default(view_copy, None, None, repeat, repeat_1, True, 0.1, 1e-05); view_copy = repeat = repeat_1 = None
getitem = _native_batch_norm_legit_functional[0]
getitem_1 = _native_batch_norm_legit_functional[1]
getitem_2 = _native_batch_norm_legit_functional[2]
Expand Down Expand Up @@ -1993,8 +1990,7 @@ def forward(self, arg0_1, arg1_1, arg2_1):
repeat_1 = torch.ops.aten.repeat.default(arg2_1, [20])
view = torch.ops.aten.view.default(arg0_1, [1, 2000, 35, 45]); arg0_1 = None
empty = torch.ops.aten.empty.memory_format([0], dtype = torch.uint8, layout = torch.strided, device = device(type='cpu'))
_native_batch_norm_legit_functional = torch.ops.aten._native_batch_norm_legit_functional.default(
view, None, None, repeat, repeat_1, True, 0.1, 1e-05); view = repeat = repeat_1 = None
_native_batch_norm_legit_functional = torch.ops.aten._native_batch_norm_legit_functional.default(view, None, None, repeat, repeat_1, True, 0.1, 1e-05); view = repeat = repeat_1 = None
getitem = _native_batch_norm_legit_functional[0]
getitem_1 = _native_batch_norm_legit_functional[1]
getitem_2 = _native_batch_norm_legit_functional[2]
Expand Down Expand Up @@ -2057,8 +2053,7 @@ def f(x, running_mean, running_var):
def forward(self, arg0_1, arg1_1, arg2_1):
empty = torch.ops.aten.empty.memory_format([0], dtype = torch.uint8, layout = torch.strided, device = device(type='cpu'))
_native_batch_norm_legit_functional = torch.ops.aten._native_batch_norm_legit_functional.default(
arg0_1, None, None, arg1_1, arg2_1, True, 0.1, 1e-05); arg0_1 = None
_native_batch_norm_legit_functional = torch.ops.aten._native_batch_norm_legit_functional.default(arg0_1, None, None, arg1_1, arg2_1, True, 0.1, 1e-05); arg0_1 = None
getitem = _native_batch_norm_legit_functional[0]
getitem_1 = _native_batch_norm_legit_functional[1]
getitem_2 = _native_batch_norm_legit_functional[2]
Expand Down Expand Up @@ -2086,8 +2081,7 @@ def forward(self, arg0_1, arg1_1, arg2_1):
def forward(self, arg0_1, arg1_1, arg2_1):
empty = torch.ops.aten.empty.memory_format([0], dtype = torch.uint8, layout = torch.strided, device = device(type='cpu'))
_native_batch_norm_legit_functional = torch.ops.aten._native_batch_norm_legit_functional.default(
arg0_1, None, None, arg1_1, arg2_1, True, 0.1, 1e-05); arg0_1 = None
_native_batch_norm_legit_functional = torch.ops.aten._native_batch_norm_legit_functional.default(arg0_1, None, None, arg1_1, arg2_1, True, 0.1, 1e-05); arg0_1 = None
getitem = _native_batch_norm_legit_functional[0]
getitem_1 = _native_batch_norm_legit_functional[1]
getitem_2 = _native_batch_norm_legit_functional[2]
Expand Down

0 comments on commit b3d0850

Please sign in to comment.