-
Notifications
You must be signed in to change notification settings - Fork 25.6k
Add tensor_split
function, based on numpy.array_split
#45168
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
a86b8a9
to
10906bc
Compare
There are at least two questions that need to be dealt with before we can merge this PR:
|
10906bc
to
fc138b1
Compare
💊 CI failures summary and remediationsAs of commit 7cea15f (more details on the Dr. CI page): Commit 7cea15f was recently pushed. Waiting for builds... This comment was automatically generated by Dr. CI (expand for details).Follow this link to opt-out of these comments for your Pull Requests.Please report bugs/suggestions on the GitHub issue tracker or post in the (internal) Dr. CI Users group. This comment has been revised 51 times. |
32b3bd8
to
f8ea56a
Compare
Good questions.
Let's not add the deprecation in this PR, but I agree we should consider it (for 1.8). |
f8ea56a
to
5a61fea
Compare
I'm not sure why the JIT tests are failing on some machines and not others. They pass locally for me (except for the new failure I introduced to |
It's because jit fusion is enabled on some builds and not others, I think, and method_tests is very hard to parse so it's hard to know what to set to tell the tests what kind of graphs are expected. I think if you just set In the future we'll use OpInfos, not method_tests(), and we won't have this problem. |
5a61fea
to
810b20d
Compare
The XLA variant of the test can be skipped for now by decorating it with |
5c53652
to
24cef4d
Compare
Hey @kurtamohler, is this ready for review? |
Yeah it is. I don't think my changes caused any of the CI failures. |
torch/_torch_docs.py
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
"Splits ... all of which are views of :attr:`input`. This function is based on NumPy's array_split."
For the ellipses, can the one-sentence description incorporate the required indices_or_sections argument and the dim argument? Maybe something as simple as "into multiple subtensors along dimension :attr:`dim` based on the indices or number of sections specified by :attr:`indices_or_sections'"?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
torch/_torch_docs.py
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe something like:
"... is an integer n, :attr:`input` is split into n sections along dimension :attr:`dim`. If :attr:`input` is divisible by n
along dimension :attr:`\dim`..."
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
Hey @kurtamohler! Another solid PR. I've asked @zou3519 to comment on one file and made a few suggestions on the docs and tests. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@kurtamohler could you add some tests to test_vmap.py similar to
Lines 1240 to 1259 in 67889db
def test_split(self): | |
test = self._vmap_view_test | |
op = torch.split | |
B0, B1, B2 = 7, 11, 13 | |
# tests for torch.split(self, split_size: int, dim) | |
test(op, (torch.rand(B0, 2, 1024), 101, -1), in_dims=(0, None, None)) | |
test(op, (torch.rand(2, B0, 1024), 130, 1), in_dims=(1, None, None)) | |
test(vmap(op, in_dims=(0, None, None)), (torch.rand(B1, 1023, B0, 5), 256, 0), | |
in_dims=(2, None, None)) | |
test(vmap(vmap(lambda t: op(t, 4, 1), in_dims=2)), | |
(torch.rand(B1, 2, B0, 64, B2),), in_dims=2) | |
# tests for torch.split(self, split_size: List[int], dim) | |
test(op, (torch.rand(B0, 2, 1024), [1, 1020, 3], -1), in_dims=(0, None, None)) | |
test(op, (torch.rand(2, B0, 1024), [100] * 10 + [24], 1), in_dims=(1, None, None)) | |
test(vmap(op, in_dims=(0, None, None)), (torch.rand(B1, 1023, B0, 5), [256] * 3 + [255], 0), | |
in_dims=(2, None, None)) | |
test(vmap(vmap(lambda t: op(t, [4] * 8 + [8] * 4, 1), in_dims=2)), | |
(torch.rand(B1, 2, B0, 64, B2),), in_dims=2) |
_vmap_view_test
eventually calls into
Lines 650 to 655 in 67889db
# Tests vmap(op, in_dims, out_dims)(*inputs) by comparing the output to a | |
# (slow) sequential map+stack fallback. | |
# | |
# check_view: Test if the first returned output is a view of the first input | |
# check_propagates_grad: Test if the operation propagates gradients. | |
def _vmap_test(self, op, inputs, in_dims=0, out_dims=0, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done! I used pretty similar test cases as test_split
, but with a few differences.
24cef4d
to
6e76477
Compare
6e76477
to
7cea15f
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks great to me!
@zou3519 would you like to review the updates to the vmap-related code?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
vmap related parts lgtm
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@mruberry has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
Fixes #9382