From fc4b646e8db6a203446acbc2ae214129c64a55a8 Mon Sep 17 00:00:00 2001 From: pbialecki Date: Wed, 22 Nov 2017 17:14:23 +0100 Subject: [PATCH 1/8] - added size_splits to functional --- torch/functional.py | 44 ++++++++++++++++++++++++++++++++------------ 1 file changed, 32 insertions(+), 12 deletions(-) diff --git a/torch/functional.py b/torch/functional.py index b1cbc27290b32..e37313bb9d46a 100644 --- a/torch/functional.py +++ b/torch/functional.py @@ -8,27 +8,47 @@ ] -def split(tensor, split_size, dim=0): - """Splits the tensor into chunks all of size :attr:`split_size` (if possible). +def split(tensor, split_size_or_sections, dim=0): + """Splits the tensor into chunks. + If ``split_size_or_sections`` is an integer type, then ``tensor`` will be + split into equally sized chunks (if possible). Last chunk will be smaller if the tensor size along a given dimension is not divisible by :attr`split_size`. - + If ``split_size_or_sections`` is a list, then ``tensor`` will be split + into ``len(split_size_or_sections)`` chunks with sizes in ``dim`` according + to ``split_size_or_sections``. Arguments: - tensor (Tensor): the tensor to split - split_size (int): size of a single chunk - dim (int): dimension along which to split the tensor + tensor (Tensor): tensor to split. + split_size_or_sections (int) or (list(int)): size of a single chunk or + list of sizes for each chunk + dim (int): dimension along which to split the tensor. """ if dim < 0: dim += tensor.dim() dim_size = tensor.size(dim) - num_splits = (dim_size + split_size - 1) // split_size - last_split_size = split_size - (split_size * num_splits - dim_size) - def get_split_size(i): - return split_size if i < num_splits - 1 else last_split_size - return tuple(tensor.narrow(int(dim), int(i * split_size), int(get_split_size(i))) for i - in _range(0, num_splits)) + splits = torch.IntTensor([split_size_or_sections]) + + if splits.dim() == 1: + split_size = split_size_or_sections + num_splits = (dim_size + split_size - 1) // split_size + last_split_size = split_size - (split_size * num_splits - dim_size) + + def get_split_size(i): + return split_size if i < num_splits - 1 else last_split_size + return tuple(tensor.narrow(int(dim), int(i * split_size), int(get_split_size(i))) for i + in _range(0, num_splits)) + + else: + if dim_size != torch.sum(splits): + raise ValueError("Sum of split sizes exceeds tensor dim") + split_indices = torch.cat((torch.zeros(1, 1).int(), splits), dim=1) + split_indices = torch.cumsum(split_indices, dim=1)[:, :-1].view(-1) + + return tuple( + tensor.narrow(int(dim), int(start), int(length)) + for start, length in zip(split_indices, split_size_or_sections)) def chunk(tensor, chunks, dim=0): From f7e49eac40e16270afca9aa3c0336c2364b7b2eb Mon Sep 17 00:00:00 2001 From: pbialecki Date: Sat, 25 Nov 2017 13:56:50 +0100 Subject: [PATCH 2/8] - merged ``split_sizes`` to ``split`` - removed ``split_sizes`` --- torch/functional.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torch/functional.py b/torch/functional.py index e37313bb9d46a..4ab4254048ba6 100644 --- a/torch/functional.py +++ b/torch/functional.py @@ -18,6 +18,7 @@ def split(tensor, split_size_or_sections, dim=0): If ``split_size_or_sections`` is a list, then ``tensor`` will be split into ``len(split_size_or_sections)`` chunks with sizes in ``dim`` according to ``split_size_or_sections``. + Arguments: tensor (Tensor): tensor to split. split_size_or_sections (int) or (list(int)): size of a single chunk or @@ -29,7 +30,6 @@ def split(tensor, split_size_or_sections, dim=0): dim_size = tensor.size(dim) splits = torch.IntTensor([split_size_or_sections]) - if splits.dim() == 1: split_size = split_size_or_sections num_splits = (dim_size + split_size - 1) // split_size From 4750357ae23c5955150be48ec9fea7693c0eeed2 Mon Sep 17 00:00:00 2001 From: pbialecki Date: Sat, 25 Nov 2017 14:26:54 +0100 Subject: [PATCH 3/8] - flake8 (#3223) --- torch/functional.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/torch/functional.py b/torch/functional.py index 4ab4254048ba6..2767eb5820883 100644 --- a/torch/functional.py +++ b/torch/functional.py @@ -14,10 +14,7 @@ def split(tensor, split_size_or_sections, dim=0): If ``split_size_or_sections`` is an integer type, then ``tensor`` will be split into equally sized chunks (if possible). Last chunk will be smaller if the tensor size along a given dimension - is not divisible by :attr`split_size`. - If ``split_size_or_sections`` is a list, then ``tensor`` will be split - into ``len(split_size_or_sections)`` chunks with sizes in ``dim`` according - to ``split_size_or_sections``. + is not divisible by ``split_size``. Arguments: tensor (Tensor): tensor to split. From 1e9fb1b4fd4a3e7bd687387850adab6d31fb8c09 Mon Sep 17 00:00:00 2001 From: pbialecki Date: Mon, 4 Dec 2017 22:16:05 +0100 Subject: [PATCH 4/8] flake8 (#3837) --- torch/functional.py | 1 - 1 file changed, 1 deletion(-) diff --git a/torch/functional.py b/torch/functional.py index 2767eb5820883..a1cb6405b7532 100644 --- a/torch/functional.py +++ b/torch/functional.py @@ -8,7 +8,6 @@ ] - def split(tensor, split_size_or_sections, dim=0): """Splits the tensor into chunks. If ``split_size_or_sections`` is an integer type, then ``tensor`` will be From 7b082102aa1e886a97e4be02d895c2bdb4a7c045 Mon Sep 17 00:00:00 2001 From: pbialecki Date: Mon, 18 Dec 2017 17:33:31 +0100 Subject: [PATCH 5/8] - changed some operations to plain python (#3837) --- torch/functional.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/torch/functional.py b/torch/functional.py index a1cb6405b7532..a49de9127da50 100644 --- a/torch/functional.py +++ b/torch/functional.py @@ -25,8 +25,7 @@ def split(tensor, split_size_or_sections, dim=0): dim += tensor.dim() dim_size = tensor.size(dim) - splits = torch.IntTensor([split_size_or_sections]) - if splits.dim() == 1: + if isinstance(split_size_or_sections, int): split_size = split_size_or_sections num_splits = (dim_size + split_size - 1) // split_size last_split_size = split_size - (split_size * num_splits - dim_size) @@ -37,9 +36,9 @@ def get_split_size(i): in _range(0, num_splits)) else: - if dim_size != torch.sum(splits): + if dim_size != sum(split_size_or_sections): raise ValueError("Sum of split sizes exceeds tensor dim") - split_indices = torch.cat((torch.zeros(1, 1).int(), splits), dim=1) + split_indices = torch.cat((torch.zeros(1, 1).int(), split_size_or_sections), dim=1) split_indices = torch.cumsum(split_indices, dim=1)[:, :-1].view(-1) return tuple( From 62bef8cfc144a5eead84a906de419fa4e882da33 Mon Sep 17 00:00:00 2001 From: pbialecki Date: Mon, 18 Dec 2017 18:50:24 +0100 Subject: [PATCH 6/8] - bugfixes on plain python operations (#3837) - added tests in test_split for variable sections splits (#3837) --- test/test_torch.py | 22 ++++++++++++++++++++++ torch/functional.py | 4 ++-- 2 files changed, 24 insertions(+), 2 deletions(-) diff --git a/test/test_torch.py b/test/test_torch.py index af9556bf6b1a9..903001a9dea07 100644 --- a/test/test_torch.py +++ b/test/test_torch.py @@ -4091,6 +4091,28 @@ def test_split(self): self.assertEqual(tensor.narrow(dim, start, target_size[dim]), split, 0) start = start + target_size[dim] + # Variable sections split + tensor = torch.randn(20, 10) + dim = 0 + split_sizes = [5, 5, 10] + target_sizes = ([[5, 10], [5, 10], [10, 10]]) + splits = tensor.split(split_sizes, dim) + start = 0 + for target_size, split in zip(target_sizes, splits): + self.assertEqual(split.size(), target_size) + self.assertEqual(tensor.narrow(dim, start, target_size[dim]), split, 0) + start = start + target_size[dim] + + split_sizes = [2, 2, 6] + target_sizes = ([20, 2], [20, 2], [20, 6]) + dim = 1 + splits = tensor.split(split_sizes, dim) + start = 0 + for target_size, split in zip(target_sizes, splits): + self.assertEqual(split.size(), target_size) + self.assertEqual(tensor.narrow(dim, start, target_size[dim]), split, 0) + start = start + target_size[dim] + def test_chunk(self): tensor = torch.rand(4, 7) num_chunks = 3 diff --git a/torch/functional.py b/torch/functional.py index a49de9127da50..55cc9b65643a5 100644 --- a/torch/functional.py +++ b/torch/functional.py @@ -38,8 +38,8 @@ def get_split_size(i): else: if dim_size != sum(split_size_or_sections): raise ValueError("Sum of split sizes exceeds tensor dim") - split_indices = torch.cat((torch.zeros(1, 1).int(), split_size_or_sections), dim=1) - split_indices = torch.cumsum(split_indices, dim=1)[:, :-1].view(-1) + split_indices = [0] + split_size_or_sections + split_indices = torch.cumsum(torch.Tensor(split_indices), dim=0) return tuple( tensor.narrow(int(dim), int(start), int(length)) From 78ae53e3727a48dbfdcecc05fbdf5a549a254aab Mon Sep 17 00:00:00 2001 From: pbialecki Date: Mon, 18 Dec 2017 18:55:27 +0100 Subject: [PATCH 7/8] - added doc string for split (#3837) --- torch/functional.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/torch/functional.py b/torch/functional.py index 55cc9b65643a5..faa54ef7db172 100644 --- a/torch/functional.py +++ b/torch/functional.py @@ -14,6 +14,8 @@ def split(tensor, split_size_or_sections, dim=0): split into equally sized chunks (if possible). Last chunk will be smaller if the tensor size along a given dimension is not divisible by ``split_size``. + If ``split_size_or_sections`` is a list, then ``tensor`` will be split + into chunks of the specified sizes. Arguments: tensor (Tensor): tensor to split. From 93ffe33f24d7efd4604206569725ac60b0885def Mon Sep 17 00:00:00 2001 From: pbialecki Date: Mon, 18 Dec 2017 18:59:49 +0100 Subject: [PATCH 8/8] - added doc string for split (#3837) --- torch/functional.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/torch/functional.py b/torch/functional.py index faa54ef7db172..7914565c55497 100644 --- a/torch/functional.py +++ b/torch/functional.py @@ -15,7 +15,8 @@ def split(tensor, split_size_or_sections, dim=0): Last chunk will be smaller if the tensor size along a given dimension is not divisible by ``split_size``. If ``split_size_or_sections`` is a list, then ``tensor`` will be split - into chunks of the specified sizes. + into ``len(split_size_or_sections)`` chunks with sizes in ``dim`` according + to ``split_size_or_sections``. Arguments: tensor (Tensor): tensor to split.