diff --git a/functorch/_src/aot_autograd.py b/functorch/_src/aot_autograd.py index 41fd7049c969..8094f6d3574c 100644 --- a/functorch/_src/aot_autograd.py +++ b/functorch/_src/aot_autograd.py @@ -286,7 +286,66 @@ def new_fn(args): def aot_dispatch_autograd(flat_fn, flat_args: List[Tensor], aot_config: AOTConfig): - joint_forward_backward = create_joint_forward_backward(flat_fn) + # Deduplicate inputs. Suppose you have: + # + # [a, b, a, c] + # + # We want: + # + # remove_dupe_args([a, b, a, c]) == [a, b, c] + # add_dupe_args([a, b, c]) == [a, b, a, c] + # + # This is done via (respectively): + # + # seen_args = {2} # what to drop + # add_dupe_map = { # how to get args from the deduped list + # 0: 0, + # 1: 1, + # 2: 0, + # 3: 2, + # } + # + # Whether to use flat_args or deduped_flat_args? flat_fn takes flat_args, + # and the autograd.Function must take deduped_flat_args; everything + # else is just getting the types right. + + seen_args = {} + keep_arg_mask = [] + dropped_args = False + add_dupe_map = {} + duped_arg_len = len(flat_args) + + j = 0 # index into deduped_flat_args + for i, t in enumerate(flat_args): + if t in seen_args: + keep_arg_mask.append(False) + dropped_args = True + add_dupe_map[i] = seen_args[t] + continue + keep_arg_mask.append(True) + seen_args[t] = j + add_dupe_map[i] = j + j += 1 + + # NB: Hot path, avoid set lookups here + def remove_dupe_args(args): + if not dropped_args: + return args + r = [] + for t, keep in zip(args, keep_arg_mask): + if keep: + r.append(t) + return r + + def add_dupe_args(args): + r = [] + for i in range(duped_arg_len): + r.append(args[add_dupe_map[i]]) + return r + + deduped_flat_args = remove_dupe_args(flat_args) + + joint_forward_backward = create_joint_forward_backward(lambda *args: flat_fn(*add_dupe_args(args))) out = flat_fn(*flat_args) out = pytree.tree_map( @@ -299,7 +358,7 @@ def aot_dispatch_autograd(flat_fn, flat_args: List[Tensor], aot_config: AOTConfi else: _num_outs = 1 - joint_inputs = (flat_args, out) + joint_inputs = (deduped_flat_args, out) if config.use_functionalize: # Trace once without decompositions, into a graph of ATen ops. @@ -333,10 +392,10 @@ def fake_fn(primals, tangents): bw_module.print_readable() with track_graph_compiling("forward"): - compiled_fw_func = aot_config.fw_compiler(fw_module, flat_args) + compiled_fw_func = aot_config.fw_compiler(fw_module, deduped_flat_args) if config.debug_partitioner: - fw_outs = call_func_with_args(compiled_fw_func, flat_args) + fw_outs = call_func_with_args(compiled_fw_func, deduped_flat_args) activation_sizes = 0 for out in fw_outs[_num_outs:]: if isinstance(out, torch.Tensor): @@ -350,9 +409,9 @@ class CompiledFunction(torch.autograd.Function): @staticmethod @disable_torchdynamo - def forward(ctx, *flat_tensor_args): + def forward(ctx, *deduped_flat_tensor_args): fw_outs = call_func_with_args( - CompiledFunction.compiled_fw, flat_tensor_args + CompiledFunction.compiled_fw, deduped_flat_tensor_args ) num_outs = CompiledFunction.num_outs ctx.save_for_backward(*fw_outs[num_outs:]) @@ -375,7 +434,11 @@ def backward(ctx, *flat_args): return tuple(out) - return CompiledFunction.apply + @wraps(CompiledFunction.apply) + def compiled_function(*args): + return CompiledFunction.apply(*remove_dupe_args(args)) + + return compiled_function @dynamo_timed @@ -407,44 +470,9 @@ def create_aot_dispatcher_function( def process_inputs(flat_args): if mode: - seen_args = set() - - def convert(x): - # HACK HACK HACK - # preserve the same behavior of the non-fake tensor branch - # of creating a unique tensor impl for each input, - # instead of memoizing the conversion. this has the same - # problem of models that resize their inputs described below, - # but fixes an issue with tied parameters. - # TODO: more full fix - if id(x) in seen_args: - with torch.utils._mode_utils.no_dispatch(): - x = x.detach().requires_grad_(x.requires_grad) - seen_args.add(id(x)) - return mode.from_tensor(x) - - fake_flat_tensor_args = pytree.tree_map_only(Tensor, convert, flat_args) + return pytree.tree_map_only(Tensor, mode.from_tensor, flat_args) else: - # The detach().requires_grad_() pattern can cause some subtle bugs. - # These will be fixed once FakeTensor is always-on for AOTAutograd. - # - # For models that might resize their inputs, the input tensors - # must have allow_tensor_metadata_change() set to true. - # detach() returns a view tensor, but with that field set to false. - # - # Specifically, this breaks quantized models - # (resnet50_quantized_qat and mobilenet_v2_quantized_qat) - # because they use a "running-mean" style op that requires - # resizing the running counter buffers stored on the module. - def make_input(x): - return x.detach().requires_grad_(x.requires_grad) - - fake_flat_tensor_args = pytree.tree_map_only( - Tensor, - make_input, - flat_args, - ) - return fake_flat_tensor_args + return flat_args fake_flat_tensor_args = process_inputs(flat_args) diff --git a/functorch/test/test_pythonkey.py b/functorch/test/test_pythonkey.py index 5bb5e02a3f57..d2595a020aed 100644 --- a/functorch/test/test_pythonkey.py +++ b/functorch/test/test_pythonkey.py @@ -369,6 +369,32 @@ def compiler(fx_g, _): out.sum().backward() self.assertEqual(count, [(['forward'], 4), (['inference'], 4), (['backward'], 8)]) + def test_dupe_arg(self): + def f(x, y): + return x + y + + x = torch.randn(3, 3, requires_grad=True) + self.verify_aot_autograd(f, [x, x]) + + def test_resize_input(self): + def f(x, y): + y.resize_(4) + y.zero_() + self.assertEqual(x.shape, (4,)) + return y + + # NB: don't use verify_aot_autograd as the inputs get + # mutated and I don't trust verify to do it right + + compiled_f = aot_function(f, nop) + ref_x = torch.randn(0) + ref_out = f(ref_x, ref_x) + + test_x = torch.randn(0) + test_out = compiled_f(test_x, test_x) + + self.assertEqual(ref_out, test_out) + @unittest.skipIf(not torch.cuda.is_available(), "CUDA is unavailable") def test_batch_norm_amp(self): device = "cuda"