Skip to content

Commit

Permalink
Add tensor_split function, based on numpy.array_split (#45168)
Browse files Browse the repository at this point in the history
Summary:
Fixes #9382

Pull Request resolved: #45168

Reviewed By: ngimel

Differential Revision: D24166164

Pulled By: mruberry

fbshipit-source-id: 795459821e52885bc99623a01a2abec060995ce6
  • Loading branch information
kurtamohler authored and facebook-github-bot committed Oct 8, 2020
1 parent b2bff9e commit ef4817f
Show file tree
Hide file tree
Showing 16 changed files with 251 additions and 2 deletions.
18 changes: 18 additions & 0 deletions aten/src/ATen/BatchingRegistrations.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,22 @@ std::vector<Tensor> chunk_batching_rule(const Tensor& self, int64_t chunks, int6
return result;
}

std::vector<Tensor> tensor_split_sections_batching_rule(const Tensor& self, int64_t sections, int64_t dim) {
auto self_physical = MultiBatchVmapTransform::logicalToPhysical(self);
auto dim_physical = self_physical.getPhysicalDim(dim);
auto result = at::tensor_split(self_physical.tensor(), sections, dim_physical);
self_physical.makeLogicalFromPhysicalListInplace(result);
return result;
}

std::vector<Tensor> tensor_split_indices_batching_rule(const Tensor& self, IntArrayRef indices, int64_t dim) {
auto self_physical = MultiBatchVmapTransform::logicalToPhysical(self);
auto dim_physical = self_physical.getPhysicalDim(dim);
auto result = at::tensor_split(self_physical.tensor(), indices, dim_physical);
self_physical.makeLogicalFromPhysicalListInplace(result);
return result;
}

Tensor unsqueeze_batching_rule(const Tensor& self, int64_t dim) {
auto self_physical = MultiBatchVmapTransform::logicalToPhysical(self);
// NB: unsqueeze has some special handling of its `dim` argument so we can't call
Expand Down Expand Up @@ -527,6 +543,8 @@ TORCH_LIBRARY_IMPL(aten, Batched, m) {

// view operations
m.impl("chunk", chunk_batching_rule);
m.impl("tensor_split.sections", tensor_split_sections_batching_rule);
m.impl("tensor_split.indices", tensor_split_indices_batching_rule);
m.impl("diagonal", diagonal_batching_rule);
m.impl("expand", expand_batching_rule);
m.impl("expand_as", native::expand_as); // composite wrt autograd
Expand Down
2 changes: 2 additions & 0 deletions aten/src/ATen/core/NamedRegistrations.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -453,6 +453,8 @@ TORCH_LIBRARY_IMPL(aten, Named, m) {
m.impl("tanh", CppFunction::makeFallthrough());
m.impl("tanh.out", CppFunction::makeFallthrough());
m.impl("tanh_", CppFunction::makeFallthrough());
m.impl("tensor_split.indices", CppFunction::makeFallthrough());
m.impl("tensor_split.sections", CppFunction::makeFallthrough());
m.impl("threshold", CppFunction::makeFallthrough());
m.impl("threshold.out", CppFunction::makeFallthrough());
m.impl("threshold_", CppFunction::makeFallthrough());
Expand Down
1 change: 1 addition & 0 deletions aten/src/ATen/core/aten_interned_strings.h
Original file line number Diff line number Diff line change
Expand Up @@ -664,6 +664,7 @@ _(aten, tan) \
_(aten, tanh) \
_(aten, tensor) \
_(aten, tensordot) \
_(aten, tensor_split) \
_(aten, th_addmm) \
_(aten, th_clone) \
_(aten, th_norm) \
Expand Down
31 changes: 31 additions & 0 deletions aten/src/ATen/native/TensorShape.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -513,6 +513,37 @@ std::vector<Tensor> chunk(const Tensor& self, int64_t chunks, int64_t dim) {
}
}

std::vector<Tensor> tensor_split(const Tensor& self, int64_t sections, int64_t dim) {
TORCH_CHECK(self.dim() > 0, "expected at least a 1-dimensional tensor");
int64_t dim_ = maybe_wrap_dim(dim, self.dim());
TORCH_CHECK(sections > 0, "number of sections must be larger than 0, got ", sections);
std::vector<Tensor> splits(sections);
int64_t min_split_size = self.size(dim_) / sections;
int64_t num_splits_one_extra = self.size(dim_) % sections;
int64_t start_idx = 0;
for (int64_t split_idx = 0; split_idx < sections; split_idx++) {
int64_t split_size = (split_idx < num_splits_one_extra) ? (min_split_size + 1) : min_split_size;
splits[split_idx] = at::slice(self, dim_, start_idx, start_idx + split_size);
start_idx += split_size;
}
return splits;
}

std::vector<Tensor> tensor_split(const Tensor& self, IntArrayRef indices, int64_t dim) {
TORCH_CHECK(self.dim() > 0, "expected at least a 1-dimensional tensor");
int64_t dim_ = maybe_wrap_dim(dim, self.dim());
int64_t num_indices = indices.size();
std::vector<Tensor> splits(num_indices + 1);
int64_t start_idx = 0;
for (int64_t split_idx = 0; split_idx < num_indices; split_idx++) {
int64_t end_idx = indices[split_idx];
splits[split_idx] = at::slice(self, dim_, start_idx, end_idx);
start_idx = end_idx;
}
splits[num_indices] = at::slice(self, dim_, start_idx, self.size(dim_));
return splits;
}

std::vector<Tensor> unsafe_chunk(const Tensor& self, int64_t chunks, int64_t dim) {
TORCH_CHECK(self.dim() > 0,
"chunk expects at least a 1-dimensional tensor");
Expand Down
8 changes: 8 additions & 0 deletions aten/src/ATen/native/native_functions.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -912,6 +912,14 @@
variants: function, method
device_guard: False

- func: tensor_split.sections(Tensor(a) self, int sections, int dim=0) -> Tensor(a)[]
use_c10_dispatcher: full
variants: function, method

- func: tensor_split.indices(Tensor(a) self, int[] indices, int dim=0) -> Tensor(a)[]
use_c10_dispatcher: full
variants: function, method

- func: clamp(Tensor self, Scalar? min=None, Scalar? max=None) -> Tensor
use_c10_dispatcher: full
variants: function, method
Expand Down
1 change: 1 addition & 0 deletions docs/source/tensors.rst
Original file line number Diff line number Diff line change
Expand Up @@ -573,6 +573,7 @@ view of a storage and defines numeric operations on it.
.. automethod:: symeig
.. automethod:: t
.. automethod:: t_
.. automethod:: tensor_split
.. automethod:: to
.. automethod:: to_mkldnn
.. automethod:: take
Expand Down
1 change: 1 addition & 0 deletions docs/source/torch.rst
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,7 @@ Indexing, Slicing, Joining, Mutating Ops
stack
t
take
tensor_split
transpose
unbind
unsqueeze
Expand Down
2 changes: 1 addition & 1 deletion test/test_autograd.py
Original file line number Diff line number Diff line change
Expand Up @@ -4833,7 +4833,7 @@ def run_functional_checks(test_case, test_name, name, apply_fn, run_grad_checks,
'permute', 'squeeze', 'unsqueeze', 'resize', 'resize_as', 'tril', 'triu',
'chunk', 'split', 'split_with_sizes', 'repeat', 'expand', 'zero_',
'eq_', 'ne_', 'add', '__radd__', 'sum', 'conj', 'sin', 'cos', 'mul', 'sinh',
'cosh', '__rmul__', 'sgn', 'abs', 'dot', 'vdot'] + separate_complex_tests
'cosh', '__rmul__', 'sgn', 'abs', 'dot', 'vdot', 'tensor_split'] + separate_complex_tests

# TODO(@anjali411): add tests for 'sub', 'div
# TODO(@anjali411): add the commented tests back after updating the formula based on tensorflow definition - @anjali411
Expand Down
6 changes: 6 additions & 0 deletions test/test_jit.py
Original file line number Diff line number Diff line change
Expand Up @@ -15424,6 +15424,12 @@ def fn(x):
'test_split_size_list',
'test_split_size_list_dim',
'test_split_size_list_dim_neg0',
'test_tensor_indices_sections',
'test_tensor_indices_sections_dim',
'test_tensor_indices_sections_dim_neg0',
'test_tensor_split_sections',
'test_tensor_split_sections_dim',
'test_tensor_split_sections_dim_neg0',
}

EXCLUDE_PYTHON_PRINT = {
Expand Down
88 changes: 87 additions & 1 deletion test/test_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
do_test_dtypes, IS_SANDCASTLE, load_tests, slowTest,
skipCUDANonDefaultStreamIf, skipCUDAMemoryLeakCheckIf, BytesIOContext,
skipIfRocm, torch_to_numpy_dtype_dict, skipIfNoSciPy, IS_MACOS, IS_PPC,
wrapDeterministicFlagAPITest)
wrapDeterministicFlagAPITest, make_tensor)
from multiprocessing.reduction import ForkingPickler
from torch.testing._internal.common_device_type import instantiate_device_type_tests, \
skipCPUIfNoLapack, skipCUDAIfNoMagma, skipCUDAIfRocm, skipCUDAIfNotRocm, onlyCUDA, onlyCPU, \
Expand Down Expand Up @@ -8346,6 +8346,92 @@ def test_contiguous(self, device):
x.set_(x.storage(), 0, x.size(), stride)
self.assertTrue(x.is_contiguous())

@onlyOnCPUAndCUDA
# Skip BFloat16 since numpy does not support it
@dtypes(*torch.testing.get_all_dtypes(include_bfloat16=False))
def test_tensor_split_sections(self, device, dtype):
input_sizes = [
(0,),
(10,),
(10, 0),
(0, 10),
(4, 10),
(12, 3),
]
for input_size in input_sizes:
a_base = make_tensor(input_size, device, dtype, low=-9, high=9)
# Run tests on transposed input if it has at least 2 dims
for a in [a_base, a_base.t()] if a_base.dim() > 2 else [a_base]:
a_n = a.cpu().numpy()
for dim in range(-a.dim(), a.dim()):
for sections in range(1, 2 * a.size(dim)):
msg = f'input_size {input_size}, sections {sections}, dim {dim}'
result = torch.tensor_split(a, sections, dim)
for result_item in result:
self.assertEqual(result_item.device, torch.device(device), msg=msg)
self.assertEqual(result_item.dtype, dtype, msg=msg)
result_n = np.array_split(a_n, sections, dim)
self.assertEqual(result_n, result, msg=msg)

@onlyOnCPUAndCUDA
# Skip BFloat16 since numpy does not support it
@dtypes(*torch.testing.get_all_dtypes(include_bfloat16=False))
def test_tensor_split_indices(self, device, dtype):
input_sizes = [
(0,),
(10,),
(10, 0),
(0, 10),
(4, 10),
(12, 3),
]
indices_args = [
(),
(0,),
(3,),
(10,),
(-1,),
(-10,),
(2, -1),
(3, 4, 10),
(0, -1, 0, 10),
(1, 5, 2, 8),
]
for input_size in input_sizes:
a_base = make_tensor(input_size, device, dtype, low=-9, high=9)
# Run tests on transposed input if it has at least 2 dims
for a in [a_base, a_base.t()] if a_base.dim() > 2 else [a_base]:
a_n = a.cpu().numpy()
for dim in range(-a.dim(), a.dim()):
for indices in indices_args:
result = torch.tensor_split(a, indices, dim)
msg = f'input_size {input_size}, indices {indices}, dim {dim}'
for result_item in result:
self.assertEqual(result_item.device, torch.device(device), msg=msg)
self.assertEqual(result_item.dtype, dtype, msg=msg)
result_n = np.array_split(a_n, indices, dim)
self.assertEqual(result_n, result, msg=msg)

@onlyOnCPUAndCUDA
def test_tensor_split_errors(self, device):
S = 10
test_cases = [
# input size, sections or indices, dim, error type, error message, numpy error type
[(S,), 10, 1, IndexError, r'Dimension out of range', IndexError],
[(), 10, 0, RuntimeError, r'expected at least a 1-dimensional tensor', IndexError],
[(S,), (10,), 1, IndexError, r'Dimension out of range', IndexError],
[(), (10,), 0, RuntimeError, r'expected at least a 1-dimensional tensor', IndexError],
[(S,), 0, 0, RuntimeError, r'number of sections must be larger than 0, got 0', ValueError],
[(S,), -1, 0, RuntimeError, r'number of sections must be larger than 0, got -1', ValueError],
]
for input_size, sections_or_indices, dim, err, err_msg, numpy_err in test_cases:
a = torch.randn(input_size, device=device)
msg = f'input_size {input_size}, sections_or_indices {sections_or_indices}, dim {dim}'
with self.assertRaisesRegex(err, err_msg, msg=msg):
torch.tensor_split(a, sections_or_indices, dim)
with self.assertRaises(numpy_err, msg=msg):
np.array_split(a.cpu().numpy(), sections_or_indices, dim)

def test_index(self, device):

def consec(size, start=1):
Expand Down
21 changes: 21 additions & 0 deletions test/test_vmap.py
Original file line number Diff line number Diff line change
Expand Up @@ -1237,6 +1237,27 @@ def wrapped(*args, **kwargs):
test(op, (torch.randn(B0, 2), torch.randint(10, [B0], dtype=torch.int64)),
check_propagates_grad=False)

def test_tensor_split(self):
test = self._vmap_view_test
op = torch.tensor_split
B0, B1, B2 = 7, 11, 13

# tests for torch.tensor_split(self, indices_or_sections: int, dim)
test(op, (torch.rand(B0, 2, 1024), 5, -1), in_dims=(0, None, None))
test(op, (torch.rand(2, B0, 1024), 150, 1), in_dims=(1, None, None))
test(vmap(op, in_dims=(0, None, None)), (torch.rand(B1, 1023, B0, 5), 256, 0),
in_dims=(2, None, None))
test(vmap(vmap(lambda t: op(t, 4, 1), in_dims=2)),
(torch.rand(B1, 2, B0, 64, B2),), in_dims=2)

# tests for torch.tensor_split(self, indices_or_sections: List[int], dim)
test(op, (torch.rand(B0, 2, 1024), [50, 100, 378, 890], -1), in_dims=(0, None, None))
test(op, (torch.rand(2, B0, 1024), [50, 100, 212, 345, 0, 378, 890], 1), in_dims=(1, None, None))
test(vmap(op, in_dims=(0, None, None)), (torch.rand(B1, 1023, B0, 5), [50, 100, 212, 345, 0, 378, 890], 0),
in_dims=(2, None, None))
test(vmap(vmap(lambda t: op(t, [4, 8, 9, 34, 29], 1), in_dims=2)),
(torch.rand(B1, 2, B0, 64, B2),), in_dims=2)

def test_split(self):
test = self._vmap_view_test
op = torch.split
Expand Down
1 change: 1 addition & 0 deletions tools/autograd/gen_autograd.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@
RETURNS_VIEWS_OF_INPUT = set(VIEW_FUNCTIONS.keys()).union({
'chunk', 'detach', 'contiguous', 'reshape', 'reshape_as',
'expand_as', 'view_as', 'real', 'imag', 'narrow', 'movedim',
'tensor_split'
})

def format_return_type(returns):
Expand Down
7 changes: 7 additions & 0 deletions torch/_tensor_docs.py
Original file line number Diff line number Diff line change
Expand Up @@ -4045,6 +4045,13 @@ def callable(a, b) -> number
See :func:`torch.unsafe_split`
""")

add_docstr_all('tensor_split',
r"""
tensor_split(indices_or_sections, dim=0) -> List of Tensors
See :func:`torch.tensor_split`
""")

add_docstr_all('stft',
r"""
stft(frame_length, hop, fft_size=None, return_onesided=True, window=None, pad_end=0) -> Tensor
Expand Down
61 changes: 61 additions & 0 deletions torch/_torch_docs.py
Original file line number Diff line number Diff line change
Expand Up @@ -1209,6 +1209,67 @@ def merge_dicts(*dicts):
""".format(**common_args))

add_docstr(torch.tensor_split,
r"""
tensor_split(input, indices_or_sections, dim=0) -> List of Tensors
Splits a tensor into multiple sub-tensors, all of which are views of :attr:`input`,
along dimension :attr:`dim` according to the indices or number of sections specified
by :attr:`indices_or_sections. This function is based on NumPy's
:func:`numpy.array_split`.
Args:
input (Tensor): the tensor to split
indices_or_sections (int or (list(int))):
If :attr:`indices_or_sections` is an integer ``n``, :attr:`input` is split
into ``n`` sections along dimension :attr:`dim`. If :attr:`input` is divisible
by ``n`` along dimension :attr:`dim`, each section will be of equal size,
:code:`input.size(dim) / n`. If :attr:`input` is not divisible by ``n``, the
sizes of the first :code:`int(input.size(dim) % n)` sections will have size
:code:`int(input.size(dim) / n) + 1`, and the rest will have size
:code:`int(input.size(dim) / n)`.
If :attr:`indices_or_sections` is a list of ints, :attr:`input` is split along
dimension :attr:`dim` at each of the indices in the list. For instance,
:code:`[2, 3]` and :code:`dim=0` would result in the following tensors:
- :code:`input[:2]`
- :code:`input[2:3]`
- :code:`input[3:]`
dim (int, optional): dimension along which to split the tensor. Default: ``0``
Example::
>>> x = torch.arange(8)
>>> torch.tensor_split(x, 3)
(tensor([0, 1, 2]), tensor([3, 4, 5]), tensor([6, 7]))
>>> x = torch.arange(7)
>>> torch.tensor_split(x, 3)
(tensor([0, 1, 2]), tensor([3, 4]), tensor([5, 6]))
>>> torch.tensor_split(x, (1, 6))
(tensor([0]), tensor([1, 2, 3, 4, 5]), tensor([6]))
>>> x = torch.arange(14).reshape(2, 7)
>>> x
tensor([[ 0, 1, 2, 3, 4, 5, 6],
[ 7, 8, 9, 10, 11, 12, 13]])
>>> torch.tensor_split(x, 3, dim=1)
(tensor([[0, 1, 2],
[7, 8, 9]]),
tensor([[ 3, 4],
[10, 11]]),
tensor([[ 5, 6],
[12, 13]]))
>>> torch.tensor_split(x, (1, 6), dim=1)
(tensor([[0],
[7]]),
tensor([[ 1, 2, 3, 4, 5],
[ 8, 9, 10, 11, 12]]),
tensor([[ 6],
[13]]))
""")

add_docstr(torch.chunk,
r"""
chunk(input, chunks, dim=0) -> List of Tensors
Expand Down
1 change: 1 addition & 0 deletions torch/overrides.py
Original file line number Diff line number Diff line change
Expand Up @@ -742,6 +742,7 @@ def get_testing_overrides() -> Dict[Callable, Callable]:
torch.tan: lambda input, out=None: -1,
torch.tanh: lambda input, out=None: -1,
torch.tensordot: lambda a, b, dims=2: -1,
torch.tensor_split: lambda input, indices_or_sections, dim=0: -1,
torch.threshold: lambda input, threshold, value, inplace=False: -1,
torch.topk: lambda input, k, dim=-1, descending=False, out=None: -1,
torch.trace: lambda input: -1,
Expand Down
4 changes: 4 additions & 0 deletions torch/testing/_internal/common_methods_invocations.py
Original file line number Diff line number Diff line change
Expand Up @@ -1321,6 +1321,10 @@ def method_tests():
('split_with_sizes', (S, S, S), ([int(S / 3), S - int(S / 3) * 2, int(S / 3)],), '', (True,)),
('split_with_sizes', (S, S, S), ([int(S / 3), S - int(S / 3), 0],), 'size_0', (True, )),
('split_with_sizes', (S, S, S), ([int(S / 3), S - int(S / 3) * 2, int(S / 3)],), 'dim', (True, ), [1]),
('tensor_split', (S, S, S), (3,), 'sections', (False,)),
('tensor_split', (S, S, S), (3, 1), 'sections_dim', (False,), [1]),
('tensor_split', (S, S, S), ([2, 4],), 'indices', (False,)),
('tensor_split', (S, S, S), ([2, 4], 1), 'indices_dim', (False,), [1]),
('gather', (M, S), (0, gather_variable((S, S), 1, M, True)), 'dim0', (), [0]),
('gather', (M, S), (1, gather_variable((M, S // 2), 0, S, True)), 'dim1', (), [0]),
('gather', (), (0, torch.tensor([0], dtype=torch.int64)), 'scalar_input', (), [0]),
Expand Down

0 comments on commit ef4817f

Please sign in to comment.