Skip to content

Commit

Permalink
Add support for torch.tensor_split to accept a tensor for indices a…
Browse files Browse the repository at this point in the history
…rgument (#49169)

Summary:
Pull Request resolved: #49169

Trying to solve PR request #47479.
This diff tries to overload method `torch.tensor_split` to also accept a tensor for argument `split_size_or_sections` which currently accepts a python list or int. The motivation is to avoid converting a tensor to a list so that when tracing a model/module the tensor operations can be recorded.

Implementation is following the diff that originally added the `tensor_split` method D24166164 (ef4817f).

Test Plan:
```
buck test caffe2/test:torch -- tensor_split
```
https://www.internalfb.com/intern/testinfra/testconsole/testrun/5910974550563805/

```
buck test caffe2/test:others -- tensor_split
```
https://www.internalfb.com/intern/testinfra/testconsole/testrun/1688849905082678/

Reviewed By: mruberry

Differential Revision: D25440885

fbshipit-source-id: 6705dc551279e3a5eb1e5ec1ede2728eab85ffb1
  • Loading branch information
Edson Romero authored and facebook-github-bot committed Dec 21, 2020
1 parent 96aed20 commit 5c3788d
Show file tree
Hide file tree
Showing 6 changed files with 117 additions and 42 deletions.
1 change: 1 addition & 0 deletions aten/src/ATen/core/NamedRegistrations.cpp
Expand Up @@ -462,6 +462,7 @@ TORCH_LIBRARY_IMPL(aten, Named, m) {
m.impl("tanh_", CppFunction::makeFallthrough());
m.impl("tensor_split.indices", CppFunction::makeFallthrough());
m.impl("tensor_split.sections", CppFunction::makeFallthrough());
m.impl("tensor_split.tensor_indices_or_sections", CppFunction::makeFallthrough());
m.impl("threshold", CppFunction::makeFallthrough());
m.impl("threshold.out", CppFunction::makeFallthrough());
m.impl("threshold_", CppFunction::makeFallthrough());
Expand Down
26 changes: 24 additions & 2 deletions aten/src/ATen/native/TensorShape.cpp
Expand Up @@ -515,7 +515,7 @@ std::vector<Tensor> chunk(const Tensor& self, int64_t chunks, int64_t dim) {
}

std::vector<Tensor> tensor_split(const Tensor& self, int64_t sections, int64_t dim) {
TORCH_CHECK(self.dim() > 0, "expected at least a 1-dimensional tensor");
TORCH_CHECK(self.dim() > 0, "tensor_split expected at least a 1-dimensional tensor, but got a tensor with ", self.dim()," dims");
int64_t dim_ = maybe_wrap_dim(dim, self.dim());
TORCH_CHECK(sections > 0, "number of sections must be larger than 0, got ", sections);
std::vector<Tensor> splits(sections);
Expand All @@ -531,7 +531,7 @@ std::vector<Tensor> tensor_split(const Tensor& self, int64_t sections, int64_t d
}

std::vector<Tensor> tensor_split(const Tensor& self, IntArrayRef indices, int64_t dim) {
TORCH_CHECK(self.dim() > 0, "expected at least a 1-dimensional tensor");
TORCH_CHECK(self.dim() > 0, "tensor_split expected at least a 1-dimensional tensor, but got a tensor with ", self.dim()," dims");
int64_t dim_ = maybe_wrap_dim(dim, self.dim());
int64_t num_indices = indices.size();
std::vector<Tensor> splits(num_indices + 1);
Expand All @@ -545,6 +545,28 @@ std::vector<Tensor> tensor_split(const Tensor& self, IntArrayRef indices, int64_
return splits;
}

std::vector<Tensor> tensor_split(const Tensor& self, const Tensor& tensor_indices_or_sections, int64_t dim) {
TORCH_CHECK(self.dim() > 0, "tensor_split expected at least a 1-dimensional tensor, but got a tensor with ", self.dim()," dims");
auto split_device = tensor_indices_or_sections.device();
TORCH_CHECK(split_device == kCPU,
"tensor_split expected tensor_indices_or_sections to be on cpu, but it's on ", split_device);
auto split_dtype = tensor_indices_or_sections.scalar_type();
TORCH_CHECK(split_dtype == at::kLong,
"tensor_split expected tensor_indices_or_sections to have dtype of long, but got ", split_dtype);
auto split_dim = tensor_indices_or_sections.dim();
TORCH_CHECK(split_dim == 1 || split_dim == 0,
"tensor_split expected tensor_indices_or_sections to be a zero-dimensional or one-dimensional tensor, but got a tensor with ", split_dim, " dims");

if (split_dim == 0) {
int64_t sections = tensor_indices_or_sections.item<int64_t>();
return self.tensor_split(sections, dim);
} else {
auto indices_data = tensor_indices_or_sections.data_ptr<int64_t>();
std::vector<int64_t> indices(indices_data, indices_data + tensor_indices_or_sections.numel());
return self.tensor_split(indices, dim);
}
}

std::vector<Tensor> unsafe_chunk(const Tensor& self, int64_t chunks, int64_t dim) {
TORCH_CHECK(self.dim() > 0,
"chunk expects at least a 1-dimensional tensor");
Expand Down
4 changes: 4 additions & 0 deletions aten/src/ATen/native/native_functions.yaml
Expand Up @@ -1086,6 +1086,10 @@
use_c10_dispatcher: full
variants: function, method

- func: tensor_split.tensor_indices_or_sections(Tensor(a) self, Tensor tensor_indices_or_sections, int dim=0) -> Tensor(a)[]
use_c10_dispatcher: full
variants: function, method

- func: clamp(Tensor self, Scalar? min=None, Scalar? max=None) -> Tensor
use_c10_dispatcher: full
variants: function, method
Expand Down
48 changes: 36 additions & 12 deletions test/test_view_ops.py
Expand Up @@ -1143,12 +1143,16 @@ def test_tensor_split_sections(self, device, dtype):
for dim in range(-a.dim(), a.dim()):
for sections in range(1, 2 * a.size(dim)):
msg = f'input_size {input_size}, sections {sections}, dim {dim}'
result = torch.tensor_split(a, sections, dim)
for result_item in result:
self.assertEqual(result_item.device, torch.device(device), msg=msg)
self.assertEqual(result_item.dtype, dtype, msg=msg)
result1 = torch.tensor_split(a, sections, dim)
result2 = torch.tensor_split(a, torch.tensor(sections, dtype=torch.int64), dim)
for r1, r2 in zip(result1, result2):
self.assertEqual(r1.device, torch.device(device), msg=msg)
self.assertEqual(r1.dtype, dtype, msg=msg)
self.assertEqual(r2.device, torch.device(device), msg=msg)
self.assertEqual(r2.dtype, dtype, msg=msg)
result_n = np.array_split(a_n, sections, dim)
self.assertEqual(result_n, result, msg=msg)
self.assertEqual(result_n, result1, msg=msg)
self.assertEqual(result_n, result2, msg=msg)

@onlyOnCPUAndCUDA
# Skip BFloat16 since numpy does not support it
Expand Down Expand Up @@ -1181,23 +1185,31 @@ def test_tensor_split_indices(self, device, dtype):
a_n = a.cpu().numpy()
for dim in range(-a.dim(), a.dim()):
for indices in indices_args:
result = torch.tensor_split(a, indices, dim)
result_1 = torch.tensor_split(a, indices, dim)
result_2 = torch.tensor_split(a, torch.tensor(indices, dtype=torch.int64), dim)

msg = f'input_size {input_size}, indices {indices}, dim {dim}'
for result_item in result:
self.assertEqual(result_item.device, torch.device(device), msg=msg)
self.assertEqual(result_item.dtype, dtype, msg=msg)
for r1, r2 in zip(result_1, result_2):
self.assertEqual(r1.device, torch.device(device), msg=msg)
self.assertEqual(r1.dtype, dtype, msg=msg)
self.assertEqual(r2.device, torch.device(device), msg=msg)
self.assertEqual(r2.dtype, dtype, msg=msg)

result_n = np.array_split(a_n, indices, dim)
self.assertEqual(result_n, result, msg=msg)
self.assertEqual(result_n, result_1, msg=msg)
self.assertEqual(result_n, result_2, msg=msg)

@onlyOnCPUAndCUDA
def test_tensor_split_errors(self, device):
S = 10
test_cases = [
# input size, sections or indices, dim, error type, error message, numpy error type
[(S,), 10, 1, IndexError, r'Dimension out of range', IndexError],
[(), 10, 0, RuntimeError, r'expected at least a 1-dimensional tensor', IndexError],
[(), 10, 0, RuntimeError, r'tensor_split expected at least a 1-dimensional tensor, '
+ 'but got a tensor with 0 dims', IndexError],
[(S,), (10,), 1, IndexError, r'Dimension out of range', IndexError],
[(), (10,), 0, RuntimeError, r'expected at least a 1-dimensional tensor', IndexError],
[(), (10,), 0, RuntimeError, r'tensor_split expected at least a 1-dimensional tensor, '
+ 'but got a tensor with 0 dims', IndexError],
[(S,), 0, 0, RuntimeError, r'number of sections must be larger than 0, got 0', ValueError],
[(S,), -1, 0, RuntimeError, r'number of sections must be larger than 0, got -1', ValueError],
]
Expand All @@ -1206,9 +1218,21 @@ def test_tensor_split_errors(self, device):
msg = f'input_size {input_size}, sections_or_indices {sections_or_indices}, dim {dim}'
with self.assertRaisesRegex(err, err_msg, msg=msg):
torch.tensor_split(a, sections_or_indices, dim)
with self.assertRaisesRegex(err, err_msg, msg=msg):
torch.tensor_split(a, torch.tensor(sections_or_indices), dim)
with self.assertRaises(numpy_err, msg=msg):
np.array_split(a.cpu().numpy(), sections_or_indices, dim)

# addtional tests for tensor_split with tensor_indices_or_sections
with self.assertRaisesRegex(RuntimeError,
r'tensor_split expected tensor_indices_or_sections to have dtype of long, but got Float'):
torch.tensor_split(a, torch.tensor(1.1), dim)

with self.assertRaisesRegex(RuntimeError,
r'tensor_split expected tensor_indices_or_sections to be a'
+ ' zero-dimensional or one-dimensional tensor, but got a tensor with 2 dims'):
torch.tensor_split(torch.rand(S, device=device), torch.tensor(((1,),)), 0)

def test_resize_all_dtypes_and_devices(self, device):
shape = (2, 2)
for dt in torch.testing.get_all_dtypes():
Expand Down
59 changes: 31 additions & 28 deletions torch/_torch_docs.py
Expand Up @@ -1252,19 +1252,22 @@ def merge_dicts(*dicts):
Args:
input (Tensor): the tensor to split
indices_or_sections (int or (list(int))):
If :attr:`indices_or_sections` is an integer ``n``, :attr:`input` is split
into ``n`` sections along dimension :attr:`dim`. If :attr:`input` is divisible
by ``n`` along dimension :attr:`dim`, each section will be of equal size,
:code:`input.size(dim) / n`. If :attr:`input` is not divisible by ``n``, the
sizes of the first :code:`int(input.size(dim) % n)` sections will have size
:code:`int(input.size(dim) / n) + 1`, and the rest will have size
:code:`int(input.size(dim) / n)`.
If :attr:`indices_or_sections` is a list of ints, :attr:`input` is split along
dimension :attr:`dim` at each of the indices in the list. For instance,
:code:`indices_or_sections=[2, 3]` and :code:`dim=0` would result in the tensors
:code:`input[:2]`, :code:`input[2:3]`, and :code:`input[3:]`.
indices_or_sections (Tensor, int or list or tuple of ints):
If :attr:`indices_or_sections` is an integer ``n`` or a zero dimensional long tensor
with value ``n``, :attr:`input` is split into ``n`` sections along dimension :attr:`dim`.
If :attr:`input` is divisible by ``n`` along dimension :attr:`dim`, each
section will be of equal size, :code:`input.size(dim) / n`. If :attr:`input`
is not divisible by ``n``, the sizes of the first :code:`int(input.size(dim) % n)`
sections will have size :code:`int(input.size(dim) / n) + 1`, and the rest will
have size :code:`int(input.size(dim) / n)`.
If :attr:`indices_or_sections` is a list or tuple of ints, or a one-dimensional long
tensor, then :attr:`input` is split along dimension :attr:`dim` at each of the indices
in the list, tuple or tensor. For instance, :code:`indices_or_sections=[2, 3]` and :code:`dim=0`
would result in the tensors :code:`input[:2]`, :code:`input[2:3]`, and :code:`input[3:]`.
If indices_or_sections is a tensor, it must be a zero-dimensional or one-dimensional
long tensor on the CPU.
dim (int, optional): dimension along which to split the tensor. Default: ``0``
Expand Down Expand Up @@ -5889,7 +5892,7 @@ def merge_dicts(*dicts):
[2, 6]],
[[1, 5],
[3, 7]]])
[3, 7]]])
""".format(**common_args))

add_docstr(torch.swapaxes, r"""
Expand Down Expand Up @@ -5919,7 +5922,7 @@ def merge_dicts(*dicts):
[2, 6]],
[[1, 5],
[3, 7]]])
[3, 7]]])
""".format(**common_args))

add_docstr(torch.narrow,
Expand Down Expand Up @@ -6486,15 +6489,15 @@ def merge_dicts(*dicts):
r"""
float_power(input, exponent, *, out=None) -> Tensor
Raises :attr:`input` to the power of :attr:`exponent`, elementwise, in double precision.
If neither input is complex returns a ``torch.float64`` tensor,
Raises :attr:`input` to the power of :attr:`exponent`, elementwise, in double precision.
If neither input is complex returns a ``torch.float64`` tensor,
and if one or more inputs is complex returns a ``torch.complex128`` tensor.
.. note::
This function always computes in double precision, unlike :func:`torch.pow`,
.. note::
This function always computes in double precision, unlike :func:`torch.pow`,
which implements more typical :ref:`type promotion <type-promotion-doc>`.
This is useful when the computation needs to be performed in a wider or more precise dtype,
or the results of the computation may contain fractional values not representable in the input dtypes,
This is useful when the computation needs to be performed in a wider or more precise dtype,
or the results of the computation may contain fractional values not representable in the input dtypes,
like when an integer base is raised to a negative integer exponent.
Args:
Expand Down Expand Up @@ -9845,21 +9848,21 @@ def merge_dicts(*dicts):
add_docstr(torch.tile, r"""
tile(input, reps) -> Tensor
Constructs a tensor by repeating the elements of :attr:`input`.
Constructs a tensor by repeating the elements of :attr:`input`.
The :attr:`reps` argument specifies the number of repetitions
in each dimension.
If :attr:`reps` specifies fewer dimensions than :attr:`input` has, then
ones are prepended to :attr:`reps` until all dimensions are specified.
For example, if :attr:`input` has shape (8, 6, 4, 2) and :attr:`reps`
is (2, 2), then :attr:`reps` is treated as (1, 1, 2, 2).
is (2, 2), then :attr:`reps` is treated as (1, 1, 2, 2).
Analogously, if :attr:`input` has fewer dimensions than :attr:`reps`
specifies, then :attr:`input` is treated as if it were unsqueezed at
dimension zero until it has as many dimensions as :attr:`reps` specifies.
Analogously, if :attr:`input` has fewer dimensions than :attr:`reps`
specifies, then :attr:`input` is treated as if it were unsqueezed at
dimension zero until it has as many dimensions as :attr:`reps` specifies.
For example, if :attr:`input` has shape (4, 2) and :attr:`reps`
is (3, 3, 2, 2), then :attr:`input` is treated as if it had the
shape (1, 1, 4, 2).
is (3, 3, 2, 2), then :attr:`input` is treated as if it had the
shape (1, 1, 4, 2).
.. note::
Expand Down
21 changes: 21 additions & 0 deletions torch/testing/_internal/common_methods_invocations.py
Expand Up @@ -271,6 +271,20 @@ def __init__(self,
# outside a function's domain.
self._domain_eps = 1e-5

def sample_inputs_tensor_split(op_info, device, dtype, requires_grad):
return (SampleInput(make_tensor((S, S, S), device, dtype,
low=None, high=None,
requires_grad=requires_grad),
args=(torch.tensor([1, 2, 3]),),),
SampleInput(make_tensor((S, S, S), device, dtype,
low=None, high=None,
requires_grad=requires_grad),
args=(torch.tensor(1),),),
SampleInput(make_tensor((S, S, S), device, dtype,
low=None, high=None,
requires_grad=requires_grad),
args=(torch.tensor([1, 2, 3]),),
kwargs=dict(dim=1)),)

def sample_inputs_addmm(op_info, device, dtype, requires_grad):
return (SampleInput((make_tensor((S, S), device, dtype,
Expand Down Expand Up @@ -868,6 +882,13 @@ def sample_inputs_pinverse(op_info, device, dtype, requires_grad=False):
device_type='cpu', dtypes=[torch.cfloat, torch.cdouble],
active_if=(IS_MACOS or IS_WINDOWS)),
)),
OpInfo('tensor_split',
dtypes=all_types_and_complex_and(torch.bool),
dtypesIfCPU=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.float16),
dtypesIfCUDA=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.float16),
supports_tensor_out=False,
test_inplace_grad=False,
sample_inputs_func=sample_inputs_tensor_split,),
UnaryUfuncInfo('exp2',
ref=np_unary_ufunc_integer_promotion_wrapper(np.exp2),
dtypes=all_types_and(torch.bool, torch.half),
Expand Down

0 comments on commit 5c3788d

Please sign in to comment.