-
Notifications
You must be signed in to change notification settings - Fork 25.6k
BatchedTensor fallback: extended to support ops with multiple Tensor returns #42628
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
…returns This PR extends the BatchedTensor fallback to support operators with multiple Tensor returns. If an operator has multiple returns, we stack shards of each return to create the full outputs. Test Plan: - `pytest test/test_vmap.py -v`. Added a new test for an operator with multiple returns (torch.var_mean). [ghstack-poisoned]
…ple Tensor returns" This PR extends the BatchedTensor fallback to support operators with multiple Tensor returns. If an operator has multiple returns, we stack shards of each return to create the full outputs. For example, let's consider ``` tensor = torch.randn(B0, 10) vmap(torch.var_mean)(torch.randn(B0, 10)) ``` `torch.var_mean` returns two Tensors. The fallback path essentially does the following: - for each batch i, let ai, bi = torch.var_mean(tensor[i]) - return torch.stack([a0, a1, ..., a{B0}]), torch.stack([b0, b1, ..., b{B0}]) Test Plan: - `pytest test/test_vmap.py -v`. Added a new test for an operator with multiple returns (torch.var_mean). Differential Revision: [D22957095](https://our.internmc.facebook.com/intern/diff/D22957095) [ghstack-poisoned]
…returns This PR extends the BatchedTensor fallback to support operators with multiple Tensor returns. If an operator has multiple returns, we stack shards of each return to create the full outputs. For example, let's consider ``` tensor = torch.randn(B0, 10) vmap(torch.var_mean)(torch.randn(B0, 10)) ``` `torch.var_mean` returns two Tensors. The fallback path essentially does the following: - for each batch i, let ai, bi = torch.var_mean(tensor[i]) - return torch.stack([a0, a1, ..., a{B0}]), torch.stack([b0, b1, ..., b{B0}]) Test Plan: - `pytest test/test_vmap.py -v`. Added a new test for an operator with multiple returns (torch.var_mean). ghstack-source-id: 729597c Pull Request resolved: #42628
💊 CI failures summary and remediationsAs of commit 9529b4e (more details on the Dr. CI page):
ci.pytorch.org: 1 failedThis comment was automatically generated by Dr. CI (expand for details).Follow this link to opt-out of these comments for your Pull Requests.Please report bugs/suggestions on the GitHub issue tracker or post in the (internal) Dr. CI Users group. This comment has been revised 11 times. |
aten/src/ATen/BatchedFallback.cpp
Outdated
TORCH_CHECK(num_returns >= 1, | ||
"Batching rule not implemented for ", schema, ". ", | ||
"We do not yet support operations with multiple returns."); | ||
"We do not support operations with no returns."); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
shouldn't this be easy to support lol
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, this isn't hard to support. However it implies the operation does nothing (since it would have to have no returns and no alias annotations). I'll clarify that this means that the fallback doesn't support that
// [ a0, a1, a2, a3, b0, b1, b2, b3, c0, c1, c2, c3] | ||
// This is so that we can call at::stack([a0...a3]), at::stack([b0...b3]) | ||
// more easily in the next step. | ||
std::vector<Tensor> output_shards(num_batches * num_returns); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
MatrixRef may be of interest here!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I changed the indexing code to use MatrixRef! It's a nice way to abstract that behavior, thanks for pointing it out
…ple Tensor returns" This PR extends the BatchedTensor fallback to support operators with multiple Tensor returns. If an operator has multiple returns, we stack shards of each return to create the full outputs. For example, let's consider ``` tensor = torch.randn(B0, 10) vmap(torch.var_mean)(torch.randn(B0, 10)) ``` `torch.var_mean` returns two Tensors. The fallback path essentially does the following: - for each batch i, let ai, bi = torch.var_mean(tensor[i]) - return torch.stack([a0, a1, ..., a{B0}]), torch.stack([b0, b1, ..., b{B0}]) Test Plan: - `pytest test/test_vmap.py -v`. Added a new test for an operator with multiple returns (torch.var_mean). Differential Revision: [D22957095](https://our.internmc.facebook.com/intern/diff/D22957095) [ghstack-poisoned]
…ple Tensor returns" This PR extends the BatchedTensor fallback to support operators with multiple Tensor returns. If an operator has multiple returns, we stack shards of each return to create the full outputs. For example, let's consider ``` tensor = torch.randn(B0, 10) vmap(torch.var_mean)(torch.randn(B0, 10)) ``` `torch.var_mean` returns two Tensors. The fallback path essentially does the following: - for each batch i, let ai, bi = torch.var_mean(tensor[i]) - return torch.stack([a0, a1, ..., a{B0}]), torch.stack([b0, b1, ..., b{B0}]) Test Plan: - `pytest test/test_vmap.py -v`. Added a new test for an operator with multiple returns (torch.var_mean). Differential Revision: [D22957095](https://our.internmc.facebook.com/intern/diff/D22957095) [ghstack-poisoned]
…ple Tensor returns" This PR extends the BatchedTensor fallback to support operators with multiple Tensor returns. If an operator has multiple returns, we stack shards of each return to create the full outputs. For example, let's consider ``` tensor = torch.randn(B0, 10) vmap(torch.var_mean)(torch.randn(B0, 10)) ``` `torch.var_mean` returns two Tensors. The fallback path essentially does the following: - for each batch i, let ai, bi = torch.var_mean(tensor[i]) - return torch.stack([a0, a1, ..., a{B0}]), torch.stack([b0, b1, ..., b{B0}]) Test Plan: - `pytest test/test_vmap.py -v`. Added a new test for an operator with multiple returns (torch.var_mean). Differential Revision: [D22957095](https://our.internmc.facebook.com/intern/diff/D22957095) [ghstack-poisoned]
Stack from ghstack:
This PR extends the BatchedTensor fallback to support operators with
multiple Tensor returns. If an operator has multiple returns, we stack
shards of each return to create the full outputs.
For example, let's consider
torch.var_mean
returns two Tensors. The fallback path essentially doesthe following:
Test Plan:
pytest test/test_vmap.py -v
. Added a new test for an operator withmultiple returns (torch.var_mean).
Differential Revision: D22957095