Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

functionalization: add a copy() native function #76083

Closed
wants to merge 9 commits into from
19 changes: 19 additions & 0 deletions aten/src/ATen/native/Copy.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

#include <ATen/ATen.h>
#include <ATen/Dispatch.h>
#include <ATen/FunctionalTensorWrapper.h>
#include <ATen/NativeFunctions.h>
#include <ATen/native/TensorIterator.h>
#include <ATen/native/quantized/Copy.h>
Expand Down Expand Up @@ -242,6 +243,24 @@ static Tensor & copy_impl(Tensor & self, const Tensor & src, bool non_blocking)
return self;
}

Tensor copy(const Tensor& self, const Tensor& src, bool non_blocking) {
// copy() is the "functional" form of copy_(). It exists so we can properly functionalize copy_(), but:
// (1) It isn't exposed to the frontend (no python bindings)
// (2) It isn't exposed to the backend (it's a composite, that decomposes into to() and expand_as() calls.
// Note: This implementation doesn't currently preserve the strides of `self`.
// That might be fine for functorch (which already doesn't preserve strides in vmap),
// but it's worth looking into whether or not this implementation will be problematic for LazyTensor/XLA.
auto intermediate = src.to(self, non_blocking);
// Unfortunately, copy()'s decomposition involves view ops.
// To preserve the functionalization pass semantics of "maybe reapply views",
// we need to manually do that here.
if (at::functionalization::impl::getFunctionalizationReapplyViewsTLS()) {
return intermediate.expand(self.sizes());
} else {
return at::expand_copy(intermediate, self.sizes());
}
}

Tensor& copy_(Tensor& self, const Tensor& src, bool non_blocking) {
auto maybe_outnames = namedinference::compute_broadcast_outnames(self, src);
{
Expand Down
3 changes: 3 additions & 0 deletions aten/src/ATen/native/native_functions.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -1373,6 +1373,9 @@

- func: conv_transpose3d.input(Tensor input, Tensor weight, Tensor? bias=None, int[3] stride=1, int[3] padding=0, int[3] output_padding=0, int groups=1, int[3] dilation=1) -> Tensor

- func: copy(Tensor self, Tensor src, bool non_blocking=False) -> Tensor
variants: function

- func: copy_(Tensor(a!) self, Tensor src, bool non_blocking=False) -> Tensor(a!)
variants: method
device_check: NoCheck
Expand Down
1 change: 1 addition & 0 deletions tools/autograd/gen_python_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,7 @@
"_reshape_alias",
"replace_", # only used by the functionalization pass, doesn't need to be exposed to python
"zero", # only used by the functionalization pass, doesn't need to be exposed to python
"copy", # only used by the functionalization pass
]

SKIP_PYTHON_BINDINGS = list(
Expand Down
19 changes: 2 additions & 17 deletions torchgen/gen_functionalization_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -381,21 +381,7 @@ def emit_inplace_functionalization_body(
for e in translate(unwrapped_args_ctx, dispatcher_sig.arguments(), method=False)
]

# Note [functionalizating copy_() and not preserving strides]
# copy_() can't be functionalized, since there doesn't exist an out-of-place variant.
# We could add one, but that would be sub-optimal for functorch: copy() would need to allocate a fresh tensor.
# This may seem like a large hack for one optimization, but copy_() is one of the most common inplace operators.
# Instead, we can replace `self.copy_(src)` with `src.to(self).expand_as(self)`.
# This maintains the exact same semantics, EXCEPT that we don't preserve the strides from `self`.
# This seems like a reasonable tradeoff, for a few reasons:
# - mutation removal is only used by functorch, and not by Vulkan or XLA. Functorch already doesn't preserve strides.
# - There are actually a few other places where the functionalization pass currently doesn't support strides:
# calls to slice/diagonal_scatter don't currently preserve the strides of their inputs (but maybe we should fix this).
if str(f.func.name) == "copy_":
functional_call_str = """\
auto tmp_intermediate = at::_ops::to_other::call(src_, self_, non_blocking, false, c10::nullopt);
tmp_output = at::_ops::expand_copy::call(tmp_intermediate, self_.sizes(), false);"""
elif functional_op is None:
if functional_op is None:
# We can't functionalize this inplace op, since we don't know what the corresponding functional op is.
warn_str = f"""Note: the functionalization pass encountered an operator ({str(f.func.name)}) that it could not \
functionalize, because it couldn't find an out-of-place equivalent of the operator to call. \
Expand Down Expand Up @@ -423,7 +409,6 @@ def emit_inplace_functionalization_body(
unwrapped_args_ctx, functional_sig.arguments(), method=False
)
]
functional_call_str = f"tmp_output = at::_ops::{functional_op.func.name.unambiguous_name()}::call({', '.join(functional_exprs)});" # noqa: B950

if f.func.is_out_fn():
mutable_input_post_processing = "\n".join(
Expand Down Expand Up @@ -467,7 +452,7 @@ def emit_inplace_functionalization_body(
{return_type} tmp_output;
{{
at::AutoDispatchSkipFunctionalize guard;
{functional_call_str}
tmp_output = at::_ops::{functional_op.func.name.unambiguous_name()}::call({', '.join(functional_exprs)});
}}
{mutable_input_post_processing}
{return_str(f)};
Expand Down