Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
116 changes: 72 additions & 44 deletions functorch/_src/aot_autograd.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,7 +286,66 @@ def new_fn(args):


def aot_dispatch_autograd(flat_fn, flat_args: List[Tensor], aot_config: AOTConfig):
joint_forward_backward = create_joint_forward_backward(flat_fn)
# Deduplicate inputs. Suppose you have:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think this should be done here. It should be done either in torchdynamo, or in aot_function (or create_aot_dispatcher_function) imo.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you do it in aot_function you will miss the analogous bug with modules. create_aot_dispatcher_function is plausible, but you technically only need to do this in the autograd case so... meh?

See PR comment about torchdynamo.

#
# [a, b, a, c]
#
# We want:
#
# remove_dupe_args([a, b, a, c]) == [a, b, c]
# add_dupe_args([a, b, c]) == [a, b, a, c]
#
# This is done via (respectively):
#
# seen_args = {2} # what to drop
# add_dupe_map = { # how to get args from the deduped list
# 0: 0,
# 1: 1,
# 2: 0,
# 3: 2,
# }
#
# Whether to use flat_args or deduped_flat_args? flat_fn takes flat_args,
# and the autograd.Function must take deduped_flat_args; everything
# else is just getting the types right.

seen_args = {}
keep_arg_mask = []
dropped_args = False
add_dupe_map = {}
duped_arg_len = len(flat_args)

j = 0 # index into deduped_flat_args
for i, t in enumerate(flat_args):
if t in seen_args:
keep_arg_mask.append(False)
dropped_args = True
add_dupe_map[i] = seen_args[t]
continue
keep_arg_mask.append(True)
seen_args[t] = j
add_dupe_map[i] = j
j += 1

# NB: Hot path, avoid set lookups here
def remove_dupe_args(args):
if not dropped_args:
return args
r = []
for t, keep in zip(args, keep_arg_mask):
if keep:
r.append(t)
return r

def add_dupe_args(args):
r = []
for i in range(duped_arg_len):
r.append(args[add_dupe_map[i]])
return r

deduped_flat_args = remove_dupe_args(flat_args)

joint_forward_backward = create_joint_forward_backward(lambda *args: flat_fn(*add_dupe_args(args)))

out = flat_fn(*flat_args)
out = pytree.tree_map(
Expand All @@ -299,7 +358,7 @@ def aot_dispatch_autograd(flat_fn, flat_args: List[Tensor], aot_config: AOTConfi
else:
_num_outs = 1

joint_inputs = (flat_args, out)
joint_inputs = (deduped_flat_args, out)

if config.use_functionalize:
# Trace once without decompositions, into a graph of ATen ops.
Expand Down Expand Up @@ -333,10 +392,10 @@ def fake_fn(primals, tangents):
bw_module.print_readable()

with track_graph_compiling("forward"):
compiled_fw_func = aot_config.fw_compiler(fw_module, flat_args)
compiled_fw_func = aot_config.fw_compiler(fw_module, deduped_flat_args)

if config.debug_partitioner:
fw_outs = call_func_with_args(compiled_fw_func, flat_args)
fw_outs = call_func_with_args(compiled_fw_func, deduped_flat_args)
activation_sizes = 0
for out in fw_outs[_num_outs:]:
if isinstance(out, torch.Tensor):
Expand All @@ -350,9 +409,9 @@ class CompiledFunction(torch.autograd.Function):

@staticmethod
@disable_torchdynamo
def forward(ctx, *flat_tensor_args):
def forward(ctx, *deduped_flat_tensor_args):
fw_outs = call_func_with_args(
CompiledFunction.compiled_fw, flat_tensor_args
CompiledFunction.compiled_fw, deduped_flat_tensor_args
)
num_outs = CompiledFunction.num_outs
ctx.save_for_backward(*fw_outs[num_outs:])
Expand All @@ -375,7 +434,11 @@ def backward(ctx, *flat_args):

return tuple(out)

return CompiledFunction.apply
@wraps(CompiledFunction.apply)
def compiled_function(*args):
return CompiledFunction.apply(*remove_dupe_args(args))

return compiled_function


@dynamo_timed
Expand Down Expand Up @@ -407,44 +470,9 @@ def create_aot_dispatcher_function(

def process_inputs(flat_args):
if mode:
seen_args = set()

def convert(x):
# HACK HACK HACK
# preserve the same behavior of the non-fake tensor branch
# of creating a unique tensor impl for each input,
# instead of memoizing the conversion. this has the same
# problem of models that resize their inputs described below,
# but fixes an issue with tied parameters.
# TODO: more full fix
if id(x) in seen_args:
with torch.utils._mode_utils.no_dispatch():
x = x.detach().requires_grad_(x.requires_grad)
seen_args.add(id(x))
return mode.from_tensor(x)

fake_flat_tensor_args = pytree.tree_map_only(Tensor, convert, flat_args)
return pytree.tree_map_only(Tensor, mode.from_tensor, flat_args)
else:
# The detach().requires_grad_() pattern can cause some subtle bugs.
# These will be fixed once FakeTensor is always-on for AOTAutograd.
#
# For models that might resize their inputs, the input tensors
# must have allow_tensor_metadata_change() set to true.
# detach() returns a view tensor, but with that field set to false.
#
# Specifically, this breaks quantized models
# (resnet50_quantized_qat and mobilenet_v2_quantized_qat)
# because they use a "running-mean" style op that requires
# resizing the running counter buffers stored on the module.
def make_input(x):
return x.detach().requires_grad_(x.requires_grad)

fake_flat_tensor_args = pytree.tree_map_only(
Tensor,
make_input,
flat_args,
)
return fake_flat_tensor_args
return flat_args

fake_flat_tensor_args = process_inputs(flat_args)

Expand Down
26 changes: 26 additions & 0 deletions functorch/test/test_pythonkey.py
Original file line number Diff line number Diff line change
Expand Up @@ -369,6 +369,32 @@ def compiler(fx_g, _):
out.sum().backward()
self.assertEqual(count, [(['forward'], 4), (['inference'], 4), (['backward'], 8)])

def test_dupe_arg(self):
def f(x, y):
return x + y

x = torch.randn(3, 3, requires_grad=True)
self.verify_aot_autograd(f, [x, x])

def test_resize_input(self):
def f(x, y):
y.resize_(4)
y.zero_()
self.assertEqual(x.shape, (4,))
return y

# NB: don't use verify_aot_autograd as the inputs get
# mutated and I don't trust verify to do it right

compiled_f = aot_function(f, nop)
ref_x = torch.randn(0)
ref_out = f(ref_x, ref_x)

test_x = torch.randn(0)
test_out = compiled_f(test_x, test_x)

self.assertEqual(ref_out, test_out)

@unittest.skipIf(not torch.cuda.is_available(), "CUDA is unavailable")
def test_batch_norm_amp(self):
device = "cuda"
Expand Down