From 84d64d72d66b9cccbc58066c731a18bf8493db05 Mon Sep 17 00:00:00 2001 From: Tugsbayasgalan Manlaibaatar Date: Wed, 8 Nov 2023 11:10:57 -0800 Subject: [PATCH] Persist copy_ in training graph for inputs that don't require grad (#111046) In this PR, we try to keep the input mutations in the forward graph IFF input mutation is data mutation and not metadata mutation and doesn't require grad. This is for optimizing inductor training graphs. (For more details: https://github.com/pytorch/pytorch/issues/109240) We keep the input mutation in the graph by wrapping the original callable in a wrapper function where in the end we add input.copy_(updated_input) call which is then traced via make_fx. Previously, this was only enabled for forward-only path but unconditionally disabled for joint graph. Another caveat is that when we are tracing through tensor subclasses, we won't allow any input mutations to be preserved in the graph. The reason is that it makes the code logic quite ugly for no obvious performance improvement. Most of the changes in this PR are mechanical and I didn't have to make any change to the partitioner. Previously forward/backward heavily relied on metadata field `num_mutated_inps` to figure out whether something is returned as extra output or not. But now since we keep some mutations in the graph, we need to propogate something similar to `num_mutated_inps - num_graph_handled_inps`. Pull Request resolved: https://github.com/pytorch/pytorch/pull/111046 Approved by: https://github.com/ezyang, https://github.com/bdhirsh --- test/functorch/test_aotdispatch.py | 193 ++++++++++++++++++++++++++ test/inductor/test_torchinductor.py | 137 +++++++++++++++++++ torch/_functorch/aot_autograd.py | 201 +++++++++++++++++++++------- torch/_inductor/compile_fx.py | 4 +- torch/_inductor/pattern_matcher.py | 1 + 5 files changed, 483 insertions(+), 53 deletions(-) diff --git a/test/functorch/test_aotdispatch.py b/test/functorch/test_aotdispatch.py index 4b7d77806460e..d571338cab458 100644 --- a/test/functorch/test_aotdispatch.py +++ b/test/functorch/test_aotdispatch.py @@ -19,6 +19,7 @@ skipIfRocm, ) from torch.testing._internal.two_tensor import TwoTensor, TwoTensorMode +import copy import torch import torch.nn as nn import torch.utils._pytree as pytree @@ -2192,6 +2193,198 @@ def f(a): self.assertEqual(inp_ref.grad, inp_test.grad) + def test_buffer_copied_in_graph(self): + class MyModel(torch.nn.Module): + def __init__(self): + super().__init__() + self.register_buffer("buf", torch.zeros(1)) + self.w1 = torch.nn.Parameter(torch.zeros(1)) + self.w2 = torch.nn.Parameter(torch.zeros(1)) + + def forward(self, x): + self.buf.add_(1) + return (self.w1 * x * self.w2).sum() + self.buf.sum() + + model_for_eager = MyModel() + model_for_compile = copy.deepcopy(model_for_eager) + + fw_graph_cell = [None] + compiled_f = aot_module( + model_for_compile, + fw_compiler=make_boxed_compiler(partial(extract_graph, graph_cell=fw_graph_cell)), + bw_compiler=nop, + keep_inference_input_mutations=True, + ) + inp_ref = torch.ones(1, requires_grad=True) + inp_test = torch.ones(1, requires_grad=True) + + out_ref = model_for_eager(inp_ref.clone()) + out_test = compiled_f(inp_test.clone()) + + self.assertExpectedInline(fw_graph_cell[0].code.strip(), """\ +def forward(self, primals_1, primals_2, primals_3, primals_4): + add = torch.ops.aten.add.Tensor(primals_3, 1) + mul = torch.ops.aten.mul.Tensor(primals_1, primals_4) + mul_1 = torch.ops.aten.mul.Tensor(mul, primals_2) + sum_1 = torch.ops.aten.sum.default(mul_1); mul_1 = None + sum_2 = torch.ops.aten.sum.default(add) + add_1 = torch.ops.aten.add.Tensor(sum_1, sum_2); sum_1 = sum_2 = None + copy_ = torch.ops.aten.copy_.default(primals_3, add); primals_3 = add = None + return [add_1, primals_1, primals_2, primals_4, mul]""") + + self.assertEqual(out_ref, out_test) + + out_ref.sum().backward() + out_test.sum().backward() + + eager_grads = [p.grad for _, p in model_for_eager.named_parameters()] + compile_grads = [p.grad for _, p in model_for_compile.named_parameters()] + + self.assertEqual(eager_grads, compile_grads) + self.assertEqual(inp_ref.grad, inp_test.grad) + + def test_buffer_copied_in_graph_with_different_shapes(self): + class MyModel(torch.nn.Module): + def __init__(self): + super().__init__() + self.register_buffer("buf", torch.ones(4, 4)) + self.w = torch.nn.Parameter(torch.Tensor([[4, 5], [1, 2], [6, 7], [8, 9]])) + + def forward(self, x): + self.buf.add_(1) + return (self.w @ x).sum() + self.buf.sum() + + model_for_eager = MyModel() + model_for_compile = copy.deepcopy(model_for_eager) + + fw_graph_cell = [None] + compiled_f = aot_module( + model_for_compile, + fw_compiler=make_boxed_compiler(partial(extract_graph, graph_cell=fw_graph_cell)), + bw_compiler=nop, + keep_inference_input_mutations=True, + ) + inp_ref = torch.ones(2, 4, requires_grad=True) + inp_test = torch.ones(2, 4, requires_grad=True) + + out_ref = model_for_eager(inp_ref.clone()) + out_test = compiled_f(inp_test.clone()) + + self.assertExpectedInline(fw_graph_cell[0].code.strip(), """\ +def forward(self, primals_1, primals_2, primals_3): + add = torch.ops.aten.add.Tensor(primals_2, 1) + mm = torch.ops.aten.mm.default(primals_1, primals_3) + sum_1 = torch.ops.aten.sum.default(mm); mm = None + sum_2 = torch.ops.aten.sum.default(add) + add_1 = torch.ops.aten.add.Tensor(sum_1, sum_2); sum_1 = sum_2 = None + copy_ = torch.ops.aten.copy_.default(primals_2, add); primals_2 = add = None + return [add_1, primals_1, primals_3]""") + self.assertEqual(out_ref, out_test) + + out_ref.sum().backward() + out_test.sum().backward() + + eager_grads = [p.grad for _, p in model_for_eager.named_parameters()] + compile_grads = [p.grad for _, p in model_for_compile.named_parameters()] + + self.assertEqual(eager_grads, compile_grads) + + self.assertEqual(inp_ref.grad, inp_test.grad) + + def test_buffer_batch_norm(self): + class MyModel(torch.nn.Module): + def __init__(self): + super().__init__() + self.m = torch.nn.BatchNorm1d(100) + + def forward(self, x): + return self.m(x) + + model_for_eager = MyModel() + model_for_compile = copy.deepcopy(model_for_eager) + + fw_graph_cell = [None] + bw_graph_cell = [None] + compiled_f = aot_module( + model_for_compile, + fw_compiler=make_boxed_compiler(partial(extract_graph, graph_cell=fw_graph_cell)), + bw_compiler=make_boxed_compiler(partial(extract_graph, graph_cell=bw_graph_cell)), + keep_inference_input_mutations=True, + ) + inp_ref = torch.ones(20, 100, requires_grad=True) + inp_test = torch.ones(20, 100, requires_grad=True) + + out_ref = model_for_eager(inp_ref.clone()) + out_test = compiled_f(inp_test.clone()) + + self.assertExpectedInline(fw_graph_cell[0].code.strip(), """\ +def forward(self, primals_1, primals_2, primals_3, primals_4, primals_5, primals_6): + add = torch.ops.aten.add.Tensor(primals_5, 1) + _native_batch_norm_legit_functional = torch.ops.aten._native_batch_norm_legit_functional.default(primals_6, primals_1, primals_2, primals_3, primals_4, True, 0.1, 1e-05); primals_2 = None + getitem = _native_batch_norm_legit_functional[0] + getitem_1 = _native_batch_norm_legit_functional[1] + getitem_2 = _native_batch_norm_legit_functional[2] + getitem_3 = _native_batch_norm_legit_functional[3] + getitem_4 = _native_batch_norm_legit_functional[4]; _native_batch_norm_legit_functional = None + copy_ = torch.ops.aten.copy_.default(primals_3, getitem_3); primals_3 = None + copy__1 = torch.ops.aten.copy_.default(primals_4, getitem_4); primals_4 = None + copy__2 = torch.ops.aten.copy_.default(primals_5, add); primals_5 = add = None + return [getitem, primals_1, primals_6, getitem_1, getitem_2, getitem_3, getitem_4]""") # noqa: B950 + + self.assertEqual(out_ref, out_test) + + out_ref.sum().backward() + out_test.sum().backward() + + eager_grads = [p.grad for _, p in model_for_eager.named_parameters()] + compile_grads = [p.grad for _, p in model_for_compile.named_parameters()] + self.assertEqual(eager_grads, compile_grads) + + self.assertExpectedInline(bw_graph_cell[0].code.strip(), """\ +def forward(self, primals_1, primals_6, getitem_1, getitem_2, getitem_3, getitem_4, tangents_1): + native_batch_norm_backward = torch.ops.aten.native_batch_norm_backward.default(tangents_1, primals_6, primals_1, getitem_3, getitem_4, getitem_1, getitem_2, True, 1e-05, [True, True, True]); tangents_1 = primals_6 = primals_1 = getitem_3 = getitem_4 = getitem_1 = getitem_2 = None + getitem_5 = native_batch_norm_backward[0] + getitem_6 = native_batch_norm_backward[1] + getitem_7 = native_batch_norm_backward[2]; native_batch_norm_backward = None + return [getitem_6, getitem_7, None, None, None, getitem_5]""") # noqa: B950 + + self.assertEqual(inp_ref.grad, inp_test.grad) + + def test_new_inp_requires_grad_now(self): + def f(x, y): + return x.add_(y) + + fw_graph_cell = [None] + bw_graph_cell = [None] + compiled_f = aot_function( + f, + fw_compiler=make_boxed_compiler(partial(extract_graph, graph_cell=fw_graph_cell)), + bw_compiler=make_boxed_compiler(partial(extract_graph, graph_cell=bw_graph_cell)), + keep_inference_input_mutations=True, + ) + + inp_ref = (torch.ones(20, 100, requires_grad=False), torch.ones(20, 100, requires_grad=True)) + inp_test = (torch.ones(20, 100, requires_grad=False), torch.ones(20, 100, requires_grad=True)) + + out_ref = f(*inp_ref) + out_test = compiled_f(*inp_test) + + # There is no copy_ method + self.assertExpectedInline(fw_graph_cell[0].code.strip(), """\ +def forward(self, primals_1, primals_2): + clone = torch.ops.aten.clone.default(primals_1); primals_1 = None + add = torch.ops.aten.add.Tensor(clone, primals_2); clone = primals_2 = None + return [add, add]""") # noqa: B950 + + self.assertEqual(out_ref, out_test) + + out_ref.sum().backward() + out_test.sum().backward() + + self.assertExpectedInline(bw_graph_cell[0].code.strip(), """\ +def forward(self, tangents_1): + return [None, tangents_1]""") # noqa: B950 + def test_real_weights_in_symbolic_mode(self): from functorch.experimental import functionalize diff --git a/test/inductor/test_torchinductor.py b/test/inductor/test_torchinductor.py index 48f774615f6fc..676fac0f0190c 100644 --- a/test/inductor/test_torchinductor.py +++ b/test/inductor/test_torchinductor.py @@ -7,6 +7,7 @@ import importlib import itertools import math +import operator import os import random import re @@ -2647,6 +2648,142 @@ def forward(self, x): expected = mod(x) self.assertTrue(torch.allclose(res, expected)) + def test_buffer_copied_in_graph(self): + class MyModel(torch.nn.Module): + def __init__(self): + super().__init__() + self.register_buffer("buf", torch.zeros(1)) + self.w1 = torch.nn.Parameter(torch.zeros(1)) + self.w2 = torch.nn.Parameter(torch.zeros(1)) + + def forward(self, x): + self.buf.add_(1) + return (self.w1 * x * self.w2).sum() + self.buf.sum() + + model_for_eager = MyModel() + model_for_compile = copy.deepcopy(model_for_eager) + + eager_version_counters = [ + buffer._version for _, buffer in model_for_eager.named_buffers() + ] + compile_version_counters = [ + buffer._version for _, buffer in model_for_compile.named_buffers() + ] + + compiled_f = torch.compile(model_for_compile, backend="inductor") + + inp_ref = torch.ones(1, requires_grad=True) + inp_test = torch.ones(1, requires_grad=True) + + out_ref = model_for_eager(inp_ref.clone()) + out_test = compiled_f(inp_test.clone()) + + eager_version_counters_after = [ + buffer._version for _, buffer in model_for_eager.named_buffers() + ] + compile_version_counters_after = [ + buffer._version for _, buffer in model_for_compile.named_buffers() + ] + + eager_delta = list( + map(operator.sub, eager_version_counters_after, eager_version_counters) + ) + compile_delta = list( + map(operator.sub, compile_version_counters_after, compile_version_counters) + ) + + self.assertEqual(eager_delta, compile_delta) + + def test_buffer_copied_in_graph_with_different_shapes(self): + class MyModel(torch.nn.Module): + def __init__(self): + super().__init__() + self.register_buffer("buf", torch.ones(4, 4)) + self.w = torch.nn.Parameter( + torch.Tensor([[4, 5], [1, 2], [6, 7], [8, 9]]) + ) + + def forward(self, x): + self.buf.add_(1) + return (self.w @ x).sum() + self.buf.sum() + + model_for_eager = MyModel() + model_for_compile = copy.deepcopy(model_for_eager) + + eager_version_counters = [ + buffer._version for _, buffer in model_for_eager.named_buffers() + ] + compile_version_counters = [ + buffer._version for _, buffer in model_for_compile.named_buffers() + ] + + compiled_f = torch.compile(model_for_compile, backend="inductor") + + inp_ref = torch.ones(2, 4, requires_grad=True) + inp_test = torch.ones(2, 4, requires_grad=True) + + out_ref = model_for_eager(inp_ref.clone()) + out_test = compiled_f(inp_test.clone()) + + eager_version_counters_after = [ + buffer._version for _, buffer in model_for_eager.named_buffers() + ] + compile_version_counters_after = [ + buffer._version for _, buffer in model_for_compile.named_buffers() + ] + + eager_delta = list( + map(operator.sub, eager_version_counters_after, eager_version_counters) + ) + compile_delta = list( + map(operator.sub, compile_version_counters_after, compile_version_counters) + ) + + self.assertEqual(eager_delta, compile_delta) + + def test_buffer_batch_norm(self): + class MyModel(torch.nn.Module): + def __init__(self): + super().__init__() + self.m = torch.nn.BatchNorm1d(100) + + def forward(self, x): + return self.m(x) + + model_for_eager = MyModel() + model_for_compile = copy.deepcopy(model_for_eager) + + eager_version_counters = [ + buffer._version for _, buffer in model_for_eager.named_buffers() + ] + compile_version_counters = [ + buffer._version for _, buffer in model_for_compile.named_buffers() + ] + + compiled_f = torch.compile(model_for_compile, backend="inductor") + + inp_ref = torch.ones(20, 100, requires_grad=True) + inp_test = torch.ones(20, 100, requires_grad=True) + + out_ref = model_for_eager(inp_ref.clone()) + out_test = compiled_f(inp_test.clone()) + + eager_version_counters_after = [ + buffer._version for _, buffer in model_for_eager.named_buffers() + ] + compile_version_counters_after = [ + buffer._version for _, buffer in model_for_compile.named_buffers() + ] + + eager_delta = list( + map(operator.sub, eager_version_counters_after, eager_version_counters) + ) + compile_delta = list( + map(operator.sub, compile_version_counters_after, compile_version_counters) + ) + + self.assertEqual(eager_delta, compile_delta) + def test_adaptive_avg_pool_with_output_size_0(self): m1 = nn.AdaptiveAvgPool1d(0) self.common(m1, (torch.randn(1, 2),)) diff --git a/torch/_functorch/aot_autograd.py b/torch/_functorch/aot_autograd.py index e2d8d9a4ab3c3..06e8c893fdb82 100644 --- a/torch/_functorch/aot_autograd.py +++ b/torch/_functorch/aot_autograd.py @@ -464,6 +464,12 @@ class OutputAliasInfo: requires_grad: bool +class MutationType(Enum): + NOT_MUTATED = 1 + MUTATED_IN_GRAPH = 2 + MUTATED_OUT_GRAPH = 3 + + # This class tells us info about user inputs. @dataclass(frozen=True) class InputAliasInfo: @@ -472,6 +478,7 @@ class InputAliasInfo: mutates_metadata: bool mutations_hidden_from_autograd: bool requires_grad: bool + mutation_type: MutationType @dataclasses.dataclass @@ -522,6 +529,7 @@ def __post_init__(self): # sanity assert to make sure we don't leak memory assert is_fake(self.original_subclass) + # This class encapsulates all aliasing + mutation info we need about the forward graph # See a more detailed overview of the edge case handling at # https://docs.google.com/document/d/19UoIh_SVrMy_b2Sx5ZaeOJttm6P0Qmyss2rdBuyfoic/edit @@ -580,6 +588,9 @@ class ViewAndMutationMeta: # TODO: we should kill this # (need to default it to not break internal) is_train: bool = False + # We're plumbing this requires_subclass_dispatch here is because it's painful to support input mutations + # on subclasses, and that info isn't easily available. + requires_subclass_dispatch: bool = False num_symints_saved_for_bw: Optional[int] = None @@ -592,14 +603,31 @@ class ViewAndMutationMeta: def __post_init__(self): mutated_inp_indices = [ - i for i, m in enumerate(self.input_info) if m.mutates_metadata or m.mutates_data + i for i, m in enumerate(self.input_info) + if m.mutation_type in (MutationType.MUTATED_IN_GRAPH, MutationType.MUTATED_OUT_GRAPH) ] # pre-compute the indices of the inputs that are mutated. # When keep_input_mutations is set, we don't need to worry about our epilogue # handling data-only mutations, because we keep them directly in the graph. - mutated_inp_runtime_indices = [ - i for i, m in enumerate(self.input_info) if m.mutates_metadata or (not self.keep_input_mutations and m.mutates_data) + + # TODO (tmanlaibaatar) Ideally input mutation type should be calculated + # based on requires_subclass_dispatch argument but this is not easy to do because you would + # have to pass around this argument multiple level down. + if not self.requires_subclass_dispatch: + mutated_inp_runtime_indices = [ + i for i, m in enumerate(self.input_info) + if (m.mutation_type == MutationType.MUTATED_OUT_GRAPH) + ] + else: + mutated_inp_runtime_indices = mutated_inp_indices + + mutated_graph_handled_indices = [ + i for i, m in enumerate(self.input_info) + if m.mutation_type == MutationType.MUTATED_IN_GRAPH and not self.requires_subclass_dispatch ] + self.mutated_graph_handled_indices = mutated_graph_handled_indices + self.num_mutated_graph_handled_indices = len(self.mutated_graph_handled_indices) + aliased_out_indices = [ i for i, m in enumerate(self.output_info) @@ -614,6 +642,13 @@ def __post_init__(self): # It contains the index of every element # of input_info that corresponds to a mutation (data or metadata or both) self.mutated_inp_runtime_indices = mutated_inp_runtime_indices + self.num_mutated_inp_runtime_indices = len(self.mutated_inp_runtime_indices) + + assert ( + self.num_mutated_graph_handled_indices + self.num_mutated_inp_runtime_indices == + len(mutated_inp_indices) + ) + # This is pre-computed for perf. # It contains the index of every element # of output_info that corresponds to an alias (either of an input or intermediate) @@ -652,6 +687,7 @@ def __post_init__(self): self.num_mutated_data_inputs = len( [x for x in self.input_info if x.mutates_data] ) + self.num_mutated_metadata_inputs = len( [ x @@ -687,7 +723,7 @@ def __post_init__(self): self.num_outputs_rng_offset = 1 if self.is_rng_op_functionalized else 0 # Our forward() returns both (mutated_inputs, outputs, output_intermediate_bases, saved_tensors, saved_symints) - self.num_forward_returns = self.num_mutated_inputs + self.num_outputs + self.num_intermediate_bases + self.num_forward_returns = self.num_mutated_inp_runtime_indices + self.num_outputs + self.num_intermediate_bases # In case of functionalization of rng ops, the fw_module returns one # additional output for rng offset. This rng offset is used right # away to advance the rng state, and is not passed on to the raw @@ -1035,6 +1071,7 @@ def run_functionalized_fw_and_collect_metadata( keep_input_mutations: bool, # TODO: refactor to kill this flag is_train: bool = False, + requires_subclass_dispatch: bool = False, ) -> ViewAndMutationMeta: memo = {} @@ -1095,12 +1132,15 @@ def inner(*flat_args): mutates_metadata = False mutations_hidden_from_autograd = False + requires_grad = isinstance(f_arg, torch.Tensor) and f_arg.requires_grad + input_info.append(InputAliasInfo( is_leaf=isinstance(arg, torch.Tensor) and safe_is_leaf(arg), mutates_data=mutates_data, mutates_metadata=mutates_metadata, mutations_hidden_from_autograd=mutations_hidden_from_autograd, - requires_grad=isinstance(f_arg, torch.Tensor) and f_arg.requires_grad + requires_grad=requires_grad, + mutation_type=_get_mutation_type(keep_input_mutations, is_train, mutates_data, mutates_metadata, requires_grad) )) # If a function involves creating a tensor, and returning a view of it, such that its _base is the intermediate, @@ -1430,6 +1470,7 @@ def view_avoid_dupes_with_primals(t): subclass_tangent_meta=create_subclass_meta(traced_tangents), is_train=is_train, grad_enabled_mutation=grad_enabled_mutation, + requires_subclass_dispatch=requires_subclass_dispatch, ) return metadata @@ -1605,9 +1646,9 @@ class AOTConfig: def maybe_to_fresh_input(idx, t, meta): if not isinstance(t, Tensor): return t - if idx in meta.mutated_inp_indices: + if idx in meta.mutated_inp_runtime_indices: # We only need to bother cloning mutated inputs that participate in autograd. - mutated_inp_idx = meta.mutated_inp_indices.index(idx) + mutated_inp_idx = meta.mutated_inp_runtime_indices.index(idx) if meta.input_info[idx].requires_grad and meta.input_info[idx].mutates_data: # Make sure the primal we pass to autograd.grad() # sees the tensor before the mutation @@ -1671,7 +1712,7 @@ def inner_fn(*args): mutated_inputs_to_return = [ x for (i, x) in enumerate(args_maybe_cloned) - if meta.input_info[i].mutates_metadata or meta.input_info[i].mutates_data + if i in meta.mutated_inp_runtime_indices ] intermediate_bases = [] @@ -1686,7 +1727,8 @@ def inner_fn(*args): # Also return a boolean mask specifying which outputs to this function will be used as tangents mutated_inputs_grad_mask = [ - meta.input_info[meta.mutated_inp_indices[i]].mutates_data and meta.input_info[meta.mutated_inp_indices[i]].requires_grad + meta.input_info[meta.mutated_inp_runtime_indices[i]].mutates_data and + meta.input_info[meta.mutated_inp_runtime_indices[i]].requires_grad for (i, x) in enumerate(mutated_inputs_to_return) ] @@ -1834,7 +1876,7 @@ def functionalized_f_helper(*args): # Run the joint f_outs = fn(*f_args) - if aot_config.keep_inference_input_mutations and not trace_joint: + if aot_config.keep_inference_input_mutations: # Note: This is a bit annoying. There's a layering issue here, where: # (1) functionalization needs to operate on **synthetic base** inputs, before unpacking them into the "real" inputs. # (2) For keep_input_mutations, we support tracing a call to copy_() directly on mutated inputs. @@ -1859,16 +1901,15 @@ def functionalized_f_helper(*args): # we will materialize an "updated" synthetic base, and copy it back to the synthetic input base. # This allows us to factor aot autograd much more nicely, since only one area of the code needs to worry # about synthetic bases. - for i, (inpt_old, inpt_f) in enumerate(zip(args, f_args)): + for i, (inpt_old, inpt_f) in enumerate(zip(args, f_args) if not trace_joint else zip(args[0], f_args[0])): if not isinstance(inpt_f, torch.Tensor): continue assert is_fun(inpt_f) inpt_new = from_fun(inpt_f) - if meta.input_info[i].mutates_data and not meta.input_info[i].mutates_metadata: + if meta.input_info[i].mutation_type == MutationType.MUTATED_IN_GRAPH: # We found an input that had a (data-only) mutation. # Since keep_input_mutations is set, we need to faithfully apply a copy_() # so the compiler will see the input mutation in the graph. - assert inpt_new is not inpt_old if meta.input_info[i].mutations_hidden_from_autograd: with torch.no_grad(), torch.autograd._unsafe_preserve_version_counter(inpt_old): inpt_old.copy_(inpt_new) @@ -2016,7 +2057,8 @@ def aot_dispatch_base_graph( fn_to_trace, flat_args, meta=fw_metadata, aot_config=aot_config, trace_joint=False) fn_to_trace, updated_flat_args_subclasses_desugared, maybe_subclass_meta = aot_dispatch_subclass( - fn_to_trace, updated_flat_args, is_joint_structure=False, meta=fw_metadata, fw_only=flat_fn) + fn_to_trace, updated_flat_args, is_joint_structure=False, meta=fw_metadata, fw_only=flat_fn + ) fw_module = create_graph( fn_to_trace, @@ -2229,6 +2271,25 @@ def compute_overlapping_inputs(fwd_inputs, aliased_input_indices): actual_aliased_indices.add(j_) return actual_aliased_indices + +def _check_if_mutation_can_be_in_graph(keep_input_mutations: bool, is_train: bool, mutates_data, mutates_metadata, requires_grad): + if keep_input_mutations: + if is_train: + return mutates_data and not mutates_metadata and not requires_grad + return not mutates_metadata + return False + + +def _get_mutation_type(keep_input_mutations: bool, is_train: bool, mutates_data, mutates_metadata, requires_grad): + if (not mutates_data) and (not mutates_metadata): + return MutationType.NOT_MUTATED + + if _check_if_mutation_can_be_in_graph(keep_input_mutations, is_train, mutates_data, mutates_metadata, requires_grad): + return MutationType.MUTATED_IN_GRAPH + + return MutationType.MUTATED_OUT_GRAPH + + # Note [Handling mutations on an input that aliases other inputs] # The easiest example to show-case this edge case is here: # @@ -2467,7 +2528,7 @@ def remove_dupe_metadata( num_data_mutations = len([x for x in m.input_info if x.mutates_data]) other_traced_tangents = m.traced_tangents[num_data_mutations:] inp_traced_tangents = m.traced_tangents[:num_data_mutations] - filtered_inp_traced_tangents = [x for i, x in enumerate(inp_traced_tangents) if keep_arg_mask[m.mutated_inp_indices[i]]] + filtered_inp_traced_tangents = [x for i, x in enumerate(inp_traced_tangents) if keep_arg_mask[m.mutated_inp_runtime_indices[i]]] traced_tangents = filtered_inp_traced_tangents + other_traced_tangents return ViewAndMutationMeta( @@ -2491,7 +2552,7 @@ def remove_dupe_metadata( subclass_inp_meta=None, subclass_fw_graph_out_meta=None, subclass_tangent_meta=None, - is_train=m.is_train, + is_train=m.is_train ) # Given our ViewAndMutation metadata, this fn constructs a new set of metadata, @@ -2533,19 +2594,26 @@ def create_synthetic_base_metadata( any_leaf = any(m.input_info[x].is_leaf for x in outer_indices) all_leaf = all(m.input_info[x].is_leaf for x in outer_indices) assert any_leaf == all_leaf + + mutates_data = True if len(outer_indices) > 1 else m.input_info[outer_indices[0]].mutates_data + mutates_metadata = False if len(outer_indices) > 1 else m.input_info[outer_indices[0]].mutates_metadata + requires_grad = any(m.input_info[x].requires_grad for x in outer_indices) + inpt_info = InputAliasInfo( # If len(outer_indices) > 1, then this input is a synthetic base. # The invariant is that to the rest of aot autograd, synthetic bases only show up if # one of their aliases gets a data mutation. And if any of their aliases get metadata # mutations, they will be hidden from the rest of aot autograd. - mutates_data=True if len(outer_indices) > 1 else m.input_info[outer_indices[0]].mutates_data, - mutates_metadata=False if len(outer_indices) > 1 else m.input_info[outer_indices[0]].mutates_metadata, + mutates_data=mutates_data, + mutates_metadata=mutates_metadata, mutations_hidden_from_autograd=all(m.input_info[x].mutations_hidden_from_autograd for x in outer_indices), is_leaf=any_leaf, - requires_grad=any(m.input_info[x].requires_grad for x in outer_indices) + requires_grad=requires_grad, + mutation_type=_get_mutation_type(m.keep_input_mutations, m.is_train, mutates_data, mutates_metadata, requires_grad) ) input_infos.append(inpt_info) + # Find any inputs that fulfill the following criteria: # (1) They are part of a synthetic base (because they alias another input, # and at least one input experiences a data mutation) @@ -2599,7 +2667,7 @@ def create_synthetic_base_metadata( subclass_inp_meta=None, subclass_fw_graph_out_meta=None, subclass_tangent_meta=None, - is_train=m.is_train, + is_train=m.is_train ), outer_aliased_arg_idx_with_metadata_mutations # MOTIVATION: @@ -3060,27 +3128,43 @@ def runtime_wrapper(*args): ) num_mutated_inps = runtime_metadata.num_mutated_inputs + num_mutated_runtime_inps = runtime_metadata.num_mutated_inp_runtime_indices num_metadata_mutated_inps = runtime_metadata.num_mutated_metadata_inputs num_intermediate_bases = runtime_metadata.num_intermediate_bases if keep_input_mutations: - assert ( - len(all_outs) - == num_metadata_mutated_inps + runtime_metadata.num_outputs + num_intermediate_bases - ) - assert ( - len(runtime_metadata.mutated_inp_runtime_indices) == num_metadata_mutated_inps - ) + if not trace_joint: + assert ( + len(all_outs) + == num_metadata_mutated_inps + runtime_metadata.num_outputs + num_intermediate_bases + ) + assert ( + runtime_metadata.num_mutated_inp_runtime_indices == num_metadata_mutated_inps + ) + else: + num_graph_handled = runtime_metadata.num_mutated_graph_handled_indices + # autograd.Function requires us to return the mutated inputs as extra outputs to the autograd.Function.forward + if num_graph_handled > 0: + all_outs = all_outs[:-num_graph_handled] + assert ( + len(all_outs) + == num_mutated_runtime_inps + runtime_metadata.num_outputs + num_intermediate_bases + ) + assert ( + runtime_metadata.num_mutated_inp_runtime_indices == num_mutated_runtime_inps + ) + else: assert ( len(all_outs) == num_mutated_inps + runtime_metadata.num_outputs + num_intermediate_bases ) assert ( - len(runtime_metadata.mutated_inp_runtime_indices) == num_mutated_inps + runtime_metadata.num_mutated_inp_runtime_indices == num_mutated_inps ) + # Step 3: After running the compiled fw, apply updates to mutated inputs - num_mutations_to_apply = len(runtime_metadata.mutated_inp_runtime_indices) + num_mutations_to_apply = runtime_metadata.num_mutated_inp_runtime_indices if num_mutations_to_apply > 0: updated_inputs = all_outs[: num_mutations_to_apply] fw_outs = all_outs[num_mutations_to_apply :] @@ -3144,8 +3228,8 @@ def runtime_wrapper(*args): else: fw_outs_no_intermediate_bases = fw_outs intermediate_bases = [] - assert len(fw_outs_no_intermediate_bases) == len(runtime_metadata.output_info) + assert len(fw_outs_no_intermediate_bases) == len(runtime_metadata.output_info) fw_outs_including_aliases = [] for i, (o, info) in enumerate(zip( fw_outs_no_intermediate_bases, runtime_metadata.output_info @@ -3538,6 +3622,7 @@ def metadata_fn(*primals): metadata_fn, keep_input_mutations=meta.keep_input_mutations, is_train=meta.is_train, + requires_subclass_dispatch=True, )(*primals_unwrapped) subclass_meta.fw_metadata = meta_updated @@ -3579,7 +3664,8 @@ def aot_dispatch_autograd_graph(flat_fn, flat_args: List[Any], aot_config: AOTCo ) subclass_tracing_info = aot_dispatch_subclass( - joint_fn_to_trace, updated_joint_inputs, is_joint_structure=True, meta=fw_metadata, fw_only=flat_fn) + joint_fn_to_trace, updated_joint_inputs, is_joint_structure=True, meta=fw_metadata, fw_only=flat_fn + ) joint_fn_to_trace = subclass_tracing_info.plain_tensor_trace_fn updated_joint_inputs = subclass_tracing_info.plain_tensor_args @@ -3588,7 +3674,7 @@ def aot_dispatch_autograd_graph(flat_fn, flat_args: List[Any], aot_config: AOTCo fx_g = create_graph(joint_fn_to_trace, updated_joint_inputs, aot_config=aot_config) # There should be *NO* mutating ops in the graph at this point. - assert_functional_graph(fx_g.graph) + assert_functional_graph(fx_g.graph, allow_input_mutations=aot_config.keep_inference_input_mutations) # Redundant with the check above, but worth having in case tracing introduced # a fake tensor. Unlikely. @@ -3622,7 +3708,7 @@ def aot_dispatch_autograd(flat_fn, flat_args: List[Any], aot_config: AOTConfig, with track_graph_compiling(aot_config, "joint"): # See Note: [Partitioner handling for Subclasses, Part 1] num_inner_fwd_outputs = ( - inner_meta.num_mutated_inputs + inner_meta.num_mutated_inp_runtime_indices + inner_meta.num_outputs + inner_meta.num_intermediate_bases + inner_meta.num_outputs_rng_offset @@ -3830,6 +3916,12 @@ def _compiled_autograd_key(ctx): @staticmethod def forward(ctx, *deduped_flat_tensor_args): args = deduped_flat_tensor_args + + marked_dirty_inps = [] + for i in fw_metadata.mutated_graph_handled_indices: + ctx.mark_dirty(deduped_flat_tensor_args[i]) + marked_dirty_inps.append(deduped_flat_tensor_args[i]) + if CompiledFunction.metadata.is_rng_op_functionalized: # Add the seed and offset to args seed, offset = CUDARngStateHelper.get_torch_state_as_tuple() @@ -3850,6 +3942,7 @@ def forward(ctx, *deduped_flat_tensor_args): num_intermediate_bases = CompiledFunction.metadata.num_intermediate_bases num_symints_saved_for_bw = CompiledFunction.num_symints_saved_for_bw num_mutated_inputs = CompiledFunction.metadata.num_mutated_inputs + num_mutated_runtime_inps = CompiledFunction.metadata.num_mutated_inp_runtime_indices num_mutated_metadata_only_inputs = ( CompiledFunction.metadata.num_mutated_metadata_only_inputs ) @@ -3895,26 +3988,27 @@ def forward(ctx, *deduped_flat_tensor_args): if CompiledFunction.metadata.num_unsafe_view_outputs > 0: for idx in CompiledFunction.metadata.unsafe_view_out_indices: - raw_return_idx = num_mutated_inputs + idx + raw_return_idx = num_mutated_runtime_inps + idx o = raw_returns[raw_return_idx] raw_returns[raw_return_idx] = torch.ops.aten._unsafe_view(o, o.shape) if num_outputs_aliased > 0: for idx in CompiledFunction.metadata.aliased_out_indices: - raw_return_idx = num_mutated_inputs + idx + raw_return_idx = num_mutated_runtime_inps + idx raw_returns[raw_return_idx] = TensorAlias(raw_returns[raw_return_idx]) if config.debug_assert: - intermediates_raw = raw_returns[num_mutated_inputs + num_outputs:] + intermediates_raw = raw_returns[num_mutated_runtime_inps + num_outputs:] assert not any(isinstance(x, TensorAlias) for x in intermediates_raw) # invariant: intermediate bases always require gradients, so we don't have to # consider marking them as non-differentiable. - raw_returns_not_including_intermediate_bases = raw_returns[:num_mutated_inputs + num_outputs] - + raw_returns_not_including_intermediate_bases = raw_returns[:num_mutated_runtime_inps + num_outputs] raw_returns_meta = ( - [x for x in CompiledFunction.metadata.input_info if x.mutates_data or x.mutates_metadata] - + CompiledFunction.metadata.output_info + [ + x for x in CompiledFunction.metadata.input_info + if x.mutation_type == MutationType.MUTATED_OUT_GRAPH + ] + CompiledFunction.metadata.output_info ) fw_outs_not_requiring_grad = [ @@ -3931,7 +4025,7 @@ def forward(ctx, *deduped_flat_tensor_args): fw_outs[num_forward_returns:num_forward], return_new_outs=False ) - return tuple(raw_returns) + return tuple(raw_returns) + tuple(marked_dirty_inps) @staticmethod def backward(ctx, *flat_args): @@ -3946,27 +4040,31 @@ def backward(ctx, *flat_args): # and we filter them out here before passing the remaining grad_outputs into the compiled backward. num_mutated_inps = CompiledFunction.metadata.num_mutated_inputs num_intermediate_bases = CompiledFunction.metadata.num_intermediate_bases + num_graph_handled_inputs = CompiledFunction.metadata.num_mutated_graph_handled_indices + num_mutated_runtime_inps = CompiledFunction.metadata.num_mutated_inp_runtime_indices expected_grad_outs = ( - CompiledFunction.metadata.num_outputs + num_mutated_inps + num_intermediate_bases + CompiledFunction.metadata.num_outputs + num_mutated_runtime_inps + num_intermediate_bases ) + if num_graph_handled_inputs > 0: + flat_args = flat_args[:-num_graph_handled_inputs] assert len(flat_args) == expected_grad_outs out_info = CompiledFunction.metadata.output_info + num_mutated_inps_returned = CompiledFunction.metadata.num_mutated_inp_runtime_indices + inp_tangents, out_tangents, intermediate_base_tangents = ( - flat_args[0:num_mutated_inps], - flat_args[num_mutated_inps:num_mutated_inps + CompiledFunction.metadata.num_outputs], - flat_args[num_mutated_inps + CompiledFunction.metadata.num_outputs:], + flat_args[0:num_mutated_inps_returned], + flat_args[num_mutated_inps_returned:num_mutated_inps_returned + CompiledFunction.metadata.num_outputs], + flat_args[num_mutated_inps_returned + CompiledFunction.metadata.num_outputs:], ) # input_info contains info on *every* input, # But in the backward(), we are only given grad outputs for every mutated input # We then need to filter out the grad outputs that correspond to metadata-only mutations or don't require grad - mutated_inp_indices = CompiledFunction.metadata.mutated_inp_indices input_info = CompiledFunction.metadata.input_info - assert len(inp_tangents) == len(mutated_inp_indices) inp_tangents_filtered = [ x - for x, info_idx in zip(inp_tangents, mutated_inp_indices) + for x, info_idx in zip(inp_tangents, CompiledFunction.metadata.mutated_inp_runtime_indices) if input_info[info_idx].mutates_data and input_info[info_idx].requires_grad ] # We also need to filter out grad outputs that correspond to outputs aliasing inputs/intermediates @@ -4128,7 +4226,7 @@ def backward(ctx, *args): runtime_metadata=fw_metadata, indices_of_inps_to_detach=_indices_of_inps_to_detach, trace_joint=True, - keep_input_mutations=False, + keep_input_mutations=aot_config.keep_inference_input_mutations, disable_amp=disable_amp ) @@ -4274,7 +4372,7 @@ def convert(idx, x): with patch("torch.cuda.set_rng_state", lambda *args: None): fw_metadata = run_functionalized_fw_and_collect_metadata( flat_fn, - keep_input_mutations=aot_config.keep_inference_input_mutations and not needs_autograd, + keep_input_mutations=aot_config.keep_inference_input_mutations, is_train=needs_autograd, )(*fake_flat_args) @@ -4367,7 +4465,6 @@ def convert(idx, x): compiled_fn = compiler_fn(flat_fn, fake_flat_args, aot_config, fw_metadata=fw_metadata) if aot_config.is_export: - mutated_user_inp_locs = [ idx - aot_config.num_params_buffers for idx in fw_metadata.mutated_inp_indices @@ -4966,7 +5063,7 @@ def fn_to_trace(*args): fn_to_trace, full_args, decompositions=decompositions, - num_params_buffers=len(params_and_buffers_flat), + num_params_buffers=params_len, no_tangents=True, ) if trace_joint: diff --git a/torch/_inductor/compile_fx.py b/torch/_inductor/compile_fx.py index 9ddc5e64d8988..874f37a2a5408 100644 --- a/torch/_inductor/compile_fx.py +++ b/torch/_inductor/compile_fx.py @@ -1116,7 +1116,9 @@ def fw_compiler_base( context = torch._guards.TracingContext.get() # See Note [User Outputs in the inductor graph] if context is not None and context.fw_metadata and not is_inference: - original_output_start_index = context.fw_metadata.num_mutated_inputs + original_output_start_index = ( + context.fw_metadata.num_mutated_inp_runtime_indices + ) else: original_output_start_index = 0 diff --git a/torch/_inductor/pattern_matcher.py b/torch/_inductor/pattern_matcher.py index c418d7512aedb..4873df92ba8cc 100644 --- a/torch/_inductor/pattern_matcher.py +++ b/torch/_inductor/pattern_matcher.py @@ -1251,6 +1251,7 @@ def record_joint_graph(joint_graph, inputs, **kwargs): lambda g, i: make_boxed_func(g), partition_fn=record_joint_graph, decompositions=select_decomp_table(), + keep_inference_input_mutations=True, enable_log=False, )(*args) assert gm