Skip to content

Commit

Permalink
functionalization: add a copy() native function
Browse files Browse the repository at this point in the history
Pull Request resolved: #76083

Approved by: https://github.com/albanD
  • Loading branch information
bdhirsh authored and pytorchmergebot committed Apr 25, 2022
1 parent 7d44b36 commit 5da76ac
Show file tree
Hide file tree
Showing 4 changed files with 25 additions and 17 deletions.
19 changes: 19 additions & 0 deletions aten/src/ATen/native/Copy.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

#include <ATen/ATen.h>
#include <ATen/Dispatch.h>
#include <ATen/FunctionalTensorWrapper.h>
#include <ATen/NativeFunctions.h>
#include <ATen/native/TensorIterator.h>
#include <ATen/native/quantized/Copy.h>
Expand Down Expand Up @@ -242,6 +243,24 @@ static Tensor & copy_impl(Tensor & self, const Tensor & src, bool non_blocking)
return self;
}

Tensor copy(const Tensor& self, const Tensor& src, bool non_blocking) {
// copy() is the "functional" form of copy_(). It exists so we can properly functionalize copy_(), but:
// (1) It isn't exposed to the frontend (no python bindings)
// (2) It isn't exposed to the backend (it's a composite, that decomposes into to() and expand_as() calls.
// Note: This implementation doesn't currently preserve the strides of `self`.
// That might be fine for functorch (which already doesn't preserve strides in vmap),
// but it's worth looking into whether or not this implementation will be problematic for LazyTensor/XLA.
auto intermediate = src.to(self, non_blocking);
// Unfortunately, copy()'s decomposition involves view ops.
// To preserve the functionalization pass semantics of "maybe reapply views",
// we need to manually do that here.
if (at::functionalization::impl::getFunctionalizationReapplyViewsTLS()) {
return intermediate.expand(self.sizes());
} else {
return at::expand_copy(intermediate, self.sizes());
}
}

Tensor& copy_(Tensor& self, const Tensor& src, bool non_blocking) {
auto maybe_outnames = namedinference::compute_broadcast_outnames(self, src);
{
Expand Down
3 changes: 3 additions & 0 deletions aten/src/ATen/native/native_functions.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -1373,6 +1373,9 @@

- func: conv_transpose3d.input(Tensor input, Tensor weight, Tensor? bias=None, int[3] stride=1, int[3] padding=0, int[3] output_padding=0, int groups=1, int[3] dilation=1) -> Tensor

- func: copy(Tensor self, Tensor src, bool non_blocking=False) -> Tensor
variants: function

- func: copy_(Tensor(a!) self, Tensor src, bool non_blocking=False) -> Tensor(a!)
variants: method
device_check: NoCheck
Expand Down
1 change: 1 addition & 0 deletions tools/autograd/gen_python_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,7 @@
"_reshape_alias",
"replace_", # only used by the functionalization pass, doesn't need to be exposed to python
"zero", # only used by the functionalization pass, doesn't need to be exposed to python
"copy", # only used by the functionalization pass
]

SKIP_PYTHON_BINDINGS = list(
Expand Down
19 changes: 2 additions & 17 deletions torchgen/gen_functionalization_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -381,21 +381,7 @@ def emit_inplace_functionalization_body(
for e in translate(unwrapped_args_ctx, dispatcher_sig.arguments(), method=False)
]

# Note [functionalizating copy_() and not preserving strides]
# copy_() can't be functionalized, since there doesn't exist an out-of-place variant.
# We could add one, but that would be sub-optimal for functorch: copy() would need to allocate a fresh tensor.
# This may seem like a large hack for one optimization, but copy_() is one of the most common inplace operators.
# Instead, we can replace `self.copy_(src)` with `src.to(self).expand_as(self)`.
# This maintains the exact same semantics, EXCEPT that we don't preserve the strides from `self`.
# This seems like a reasonable tradeoff, for a few reasons:
# - mutation removal is only used by functorch, and not by Vulkan or XLA. Functorch already doesn't preserve strides.
# - There are actually a few other places where the functionalization pass currently doesn't support strides:
# calls to slice/diagonal_scatter don't currently preserve the strides of their inputs (but maybe we should fix this).
if str(f.func.name) == "copy_":
functional_call_str = """\
auto tmp_intermediate = at::_ops::to_other::call(src_, self_, non_blocking, false, c10::nullopt);
tmp_output = at::_ops::expand_copy::call(tmp_intermediate, self_.sizes(), false);"""
elif functional_op is None:
if functional_op is None:
# We can't functionalize this inplace op, since we don't know what the corresponding functional op is.
warn_str = f"""Note: the functionalization pass encountered an operator ({str(f.func.name)}) that it could not \
functionalize, because it couldn't find an out-of-place equivalent of the operator to call. \
Expand Down Expand Up @@ -423,7 +409,6 @@ def emit_inplace_functionalization_body(
unwrapped_args_ctx, functional_sig.arguments(), method=False
)
]
functional_call_str = f"tmp_output = at::_ops::{functional_op.func.name.unambiguous_name()}::call({', '.join(functional_exprs)});" # noqa: B950

if f.func.is_out_fn():
mutable_input_post_processing = "\n".join(
Expand Down Expand Up @@ -467,7 +452,7 @@ def emit_inplace_functionalization_body(
{return_type} tmp_output;
{{
at::AutoDispatchSkipFunctionalize guard;
{functional_call_str}
tmp_output = at::_ops::{functional_op.func.name.unambiguous_name()}::call({', '.join(functional_exprs)});
}}
{mutable_input_post_processing}
{return_str(f)};
Expand Down

0 comments on commit 5da76ac

Please sign in to comment.