Skip to content

Commit

Permalink
aot autograd refactor: make all synthetic base logic layered in a sin…
Browse files Browse the repository at this point in the history
…gle location

ghstack-source-id: 529ecb436f17194feddf3a88aada7946a6875571
Pull Request resolved: #96235
  • Loading branch information
bdhirsh committed Mar 8, 2023
1 parent 0208ce7 commit 9aa531b
Show file tree
Hide file tree
Showing 2 changed files with 271 additions and 129 deletions.
44 changes: 20 additions & 24 deletions test/functorch/test_aotdispatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -1118,9 +1118,9 @@ def forward(self, primals_1):
add = torch.ops.aten.add.Tensor(as_strided, 1); as_strided = None
as_strided_scatter = torch.ops.aten.as_strided_scatter.default(clone, add, [2], [1], 0); clone = add = None
as_strided_2 = torch.ops.aten.as_strided.default(as_strided_scatter, [2], [1], 0)
as_strided_5 = torch.ops.aten.as_strided.default(as_strided_scatter, [2], [1], 2); as_strided_scatter = None
add_1 = torch.ops.aten.add.Tensor(as_strided_2, as_strided_5); as_strided_5 = None
return [as_strided_2, add_1]""") # noqa: B950
as_strided_5 = torch.ops.aten.as_strided.default(as_strided_scatter, [2], [1], 2)
add_1 = torch.ops.aten.add.Tensor(as_strided_2, as_strided_5); as_strided_2 = as_strided_5 = None
return [as_strided_scatter, add_1]""") # noqa: B950

@patch("functorch.compile.config.use_fake_tensor", True)
def test_input_mutation_aliases_other_input2(self):
Expand All @@ -1145,9 +1145,9 @@ def forward(self, primals_1):
add = torch.ops.aten.add.Tensor(as_strided, 1); as_strided = None
as_strided_scatter = torch.ops.aten.as_strided_scatter.default(clone, add, [2], [1], 0); clone = add = None
as_strided_2 = torch.ops.aten.as_strided.default(as_strided_scatter, [2], [1], 0)
as_strided_5 = torch.ops.aten.as_strided.default(as_strided_scatter, [2, 2], [2, 1], 0); as_strided_scatter = None
add_1 = torch.ops.aten.add.Tensor(as_strided_2, as_strided_5); as_strided_5 = None
return [as_strided_2, add_1]""") # noqa: B950
as_strided_5 = torch.ops.aten.as_strided.default(as_strided_scatter, [2, 2], [2, 1], 0)
add_1 = torch.ops.aten.add.Tensor(as_strided_2, as_strided_5); as_strided_2 = as_strided_5 = None
return [as_strided_scatter, add_1]""") # noqa: B950

@patch("functorch.compile.config.use_fake_tensor", True)
def test_input_mutation_aliases_and_output_alias(self):
Expand All @@ -1170,10 +1170,9 @@ def forward(self, primals_1):
as_strided = torch.ops.aten.as_strided.default(clone, [4], [1], 0)
add = torch.ops.aten.add.Tensor(as_strided, 1); as_strided = None
as_strided_scatter = torch.ops.aten.as_strided_scatter.default(clone, add, [4], [1], 0); clone = add = None
as_strided_2 = torch.ops.aten.as_strided.default(as_strided_scatter, [4], [1], 0)
as_strided_13 = torch.ops.aten.as_strided.default(as_strided_scatter, [4], [1], 0); as_strided_scatter = None
view_1 = torch.ops.aten.view.default(as_strided_13, [4]); as_strided_13 = None
return [as_strided_2, view_1]""") # noqa: B950
as_strided_8 = torch.ops.aten.as_strided.default(as_strided_scatter, [4], [1], 0)
view_1 = torch.ops.aten.view.default(as_strided_8, [4]); as_strided_8 = None
return [as_strided_scatter, view_1]""") # noqa: B950

@patch("functorch.compile.config.use_fake_tensor", True)
def test_input_aliased_with_mutation_output_alias(self):
Expand Down Expand Up @@ -1201,11 +1200,10 @@ def forward(self, primals_1, primals_2):
as_strided_1 = torch.ops.aten.as_strided.default(clone, [4], [1], 0)
mul = torch.ops.aten.mul.Tensor(as_strided_1, 2); as_strided_1 = None
as_strided_scatter = torch.ops.aten.as_strided_scatter.default(clone, mul, [4], [1], 0); clone = mul = None
as_strided_2 = torch.ops.aten.as_strided.default(as_strided_scatter, [4], [1], 0)
add = torch.ops.aten.add.Tensor(primals_2, 1); primals_2 = None
as_strided_12 = torch.ops.aten.as_strided.default(as_strided_scatter, [4], [1], 0); as_strided_scatter = None
view_1 = torch.ops.aten.view.default(as_strided_12, [-1]); as_strided_12 = None
return [as_strided_2, add, view_1]""") # noqa: B950
as_strided_7 = torch.ops.aten.as_strided.default(as_strided_scatter, [4], [1], 0)
view_1 = torch.ops.aten.view.default(as_strided_7, [-1]); as_strided_7 = None
return [as_strided_scatter, add, view_1]""") # noqa: B950

@patch("functorch.compile.config.use_fake_tensor", True)
def test_input_metadata_mutation_aliases(self):
Expand Down Expand Up @@ -1253,11 +1251,10 @@ def forward(self, primals_1, primals_2):
as_strided = torch.ops.aten.as_strided.default(clone, [4], [1], 0)
mul = torch.ops.aten.mul.Tensor(as_strided, 2); as_strided = None
as_strided_scatter = torch.ops.aten.as_strided_scatter.default(clone, mul, [4], [1], 0); clone = mul = None
as_strided_2 = torch.ops.aten.as_strided.default(as_strided_scatter, [4], [1], 0)
as_strided_3 = torch.ops.aten.as_strided.default(as_strided_scatter, [4], [1], 0); as_strided_scatter = None
as_strided_3 = torch.ops.aten.as_strided.default(as_strided_scatter, [4], [1], 0)
add = torch.ops.aten.add.Tensor(as_strided_3, 1); as_strided_3 = None
add_1 = torch.ops.aten.add.Tensor(primals_2, 1); primals_2 = None
return [as_strided_2, add, add_1]""") # noqa: B950
return [as_strided_scatter, add, add_1]""") # noqa: B950

@patch("functorch.compile.config.use_fake_tensor", True)
def test_input_mutation_aliases_bases_out_of_order(self):
Expand Down Expand Up @@ -1291,14 +1288,13 @@ def forward(self, primals_1, primals_2, primals_3):
as_strided = torch.ops.aten.as_strided.default(clone, [4], [1], 0)
add = torch.ops.aten.add.Tensor(as_strided, 1); as_strided = None
as_strided_scatter = torch.ops.aten.as_strided_scatter.default(clone, add, [4], [1], 0); clone = add = None
as_strided_2 = torch.ops.aten.as_strided.default(as_strided_scatter, [4], [1], 0)
add_1 = torch.ops.aten.add.Tensor(primals_2, primals_3); primals_2 = primals_3 = None
as_strided_5 = torch.ops.aten.as_strided.default(as_strided_scatter, [4], [1], 0)
t_1 = torch.ops.aten.t.default(as_strided_5); as_strided_5 = None
add_2 = torch.ops.aten.add.Tensor(add_1, t_1); add_1 = None
as_strided_18 = torch.ops.aten.as_strided.default(as_strided_scatter, [4], [1], 0); as_strided_scatter = None
view_1 = torch.ops.aten.view.default(as_strided_18, [-1]); as_strided_18 = None
return [as_strided_2, t_1, add_2, view_1]""") # noqa: B950
as_strided_14 = torch.ops.aten.as_strided.default(as_strided_scatter, [4], [1], 0)
view_1 = torch.ops.aten.view.default(as_strided_14, [-1]); as_strided_14 = None
return [as_strided_scatter, add_2, view_1, t_1]""") # noqa: B950

@unittest.skipIf(not torch.cuda.is_available(), "CUDA is unavailable")
def test_synthetic_base_base_attribute_is_none(self):
Expand Down Expand Up @@ -1362,12 +1358,12 @@ def forward(self, primals_1, primals_2):
as_strided_scatter = torch.ops.aten.as_strided_scatter.default(clone, mul, [4], [1], 0); clone = mul = None
as_strided_2 = torch.ops.aten.as_strided.default(as_strided_scatter, [4], [1], 0)
t = torch.ops.aten.t.default(view); view = None
as_strided_5 = torch.ops.aten.as_strided.default(as_strided_scatter, [4], [1], 0); as_strided_scatter = None
add = torch.ops.aten.add.Tensor(as_strided_5, as_strided_2); as_strided_5 = None
as_strided_5 = torch.ops.aten.as_strided.default(as_strided_scatter, [4], [1], 0)
add = torch.ops.aten.add.Tensor(as_strided_5, as_strided_2); as_strided_5 = as_strided_2 = None
view_1 = torch.ops.aten.view.default(add, [-1])
t_1 = torch.ops.aten.t.default(t)
unsqueeze = torch.ops.aten.unsqueeze.default(view_1, 0)
return [t, as_strided_2, view_1, t_1, unsqueeze, add]""") # noqa: B950
return [as_strided_scatter, t, view_1, t_1, unsqueeze, add]""") # noqa: B950

@patch("functorch.compile.config.use_fake_tensor", True)
def test_dynamic_shape_output_not_in_bw_graph(self):
Expand Down

0 comments on commit 9aa531b

Please sign in to comment.