Skip to content

Commit

Permalink
fix FLAKE issues
Browse files Browse the repository at this point in the history
  • Loading branch information
huniu20 committed Apr 25, 2024
1 parent b3d0850 commit d02183a
Showing 1 changed file with 12 additions and 12 deletions.
24 changes: 12 additions & 12 deletions test/test_functionalization.py
Expand Up @@ -560,8 +560,8 @@ def forward(self, arg0_1):
getitem_5 = _fused_moving_avg_obs_fq_helper_functional[5]; _fused_moving_avg_obs_fq_helper_functional = None
copy_ = torch.ops.aten.copy_.default(arg0_1, getitem_5); arg0_1 = getitem_5 = None
return (getitem, getitem_1)
""",
) # noqa: B950
""", # noqa: B950
)

def test_as_strided(self):
def f(x):
Expand Down Expand Up @@ -1698,8 +1698,8 @@ def forward(self, arg0_1):
as_strided_copy_3 = torch.ops.aten.as_strided_copy.default(view_copy_7, [3, 3], [3, 1]); view_copy_7 = None
add_2 = torch.ops.aten.add.Tensor(as_strided_copy_3, 1); as_strided_copy_3 = None
return add_2
""",
) # noqa: B950
""", # noqa: B950
)

reinplaced_logs = self.get_logs(
f, torch.ones(8, 2), reapply_views=True, run_reinplace=True
Expand Down Expand Up @@ -1968,8 +1968,8 @@ def forward(self, arg0_1, arg1_1, arg2_1):
copy_ = torch.ops.aten.copy_.default(arg1_1, alias_copy_1); arg1_1 = alias_copy_1 = None
copy__1 = torch.ops.aten.copy_.default(arg2_1, alias_copy_4); arg2_1 = alias_copy_4 = None
return view_copy_5
""",
) # noqa: B950
""", # noqa: B950
)

reinplaced_logs = self.get_logs(
f,
Expand Down Expand Up @@ -2014,8 +2014,8 @@ def forward(self, arg0_1, arg1_1, arg2_1):
copy_ = torch.ops.aten.copy_.default(arg1_1, alias_1); arg1_1 = alias_1 = None
copy__1 = torch.ops.aten.copy_.default(arg2_1, alias_4); arg2_1 = alias_4 = None
return view_5
""",
) # noqa: B950
""", # noqa: B950
)

def test_mutation_overlapping_mem(self):
def fn(x):
Expand Down Expand Up @@ -2062,8 +2062,8 @@ def forward(self, arg0_1, arg1_1, arg2_1):
copy_ = torch.ops.aten.copy_.default(arg1_1, getitem_3); arg1_1 = getitem_3 = None
copy__1 = torch.ops.aten.copy_.default(arg2_1, getitem_4); arg2_1 = getitem_4 = None
return getitem
""",
) # noqa: B950
""", # noqa: B950
)

reinplaced_logs = self.get_logs(
f,
Expand All @@ -2090,8 +2090,8 @@ def forward(self, arg0_1, arg1_1, arg2_1):
copy_ = torch.ops.aten.copy_.default(arg1_1, getitem_3); arg1_1 = getitem_3 = None
copy__1 = torch.ops.aten.copy_.default(arg2_1, getitem_4); arg2_1 = getitem_4 = None
return getitem
""",
) # noqa: B950
""", # noqa: B950
)

# This tests our python shims around C++ Functionalization: FunctionalTensor and FunctionalTensorMode
def test_python_functionalization(self):
Expand Down

0 comments on commit d02183a

Please sign in to comment.