Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement Tensor.new_empty_strided(sizes, strides, *, dtype, device, requires_grad) #47225

Closed
wants to merge 8 commits into from
9 changes: 9 additions & 0 deletions aten/src/ATen/native/TensorFactories.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -341,6 +341,15 @@ Tensor new_empty(
return at::empty(size, self.options().merge_in(options));
}

Tensor new_empty_strided(
const Tensor& self,
IntArrayRef size,
IntArrayRef stride,
const TensorOptions& options
) {
return at::empty_strided(size, stride, self.options().merge_in(options));
}

// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ eye ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

Tensor eye(int64_t n, const TensorOptions& options) {
Expand Down
3 changes: 3 additions & 0 deletions aten/src/ATen/native/native_functions.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -1629,6 +1629,9 @@
#use_c10_dispatcher: full
variants: method

- func: new_empty_strided(Tensor self, int[] size, int[] stride, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
variants: method

- func: new_full(Tensor self, int[] size, Scalar fill_value, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
use_c10_dispatcher: hacky_wrapper_for_legacy_signatures
variants: method
Expand Down
19 changes: 19 additions & 0 deletions test/test_tensor_creation_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -680,6 +680,25 @@ def test_empty_strided(self, device):
self.assertEqual(empty_strided.shape, as_strided.shape)
self.assertEqual(empty_strided.stride(), as_strided.stride())

def test_new_empty_strided(self, device):
def _test(sizes, strides, dtype):
x = torch.zeros(5, 5, dtype=dtype, device=device)
result = x.new_empty_strided(sizes, strides)
expected = torch.empty_strided(sizes, strides, dtype=x.dtype, device=x.device)
self.assertEqual(result.shape, expected.shape)
self.assertEqual(result.stride(), expected.stride())
self.assertEqual(result.dtype, expected.dtype)
self.assertEqual(result.device, expected.device)

_test([2, 3], [3, 1], torch.float)
_test([5, 3], [0, 1], torch.int)
_test([], [], torch.float)

# Some really weird cases
for shape in [(2, 3, 4), (0, 2, 0)]:
for strides in [(12, 4, 1), (2, 4, 6), (0, 0, 0)]:
_test(shape, strides, torch.float)

def test_strided_mismatched_stride_shape(self, device):
for shape, strides in [((1, ), ()), ((1, 2), (1, ))]:
with self.assertRaisesRegex(RuntimeError, "mismatch in length of strides and shape"):
Expand Down
22 changes: 22 additions & 0 deletions torch/_tensor_docs.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,28 @@ def add_docstr_all(method, docstr):

""".format(**new_common_args))

add_docstr_all('new_empty_strided',
r"""
new_empty_strided(size, stride, dtype=None, device=None, requires_grad=False) -> Tensor

Returns a Tensor of size :attr:`size` and strides :attr:`stride` filled with
uninitialized data. By default, the returned Tensor has the same
:class:`torch.dtype` and :class:`torch.device` as this tensor.

Args:
{dtype}
{device}
{requires_grad}

Example::

>>> tensor = torch.ones(())
>>> tensor.new_empty_strided((2, 3), (3, 1))
tensor([[ 5.8182e-18, 4.5765e-41, -1.0545e+30],
[ 3.0949e-41, 4.4842e-44, 0.0000e+00]])

""".format(**new_common_args))

add_docstr_all('new_ones',
r"""
new_ones(size, dtype=None, device=None, requires_grad=False) -> Tensor
Expand Down
1 change: 1 addition & 0 deletions torch/overrides.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,7 @@ def get_ignored_functions() -> Set[Callable]:
Tensor.new,
Tensor.new_tensor,
Tensor.new_empty,
Tensor.new_empty_strided,
Tensor.new_zeros,
Tensor.new_ones,
Tensor.new_full,
Expand Down