Skip to content

- added size_splits to functional #3837

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
Jan 4, 2018
Merged

Conversation

ptrblck
Copy link
Collaborator

@ptrblck ptrblck commented Nov 22, 2017

Pull request addresses issue #3223

The split function splits tensors into equally sized chunks.
split_sizes let the user define a list with sizes for each chunk.

tf.split combines both functionalities in one function. Maybe this is also desired for Pytorch?

split_sizes seems to be a bit slower (6.991s vs 6.704s)

@colesbury
Copy link
Member

Yeah, I think a single function would be nicer. Both tf.split and numpy.split have that sort of API.

@ptrblck
Copy link
Collaborator Author

ptrblck commented Nov 24, 2017

Thanks for the feedback, I will merge it in torch.split then.

@ptrblck
Copy link
Collaborator Author

ptrblck commented Nov 25, 2017

I merged both functions now.

Before that, I timed all functions on my machine (tensor=torch.randn(200, 10, 2, 2) into 40 chunks of size 5 in dim=0)

  • merged function with split_size_or_sections = [5] * 40: ~8.1586s
  • merged function with split_size_or_sections = 5: ~7.4982
  • native with `split_size=5: ~7.2688s

What are your thoughts? Any suggestion on the code / naming / documentation?

@ezyang
Copy link
Contributor

ezyang commented Nov 27, 2017

@pytorchbot test this please

@flennerhag
Copy link

@ptrblck great initiative, I've been using my own wrapper for a while. Would be nice to have it in the code base.

You could simplify the code quite a bit though by using plain python operations instead of invoking torch overheads, e.g. something like

def split(tensor, sizes, dim=0):
    if dim < 0:
        dim += tensor.dim()

    if isinstance(sizes, int):
        # original code ...
        return chunks

    if tensor.size(dim) != sum(sizes):
        raise ValueError("Sizes do not match tensor size in dim")

    nsizes = len(sizes)
    sizes = [0] + sizes
    return tuple(tensor.narrow(dim, sizes[i], sizes[i + 1])
                 for i in range(nsizes))

Should be slightly faster too.

@soumith
Copy link
Member

soumith commented Dec 18, 2017

@ptrblck as soon as you add unit tests for the list of splits case, i can merge this in.

@ptrblck
Copy link
Collaborator Author

ptrblck commented Dec 18, 2017

@flennerhag Thanks for the suggestions! I tried to change some Pytorch code to plain python operations.
@soumith I also added some tests in test_split. If it's not sufficient, I can add some more test cases.

@soumith
Copy link
Member

soumith commented Jan 4, 2018

@pytorchbot test this please

@soumith soumith merged commit 7c729e6 into pytorch:master Jan 4, 2018
@soumith
Copy link
Member

soumith commented Jan 4, 2018

thanks a lot @ptrblck !

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants