Skip to content

Conversation

kurtamohler
Copy link
Collaborator

Fixes #9382

@kurtamohler kurtamohler force-pushed the tensor-split-9382 branch 2 times, most recently from a86b8a9 to 10906bc Compare September 22, 2020 23:20
@kurtamohler
Copy link
Collaborator Author

kurtamohler commented Sep 22, 2020

There are at least two questions that need to be dealt with before we can merge this PR:

  • Is torch.tensor_split an acceptable name, or is there a better name? Note Richard's points here: tensor.chunk returns wrong number of chunks #9382 (comment)
  • Should we add a deprecation warning to torch.chunk, since this new function is meant to be an improvement upon its behavior? Or would there be any merit to keeping both functions?

@kurtamohler kurtamohler requested a review from zou3519 September 22, 2020 23:25
@kurtamohler kurtamohler marked this pull request as ready for review September 22, 2020 23:25
@dr-ci
Copy link

dr-ci bot commented Sep 23, 2020

💊 CI failures summary and remediations

As of commit 7cea15f (more details on the Dr. CI page):


Commit 7cea15f was recently pushed. Waiting for builds...


This comment was automatically generated by Dr. CI (expand for details).Follow this link to opt-out of these comments for your Pull Requests.

Please report bugs/suggestions on the GitHub issue tracker or post in the (internal) Dr. CI Users group.

See how this bot performed.

This comment has been revised 51 times.

@ezyang ezyang added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Sep 23, 2020
@kurtamohler kurtamohler force-pushed the tensor-split-9382 branch 2 times, most recently from 32b3bd8 to f8ea56a Compare September 24, 2020 06:42
@mruberry
Copy link
Collaborator

There are at least two questions that need to be dealt with before we can merge this PR:

  • Is torch.tensor_split an acceptable name, or is there a better name? Note Richard's points here: #9382 (comment)
  • Should we add a deprecation warning to torch.chunk, since this new function is meant to be an improvement upon its behavior? Or would there be any merit to keeping both functions?

Good questions.

torch.tensor_split seems like the appropriate name to me (and we can alias torch.array_split to it in the future).

Let's not add the deprecation in this PR, but I agree we should consider it (for 1.8).

@kurtamohler
Copy link
Collaborator Author

kurtamohler commented Sep 24, 2020

I'm not sure why the JIT tests are failing on some machines and not others. They pass locally for me (except for the new failure I introduced to TestAutogradDeviceTypeCUDA.test_inplace_view_multi_output_safe_cuda). I'll probably have to ssh into the jobs to debug.

@mruberry
Copy link
Collaborator

I'm not sure why the JIT tests are failing on some machines and not others. They pass locally for me (except for the new failure I introduced to TestAutogradDeviceTypeCUDA.test_inplace_view_multi_output_safe_cuda). I'll probably have to ssh into the jobs to debug.

It's because jit fusion is enabled on some builds and not others, I think, and method_tests is very hard to parse so it's hard to know what to set to tell the tests what kind of graphs are expected. I think if you just set (True,) to (False,) things will be fine.

In the future we'll use OpInfos, not method_tests(), and we won't have this problem.

@mruberry
Copy link
Collaborator

The XLA variant of the test can be skipped for now by decorating it with @onlyOnCPUAndCUDA

@kurtamohler kurtamohler force-pushed the tensor-split-9382 branch 2 times, most recently from 5c53652 to 24cef4d Compare September 30, 2020 14:55
@mruberry
Copy link
Collaborator

mruberry commented Oct 1, 2020

Hey @kurtamohler, is this ready for review?

@kurtamohler
Copy link
Collaborator Author

Hey @kurtamohler, is this ready for review?

Yeah it is. I don't think my changes caused any of the CI failures.

Copy link
Collaborator

@mruberry mruberry Oct 6, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"Splits ... all of which are views of :attr:`input`. This function is based on NumPy's array_split."

For the ellipses, can the one-sentence description incorporate the required indices_or_sections argument and the dim argument? Maybe something as simple as "into multiple subtensors along dimension :attr:`dim` based on the indices or number of sections specified by :attr:`indices_or_sections'"?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe something like:

"... is an integer n, :attr:`input` is split into n sections along dimension :attr:`dim`. If :attr:`input` is divisible by n along dimension :attr:`\dim`..."

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

@mruberry
Copy link
Collaborator

mruberry commented Oct 6, 2020

Hey @kurtamohler! Another solid PR. I've asked @zou3519 to comment on one file and made a few suggestions on the docs and tests.

Comment on lines +546 to +547
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@kurtamohler could you add some tests to test_vmap.py similar to

pytorch/test/test_vmap.py

Lines 1240 to 1259 in 67889db

def test_split(self):
test = self._vmap_view_test
op = torch.split
B0, B1, B2 = 7, 11, 13
# tests for torch.split(self, split_size: int, dim)
test(op, (torch.rand(B0, 2, 1024), 101, -1), in_dims=(0, None, None))
test(op, (torch.rand(2, B0, 1024), 130, 1), in_dims=(1, None, None))
test(vmap(op, in_dims=(0, None, None)), (torch.rand(B1, 1023, B0, 5), 256, 0),
in_dims=(2, None, None))
test(vmap(vmap(lambda t: op(t, 4, 1), in_dims=2)),
(torch.rand(B1, 2, B0, 64, B2),), in_dims=2)
# tests for torch.split(self, split_size: List[int], dim)
test(op, (torch.rand(B0, 2, 1024), [1, 1020, 3], -1), in_dims=(0, None, None))
test(op, (torch.rand(2, B0, 1024), [100] * 10 + [24], 1), in_dims=(1, None, None))
test(vmap(op, in_dims=(0, None, None)), (torch.rand(B1, 1023, B0, 5), [256] * 3 + [255], 0),
in_dims=(2, None, None))
test(vmap(vmap(lambda t: op(t, [4] * 8 + [8] * 4, 1), in_dims=2)),
(torch.rand(B1, 2, B0, 64, B2),), in_dims=2)
?

_vmap_view_test eventually calls into

pytorch/test/test_vmap.py

Lines 650 to 655 in 67889db

# Tests vmap(op, in_dims, out_dims)(*inputs) by comparing the output to a
# (slow) sequential map+stack fallback.
#
# check_view: Test if the first returned output is a view of the first input
# check_propagates_grad: Test if the operation propagates gradients.
def _vmap_test(self, op, inputs, in_dims=0, out_dims=0,
, the comments there might be helpful. Please shout if you have questions!

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done! I used pretty similar test cases as test_split, but with a few differences.

@kurtamohler
Copy link
Collaborator Author

Thanks for the comments @mruberry and @zou3519! Latest update should cover everything brought up so far.

Copy link
Collaborator

@mruberry mruberry left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks great to me!

@zou3519 would you like to review the updates to the vmap-related code?

Copy link
Contributor

@zou3519 zou3519 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

vmap related parts lgtm

Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@mruberry has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

@facebook-github-bot
Copy link
Contributor

@mruberry merged this pull request in ef4817f.

@facebook-github-bot
Copy link
Contributor

@mruberry merged this pull request in ef4817f.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Merged open source triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Projects

None yet

Development

Successfully merging this pull request may close these issues.

tensor.chunk returns wrong number of chunks

6 participants