Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Removed overhead from reshape() call if tensor doesn't need to be changed #61466

Closed
wants to merge 6 commits into from
50 changes: 37 additions & 13 deletions aten/src/ATen/native/TensorShape.cpp
Expand Up @@ -1040,21 +1040,43 @@ Tensor reshape(const Tensor& self, IntArrayRef proposed_shape) {
return at::_mkldnn_reshape(self, shape);
}

auto stride =
at::detail::computeStride(self.sizes(), self.strides(), shape);
// `computeStride` returns the proper strides to use if this
// `reshape` can be just a view.
//
// NB: Even though we have viewable geometry and the target strides here,
// we do not just call `as_strided` on `self` because the backward
// for `as_strided` is not as efficient as that of `view` (since the
// former is meant to handle general cases).
// `computeStride` returns the proper strides to use if this
// `reshape` can be just a view.
auto stride = at::detail::computeStride(self.sizes(), self.strides(), shape);

// NB: Even though we have viewable geometry and the target strides here,
// we do not just call `as_strided` on `self` because the backward
// for `as_strided` is not as efficient as that of `view` (since the
// former is meant to handle general cases).
//
// Similarly we don't call `view` because it duplicates some of the work
// we've already done, and instead call our internal/private operator
// `_reshape_alias` that essentially does the same thing as `view` and
// `as_strided` without any of the extra overhead.
if (stride.has_value()) {
return self.view(shape);
// Temporary check to revert to the old behavior/view in cases where the
// device is not supported (e.g. for XLA the operation is not supported
// so we use `view` instead).
//
// We need to do the checks here instead of in `native_functions.yaml`
// to preserve backwards compatibility.
if (! self.is_xla()) {
return self._reshape_alias(shape, stride.value());
} else {
return self.view(shape);
}
}
return at::_unsafe_view(self.clone(at::MemoryFormat::Contiguous), shape);
}

Tensor _reshape_alias(const Tensor& self, IntArrayRef sizes, IntArrayRef strides) {
// This is only used by `reshape` in cases where it would otherwise have dispatched
// to `view`. This removes the overhead of calling `view` which duplicates some of
// the work that's already been done (`infer_size_dv` and `computeStride`).

return alias_with_sizes_and_strides(self, sizes, strides);
}

Tensor reshape_as(const Tensor& self, const Tensor& other) {
return self.reshape(other.sizes());
}
Expand Down Expand Up @@ -2152,11 +2174,13 @@ Tensor numpy_T(const Tensor &self) {
return self.permute(transpose_dims);
}

Tensor view(const Tensor& self, IntArrayRef size) {
Tensor view(const Tensor& self,
IntArrayRef size) {

at::DimVector inferred_size = at::infer_size_dv(size, self.numel());
auto stride = at::detail::computeStride(self.sizes(),
self.strides(),
inferred_size);
self.strides(),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the indentation was actually better before no?

inferred_size);
TORCH_CHECK(stride.has_value(), "view size is "
"not compatible with input tensor's size and stride (at least one dimension"
" spans across two contiguous subspaces). Use .reshape(...) instead.");
Expand Down
11 changes: 11 additions & 0 deletions aten/src/ATen/native/native_functions.yaml
Expand Up @@ -3443,6 +3443,17 @@
device_check: NoCheck
device_guard: False

# NOTE [ _reshape_alias ] is meant to be used in the implementation of reshape.
# They are not user-facing, hence the leading underscore. Please don't use it
# anywhere else.
- func: _reshape_alias(Tensor(a) self, int[] size, int[] stride) -> Tensor(a)
variants: function, method
device_check: NoCheck
device_guard: False
dispatch:
CPU, CUDA, Meta, QuantizedCPU, QuantizedCUDA: _reshape_alias
albanD marked this conversation as resolved.
Show resolved Hide resolved
# We don't need to support mkldnn since this is handled explicitly by the reshape operator.

- func: _mkldnn_reshape(Tensor self, int[] shape) -> Tensor
device_check: NoCheck
device_guard: False
Expand Down
25 changes: 25 additions & 0 deletions test/cpp/api/tensor.cpp
Expand Up @@ -1117,3 +1117,28 @@ TEST(TensorTest, StdDimension) {
ASSERT_EQ(torch::std(x, 0, /*unbiased=*/true).numel(), 3);
ASSERT_EQ(std::get<0>(torch::std_mean(x, 0, /*unbiased=*/true)).numel(), 3);
}

// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
TEST(TensorTest, ReshapeAlias) {
// Tests the behavior of the _reshape_alias private operator so
// that it matches the behavior of as_strided and view.
auto x = torch::randn({3, 3});
ASSERT_TRUE(torch::equal(
torch::_reshape_alias(x, {2, 2}, {1, 2}),
torch::as_strided(x, {2, 2}, {1, 2})
));
ASSERT_TRUE(torch::equal(
torch::_reshape_alias(x, {9}, {1}),
x.view({-1})
));

// Test that the backward works fine.
auto y = torch::randn({3, 3}, torch::requires_grad(true));
auto z = torch::clone(y).detach().requires_grad_(true);
(y * y).view({-1}).mean().backward();
torch::_reshape_alias((z * z), {9}, {1}).mean().backward();
ASSERT_TRUE(torch::equal(
y.grad(),
z.grad()
));
}
16 changes: 16 additions & 0 deletions test/test_view_ops.py
Expand Up @@ -1355,6 +1355,22 @@ def test_view(self, device):
self.assertEqual(tensor.view(6, 2, 1), contig_tensor.view(6, 2, 1))
self.assertEqual(tensor.view(1, 6, 2, 1), contig_tensor.view(1, 6, 2, 1))

@dtypes(*torch.testing.get_all_dtypes())
def test_reshape_view_semantics(self, device, dtype):
tensor = make_tensor((15, 4), device, dtype)
target = (20, 3)

# Cases where the tensor can be returned as a view.
view_tensor = tensor.reshape(target)
self.assertEqual((view_tensor.size()), target)
self.assertEqual(tensor.storage().data_ptr(), view_tensor.storage().data_ptr())

# Cases where the tensor must be copied (transpose makes it non-contiguous forcing
# the copy).
copy_tensor = tensor.transpose(0, 1).reshape(target)
self.assertEqual(copy_tensor.size(), target)
self.assertNotEqual(tensor.storage().data_ptr(), copy_tensor.storage().data_ptr())

def test_contiguous(self, device):
x = torch.randn(1, 16, 5, 5, device=device)
self.assertTrue(x.is_contiguous())
Expand Down
3 changes: 3 additions & 0 deletions tools/autograd/derivatives.yaml
Expand Up @@ -1141,6 +1141,9 @@
# making it impossible (hard) to detect when it is actually a view.
# - name: reshape(Tensor self, IntArrayRef shape)

- name: _reshape_alias(Tensor(a) self, int[] size, int[] stride) -> Tensor(a)
self: grad.reshape(self.sizes())

- name: round(Tensor self) -> Tensor
self: zeros_like(grad)

Expand Down
1 change: 1 addition & 0 deletions tools/autograd/gen_inplace_or_view_type.py
Expand Up @@ -58,6 +58,7 @@
# discrete anyways.
# FIXME: clone indices on construction.
'sparse_coo_tensor_with_dims_and_tensors': 'values',
'_reshape_alias': 'self',
}

for key in VIEW_FUNCTIONS_WITH_METADATA_CHANGE:
Expand Down
1 change: 1 addition & 0 deletions tools/autograd/gen_python_functions.py
Expand Up @@ -92,6 +92,7 @@
'data', 'is_leaf', 'output_nr', '_version', 'requires_grad_', 'retains_grad', 'set_',
'_fw_primal', 'fake_quantize_per_tensor_affine_cachemask',
'fake_quantize_per_channel_affine_cachemask',
'_reshape_alias',
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note that this is not necessary as the leading _ means that it won't be dumped in the main namespace.
So if you wanted to write the c++ tests your added in python, you could remove that.
Whichever way you prefer works.

]

SKIP_PYTHON_BINDINGS = list(map(lambda pattern: re.compile(rf'^{pattern}$'), _SKIP_PYTHON_BINDINGS))
Expand Down
2 changes: 1 addition & 1 deletion tools/autograd/gen_variable_type.py
Expand Up @@ -101,7 +101,7 @@
'diag', 'masked_scatter', 'masked_select', 'index_fill', 'trace', 'polar', 'cumsum', 'rsub',
'eig', 'lerp', 'linalg_vector_norm', 'cumprod', 'prod', 'index_copy', 'lu', 'unfold', 'unfold_backward',
'index', 'masked_fill', 'cross', 'lu_unpack', 'renorm', '_conj_physical',
'scatter', 'scatter_add', 'sigmoid', 'sigmoid_backward', 'conj_physical_', '_neg_view'
'scatter', 'scatter_add', 'sigmoid', 'sigmoid_backward', 'conj_physical_', '_neg_view', '_reshape_alias'
}

GRADIENT_IMPLEMENTED_FOR_SPARSE_COMPLEX = {
Expand Down