Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

aot autograd refactor: make all synthetic base logic layered in a single location #96235

Closed
wants to merge 5 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
44 changes: 20 additions & 24 deletions test/functorch/test_aotdispatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -1132,9 +1132,9 @@ def forward(self, primals_1):
add = torch.ops.aten.add.Tensor(as_strided, 1); as_strided = None
as_strided_scatter = torch.ops.aten.as_strided_scatter.default(clone, add, [2], [1], 0); clone = add = None
as_strided_2 = torch.ops.aten.as_strided.default(as_strided_scatter, [2], [1], 0)
as_strided_5 = torch.ops.aten.as_strided.default(as_strided_scatter, [2], [1], 2); as_strided_scatter = None
add_1 = torch.ops.aten.add.Tensor(as_strided_2, as_strided_5); as_strided_5 = None
return [as_strided_2, add_1]""") # noqa: B950
as_strided_5 = torch.ops.aten.as_strided.default(as_strided_scatter, [2], [1], 2)
add_1 = torch.ops.aten.add.Tensor(as_strided_2, as_strided_5); as_strided_2 = as_strided_5 = None
return [as_strided_scatter, add_1]""") # noqa: B950

@patch("functorch.compile.config.use_fake_tensor", True)
def test_input_mutation_aliases_other_input2(self):
Expand All @@ -1159,9 +1159,9 @@ def forward(self, primals_1):
add = torch.ops.aten.add.Tensor(as_strided, 1); as_strided = None
as_strided_scatter = torch.ops.aten.as_strided_scatter.default(clone, add, [2], [1], 0); clone = add = None
as_strided_2 = torch.ops.aten.as_strided.default(as_strided_scatter, [2], [1], 0)
as_strided_5 = torch.ops.aten.as_strided.default(as_strided_scatter, [2, 2], [2, 1], 0); as_strided_scatter = None
add_1 = torch.ops.aten.add.Tensor(as_strided_2, as_strided_5); as_strided_5 = None
return [as_strided_2, add_1]""") # noqa: B950
as_strided_5 = torch.ops.aten.as_strided.default(as_strided_scatter, [2, 2], [2, 1], 0)
add_1 = torch.ops.aten.add.Tensor(as_strided_2, as_strided_5); as_strided_2 = as_strided_5 = None
return [as_strided_scatter, add_1]""") # noqa: B950

@patch("functorch.compile.config.use_fake_tensor", True)
def test_input_mutation_aliases_and_output_alias(self):
Expand All @@ -1184,10 +1184,9 @@ def forward(self, primals_1):
as_strided = torch.ops.aten.as_strided.default(clone, [4], [1], 0)
add = torch.ops.aten.add.Tensor(as_strided, 1); as_strided = None
as_strided_scatter = torch.ops.aten.as_strided_scatter.default(clone, add, [4], [1], 0); clone = add = None
as_strided_2 = torch.ops.aten.as_strided.default(as_strided_scatter, [4], [1], 0)
as_strided_13 = torch.ops.aten.as_strided.default(as_strided_scatter, [4], [1], 0); as_strided_scatter = None
view_1 = torch.ops.aten.view.default(as_strided_13, [4]); as_strided_13 = None
return [as_strided_2, view_1]""") # noqa: B950
as_strided_8 = torch.ops.aten.as_strided.default(as_strided_scatter, [4], [1], 0)
view_1 = torch.ops.aten.view.default(as_strided_8, [4]); as_strided_8 = None
return [as_strided_scatter, view_1]""") # noqa: B950

@patch("functorch.compile.config.use_fake_tensor", True)
def test_input_aliased_with_mutation_output_alias(self):
Expand Down Expand Up @@ -1215,11 +1214,10 @@ def forward(self, primals_1, primals_2):
as_strided_1 = torch.ops.aten.as_strided.default(clone, [4], [1], 0)
mul = torch.ops.aten.mul.Tensor(as_strided_1, 2); as_strided_1 = None
as_strided_scatter = torch.ops.aten.as_strided_scatter.default(clone, mul, [4], [1], 0); clone = mul = None
as_strided_2 = torch.ops.aten.as_strided.default(as_strided_scatter, [4], [1], 0)
add = torch.ops.aten.add.Tensor(primals_2, 1); primals_2 = None
as_strided_12 = torch.ops.aten.as_strided.default(as_strided_scatter, [4], [1], 0); as_strided_scatter = None
view_1 = torch.ops.aten.view.default(as_strided_12, [-1]); as_strided_12 = None
return [as_strided_2, add, view_1]""") # noqa: B950
as_strided_7 = torch.ops.aten.as_strided.default(as_strided_scatter, [4], [1], 0)
view_1 = torch.ops.aten.view.default(as_strided_7, [-1]); as_strided_7 = None
return [as_strided_scatter, add, view_1]""") # noqa: B950

@patch("functorch.compile.config.use_fake_tensor", True)
def test_input_metadata_mutation_aliases(self):
Expand Down Expand Up @@ -1267,11 +1265,10 @@ def forward(self, primals_1, primals_2):
as_strided = torch.ops.aten.as_strided.default(clone, [4], [1], 0)
mul = torch.ops.aten.mul.Tensor(as_strided, 2); as_strided = None
as_strided_scatter = torch.ops.aten.as_strided_scatter.default(clone, mul, [4], [1], 0); clone = mul = None
as_strided_2 = torch.ops.aten.as_strided.default(as_strided_scatter, [4], [1], 0)
as_strided_3 = torch.ops.aten.as_strided.default(as_strided_scatter, [4], [1], 0); as_strided_scatter = None
as_strided_3 = torch.ops.aten.as_strided.default(as_strided_scatter, [4], [1], 0)
add = torch.ops.aten.add.Tensor(as_strided_3, 1); as_strided_3 = None
add_1 = torch.ops.aten.add.Tensor(primals_2, 1); primals_2 = None
return [as_strided_2, add, add_1]""") # noqa: B950
return [as_strided_scatter, add, add_1]""") # noqa: B950

@patch("functorch.compile.config.use_fake_tensor", True)
def test_input_mutation_aliases_bases_out_of_order(self):
Expand Down Expand Up @@ -1305,14 +1302,13 @@ def forward(self, primals_1, primals_2, primals_3):
as_strided = torch.ops.aten.as_strided.default(clone, [4], [1], 0)
add = torch.ops.aten.add.Tensor(as_strided, 1); as_strided = None
as_strided_scatter = torch.ops.aten.as_strided_scatter.default(clone, add, [4], [1], 0); clone = add = None
as_strided_2 = torch.ops.aten.as_strided.default(as_strided_scatter, [4], [1], 0)
add_1 = torch.ops.aten.add.Tensor(primals_2, primals_3); primals_2 = primals_3 = None
as_strided_5 = torch.ops.aten.as_strided.default(as_strided_scatter, [4], [1], 0)
t_1 = torch.ops.aten.t.default(as_strided_5); as_strided_5 = None
add_2 = torch.ops.aten.add.Tensor(add_1, t_1); add_1 = None
as_strided_18 = torch.ops.aten.as_strided.default(as_strided_scatter, [4], [1], 0); as_strided_scatter = None
view_1 = torch.ops.aten.view.default(as_strided_18, [-1]); as_strided_18 = None
return [as_strided_2, t_1, add_2, view_1]""") # noqa: B950
as_strided_14 = torch.ops.aten.as_strided.default(as_strided_scatter, [4], [1], 0)
view_1 = torch.ops.aten.view.default(as_strided_14, [-1]); as_strided_14 = None
return [as_strided_scatter, add_2, view_1, t_1]""") # noqa: B950

@unittest.skipIf(not torch.cuda.is_available(), "CUDA is unavailable")
def test_synthetic_base_base_attribute_is_none(self):
Expand Down Expand Up @@ -1376,12 +1372,12 @@ def forward(self, primals_1, primals_2):
as_strided_scatter = torch.ops.aten.as_strided_scatter.default(clone, mul, [4], [1], 0); clone = mul = None
as_strided_2 = torch.ops.aten.as_strided.default(as_strided_scatter, [4], [1], 0)
t = torch.ops.aten.t.default(view); view = None
as_strided_5 = torch.ops.aten.as_strided.default(as_strided_scatter, [4], [1], 0); as_strided_scatter = None
add = torch.ops.aten.add.Tensor(as_strided_5, as_strided_2); as_strided_5 = None
as_strided_5 = torch.ops.aten.as_strided.default(as_strided_scatter, [4], [1], 0)
add = torch.ops.aten.add.Tensor(as_strided_5, as_strided_2); as_strided_5 = as_strided_2 = None
view_1 = torch.ops.aten.view.default(add, [-1])
t_1 = torch.ops.aten.t.default(t)
unsqueeze = torch.ops.aten.unsqueeze.default(view_1, 0)
return [t, as_strided_2, view_1, t_1, unsqueeze, add]""") # noqa: B950
return [as_strided_scatter, t, view_1, t_1, unsqueeze, add]""") # noqa: B950

@patch("functorch.compile.config.use_fake_tensor", True)
def test_dynamic_shape_output_not_in_bw_graph(self):
Expand Down