New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Removed overhead from reshape() call if tensor doesn't need to be changed #61466
Conversation
…nged [ghstack-poisoned]
💊 CI failures summary and remediationsAs of commit 17fedac (more details on the Dr. CI page and at hud.pytorch.org/pr/61466):
🕵️ 1 new failure recognized by patternsThe following CI failures do not appear to be due to upstream breakages: pytorch_macos_10_13_py3_test (1/1)Step: "Test" (full log | diagnosis details | 🔁 rerun)
|
Job | Step | Action |
---|---|---|
Test tools / test | Install dependencies | 🔁 rerun |
Preview docs built from this PR
This comment was automatically generated by Dr. CI (expand for details).
Follow this link to opt-out of these comments for your Pull Requests.Please report bugs/suggestions to the (internal) Dr. CI Users group.
This PR addresses #55126 and improves the performance of Before the change on a test dataset, Further investigation is required to understand why Before the change:
After the changes we can see good improvements where
|
cc @albanD, likely pytorch/tools/autograd/derivatives.yaml Lines 1136 to 1140 in 24a8915
.view , it will be impossible to backward through reshape op.
|
…d to be changed" [ghstack-poisoned]
…nged ghstack-source-id: 1342caf0a4703cc7e07da2f18c4ceebb51e581ff Pull Request resolved: #61466
# _unsafe_reshape_alias is used in cases where reshape can be performed as | ||
# a view operation and we want to preserve gradients. | ||
|
||
- func: _unsafe_reshape_alias(Tensor(a) self, int[] size, int[] strides) -> Tensor(a) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why do you call this one "_unsafe" ?
It should be an actual view op and be added to the list in tools/autograd/gen_view_or_inplace.py
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
_unsafe_reshape_alias
essentially does the same thing as view
does but without any of the checks to avoid recalculating shape and computing strides.
The original implementation I did to skip the view
overhead lost the autograd support (it just called alias_with_sizes_and_strides
directly from the reshape
operator instead of dispatching). So this is a first attempt at preserving autograd while still removing the unnecessary overhead of calculating strides (see #55126 for the motivation).
Is there a doc I can look at to find out more about view ops versus regular ops?
I don't expect this function to be called by end users or outside of the reshape
operator since it's skipping bounds checking.
…d to be changed" Trying a different approach to achieve a performance gain while retaining the ability to do autograd on the results when possible. ``` [<torch.utils.benchmark.utils.common.Measurement object at 0x7f45dd86b070> x.view(-1); setup: at::Tensor x=torch::empty({2,2} ); Median: 583.27 ns IQR: 39.04 ns (563.48 to 602.52) 852 measurements, 10000 runs per measurement, 1 thread] [<torch.utils.benchmark.utils.common.Measurement object at 0x7f45dd936d30> x.reshape(-1); setup: at::Tensor x=torch::empty({2,2} ); Median: 383.82 ns IQR: 19.56 ns (372.88 to 392.45) 130 measurements, 100000 runs per measurement, 1 thread] 56896 41036 <torch.utils.benchmark.utils.valgrind_wrapper.timer_interface.FunctionCounts object at 0x7f45b138cd00> 2000 ???:at::native::reshape(at::Tensor const&, c10::ArrayRef<long>) 1860 ???:torch::autograd::VariableType::(anonymous namespace)::_unsafe_reshape_alias(c10::DispatchKeySet, at::Tensor const&, c10::ArrayRef<long>, c10::ArrayRef<long>) 1660 ???:at::_ops::_unsafe_reshape_alias::call(at::Tensor const&, c10::ArrayRef<long>, c10::ArrayRef<long>) 1600 ???:at::Tensor at::native::alias_with_sizes_and_strides<c10::ArrayRef<long> >(at::Tensor const&, c10::ArrayRef<long> const&, c10::ArrayRef<long> const&) 1520 ???:at::_ops::reshape::call(at::Tensor const&, c10::ArrayRef<long>) 1240 ???:at::_ops::_unsafe_reshape_alias::redispatch(c10::DispatchKeySet, at::Tensor const&, c10::ArrayRef<long>, c10::ArrayRef<long>) 980 ???:void at::infer_size_impl<c10::SmallVector<long, 5u> >(c10::ArrayRef<long>, long, c10::SmallVector<long, 5u>&) 520 ???:at::shouldRunRecordFunction(bool*) 360 ???:c10::TensorImpl::~TensorImpl() ... -1520 ???:at::_ops::view::call(at::Tensor const&, c10::ArrayRef<long>) -1580 ???:torch::ADInplaceOrView::(anonymous namespace)::view(c10::DispatchKeySet, at::Tensor const&, c10::ArrayRef<long>) -1640 ???:c10::intrusive_ptr<c10::TensorImpl, c10::UndefinedTensorImpl>::reset_() -1680 ???:at::Tensor at::native::alias_with_sizes_and_strides<c10::SmallVector<long, 5u> >(at::Tensor const&, c10::SmallVector<long, 5u> const&, c10::SmallVector<long, 5u> const&) -1740 ???:torch::autograd::VariableType::(anonymous namespace)::view(c10::DispatchKeySet, at::Tensor const&, c10::ArrayRef<long>) -2080 ???:torch::autograd::make_variable_differentiable_view(at::Tensor const&, c10::optional<torc ... ewInfo>, c10::optional<torch::autograd::ViewInfo>, bool, torch::autograd::CreationMeta, bool) -2180 ???:torch::autograd::as_view(at::Tensor const&, at::Tensor const&, bool, bool, std::function<at::Tensor (at::Tensor const&)>, torch::autograd::CreationMeta, bool) -2240 ???:torch::autograd::DifferentiableViewMeta::DifferentiableViewMeta(c10::TensorImpl*, c10::o ... ad::ViewInfo>, c10::optional<torch::autograd::ViewInfo>, bool, torch::autograd::CreationMeta) -2640 ???:at::native::view(at::Tensor const&, c10::ArrayRef<long>) Total: -15860 ``` [ghstack-poisoned]
…nged ghstack-source-id: cb653b74e22d5b89dc8a6a715d95c3187cfbcaed Pull Request resolved: #61466
Tidied up the changes per the comments from @albanD and @ngimel. Created a new view operator I've tested that this works as expected for handling/propagating gradients and the performance is now within 24% of To further speed up
|
…d to be changed" Trying a different approach to achieve a performance gain while retaining the ability to do autograd on the results when possible. ``` [<torch.utils.benchmark.utils.common.Measurement object at 0x7f45dd86b070> x.view(-1); setup: at::Tensor x=torch::empty({2,2} ); Median: 583.27 ns IQR: 39.04 ns (563.48 to 602.52) 852 measurements, 10000 runs per measurement, 1 thread] [<torch.utils.benchmark.utils.common.Measurement object at 0x7f45dd936d30> x.reshape(-1); setup: at::Tensor x=torch::empty({2,2} ); Median: 383.82 ns IQR: 19.56 ns (372.88 to 392.45) 130 measurements, 100000 runs per measurement, 1 thread] 56896 41036 <torch.utils.benchmark.utils.valgrind_wrapper.timer_interface.FunctionCounts object at 0x7f45b138cd00> 2000 ???:at::native::reshape(at::Tensor const&, c10::ArrayRef<long>) 1860 ???:torch::autograd::VariableType::(anonymous namespace)::_unsafe_reshape_alias(c10::DispatchKeySet, at::Tensor const&, c10::ArrayRef<long>, c10::ArrayRef<long>) 1660 ???:at::_ops::_unsafe_reshape_alias::call(at::Tensor const&, c10::ArrayRef<long>, c10::ArrayRef<long>) 1600 ???:at::Tensor at::native::alias_with_sizes_and_strides<c10::ArrayRef<long> >(at::Tensor const&, c10::ArrayRef<long> const&, c10::ArrayRef<long> const&) 1520 ???:at::_ops::reshape::call(at::Tensor const&, c10::ArrayRef<long>) 1240 ???:at::_ops::_unsafe_reshape_alias::redispatch(c10::DispatchKeySet, at::Tensor const&, c10::ArrayRef<long>, c10::ArrayRef<long>) 980 ???:void at::infer_size_impl<c10::SmallVector<long, 5u> >(c10::ArrayRef<long>, long, c10::SmallVector<long, 5u>&) 520 ???:at::shouldRunRecordFunction(bool*) 360 ???:c10::TensorImpl::~TensorImpl() ... -1520 ???:at::_ops::view::call(at::Tensor const&, c10::ArrayRef<long>) -1580 ???:torch::ADInplaceOrView::(anonymous namespace)::view(c10::DispatchKeySet, at::Tensor const&, c10::ArrayRef<long>) -1640 ???:c10::intrusive_ptr<c10::TensorImpl, c10::UndefinedTensorImpl>::reset_() -1680 ???:at::Tensor at::native::alias_with_sizes_and_strides<c10::SmallVector<long, 5u> >(at::Tensor const&, c10::SmallVector<long, 5u> const&, c10::SmallVector<long, 5u> const&) -1740 ???:torch::autograd::VariableType::(anonymous namespace)::view(c10::DispatchKeySet, at::Tensor const&, c10::ArrayRef<long>) -2080 ???:torch::autograd::make_variable_differentiable_view(at::Tensor const&, c10::optional<torc ... ewInfo>, c10::optional<torch::autograd::ViewInfo>, bool, torch::autograd::CreationMeta, bool) -2180 ???:torch::autograd::as_view(at::Tensor const&, at::Tensor const&, bool, bool, std::function<at::Tensor (at::Tensor const&)>, torch::autograd::CreationMeta, bool) -2240 ???:torch::autograd::DifferentiableViewMeta::DifferentiableViewMeta(c10::TensorImpl*, c10::o ... ad::ViewInfo>, c10::optional<torch::autograd::ViewInfo>, bool, torch::autograd::CreationMeta) -2640 ???:at::native::view(at::Tensor const&, c10::ArrayRef<long>) Total: -15860 ``` [ghstack-poisoned]
…nged ghstack-source-id: e037b70866e2178513f424fbea4927f989553b02 Pull Request resolved: #61466
Tidied up the overview message with the latest details (see the first message), and made changes per comments in discussion with @albanD. Originally we had discussed simply delegating to I ended up implementing a private/internal operator that is used only by |
…d to be changed" ## Goal Per #55126 the performance of `reshape` is worse than `alias` in cases where they are performing the same operation (i.e. where reshape is returning a view) because `reshape` delegates to `view` and duplicates some of the operations (specifically `infer_size_dv` and `computeStride`). The goal of this pull-request is to reduce or remove the additional overhead that `reshape` has. ### Proposed Implementation Instead of using `view` we implement a private/internal operator (`_reshape_alias`) that `reshape` dispatches to which skips the relevant checks. This is functionally equivalent to `as_strided` however it is a lot simpler because it's specialized to this use-case, and importantly the `backward` implementation is a lot faster. Note that we have to dispatch (`reshape` is a composite operator) because `reshape` can return either a view or a copy of the Tensor depending on the parameters, and this complicates implementing a derivative/backward for `reshape`. ### Why not `as_strided`? Using `as_strided` directly slows down autograd. If we use a custom function equivalent to `_reshape_alias` but with a simpler backward function then `view` has the same performance as `reshape`. If we delegate to `as_strided` it is about 56% slower (and this holds against our custom function). This is also the reason we make an internal operator named `_reshape_alias` instead of exposing a new operator since this should only be used in the `reshape` case and it is effectively a more limited version of `view`, `alias`, and `as_strided`. ## Benchmarks In a micro-benchmark for `backward` running: ```cpp // Setup at::Tensor x=torch::empty({2,2}, torch::requires_grad(true)); // Benchmark loop // `reshape(-1)` replaced with a call to view(-1) for view baseline x.pow(4).reshape(-1).mean().backward(); ``` I also benchmarked simple operations without gradients using: ```cpp // Setup at::Tensor x=torch::empty({2,2}, torch::requires_grad(true)); // Benchmark loop x.reshape(-1) // replaced with a call to view(-1) for view baseline ``` Baselined to `view`: * Original `reshape`: `+3.3%` (without gradients `+20.8%`) * Using `as_strided`: `+55.1%` (without gradients `+1.0%`) * Using custom `_reshape_view`: `-1.0%` (without gradients `+6.2%`) In absolute terms (note the percentages above were generated comparing between runs/tests rather than to a single baseline): * Original `view`: `53.66 us` (without gradients `582.78 ns`) * Original `reshape`: `55.46 us` (without gradients `704.24 ns`) * Using `as_strided`: `83.24 us` (without gradients `576.49 ns`) * Using custom `_reshape_view`: `53.13 us` (without gradients `536.01 ns`) Note that these benchmarks perform a backwards operation as well. When compared without using gradient computation at all the performance differneces are more pronounced as this takes up more of the time. ### Original performance <details> <summary>Benchmark results</summary> ``` [<torch.utils.benchmark.utils.common.Measurement object at 0x7f0e4d393160> x.pow(4).view(-1).mean().backward(); setup: at::Tensor x=torch::empty({2,2}, torch::requires_grad(true)); Median: 53.66 us IQR: 2.70 us (52.54 to 55.24) 884 measurements, 100 runs per measurement, 1 thread] [<torch.utils.benchmark.utils.common.Measurement object at 0x7f0e2ebd4fa0> x.pow(4).reshape(-1).mean().backward(); setup: at::Tensor x=torch::empty({2,2}, torch::requires_grad(true)); Median: 55.46 us IQR: 2.61 us (54.39 to 57.01) 889 measurements, 100 runs per measurement, 1 thread] 2276116 2286256 <torch.utils.benchmark.utils.valgrind_wrapper.timer_interface.FunctionCounts object at 0x7f0e5b2e3e20> 2640 ???:at::detail::computeStride(c10::ArrayRef<long>, c10::ArrayRef<long>, c10::SmallVector<long, 5u> const&) 1920 ???:at::native::reshape(at::Tensor const&, c10::ArrayRef<long>) 1520 ???:at::_ops::reshape::call(at::Tensor const&, c10::ArrayRef<long>) 1040 ???:c10::SmallVectorImpl<long>::operator=(c10::SmallVectorImpl<long>&&) 980 ???:void at::infer_size_impl<c10::SmallVector<long, 5u> >(c10::ArrayRef<long>, long, c10::SmallVector<long, 5u>&) 720 ???:__tls_get_addr 520 ???:at::shouldRunRecordFunction(bool*) 520 ???:__memcpy_avx_unaligned_erms 200 ???:c10::impl::wrap_kernel_functor_unboxed_<c10::impl::detail::WrapFunctionIntoFunctor_<c10:: ... g>)>::call(c10::OperatorKernel*, c10::DispatchKeySet, at::Tensor const&, c10::ArrayRef<long>) 100 ???:c10::TensorImpl::strides() const 100 ???:c10::TensorImpl::sizes() const 100 ???:at::(anonymous namespace)::manager() 77 /tmp/benchmark_utils_jit_build__1626465284__8a34e7ff-cd37-4a82-be28-7f19e081e771/timer_cpp_7815557938202456331/timer_src.cpp:main 40 ???:c10::TensorImpl::numel() const -77 /tmp/benchmark_utils_jit_build__1626465284__8a34e7ff-cd37-4a82-be28-7f19e081e771/timer_cpp_8055217880649990171/timer_src.cpp:main -260 ???:at::native::view(at::Tensor const&, c10::ArrayRef<long>) Total: 10140 ``` ``` [<torch.utils.benchmark.utils.common.Measurement object at 0x7f850dd66c10> x.view(-1); setup: at::Tensor x=torch::empty({2,2}); Median: 582.78 ns IQR: 33.80 ns (573.80 to 607.61) 833 measurements, 10000 runs per measurement, 1 thread] [<torch.utils.benchmark.utils.common.Measurement object at 0x7f850de31e20> x.reshape(-1); setup: at::Tensor x=torch::empty({2,2}); Median: 704.24 ns IQR: 24.42 ns (697.20 to 721.62) 679 measurements, 10000 runs per measurement, 1 thread] 56896 67036 <torch.utils.benchmark.utils.valgrind_wrapper.timer_interface.FunctionCounts object at 0x7f84e1930bb0> 2640 ???:at::detail::computeStride(c10::ArrayRef<long>, c10::ArrayRef<long>, c10::SmallVector<long, 5u> const&) 1920 ???:at::native::reshape(at::Tensor const&, c10::ArrayRef<long>) 1520 ???:at::_ops::reshape::call(at::Tensor const&, c10::ArrayRef<long>) 1040 ???:c10::SmallVectorImpl<long>::operator=(c10::SmallVectorImpl<long>&&) 980 ???:void at::infer_size_impl<c10::SmallVector<long, 5u> >(c10::ArrayRef<long>, long, c10::SmallVector<long, 5u>&) 720 ???:__tls_get_addr 520 ???:at::shouldRunRecordFunction(bool*) 520 ???:__memcpy_avx_unaligned_erms 200 ???:c10::impl::wrap_kernel_functor_unboxed_<c10::impl::detail::WrapFunctionIntoFunctor_<c10:: ... g>)>::call(c10::OperatorKernel*, c10::DispatchKeySet, at::Tensor const&, c10::ArrayRef<long>) 100 ???:c10::TensorImpl::strides() const 100 ???:c10::TensorImpl::sizes() const 100 ???:at::(anonymous namespace)::manager() 76 /tmp/benchmark_utils_jit_build__1626466038__15fbbac0-2072-4459-8f8e-08121a905b99/timer_cpp_547407365342278353/timer_src.cpp:main 40 ???:c10::TensorImpl::numel() const -76 /tmp/benchmark_utils_jit_build__1626466038__15fbbac0-2072-4459-8f8e-08121a905b99/timer_cpp_3457873755756181226/timer_src.cpp:main -260 ???:at::native::view(at::Tensor const&, c10::ArrayRef<long>) Total: 10140 ``` </details> ### Using `as_strided` <details> <summary>Benchmark results</summary> ``` [<torch.utils.benchmark.utils.common.Measurement object at 0x7f8b13bb5b50> x.pow(4).view(-1).mean().backward(); setup: at::Tensor x=torch::empty({2,2}, torch::requires_grad(true)); Median: 53.37 us IQR: 3.15 us (51.73 to 54.88) 936 measurements, 100 runs per measurement, 1 thread] [<torch.utils.benchmark.utils.common.Measurement object at 0x7f8af55f8490> x.pow(4).reshape(-1).mean().backward(); setup: at::Tensor x=torch::empty({2,2}, torch::requires_grad(true)); Median: 83.24 us IQR: 4.05 us (81.20 to 85.25) 609 measurements, 100 runs per measurement, 1 thread] 2267916 2525061 <torch.utils.benchmark.utils.valgrind_wrapper.timer_interface.FunctionCounts object at 0x7f8af55f8e50> 31930 ???:_int_free 15940 ???:malloc 11595 ???:_int_malloc 10100 ???:torch::autograd::generated::details::as_strided_backward(at::Tensor, at::TensorGeometry, c10::ArrayRef<long>, c10::ArrayRef<long>, c10::optional<long>) 9360 ???:__tls_get_addr 8280 ???:free 8100 ???:torch::autograd::VariableType::(anonymous namespace)::as_strided(c10::DispatchKeySet, at::Tensor const&, c10::ArrayRef<long>, c10::ArrayRef<long>, c10::optional<long>) 4520 ???:c10::intrusive_ptr<c10::TensorImpl, c10::UndefinedTensorImpl>::reset_() 4080 ???:operator new(unsigned long) ... -780 ???:at::_ops::view::redispatch(c10::DispatchKeySet, at::Tensor const&, c10::ArrayRef<long>) -920 ???:c10::SmallVectorImpl<long>::operator=(c10::SmallVectorImpl<long> const&) -1220 ???:torch::autograd::generated::ViewBackward::apply(std::vector<at::Tensor, std::allocator<at::Tensor> >&&) -1520 ???:at::_ops::view::call(at::Tensor const&, c10::ArrayRef<long>) -1580 ???:torch::ADInplaceOrView::(anonymous namespace)::view(c10::DispatchKeySet, at::Tensor const&, c10::ArrayRef<long>) -1680 ???:at::Tensor at::native::alias_with_sizes_and_strides<c10::SmallVector<long, 5u> >(at::Tensor const&, c10::SmallVector<long, 5u> const&, c10::SmallVector<long, 5u> const&) -2560 ???:at::detail::computeStride(c10::ArrayRef<long>, c10::ArrayRef<long>, c10::SmallVector<long, 5u> const&) -2640 ???:at::native::view(at::Tensor const&, c10::ArrayRef<long>) -4860 ???:torch::autograd::VariableType::(anonymous namespace)::view(c10::DispatchKeySet, at::Tensor const&, c10::ArrayRef<long>) Total: 257145 ``` ``` [<torch.utils.benchmark.utils.common.Measurement object at 0x7f93176a0160> x.view(-1); setup: at::Tensor x=torch::empty({2,2}); Median: 570.55 ns IQR: 32.69 ns (552.87 to 585.56) 874 measurements, 10000 runs per measurement, 1 thread] [<torch.utils.benchmark.utils.common.Measurement object at 0x7f92f8f29490> x.reshape(-1); setup: at::Tensor x=torch::empty({2,2}); Median: 576.49 ns IQR: 37.95 ns (559.51 to 597.46) 861 measurements, 10000 runs per measurement, 1 thread] 56896 58556 <torch.utils.benchmark.utils.valgrind_wrapper.timer_interface.FunctionCounts object at 0x7f932556ca60> 2140 ???:at::native::reshape(at::Tensor const&, c10::ArrayRef<long>) 1940 ???:torch::autograd::VariableType::(anonymous namespace)::as_strided(c10::DispatchKeySet, at::Tensor const&, c10::ArrayRef<long>, c10::ArrayRef<long>, c10::optional<long>) 1880 ???:torch::ADInplaceOrView::(anonymous namespace)::as_strided(c10::DispatchKeySet, at::Tensor const&, c10::ArrayRef<long>, c10::ArrayRef<long>, c10::optional<long>) 1720 ???:at::_ops::as_strided::call(at::Tensor const&, c10::ArrayRef<long>, c10::ArrayRef<long>, c10::optional<long>) 1520 ???:at::_ops::reshape::call(at::Tensor const&, c10::ArrayRef<long>) 1400 ???:at::native::as_strided_tensorimpl(at::Tensor const&, c10::ArrayRef<long>, c10::ArrayRef<long>, c10::optional<long>) 1260 ???:at::_ops::as_strided::redispatch(c10::DispatchKeySet, at::Tensor const&, c10::ArrayRef<long>, c10::ArrayRef<long>, c10::optional<long>)'2 1260 ???:at::_ops::as_strided::redispatch(c10::DispatchKeySet, at::Tensor const&, c10::ArrayRef<long>, c10::ArrayRef<long>, c10::optional<long>) 980 ???:void at::infer_size_impl<c10::SmallVector<long, 5u> >(c10::ArrayRef<long>, long, c10::SmallVector<long, 5u>&) ... -620 ???:at::Tensor c10::Dispatcher::redispatch<at::Tensor, at::Tensor const&, c10::ArrayRef<long ... ::ArrayRef<long>)> const&, c10::DispatchKeySet, at::Tensor const&, c10::ArrayRef<long>) const -780 ???:at::_ops::view::redispatch(c10::DispatchKeySet, at::Tensor const&, c10::ArrayRef<long>)'2 -780 ???:at::_ops::view::redispatch(c10::DispatchKeySet, at::Tensor const&, c10::ArrayRef<long>) -920 ???:c10::SmallVectorImpl<long>::operator=(c10::SmallVectorImpl<long> const&) -1520 ???:at::_ops::view::call(at::Tensor const&, c10::ArrayRef<long>) -1580 ???:torch::ADInplaceOrView::(anonymous namespace)::view(c10::DispatchKeySet, at::Tensor const&, c10::ArrayRef<long>) -1680 ???:at::Tensor at::native::alias_with_sizes_and_strides<c10::SmallVector<long, 5u> >(at::Tensor const&, c10::SmallVector<long, 5u> const&, c10::SmallVector<long, 5u> const&) -1740 ???:torch::autograd::VariableType::(anonymous namespace)::view(c10::DispatchKeySet, at::Tensor const&, c10::ArrayRef<long>) -2640 ???:at::native::view(at::Tensor const&, c10::ArrayRef<long>) Total: 1660 ``` </details> ### Using custom function (`_reshape_alias`) <details> <summary>Benchmark results</summary> ``` [<torch.utils.benchmark.utils.common.Measurement object at 0x7f16861d6b50> x.pow(4).view(-1).mean().backward(); setup: at::Tensor x=torch::empty({2,2}, torch::requires_grad(true)); Median: 53.50 us IQR: 2.64 us (52.32 to 54.96) 906 measurements, 100 runs per measurement, 1 thread] [<torch.utils.benchmark.utils.common.Measurement object at 0x7f1667b2ed60> x.pow(4).reshape(-1).mean().backward(); setup: at::Tensor x=torch::empty({2,2}, torch::requires_grad(true)); Median: 53.13 us IQR: 3.40 us (51.72 to 55.13) 914 measurements, 100 runs per measurement, 1 thread] 2269736 2273236 <torch.utils.benchmark.utils.valgrind_wrapper.timer_interface.FunctionCounts object at 0x7f1693f8dc10> 5060 ???:torch::autograd::VariableType::(anonymous namespace)::_reshape_alias(c10::DispatchKeySet, at::Tensor const&, c10::ArrayRef<long>, c10::ArrayRef<long>) 2000 ???:at::native::reshape(at::Tensor const&, c10::ArrayRef<long>) 1780 ???:torch::ADInplaceOrView::(anonymous namespace)::_reshape_alias(c10::DispatchKeySet, at::Tensor const&, c10::ArrayRef<long>, c10::ArrayRef<long>) 1660 ???:at::_ops::_reshape_alias::call(at::Tensor const&, c10::ArrayRef<long>, c10::ArrayRef<long>) 1600 ???:at::Tensor at::native::alias_with_sizes_and_strides<c10::ArrayRef<long> >(at::Tensor const&, c10::ArrayRef<long> const&, c10::ArrayRef<long> const&) 1520 ???:at::_ops::reshape::call(at::Tensor const&, c10::ArrayRef<long>) 1240 ???:at::_ops::_reshape_alias::redispatch(c10::DispatchKeySet, at::Tensor const&, c10::ArrayRef<long>, c10::ArrayRef<long>)'2 1240 ???:at::_ops::_reshape_alias::redispatch(c10::DispatchKeySet, at::Tensor const&, c10::ArrayRef<long>, c10::ArrayRef<long>) 1220 ???:torch::autograd::generated::AliasToShapeBackward::apply(std::vector<at::Tensor, std::allocator<at::Tensor> >&&) ... -780 ???:at::_ops::view::redispatch(c10::DispatchKeySet, at::Tensor const&, c10::ArrayRef<long>)'2 -780 ???:at::_ops::view::redispatch(c10::DispatchKeySet, at::Tensor const&, c10::ArrayRef<long>) -920 ???:c10::SmallVectorImpl<long>::operator=(c10::SmallVectorImpl<long> const&) -1220 ???:torch::autograd::generated::ViewBackward::apply(std::vector<at::Tensor, std::allocator<at::Tensor> >&&) -1520 ???:at::_ops::view::call(at::Tensor const&, c10::ArrayRef<long>) -1580 ???:torch::ADInplaceOrView::(anonymous namespace)::view(c10::DispatchKeySet, at::Tensor const&, c10::ArrayRef<long>) -1680 ???:at::Tensor at::native::alias_with_sizes_and_strides<c10::SmallVector<long, 5u> >(at::Tensor const&, c10::SmallVector<long, 5u> const&, c10::SmallVector<long, 5u> const&) -2640 ???:at::native::view(at::Tensor const&, c10::ArrayRef<long>) -4860 ???:torch::autograd::VariableType::(anonymous namespace)::view(c10::DispatchKeySet, at::Tensor const&, c10::ArrayRef<long>) Total: 3500 ``` ``` [<torch.utils.benchmark.utils.common.Measurement object at 0x7f5287adfb20> x.view(-1); setup: at::Tensor x=torch::empty({2,2}); Median: 505.10 ns IQR: 20.04 ns (500.41 to 520.45) 944 measurements, 10000 runs per measurement, 1 thread] [<torch.utils.benchmark.utils.common.Measurement object at 0x7f526951b430> x.reshape(-1); setup: at::Tensor x=torch::empty({2,2}); Median: 536.01 ns IQR: 17.81 ns (531.34 to 549.16) 916 measurements, 10000 runs per measurement, 1 thread] 56896 60376 <torch.utils.benchmark.utils.valgrind_wrapper.timer_interface.FunctionCounts object at 0x7f5295896c10> 2000 ???:at::native::reshape(at::Tensor const&, c10::ArrayRef<long>) 1860 ???:torch::autograd::VariableType::(anonymous namespace)::_reshape_alias(c10::DispatchKeySet, at::Tensor const&, c10::ArrayRef<long>, c10::ArrayRef<long>) 1780 ???:torch::ADInplaceOrView::(anonymous namespace)::_reshape_alias(c10::DispatchKeySet, at::Tensor const&, c10::ArrayRef<long>, c10::ArrayRef<long>) 1660 ???:at::_ops::_reshape_alias::call(at::Tensor const&, c10::ArrayRef<long>, c10::ArrayRef<long>) 1600 ???:at::Tensor at::native::alias_with_sizes_and_strides<c10::ArrayRef<long> >(at::Tensor const&, c10::ArrayRef<long> const&, c10::ArrayRef<long> const&) 1520 ???:at::_ops::reshape::call(at::Tensor const&, c10::ArrayRef<long>) 1240 ???:at::_ops::_reshape_alias::redispatch(c10::DispatchKeySet, at::Tensor const&, c10::ArrayRef<long>, c10::ArrayRef<long>)'2 1240 ???:at::_ops::_reshape_alias::redispatch(c10::DispatchKeySet, at::Tensor const&, c10::ArrayRef<long>, c10::ArrayRef<long>) 980 ???:void at::infer_size_impl<c10::SmallVector<long, 5u> >(c10::ArrayRef<long>, long, c10::SmallVector<long, 5u>&) ... -620 ???:at::Tensor c10::Dispatcher::redispatch<at::Tensor, at::Tensor const&, c10::ArrayRef<long ... ::ArrayRef<long>)> const&, c10::DispatchKeySet, at::Tensor const&, c10::ArrayRef<long>) const -780 ???:at::_ops::view::redispatch(c10::DispatchKeySet, at::Tensor const&, c10::ArrayRef<long>)'2 -780 ???:at::_ops::view::redispatch(c10::DispatchKeySet, at::Tensor const&, c10::ArrayRef<long>) -920 ???:c10::SmallVectorImpl<long>::operator=(c10::SmallVectorImpl<long> const&) -1520 ???:at::_ops::view::call(at::Tensor const&, c10::ArrayRef<long>) -1580 ???:torch::ADInplaceOrView::(anonymous namespace)::view(c10::DispatchKeySet, at::Tensor const&, c10::ArrayRef<long>) -1680 ???:at::Tensor at::native::alias_with_sizes_and_strides<c10::SmallVector<long, 5u> >(at::Tensor const&, c10::SmallVector<long, 5u> const&, c10::SmallVector<long, 5u> const&) -1740 ???:torch::autograd::VariableType::(anonymous namespace)::view(c10::DispatchKeySet, at::Tensor const&, c10::ArrayRef<long>) -2640 ???:at::native::view(at::Tensor const&, c10::ArrayRef<long>) Total: 3480 ``` </details> [ghstack-poisoned]
…nged ghstack-source-id: 6b500b14d77fd90e90f180168de5772f28b4fadb Pull Request resolved: #61466
…d to be changed" ## Goal Per #55126 the performance of `reshape` is worse than `alias` in cases where they are performing the same operation (i.e. where reshape is returning a view) because `reshape` delegates to `view` and duplicates some of the operations (specifically `infer_size_dv` and `computeStride`). The goal of this pull-request is to reduce or remove the additional overhead that `reshape` has. ### Proposed Implementation Instead of using `view` we implement a private/internal operator (`_reshape_alias`) that `reshape` dispatches to which skips the relevant checks. This is functionally equivalent to `as_strided` however it is a lot simpler because it's specialized to this use-case, and importantly the `backward` implementation is a lot faster. Note that we have to dispatch (`reshape` is a composite operator) because `reshape` can return either a view or a copy of the Tensor depending on the parameters, and this complicates implementing a derivative/backward for `reshape`. ### Why not `as_strided`? Using `as_strided` directly slows down autograd. If we use a custom function equivalent to `_reshape_alias` but with a simpler backward function then `view` has the same performance as `reshape`. If we delegate to `as_strided` it is about 56% slower (and this holds against our custom function). This is also the reason we make an internal operator named `_reshape_alias` instead of exposing a new operator since this should only be used in the `reshape` case and it is effectively a more limited version of `view`, `alias`, and `as_strided`. ## Benchmarks In a micro-benchmark for `backward` running: ```cpp // Setup at::Tensor x=torch::empty({2,2}, torch::requires_grad(true)); // Benchmark loop // `reshape(-1)` replaced with a call to view(-1) for view baseline x.pow(4).reshape(-1).mean().backward(); ``` I also benchmarked simple operations without gradients using: ```cpp // Setup at::Tensor x=torch::empty({2,2}, torch::requires_grad(true)); // Benchmark loop x.reshape(-1) // replaced with a call to view(-1) for view baseline ``` Baselined to `view`: * Original `reshape`: `+3.3%` (without gradients `+20.8%`) * Using `as_strided`: `+55.1%` (without gradients `+1.0%`) * Using custom `_reshape_view`: `-1.0%` (without gradients `+6.2%`) In absolute terms (note the percentages above were generated comparing between runs/tests rather than to a single baseline): * Original `view`: `53.66 us` (without gradients `582.78 ns`) * Original `reshape`: `55.46 us` (without gradients `704.24 ns`) * Using `as_strided`: `83.24 us` (without gradients `576.49 ns`) * Using custom `_reshape_view`: `53.13 us` (without gradients `536.01 ns`) Note that these benchmarks perform a backwards operation as well. When compared without using gradient computation at all the performance differneces are more pronounced as this takes up more of the time. ### Original performance <details> <summary>Benchmark results</summary> ``` [<torch.utils.benchmark.utils.common.Measurement object at 0x7f0e4d393160> x.pow(4).view(-1).mean().backward(); setup: at::Tensor x=torch::empty({2,2}, torch::requires_grad(true)); Median: 53.66 us IQR: 2.70 us (52.54 to 55.24) 884 measurements, 100 runs per measurement, 1 thread] [<torch.utils.benchmark.utils.common.Measurement object at 0x7f0e2ebd4fa0> x.pow(4).reshape(-1).mean().backward(); setup: at::Tensor x=torch::empty({2,2}, torch::requires_grad(true)); Median: 55.46 us IQR: 2.61 us (54.39 to 57.01) 889 measurements, 100 runs per measurement, 1 thread] 2276116 2286256 <torch.utils.benchmark.utils.valgrind_wrapper.timer_interface.FunctionCounts object at 0x7f0e5b2e3e20> 2640 ???:at::detail::computeStride(c10::ArrayRef<long>, c10::ArrayRef<long>, c10::SmallVector<long, 5u> const&) 1920 ???:at::native::reshape(at::Tensor const&, c10::ArrayRef<long>) 1520 ???:at::_ops::reshape::call(at::Tensor const&, c10::ArrayRef<long>) 1040 ???:c10::SmallVectorImpl<long>::operator=(c10::SmallVectorImpl<long>&&) 980 ???:void at::infer_size_impl<c10::SmallVector<long, 5u> >(c10::ArrayRef<long>, long, c10::SmallVector<long, 5u>&) 720 ???:__tls_get_addr 520 ???:at::shouldRunRecordFunction(bool*) 520 ???:__memcpy_avx_unaligned_erms 200 ???:c10::impl::wrap_kernel_functor_unboxed_<c10::impl::detail::WrapFunctionIntoFunctor_<c10:: ... g>)>::call(c10::OperatorKernel*, c10::DispatchKeySet, at::Tensor const&, c10::ArrayRef<long>) 100 ???:c10::TensorImpl::strides() const 100 ???:c10::TensorImpl::sizes() const 100 ???:at::(anonymous namespace)::manager() 77 /tmp/benchmark_utils_jit_build__1626465284__8a34e7ff-cd37-4a82-be28-7f19e081e771/timer_cpp_7815557938202456331/timer_src.cpp:main 40 ???:c10::TensorImpl::numel() const -77 /tmp/benchmark_utils_jit_build__1626465284__8a34e7ff-cd37-4a82-be28-7f19e081e771/timer_cpp_8055217880649990171/timer_src.cpp:main -260 ???:at::native::view(at::Tensor const&, c10::ArrayRef<long>) Total: 10140 ``` ``` [<torch.utils.benchmark.utils.common.Measurement object at 0x7f850dd66c10> x.view(-1); setup: at::Tensor x=torch::empty({2,2}); Median: 582.78 ns IQR: 33.80 ns (573.80 to 607.61) 833 measurements, 10000 runs per measurement, 1 thread] [<torch.utils.benchmark.utils.common.Measurement object at 0x7f850de31e20> x.reshape(-1); setup: at::Tensor x=torch::empty({2,2}); Median: 704.24 ns IQR: 24.42 ns (697.20 to 721.62) 679 measurements, 10000 runs per measurement, 1 thread] 56896 67036 <torch.utils.benchmark.utils.valgrind_wrapper.timer_interface.FunctionCounts object at 0x7f84e1930bb0> 2640 ???:at::detail::computeStride(c10::ArrayRef<long>, c10::ArrayRef<long>, c10::SmallVector<long, 5u> const&) 1920 ???:at::native::reshape(at::Tensor const&, c10::ArrayRef<long>) 1520 ???:at::_ops::reshape::call(at::Tensor const&, c10::ArrayRef<long>) 1040 ???:c10::SmallVectorImpl<long>::operator=(c10::SmallVectorImpl<long>&&) 980 ???:void at::infer_size_impl<c10::SmallVector<long, 5u> >(c10::ArrayRef<long>, long, c10::SmallVector<long, 5u>&) 720 ???:__tls_get_addr 520 ???:at::shouldRunRecordFunction(bool*) 520 ???:__memcpy_avx_unaligned_erms 200 ???:c10::impl::wrap_kernel_functor_unboxed_<c10::impl::detail::WrapFunctionIntoFunctor_<c10:: ... g>)>::call(c10::OperatorKernel*, c10::DispatchKeySet, at::Tensor const&, c10::ArrayRef<long>) 100 ???:c10::TensorImpl::strides() const 100 ???:c10::TensorImpl::sizes() const 100 ???:at::(anonymous namespace)::manager() 76 /tmp/benchmark_utils_jit_build__1626466038__15fbbac0-2072-4459-8f8e-08121a905b99/timer_cpp_547407365342278353/timer_src.cpp:main 40 ???:c10::TensorImpl::numel() const -76 /tmp/benchmark_utils_jit_build__1626466038__15fbbac0-2072-4459-8f8e-08121a905b99/timer_cpp_3457873755756181226/timer_src.cpp:main -260 ???:at::native::view(at::Tensor const&, c10::ArrayRef<long>) Total: 10140 ``` </details> ### Using `as_strided` <details> <summary>Benchmark results</summary> ``` [<torch.utils.benchmark.utils.common.Measurement object at 0x7f8b13bb5b50> x.pow(4).view(-1).mean().backward(); setup: at::Tensor x=torch::empty({2,2}, torch::requires_grad(true)); Median: 53.37 us IQR: 3.15 us (51.73 to 54.88) 936 measurements, 100 runs per measurement, 1 thread] [<torch.utils.benchmark.utils.common.Measurement object at 0x7f8af55f8490> x.pow(4).reshape(-1).mean().backward(); setup: at::Tensor x=torch::empty({2,2}, torch::requires_grad(true)); Median: 83.24 us IQR: 4.05 us (81.20 to 85.25) 609 measurements, 100 runs per measurement, 1 thread] 2267916 2525061 <torch.utils.benchmark.utils.valgrind_wrapper.timer_interface.FunctionCounts object at 0x7f8af55f8e50> 31930 ???:_int_free 15940 ???:malloc 11595 ???:_int_malloc 10100 ???:torch::autograd::generated::details::as_strided_backward(at::Tensor, at::TensorGeometry, c10::ArrayRef<long>, c10::ArrayRef<long>, c10::optional<long>) 9360 ???:__tls_get_addr 8280 ???:free 8100 ???:torch::autograd::VariableType::(anonymous namespace)::as_strided(c10::DispatchKeySet, at::Tensor const&, c10::ArrayRef<long>, c10::ArrayRef<long>, c10::optional<long>) 4520 ???:c10::intrusive_ptr<c10::TensorImpl, c10::UndefinedTensorImpl>::reset_() 4080 ???:operator new(unsigned long) ... -780 ???:at::_ops::view::redispatch(c10::DispatchKeySet, at::Tensor const&, c10::ArrayRef<long>) -920 ???:c10::SmallVectorImpl<long>::operator=(c10::SmallVectorImpl<long> const&) -1220 ???:torch::autograd::generated::ViewBackward::apply(std::vector<at::Tensor, std::allocator<at::Tensor> >&&) -1520 ???:at::_ops::view::call(at::Tensor const&, c10::ArrayRef<long>) -1580 ???:torch::ADInplaceOrView::(anonymous namespace)::view(c10::DispatchKeySet, at::Tensor const&, c10::ArrayRef<long>) -1680 ???:at::Tensor at::native::alias_with_sizes_and_strides<c10::SmallVector<long, 5u> >(at::Tensor const&, c10::SmallVector<long, 5u> const&, c10::SmallVector<long, 5u> const&) -2560 ???:at::detail::computeStride(c10::ArrayRef<long>, c10::ArrayRef<long>, c10::SmallVector<long, 5u> const&) -2640 ???:at::native::view(at::Tensor const&, c10::ArrayRef<long>) -4860 ???:torch::autograd::VariableType::(anonymous namespace)::view(c10::DispatchKeySet, at::Tensor const&, c10::ArrayRef<long>) Total: 257145 ``` ``` [<torch.utils.benchmark.utils.common.Measurement object at 0x7f93176a0160> x.view(-1); setup: at::Tensor x=torch::empty({2,2}); Median: 570.55 ns IQR: 32.69 ns (552.87 to 585.56) 874 measurements, 10000 runs per measurement, 1 thread] [<torch.utils.benchmark.utils.common.Measurement object at 0x7f92f8f29490> x.reshape(-1); setup: at::Tensor x=torch::empty({2,2}); Median: 576.49 ns IQR: 37.95 ns (559.51 to 597.46) 861 measurements, 10000 runs per measurement, 1 thread] 56896 58556 <torch.utils.benchmark.utils.valgrind_wrapper.timer_interface.FunctionCounts object at 0x7f932556ca60> 2140 ???:at::native::reshape(at::Tensor const&, c10::ArrayRef<long>) 1940 ???:torch::autograd::VariableType::(anonymous namespace)::as_strided(c10::DispatchKeySet, at::Tensor const&, c10::ArrayRef<long>, c10::ArrayRef<long>, c10::optional<long>) 1880 ???:torch::ADInplaceOrView::(anonymous namespace)::as_strided(c10::DispatchKeySet, at::Tensor const&, c10::ArrayRef<long>, c10::ArrayRef<long>, c10::optional<long>) 1720 ???:at::_ops::as_strided::call(at::Tensor const&, c10::ArrayRef<long>, c10::ArrayRef<long>, c10::optional<long>) 1520 ???:at::_ops::reshape::call(at::Tensor const&, c10::ArrayRef<long>) 1400 ???:at::native::as_strided_tensorimpl(at::Tensor const&, c10::ArrayRef<long>, c10::ArrayRef<long>, c10::optional<long>) 1260 ???:at::_ops::as_strided::redispatch(c10::DispatchKeySet, at::Tensor const&, c10::ArrayRef<long>, c10::ArrayRef<long>, c10::optional<long>)'2 1260 ???:at::_ops::as_strided::redispatch(c10::DispatchKeySet, at::Tensor const&, c10::ArrayRef<long>, c10::ArrayRef<long>, c10::optional<long>) 980 ???:void at::infer_size_impl<c10::SmallVector<long, 5u> >(c10::ArrayRef<long>, long, c10::SmallVector<long, 5u>&) ... -620 ???:at::Tensor c10::Dispatcher::redispatch<at::Tensor, at::Tensor const&, c10::ArrayRef<long ... ::ArrayRef<long>)> const&, c10::DispatchKeySet, at::Tensor const&, c10::ArrayRef<long>) const -780 ???:at::_ops::view::redispatch(c10::DispatchKeySet, at::Tensor const&, c10::ArrayRef<long>)'2 -780 ???:at::_ops::view::redispatch(c10::DispatchKeySet, at::Tensor const&, c10::ArrayRef<long>) -920 ???:c10::SmallVectorImpl<long>::operator=(c10::SmallVectorImpl<long> const&) -1520 ???:at::_ops::view::call(at::Tensor const&, c10::ArrayRef<long>) -1580 ???:torch::ADInplaceOrView::(anonymous namespace)::view(c10::DispatchKeySet, at::Tensor const&, c10::ArrayRef<long>) -1680 ???:at::Tensor at::native::alias_with_sizes_and_strides<c10::SmallVector<long, 5u> >(at::Tensor const&, c10::SmallVector<long, 5u> const&, c10::SmallVector<long, 5u> const&) -1740 ???:torch::autograd::VariableType::(anonymous namespace)::view(c10::DispatchKeySet, at::Tensor const&, c10::ArrayRef<long>) -2640 ???:at::native::view(at::Tensor const&, c10::ArrayRef<long>) Total: 1660 ``` </details> ### Using custom function (`_reshape_alias`) <details> <summary>Benchmark results</summary> ``` [<torch.utils.benchmark.utils.common.Measurement object at 0x7f16861d6b50> x.pow(4).view(-1).mean().backward(); setup: at::Tensor x=torch::empty({2,2}, torch::requires_grad(true)); Median: 53.50 us IQR: 2.64 us (52.32 to 54.96) 906 measurements, 100 runs per measurement, 1 thread] [<torch.utils.benchmark.utils.common.Measurement object at 0x7f1667b2ed60> x.pow(4).reshape(-1).mean().backward(); setup: at::Tensor x=torch::empty({2,2}, torch::requires_grad(true)); Median: 53.13 us IQR: 3.40 us (51.72 to 55.13) 914 measurements, 100 runs per measurement, 1 thread] 2269736 2273236 <torch.utils.benchmark.utils.valgrind_wrapper.timer_interface.FunctionCounts object at 0x7f1693f8dc10> 5060 ???:torch::autograd::VariableType::(anonymous namespace)::_reshape_alias(c10::DispatchKeySet, at::Tensor const&, c10::ArrayRef<long>, c10::ArrayRef<long>) 2000 ???:at::native::reshape(at::Tensor const&, c10::ArrayRef<long>) 1780 ???:torch::ADInplaceOrView::(anonymous namespace)::_reshape_alias(c10::DispatchKeySet, at::Tensor const&, c10::ArrayRef<long>, c10::ArrayRef<long>) 1660 ???:at::_ops::_reshape_alias::call(at::Tensor const&, c10::ArrayRef<long>, c10::ArrayRef<long>) 1600 ???:at::Tensor at::native::alias_with_sizes_and_strides<c10::ArrayRef<long> >(at::Tensor const&, c10::ArrayRef<long> const&, c10::ArrayRef<long> const&) 1520 ???:at::_ops::reshape::call(at::Tensor const&, c10::ArrayRef<long>) 1240 ???:at::_ops::_reshape_alias::redispatch(c10::DispatchKeySet, at::Tensor const&, c10::ArrayRef<long>, c10::ArrayRef<long>)'2 1240 ???:at::_ops::_reshape_alias::redispatch(c10::DispatchKeySet, at::Tensor const&, c10::ArrayRef<long>, c10::ArrayRef<long>) 1220 ???:torch::autograd::generated::AliasToShapeBackward::apply(std::vector<at::Tensor, std::allocator<at::Tensor> >&&) ... -780 ???:at::_ops::view::redispatch(c10::DispatchKeySet, at::Tensor const&, c10::ArrayRef<long>)'2 -780 ???:at::_ops::view::redispatch(c10::DispatchKeySet, at::Tensor const&, c10::ArrayRef<long>) -920 ???:c10::SmallVectorImpl<long>::operator=(c10::SmallVectorImpl<long> const&) -1220 ???:torch::autograd::generated::ViewBackward::apply(std::vector<at::Tensor, std::allocator<at::Tensor> >&&) -1520 ???:at::_ops::view::call(at::Tensor const&, c10::ArrayRef<long>) -1580 ???:torch::ADInplaceOrView::(anonymous namespace)::view(c10::DispatchKeySet, at::Tensor const&, c10::ArrayRef<long>) -1680 ???:at::Tensor at::native::alias_with_sizes_and_strides<c10::SmallVector<long, 5u> >(at::Tensor const&, c10::SmallVector<long, 5u> const&, c10::SmallVector<long, 5u> const&) -2640 ???:at::native::view(at::Tensor const&, c10::ArrayRef<long>) -4860 ???:torch::autograd::VariableType::(anonymous namespace)::view(c10::DispatchKeySet, at::Tensor const&, c10::ArrayRef<long>) Total: 3500 ``` ``` [<torch.utils.benchmark.utils.common.Measurement object at 0x7f5287adfb20> x.view(-1); setup: at::Tensor x=torch::empty({2,2}); Median: 505.10 ns IQR: 20.04 ns (500.41 to 520.45) 944 measurements, 10000 runs per measurement, 1 thread] [<torch.utils.benchmark.utils.common.Measurement object at 0x7f526951b430> x.reshape(-1); setup: at::Tensor x=torch::empty({2,2}); Median: 536.01 ns IQR: 17.81 ns (531.34 to 549.16) 916 measurements, 10000 runs per measurement, 1 thread] 56896 60376 <torch.utils.benchmark.utils.valgrind_wrapper.timer_interface.FunctionCounts object at 0x7f5295896c10> 2000 ???:at::native::reshape(at::Tensor const&, c10::ArrayRef<long>) 1860 ???:torch::autograd::VariableType::(anonymous namespace)::_reshape_alias(c10::DispatchKeySet, at::Tensor const&, c10::ArrayRef<long>, c10::ArrayRef<long>) 1780 ???:torch::ADInplaceOrView::(anonymous namespace)::_reshape_alias(c10::DispatchKeySet, at::Tensor const&, c10::ArrayRef<long>, c10::ArrayRef<long>) 1660 ???:at::_ops::_reshape_alias::call(at::Tensor const&, c10::ArrayRef<long>, c10::ArrayRef<long>) 1600 ???:at::Tensor at::native::alias_with_sizes_and_strides<c10::ArrayRef<long> >(at::Tensor const&, c10::ArrayRef<long> const&, c10::ArrayRef<long> const&) 1520 ???:at::_ops::reshape::call(at::Tensor const&, c10::ArrayRef<long>) 1240 ???:at::_ops::_reshape_alias::redispatch(c10::DispatchKeySet, at::Tensor const&, c10::ArrayRef<long>, c10::ArrayRef<long>)'2 1240 ???:at::_ops::_reshape_alias::redispatch(c10::DispatchKeySet, at::Tensor const&, c10::ArrayRef<long>, c10::ArrayRef<long>) 980 ???:void at::infer_size_impl<c10::SmallVector<long, 5u> >(c10::ArrayRef<long>, long, c10::SmallVector<long, 5u>&) ... -620 ???:at::Tensor c10::Dispatcher::redispatch<at::Tensor, at::Tensor const&, c10::ArrayRef<long ... ::ArrayRef<long>)> const&, c10::DispatchKeySet, at::Tensor const&, c10::ArrayRef<long>) const -780 ???:at::_ops::view::redispatch(c10::DispatchKeySet, at::Tensor const&, c10::ArrayRef<long>)'2 -780 ???:at::_ops::view::redispatch(c10::DispatchKeySet, at::Tensor const&, c10::ArrayRef<long>) -920 ???:c10::SmallVectorImpl<long>::operator=(c10::SmallVectorImpl<long> const&) -1520 ???:at::_ops::view::call(at::Tensor const&, c10::ArrayRef<long>) -1580 ???:torch::ADInplaceOrView::(anonymous namespace)::view(c10::DispatchKeySet, at::Tensor const&, c10::ArrayRef<long>) -1680 ???:at::Tensor at::native::alias_with_sizes_and_strides<c10::SmallVector<long, 5u> >(at::Tensor const&, c10::SmallVector<long, 5u> const&, c10::SmallVector<long, 5u> const&) -1740 ???:torch::autograd::VariableType::(anonymous namespace)::view(c10::DispatchKeySet, at::Tensor const&, c10::ArrayRef<long>) -2640 ???:at::native::view(at::Tensor const&, c10::ArrayRef<long>) Total: 3480 ``` </details> [ghstack-poisoned]
…nged ghstack-source-id: ea9ba4cb058a47c1c9e3963b4a18f76738e5fb47 Pull Request resolved: #61466
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
Just small details
at::DimVector inferred_size = at::infer_size_dv(size, self.numel()); | ||
auto stride = at::detail::computeStride(self.sizes(), | ||
self.strides(), | ||
inferred_size); | ||
self.strides(), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think the indentation was actually better before no?
@@ -92,6 +92,7 @@ | |||
'data', 'is_leaf', 'output_nr', '_version', 'requires_grad_', 'retains_grad', 'set_', | |||
'_fw_primal', 'fake_quantize_per_tensor_affine_cachemask', | |||
'fake_quantize_per_channel_affine_cachemask', | |||
'_reshape_alias', |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Note that this is not necessary as the leading _ means that it won't be dumped in the main namespace.
So if you wanted to write the c++ tests your added in python, you could remove that.
Whichever way you prefer works.
@laurencer has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator. |
@laurencer merged this pull request in adb73d3. |
Stack from ghstack:
Goal
Per #55126 the performance of
reshape
is worse thanalias
in cases where they are performing the same operation (i.e. where reshape is returning a view) becausereshape
delegates toview
and duplicates some of the operations (specificallyinfer_size_dv
andcomputeStride
).The goal of this pull-request is to reduce or remove the additional overhead that
reshape
has.Proposed Implementation
Instead of using
view
we implement a private/internal operator (_reshape_alias
) thatreshape
dispatches to which skips the relevant checks. This is functionally equivalent toas_strided
however it is a lot simpler because it's specialized to this use-case, and importantly thebackward
implementation is a lot faster.Note that we have to dispatch (
reshape
is a composite operator) becausereshape
can return either a view or a copy of the Tensor depending on the parameters, and this complicates implementing a derivative/backward forreshape
.Why not
as_strided
?Using
as_strided
directly slows down autograd. If we use a custom function equivalent to_reshape_alias
but with a simpler backward function thenview
has the same performance asreshape
. If we delegate toas_strided
it is about 56% slower (and this holds against our custom function).This is also the reason we make an internal operator named
_reshape_alias
instead of exposing a new operator since this should only be used in thereshape
case and it is effectively a more limited version ofview
,alias
, andas_strided
.Benchmarks
In a micro-benchmark for
backward
running:I also benchmarked simple operations without gradients using:
Baselined to
view
:reshape
:+3.3%
(without gradients+20.8%
)as_strided
:+55.1%
(without gradients+1.0%
)_reshape_view
:-1.0%
(without gradients+6.2%
)In absolute terms (note the percentages above were generated comparing between runs/tests rather than to a single baseline):
view
:53.66 us
(without gradients582.78 ns
)reshape
:55.46 us
(without gradients704.24 ns
)as_strided
:83.24 us
(without gradients576.49 ns
)_reshape_view
:53.13 us
(without gradients536.01 ns
)Note that these benchmarks perform a backwards operation as well. When compared without using gradient computation at all the performance differneces are more pronounced as this takes up more of the time.
Original performance
Benchmark results
Using
as_strided
Benchmark results
Using custom function (
_reshape_alias
)Benchmark results
Differential Revision: D29792126