Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

AOTAutograd: avoid intermediate_base logic when all aliased outputs came from a multi_output_view #111411

Closed
wants to merge 5 commits into from
Closed
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
2 changes: 1 addition & 1 deletion aten/src/ATen/FunctionalStorageImpl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ namespace at::functionalization {

ViewMeta ViewMeta::to_out_idx(int64_t out_idx) {
if (out_idx == this->out_index) return *this;
return ViewMeta(forward_fn, reverse_fn, out_idx);
return ViewMeta(forward_fn, reverse_fn, is_multi_output, out_idx);
}

// Note [Functionalization: Alias Removal Part 2]
Expand Down
7 changes: 6 additions & 1 deletion aten/src/ATen/FunctionalStorageImpl.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,16 +31,21 @@ struct ViewMeta {
ViewMeta(
std::function<Tensor(const Tensor&, int64_t)> forward,
std::function<Tensor(const Tensor&, const Tensor&, int64_t)> reverse,
bool is_multi_output = false,
int64_t out_idx = 0)
: forward_fn(std::move(forward)),
reverse_fn(std::move(reverse)),
out_index(out_idx) {}
out_index(out_idx),
is_multi_output(is_multi_output) {}

std::function<Tensor(const Tensor&, int64_t)> forward_fn;
std::function<Tensor(const Tensor&, const Tensor&, int64_t)> reverse_fn;
// See Note [out_idx in ViewMeta]
int64_t out_index;

// Tells us if this is a multi-output view
bool is_multi_output;

// Returns a copy of the current ViewMeta, if out_idx matches the current
// out_index. Otherwise, returns a new ViewMeta with the same forward/reverse
// functions, but a new out index.
Expand Down
3 changes: 2 additions & 1 deletion aten/src/ATen/FunctionalTensorWrapper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,8 @@ FunctionalTensorWrapper::FunctionalTensorWrapper(const Tensor& view_value, const
view_value.dtype(),
view_value.device()
),
value_(view_value)
value_(view_value),
is_multi_output_view_(base->is_multi_output_view_ || meta.is_multi_output)
{
TORCH_INTERNAL_ASSERT(!at::functionalization::impl::isFunctionalTensor(value_));
TORCH_INTERNAL_ASSERT(!value_.key_set().has(c10::DispatchKey::Functionalize));
Expand Down
5 changes: 5 additions & 0 deletions aten/src/ATen/FunctionalTensorWrapper.h
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,10 @@ struct TORCH_API FunctionalTensorWrapper : public c10::TensorImpl {
// replace_() swaps out the wrapped tensor, value_, with tmp.
void replace_(const Tensor& other);

bool is_multi_output_view() {
return is_multi_output_view_;
}

// See Note[resize_() in functionalization pass]
void maybe_replace_storage(const Tensor& other);

Expand Down Expand Up @@ -173,6 +177,7 @@ struct TORCH_API FunctionalTensorWrapper : public c10::TensorImpl {
Tensor value_;
int64_t level_;
bool has_metadata_mutation_ = false;
bool is_multi_output_view_ = false;

size_t generation_ = 0;
std::vector<at::functionalization::ViewMeta> view_metas_;
Expand Down
104 changes: 104 additions & 0 deletions test/functorch/test_aotdispatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -825,6 +825,110 @@ def forward(self, primals_1):
view = torch.ops.aten.view.default(mul, [-1]); mul = None
return [view]""")

def test_output_aliases_intermediate_multi_output_view(self):
# All aliased outs are from multi-output views, so AOTAutograd will hide the aliasing from autograd.
def f1(a):
out = torch.mul(a, 3)
return list(out.unbind(0))

inp = torch.ones(3, 3, requires_grad=True)
inp_ref = torch.ones(3, 3, requires_grad=True)
f1_compiled = aot_function(f1, nop)

out_ref = f1(inp_ref)
out_test = f1_compiled(inp)
# Assert that we get CompiledFunctionBackward in the backward graph,
# and not AsStridedBackward. No view-regeneration necesssary for this mult-output view case.
# See Note: [AOTAutograd: differentiable outputs that alias each other from a multi-output view call]
self.assertTrue(all('CompiledFunctionBackward' in str(o.grad_fn) for o in out_test))

sum(out_ref).sum().backward()
sum(out_test).sum().backward()
self.assertEqual(inp_ref.grad, inp.grad)

# All aliased outs but one are from multi-output views, so AOTAutograd will hide the aliasing from autograd.
def f2(a):
out = torch.mul(a, 3)
return *list(out.unbind(0)), out

inp = torch.ones(3, 3, requires_grad=True)
inp_ref = torch.ones(3, 3, requires_grad=True)
f2_compiled = aot_function(f2, nop)

out_ref = f2(inp_ref)
out_test = f2_compiled(inp)
# Assert that we get CompiledFunctionBackward in the backward graph,
# and not AsStridedBackward. No view-regeneration necesssary for this mult-output view case.
# See Note: [AOTAutograd: differentiable outputs that alias each other from a multi-output view call]
self.assertTrue(all('CompiledFunctionBackward' in str(o.grad_fn) for o in out_test))

# The last output is not from a multi-output view, so autograd will let us mutate it.
out_ref[-1].mul_(2)
out_test[-1].mul_(2)
out_ref[-1].sum().backward()
out_test[-1].sum().backward()
self.assertEqual(inp_ref.grad, inp.grad)

# All aliased outs but one are from multi-output views, so AOTAutograd will hide the aliasing from autograd.
def f3(a):
out = torch.mul(a, 3)
return *list(out.unbind(0)), out.view(out.shape)

inp = torch.ones(3, 3, requires_grad=True)
inp_ref = torch.ones(3, 3, requires_grad=True)
f3_compiled = aot_function(f3, nop)

out_ref = f3(inp_ref)
out_test = f3_compiled(inp)
# Assert that we get CompiledFunctionBackward in the backward graph,
# and not AsStridedBackward. No view-regeneration necesssary for this mult-output view case.
# See Note: [AOTAutograd: differentiable outputs that alias each other from a multi-output view call]
self.assertTrue(all('CompiledFunctionBackward' in str(o.grad_fn) for o in out_test))

# The last output is not from a multi-output view, so autograd will let us mutate it.
out_ref[-1].mul_(2)
out_test[-1].mul_(2)
out_ref[-1].sum().backward()
out_test[-1].sum().backward()
self.assertEqual(inp_ref.grad, inp.grad)

# There are 5 outputs that all alias each other.
# 3 of them come from multi-output views, but the other 3 are "ordinary" aliases.
# Therefore, AOTAutograd will not attempt the multi-output-view optimization,
# and apply the intermediate_base logic to all aliases.
# (In theory we could probably get AOTAutograd to only apply the intermediate base
# logic to the last 2 outputs and not the first 3. We should probably
# just do the graph partitioning defined in this doc instead though).
# https://docs.google.com/document/d/1DlfFq8TKbuAn2zyJxLfoW-X1qkkm5PLdHFtySo03QAk/edit
def f4(a):
out = torch.mul(a, 3)
# also return the graph intermediate directly,
# which will force AOTAutograd to do the "intermediate base" logic.
# (Why? The user can mutate "out", which should change the autograd metadata
# of the other aliased outputs)
return *list(out.unbind(0)), out, out.view(out.shape)

inp = torch.ones(3, 3, requires_grad=True)
inp_ref = torch.ones(3, 3, requires_grad=True)
f4_compiled = aot_function(f4, nop)

out_ref = f4(inp_ref)
out_test = f4_compiled(inp)
# Mutate the last output of f4 (autograd will allow this, since it is not a multi-output view,
# as long as *only* the non-multi-output views participate in the backward)
# Note: We could probably try to hide **only** the multi-output views from autograd here
# and only do the intermediate base logic for the last two aliases.
# Longer term solution of graph partitioning is probably cleaner though (see the note).
out_ref[-1].mul_(2)
out_test[-1].mul_(2)

out_ref_sum = out_ref[-1] + out_ref[-2]
out_test_sum = out_test[-1] + out_test[-2]
out_ref_sum.sum().backward()
out_test_sum.sum().backward()
self.assertEqual(inp_ref.grad, inp.grad)


def test_output_aliases_intermediate_mutation_linear(self):
def f(x):
return (x + 1).view(-1)
Expand Down
78 changes: 77 additions & 1 deletion torch/_functorch/aot_autograd.py
Original file line number Diff line number Diff line change
Expand Up @@ -1078,11 +1078,83 @@ def inner(*flat_args):

# Keep track of which outputs alias other outputs
out_tensor_alias_counts = collections.defaultdict(int)
# This tells us, for a given group of outputs that alias each other,
# whether they e.g. all came from an unbind call
num_aliased_tensors_that_are_multi_output_views = collections.defaultdict(int)
out_storage_to_tensors = collections.defaultdict(set)
for o in flat_f_outs:
if isinstance(o, torch.Tensor):
curr_storage = StorageWeakRef(o.untyped_storage())
out_tensor_alias_counts[curr_storage] += 1
# Note: [AOTAutograd: differentiable outputs that alias each other from a multi-output view call]
# This is an optimization on top of the "alias of intermediates" logic,
# which you can read more about under Note [AOT Autograd: outputs aliasing inputs or intermediates!]
#
# What is this optimization? Consider the below case:
# def f(x):
# intermediate = x.mul(2)
# # x and intermediate here require grad
# o1, o2, ... o10 = intermediate.unbind(-1)
# return intermediate, o1, o2, ... o10
# Now, the "intermediate base" handling in AOTAutograd implies that we must do the following:
# (1) return "intermediate as an extra output of the compiled graph
# (2) regenerate each aliased output off of "intermediate", **outside** of the autograd.Function.
# The reason AOTAutograd ordinarily does this is for safety: the autograd engine needs to know
# that o1 through o10 are all aliased, and if we blindly return o1 through o10 from the autograd.Function,
# this information will be hidden.
# In particular, mutating one alias might require autograd to update autograd metadata on the other aliases
# (like their grad_fn, for example, when the autograd engine needs to do view-replay).
#
# However, intermediate_base logic can be bad for backward performance (we sometimes generate
# as_strided calls during the intermediate base logic, which can have a slow backward formula).
# Is it possible to find a set of conditions where it is **safe** to hide the output aliasing from autograd?
#
# For a set of outputs of the graph that alias each other, o_1...o_k, consider:
# (1) They came from the same multi-outout view op, e.g. o_1, ..., o_k = intermediate.unbind(0)
bdhirsh marked this conversation as resolved.
Show resolved Hide resolved
# (2) If there are any other aliases of o_1 through o_k (in the example above, intermediate),
# **at most** 1 can escape from the graph (e.g. there is not some other graph input/output
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sorry, is this at most one can escape, ON TOP of the original multi-output view op's outputs o_1, ... ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes - if I have K multi-output view aliases of some tensor (say they all came from an unbind call), then this comment is saying that at most K+1 aliases are allowed to escape the graph

# o_other, that aliases these outputs)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't like this. I think condition (2) is too lax and it is making it hard for me to verify that this optimization is correct.

Here is a graph for which I think it is obviously correct to attach grad_fns to the outputs:

def f(x):
    y = x + 2
    return y.unbind()

Here is a graph where I am not sure:

def f(x):
    y = x + 2
    return y, y.unbind()

In the first graph, no outputs actually aliased each other (the unbind says they alias, but this is because autograd isn't smart enough to know that actually these views are guaranteed to reference disjoint spots of memory). In the second graph, the outputs DO alias, and so we have to make sure that if you mutate y, the grad_fns for y.unbind() get updated.

Maybe there is a way we can get the second case to work, but if the first case is good enough to solve the problem we're facing in prod, I would prefer to do it first, as it is much more obviously correct.

Copy link
Contributor Author

@bdhirsh bdhirsh Oct 21, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, after I made this PR I realized that the second case also shows up in prod sadly (returning both an intermediate, and its unbinded tensors).

So given that, I think our options are:

(1) convince ourselves that this is safe, even in the case where we return one extra alias (like the intermediate here).

(2) decide that this either violates safety, or this feels too dangerous to allow, and instead just add some one-off logic for the prod case to specifically handle regenerating unbind() at runtime.

Let me know what you think

# (3) o_1...o_k all require_grad, they all share the same ._base, and their ._base requires grad.
# This condition is important because it's what causes slowness in the intermediate_base
# codepath of aot_autograd. Ordinarily, o_1...o_k would all get a grad_fn, and
# aot_autograd's view-replay might give each output an AsStridedBackward as its grad_fn.
# "K" AsStridedBackward calls will be *much* slower than a single UnbindBackward.
# In this setup, is it possible to mutate one of the outputs o_i in a way that would affect the autograd meta
# of the other aliases?
#
# Claim: No! Consider a few example (which I'm pretty sure cover all cases of mutation w.r.t. autograd):
# (a) What happens if we mutate any of o_1 through o_k directly?
# Autograd raises an error:
# "RuntimeError: Output 0 of UnbindBackward0 is a view and is being modified inplace. This view is
# the output of a function that returns multiple views. Such functions do not allow the output
# views to be modified inplace. You should replace the inplace operation by an out-of-place one."
# (b) What if we take a view of o_k and mutate it, o_k.view(o_k.shape).mul_(2)?
# Autograd raises the same error- the "multi-output-view"ness of an alias propagates to future views.
# (c) What if we mutate o_k under no_grad?
# Autograd raises the same error
# (d) What if we detach and mutate, e.g. o_k.detach().mul_(2)?
# Autograd allows this, *but* autograd updates all alias's grad_fn's to be error functions when accessed.
# Autograd raises the same error
# (e) What if we try to mutate another alias of o_1...o_k, that was **not** created from a multi-output view?
# We promised that there is at most **one** such alias, e.g. intermediate in the example above.
# You can mutate intermediate, but in eager mode this will change the grad_fn of o_1...o_k
# to be error fn's.
# Since intermediate was the *only* non-multi-output-alias, there are no other aliases
# of `intermediate` around that were produced by the compiled fn and have a valid grad_fn.
#
# Coming back to this optimization:
# Given that it is not possible for mutating one of these aliases to affect the autograd metadata of another alias
# without causing an error in eager mode, we will simple hide the aliasing from autograd during torch.compile
# if all of the above conditions are met.
# This has the slight downside that it's possible to write some "bad" code that autograd will raise an error on
# in eager but fail to during torch.compile, but it has the benefit that this code has much better performance.
# NOTE: if and when we eventually update AOTAutograd to do the "view graph slicing" defined here:
# https://docs.google.com/document/d/1DlfFq8TKbuAn2zyJxLfoW-X1qkkm5PLdHFtySo03QAk/edit,
# then this optimization will probably matter less and might be ok to remove.
is_cur_tensor_multi_out_view = isinstance(o, FunctionalTensor) \
and torch._functionalize_is_multi_output_view(o.elem)
if is_cur_tensor_multi_out_view:
num_aliased_tensors_that_are_multi_output_views[curr_storage] += 1
out_storage_to_tensors[curr_storage].add(o)

# maps the id of an intermediate base to its index in the output of the compiled forward
Expand Down Expand Up @@ -1125,7 +1197,11 @@ def inner(*flat_args):
and o.requires_grad
and o._base.requires_grad
):
if out_tensor_alias_counts[curr_storage] == 1:
num_aliased_outs = out_tensor_alias_counts[curr_storage]
num_multi_output_view_outs = num_aliased_tensors_that_are_multi_output_views[curr_storage]
num_aliased_outs_that_are_not_multi_output_views = num_aliased_outs - num_multi_output_view_outs
# Note: [AOTAutograd: differentiable outputs that alias each other from a multi-output view call]
if out_tensor_alias_counts[curr_storage] == 1 or num_aliased_outs_that_are_not_multi_output_views <= 1:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we add a check on line 1184 elif curr_storage in inp_storage_refs: sometimes we have multi_output_views that are alias_of_input, it would fall into this if and speed up quite a bit.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree that if we're going to apply this "hide multi-output view aliasing from autograd" to aliases of intermediates, we should also apply the same logic to aliases of inputs.

This is especially true because autograd's view-replay will generate equally inefficient backward code for the alias-of-input case, if our aliases are multi-output views (view-replay will always generate an as_strided).

Let me spend some time trying to add it and add better testing for it.

# Note [Intermediate Bases Optimization]
# Normally if we have an output that aliases an intermediate,
# we need to add the extra "intermediate base" logic further down
Expand Down
26 changes: 26 additions & 0 deletions torch/csrc/autograd/python_torch_functions_manual.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -535,6 +535,27 @@ static PyObject* THPVariable__functionalize_enable_reapply_views(
END_HANDLE_TH_ERRORS
}

static PyObject* THPVariable__functionalize_is_multi_output_view(
PyObject* self,
PyObject* args,
PyObject* kwargs) {
HANDLE_TH_ERRORS
static PythonArgParser parser(
{"_functionalize_is_multi_output_view(Tensor t)"},
/*traceable=*/true);
ParsedArgs<1> parsed_args;
auto r = parser.parse(args, kwargs, parsed_args);
auto t = r.tensor(0);
TORCH_CHECK(at::functionalization::impl::isFunctionalTensor(t));
auto t_impl = at::functionalization::impl::unsafeGetFunctionalWrapper(t);
if (t_impl->is_multi_output_view()) {
Py_RETURN_TRUE;
} else {
Py_RETURN_FALSE;
}
END_HANDLE_TH_ERRORS
}

static PyObject* THPVariable__disable_functionalization(
PyObject* self,
PyObject* args,
Expand Down Expand Up @@ -664,6 +685,11 @@ static PyMethodDef torch_functions_manual[] = {
THPVariable__functionalize_has_metadata_mutation),
METH_VARARGS | METH_KEYWORDS | METH_STATIC,
nullptr},
{"_functionalize_is_multi_output_view",
castPyCFunctionWithKeywords(
THPVariable__functionalize_is_multi_output_view),
METH_VARARGS | METH_KEYWORDS | METH_STATIC,
nullptr},
{"_functionalize_enable_reapply_views",
castPyCFunctionWithKeywords(
THPVariable__functionalize_enable_reapply_views),
Expand Down
4 changes: 3 additions & 1 deletion torchgen/gen_functionalization_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -377,6 +377,7 @@ def emit_view_functionalization_body(
"""

else:
is_multi_output_view = isinstance(f.func.returns[0].type, ListType)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nyeh, I think you have to be more selective about this. If hypothetically you have a multi-output view where the outputs alias each other, e.g., some sort of sliding window + unbind, then it wouldn't be right to do the logic you've introduced here, because if you did a mutation it would be necessary to error to prevent incorrect gradients from the other rules

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

because if you did a mutation it would be necessary to error to prevent incorrect gradients from the other rules

I agree that this is more complicated in the overlapping case (Alban pointed out that even the unbind case can be non-overlapping, e.g. torch.ones(4).expand(10, 4).unbind(0)).

I would agree with this, but my thought was that autograd should never let this happen, because it doesn't allow you to change the grad_fn of multi-output views (it raises an error, or replaces the multi-output-view grad_fn with an error grad_fn).

I agree this does feel hairy though, so lmk if you think this isn't water-tight / we're risking some correctness issues slipping through.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You're basically saying that a user will never pass torch.compile a program that doesn't run in eager mode. If this is true, I agree that the user cannot have written a naughty program. But we definitely have users passing malformed programs to torch.compile without having checked the eager version first...

return f"""
{dispatcher_sig.defn(name=wrapper_name(f.func), is_redispatching_fn=True)} {{
{unwrap_tensor_args_str}
Expand Down Expand Up @@ -415,7 +416,8 @@ def emit_view_functionalization_body(
}},
{reverse_lambda.decl()} {{
return {reverse_lambda.inner_call()}
}}
}},
/*is_multi_output=*/{str(is_multi_output_view).lower()}
);
auto out = at::functionalization::impl::create_functional_tensor_with_view_meta(tmp_output, {view_tensor_name}, view_meta);
// See Note [Propagating strides in the functionalization pass]
Expand Down