From 5c3788d5d76f64f6708e0b79f40b1cf45276625a Mon Sep 17 00:00:00 2001 From: Edson Romero Date: Sun, 20 Dec 2020 21:41:55 -0800 Subject: [PATCH] Add support for torch.tensor_split to accept a tensor for `indices` argument (#49169) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/49169 Trying to solve PR request https://github.com/pytorch/pytorch/issues/47479. This diff tries to overload method `torch.tensor_split` to also accept a tensor for argument `split_size_or_sections` which currently accepts a python list or int. The motivation is to avoid converting a tensor to a list so that when tracing a model/module the tensor operations can be recorded. Implementation is following the diff that originally added the `tensor_split` method D24166164 (https://github.com/pytorch/pytorch/commit/ef4817fe5a16ba9969562911c5363736a1003bb0). Test Plan: ``` buck test caffe2/test:torch -- tensor_split ``` https://www.internalfb.com/intern/testinfra/testconsole/testrun/5910974550563805/ ``` buck test caffe2/test:others -- tensor_split ``` https://www.internalfb.com/intern/testinfra/testconsole/testrun/1688849905082678/ Reviewed By: mruberry Differential Revision: D25440885 fbshipit-source-id: 6705dc551279e3a5eb1e5ec1ede2728eab85ffb1 --- aten/src/ATen/core/NamedRegistrations.cpp | 1 + aten/src/ATen/native/TensorShape.cpp | 26 +++++++- aten/src/ATen/native/native_functions.yaml | 4 ++ test/test_view_ops.py | 48 +++++++++++---- torch/_torch_docs.py | 59 ++++++++++--------- .../_internal/common_methods_invocations.py | 21 +++++++ 6 files changed, 117 insertions(+), 42 deletions(-) diff --git a/aten/src/ATen/core/NamedRegistrations.cpp b/aten/src/ATen/core/NamedRegistrations.cpp index 50d53d872e66..c4c1b1ecc9ba 100644 --- a/aten/src/ATen/core/NamedRegistrations.cpp +++ b/aten/src/ATen/core/NamedRegistrations.cpp @@ -462,6 +462,7 @@ TORCH_LIBRARY_IMPL(aten, Named, m) { m.impl("tanh_", CppFunction::makeFallthrough()); m.impl("tensor_split.indices", CppFunction::makeFallthrough()); m.impl("tensor_split.sections", CppFunction::makeFallthrough()); + m.impl("tensor_split.tensor_indices_or_sections", CppFunction::makeFallthrough()); m.impl("threshold", CppFunction::makeFallthrough()); m.impl("threshold.out", CppFunction::makeFallthrough()); m.impl("threshold_", CppFunction::makeFallthrough()); diff --git a/aten/src/ATen/native/TensorShape.cpp b/aten/src/ATen/native/TensorShape.cpp index eda688ad6e1d..324985734f2d 100644 --- a/aten/src/ATen/native/TensorShape.cpp +++ b/aten/src/ATen/native/TensorShape.cpp @@ -515,7 +515,7 @@ std::vector chunk(const Tensor& self, int64_t chunks, int64_t dim) { } std::vector tensor_split(const Tensor& self, int64_t sections, int64_t dim) { - TORCH_CHECK(self.dim() > 0, "expected at least a 1-dimensional tensor"); + TORCH_CHECK(self.dim() > 0, "tensor_split expected at least a 1-dimensional tensor, but got a tensor with ", self.dim()," dims"); int64_t dim_ = maybe_wrap_dim(dim, self.dim()); TORCH_CHECK(sections > 0, "number of sections must be larger than 0, got ", sections); std::vector splits(sections); @@ -531,7 +531,7 @@ std::vector tensor_split(const Tensor& self, int64_t sections, int64_t d } std::vector tensor_split(const Tensor& self, IntArrayRef indices, int64_t dim) { - TORCH_CHECK(self.dim() > 0, "expected at least a 1-dimensional tensor"); + TORCH_CHECK(self.dim() > 0, "tensor_split expected at least a 1-dimensional tensor, but got a tensor with ", self.dim()," dims"); int64_t dim_ = maybe_wrap_dim(dim, self.dim()); int64_t num_indices = indices.size(); std::vector splits(num_indices + 1); @@ -545,6 +545,28 @@ std::vector tensor_split(const Tensor& self, IntArrayRef indices, int64_ return splits; } +std::vector tensor_split(const Tensor& self, const Tensor& tensor_indices_or_sections, int64_t dim) { + TORCH_CHECK(self.dim() > 0, "tensor_split expected at least a 1-dimensional tensor, but got a tensor with ", self.dim()," dims"); + auto split_device = tensor_indices_or_sections.device(); + TORCH_CHECK(split_device == kCPU, + "tensor_split expected tensor_indices_or_sections to be on cpu, but it's on ", split_device); + auto split_dtype = tensor_indices_or_sections.scalar_type(); + TORCH_CHECK(split_dtype == at::kLong, + "tensor_split expected tensor_indices_or_sections to have dtype of long, but got ", split_dtype); + auto split_dim = tensor_indices_or_sections.dim(); + TORCH_CHECK(split_dim == 1 || split_dim == 0, + "tensor_split expected tensor_indices_or_sections to be a zero-dimensional or one-dimensional tensor, but got a tensor with ", split_dim, " dims"); + + if (split_dim == 0) { + int64_t sections = tensor_indices_or_sections.item(); + return self.tensor_split(sections, dim); + } else { + auto indices_data = tensor_indices_or_sections.data_ptr(); + std::vector indices(indices_data, indices_data + tensor_indices_or_sections.numel()); + return self.tensor_split(indices, dim); + } +} + std::vector unsafe_chunk(const Tensor& self, int64_t chunks, int64_t dim) { TORCH_CHECK(self.dim() > 0, "chunk expects at least a 1-dimensional tensor"); diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml index 66698d6b67ad..3cde7e3605ce 100644 --- a/aten/src/ATen/native/native_functions.yaml +++ b/aten/src/ATen/native/native_functions.yaml @@ -1086,6 +1086,10 @@ use_c10_dispatcher: full variants: function, method +- func: tensor_split.tensor_indices_or_sections(Tensor(a) self, Tensor tensor_indices_or_sections, int dim=0) -> Tensor(a)[] + use_c10_dispatcher: full + variants: function, method + - func: clamp(Tensor self, Scalar? min=None, Scalar? max=None) -> Tensor use_c10_dispatcher: full variants: function, method diff --git a/test/test_view_ops.py b/test/test_view_ops.py index 15f1bcd8183f..18633bbfb8ac 100644 --- a/test/test_view_ops.py +++ b/test/test_view_ops.py @@ -1143,12 +1143,16 @@ def test_tensor_split_sections(self, device, dtype): for dim in range(-a.dim(), a.dim()): for sections in range(1, 2 * a.size(dim)): msg = f'input_size {input_size}, sections {sections}, dim {dim}' - result = torch.tensor_split(a, sections, dim) - for result_item in result: - self.assertEqual(result_item.device, torch.device(device), msg=msg) - self.assertEqual(result_item.dtype, dtype, msg=msg) + result1 = torch.tensor_split(a, sections, dim) + result2 = torch.tensor_split(a, torch.tensor(sections, dtype=torch.int64), dim) + for r1, r2 in zip(result1, result2): + self.assertEqual(r1.device, torch.device(device), msg=msg) + self.assertEqual(r1.dtype, dtype, msg=msg) + self.assertEqual(r2.device, torch.device(device), msg=msg) + self.assertEqual(r2.dtype, dtype, msg=msg) result_n = np.array_split(a_n, sections, dim) - self.assertEqual(result_n, result, msg=msg) + self.assertEqual(result_n, result1, msg=msg) + self.assertEqual(result_n, result2, msg=msg) @onlyOnCPUAndCUDA # Skip BFloat16 since numpy does not support it @@ -1181,13 +1185,19 @@ def test_tensor_split_indices(self, device, dtype): a_n = a.cpu().numpy() for dim in range(-a.dim(), a.dim()): for indices in indices_args: - result = torch.tensor_split(a, indices, dim) + result_1 = torch.tensor_split(a, indices, dim) + result_2 = torch.tensor_split(a, torch.tensor(indices, dtype=torch.int64), dim) + msg = f'input_size {input_size}, indices {indices}, dim {dim}' - for result_item in result: - self.assertEqual(result_item.device, torch.device(device), msg=msg) - self.assertEqual(result_item.dtype, dtype, msg=msg) + for r1, r2 in zip(result_1, result_2): + self.assertEqual(r1.device, torch.device(device), msg=msg) + self.assertEqual(r1.dtype, dtype, msg=msg) + self.assertEqual(r2.device, torch.device(device), msg=msg) + self.assertEqual(r2.dtype, dtype, msg=msg) + result_n = np.array_split(a_n, indices, dim) - self.assertEqual(result_n, result, msg=msg) + self.assertEqual(result_n, result_1, msg=msg) + self.assertEqual(result_n, result_2, msg=msg) @onlyOnCPUAndCUDA def test_tensor_split_errors(self, device): @@ -1195,9 +1205,11 @@ def test_tensor_split_errors(self, device): test_cases = [ # input size, sections or indices, dim, error type, error message, numpy error type [(S,), 10, 1, IndexError, r'Dimension out of range', IndexError], - [(), 10, 0, RuntimeError, r'expected at least a 1-dimensional tensor', IndexError], + [(), 10, 0, RuntimeError, r'tensor_split expected at least a 1-dimensional tensor, ' + + 'but got a tensor with 0 dims', IndexError], [(S,), (10,), 1, IndexError, r'Dimension out of range', IndexError], - [(), (10,), 0, RuntimeError, r'expected at least a 1-dimensional tensor', IndexError], + [(), (10,), 0, RuntimeError, r'tensor_split expected at least a 1-dimensional tensor, ' + + 'but got a tensor with 0 dims', IndexError], [(S,), 0, 0, RuntimeError, r'number of sections must be larger than 0, got 0', ValueError], [(S,), -1, 0, RuntimeError, r'number of sections must be larger than 0, got -1', ValueError], ] @@ -1206,9 +1218,21 @@ def test_tensor_split_errors(self, device): msg = f'input_size {input_size}, sections_or_indices {sections_or_indices}, dim {dim}' with self.assertRaisesRegex(err, err_msg, msg=msg): torch.tensor_split(a, sections_or_indices, dim) + with self.assertRaisesRegex(err, err_msg, msg=msg): + torch.tensor_split(a, torch.tensor(sections_or_indices), dim) with self.assertRaises(numpy_err, msg=msg): np.array_split(a.cpu().numpy(), sections_or_indices, dim) + # addtional tests for tensor_split with tensor_indices_or_sections + with self.assertRaisesRegex(RuntimeError, + r'tensor_split expected tensor_indices_or_sections to have dtype of long, but got Float'): + torch.tensor_split(a, torch.tensor(1.1), dim) + + with self.assertRaisesRegex(RuntimeError, + r'tensor_split expected tensor_indices_or_sections to be a' + + ' zero-dimensional or one-dimensional tensor, but got a tensor with 2 dims'): + torch.tensor_split(torch.rand(S, device=device), torch.tensor(((1,),)), 0) + def test_resize_all_dtypes_and_devices(self, device): shape = (2, 2) for dt in torch.testing.get_all_dtypes(): diff --git a/torch/_torch_docs.py b/torch/_torch_docs.py index b038a5f96c36..747693f53d4c 100644 --- a/torch/_torch_docs.py +++ b/torch/_torch_docs.py @@ -1252,19 +1252,22 @@ def merge_dicts(*dicts): Args: input (Tensor): the tensor to split - indices_or_sections (int or (list(int))): - If :attr:`indices_or_sections` is an integer ``n``, :attr:`input` is split - into ``n`` sections along dimension :attr:`dim`. If :attr:`input` is divisible - by ``n`` along dimension :attr:`dim`, each section will be of equal size, - :code:`input.size(dim) / n`. If :attr:`input` is not divisible by ``n``, the - sizes of the first :code:`int(input.size(dim) % n)` sections will have size - :code:`int(input.size(dim) / n) + 1`, and the rest will have size - :code:`int(input.size(dim) / n)`. - - If :attr:`indices_or_sections` is a list of ints, :attr:`input` is split along - dimension :attr:`dim` at each of the indices in the list. For instance, - :code:`indices_or_sections=[2, 3]` and :code:`dim=0` would result in the tensors - :code:`input[:2]`, :code:`input[2:3]`, and :code:`input[3:]`. + indices_or_sections (Tensor, int or list or tuple of ints): + If :attr:`indices_or_sections` is an integer ``n`` or a zero dimensional long tensor + with value ``n``, :attr:`input` is split into ``n`` sections along dimension :attr:`dim`. + If :attr:`input` is divisible by ``n`` along dimension :attr:`dim`, each + section will be of equal size, :code:`input.size(dim) / n`. If :attr:`input` + is not divisible by ``n``, the sizes of the first :code:`int(input.size(dim) % n)` + sections will have size :code:`int(input.size(dim) / n) + 1`, and the rest will + have size :code:`int(input.size(dim) / n)`. + + If :attr:`indices_or_sections` is a list or tuple of ints, or a one-dimensional long + tensor, then :attr:`input` is split along dimension :attr:`dim` at each of the indices + in the list, tuple or tensor. For instance, :code:`indices_or_sections=[2, 3]` and :code:`dim=0` + would result in the tensors :code:`input[:2]`, :code:`input[2:3]`, and :code:`input[3:]`. + + If indices_or_sections is a tensor, it must be a zero-dimensional or one-dimensional + long tensor on the CPU. dim (int, optional): dimension along which to split the tensor. Default: ``0`` @@ -5889,7 +5892,7 @@ def merge_dicts(*dicts): [2, 6]], [[1, 5], - [3, 7]]]) + [3, 7]]]) """.format(**common_args)) add_docstr(torch.swapaxes, r""" @@ -5919,7 +5922,7 @@ def merge_dicts(*dicts): [2, 6]], [[1, 5], - [3, 7]]]) + [3, 7]]]) """.format(**common_args)) add_docstr(torch.narrow, @@ -6486,15 +6489,15 @@ def merge_dicts(*dicts): r""" float_power(input, exponent, *, out=None) -> Tensor -Raises :attr:`input` to the power of :attr:`exponent`, elementwise, in double precision. -If neither input is complex returns a ``torch.float64`` tensor, +Raises :attr:`input` to the power of :attr:`exponent`, elementwise, in double precision. +If neither input is complex returns a ``torch.float64`` tensor, and if one or more inputs is complex returns a ``torch.complex128`` tensor. -.. note:: - This function always computes in double precision, unlike :func:`torch.pow`, +.. note:: + This function always computes in double precision, unlike :func:`torch.pow`, which implements more typical :ref:`type promotion `. - This is useful when the computation needs to be performed in a wider or more precise dtype, - or the results of the computation may contain fractional values not representable in the input dtypes, + This is useful when the computation needs to be performed in a wider or more precise dtype, + or the results of the computation may contain fractional values not representable in the input dtypes, like when an integer base is raised to a negative integer exponent. Args: @@ -9845,21 +9848,21 @@ def merge_dicts(*dicts): add_docstr(torch.tile, r""" tile(input, reps) -> Tensor -Constructs a tensor by repeating the elements of :attr:`input`. +Constructs a tensor by repeating the elements of :attr:`input`. The :attr:`reps` argument specifies the number of repetitions in each dimension. If :attr:`reps` specifies fewer dimensions than :attr:`input` has, then ones are prepended to :attr:`reps` until all dimensions are specified. For example, if :attr:`input` has shape (8, 6, 4, 2) and :attr:`reps` -is (2, 2), then :attr:`reps` is treated as (1, 1, 2, 2). +is (2, 2), then :attr:`reps` is treated as (1, 1, 2, 2). -Analogously, if :attr:`input` has fewer dimensions than :attr:`reps` -specifies, then :attr:`input` is treated as if it were unsqueezed at -dimension zero until it has as many dimensions as :attr:`reps` specifies. +Analogously, if :attr:`input` has fewer dimensions than :attr:`reps` +specifies, then :attr:`input` is treated as if it were unsqueezed at +dimension zero until it has as many dimensions as :attr:`reps` specifies. For example, if :attr:`input` has shape (4, 2) and :attr:`reps` -is (3, 3, 2, 2), then :attr:`input` is treated as if it had the -shape (1, 1, 4, 2). +is (3, 3, 2, 2), then :attr:`input` is treated as if it had the +shape (1, 1, 4, 2). .. note:: diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py index f234ddbd0170..b7c8ed9567a1 100644 --- a/torch/testing/_internal/common_methods_invocations.py +++ b/torch/testing/_internal/common_methods_invocations.py @@ -271,6 +271,20 @@ def __init__(self, # outside a function's domain. self._domain_eps = 1e-5 +def sample_inputs_tensor_split(op_info, device, dtype, requires_grad): + return (SampleInput(make_tensor((S, S, S), device, dtype, + low=None, high=None, + requires_grad=requires_grad), + args=(torch.tensor([1, 2, 3]),),), + SampleInput(make_tensor((S, S, S), device, dtype, + low=None, high=None, + requires_grad=requires_grad), + args=(torch.tensor(1),),), + SampleInput(make_tensor((S, S, S), device, dtype, + low=None, high=None, + requires_grad=requires_grad), + args=(torch.tensor([1, 2, 3]),), + kwargs=dict(dim=1)),) def sample_inputs_addmm(op_info, device, dtype, requires_grad): return (SampleInput((make_tensor((S, S), device, dtype, @@ -868,6 +882,13 @@ def sample_inputs_pinverse(op_info, device, dtype, requires_grad=False): device_type='cpu', dtypes=[torch.cfloat, torch.cdouble], active_if=(IS_MACOS or IS_WINDOWS)), )), + OpInfo('tensor_split', + dtypes=all_types_and_complex_and(torch.bool), + dtypesIfCPU=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.float16), + dtypesIfCUDA=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.float16), + supports_tensor_out=False, + test_inplace_grad=False, + sample_inputs_func=sample_inputs_tensor_split,), UnaryUfuncInfo('exp2', ref=np_unary_ufunc_integer_promotion_wrapper(np.exp2), dtypes=all_types_and(torch.bool, torch.half),