New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add vmap support for torch.index_fill #91364
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/91364
Note: Links to docs will display an error until the docs builds have been completed. ❌ 2 FailuresAs of commit c05b326: NEW FAILURES - The following jobs have failed:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
c871aa8
to
b0cbf0f
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the Pull Request, @qqaatw. torch.index_fill is a tricky one to handle.
Your code looks correct to me. For some of the cases -- vmap is supposed to "eliminate" the for-loop. It looks like we're still doing a for-loop here. It should be possible to avoid the for-loop by modifying the index and then doing a single call to index_fill (see my inline comments), please let me know your thoughts.
std::tuple<Tensor,optional<int64_t>> index_fill__int_scalar_batch_rule( | ||
Tensor & self, optional<int64_t> self_bdim, | ||
int dim, | ||
const Tensor & index, optional<int64_t> index_bdim, | ||
const Scalar & value) { | ||
return index_fill_int_scalar_batch_rule_impl(self, self_bdim, dim, index, index_bdim, value, true); | ||
} | ||
|
||
std::tuple<Tensor,optional<int64_t>> index_fill__int_tensor_batch_rule( | ||
Tensor & self, optional<int64_t> self_bdim, | ||
int dim, | ||
const Tensor & index, optional<int64_t> index_bdim, | ||
const Tensor & value, optional<int64_t> value_bdim) { | ||
return index_fill_int_tensor_batch_rule_impl(self, self_bdim, dim, index, index_bdim, value, value_bdim, true); | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The signature for inplace batching rules is actually
void index_fill__int_scalar_batch_rule(...) {
...
}
Since the operation is in-place, we just return self
and self_bdim
. This happens somewhere in the codegen for vmap (I can link it if you're interested). The codegen ends up ignoring the return value of this function. To make it clearer, we should change the signature to return void.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This makes sense. I'm indeed interested in the codegen part, can you please point out the link?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I should probably toss this into a guide somewhere, but:
- here is the codegen
- it generates a
{operator}_generated_plumbing
for each PyTorch ATen operator. You can see the output in the local build/aten/src/ATen/VmapGeneratedPlumbing.h file after you build PyTorch. example - the VMAP_SUPPORT(operator, batch_rule) macro is just
{operator}_generated_plumbing<decltype(batch_rule), batch_rule>
. - you'll notice that the plumbing doesn't actually use the return value of
batch_rule
, so it can be whatever. To avoid confusion, we should return nothing from index_fill__int_tensor_batch_rule
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the info! So based on what I found, since the in-place plumbing doesn't take any return value from the batch_rule, the batch dims that seem to me are expected to be no change after the batch_rule. As a result, should we move the batch dims back before returning from the batch_rule?
From the implementation of _index_put_impl__batch_rule
, for example, it seems not moving the batch dims back.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
When we move the batch dims to the front, usually we use at::movedim (or equivalent). This produces a new Tensor that is a view of the original one. In this case, there's no need to move the batch dims back -- the original tensor still has the batch dim in the correct position.
for (const auto i : c10::irange(0, batch_size)) { | ||
const auto& self_slice = self_.select(0, i); | ||
const auto& index_slice = index_.select(0, i); | ||
const auto& value_slice = value_.select(0, i); | ||
self_slice.index_fill_( | ||
dim, | ||
index_slice, | ||
value_slice | ||
); | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I haven't thought about this case as much, but, can we do something (the arange + single index_fill_ ) in the out-of-place case here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we can. But since value
can only be 1-element tensor to be fed into index_fill_
, this path is used only when value
is not batched.
Another thing is that the test framework doesn't include a test sample with a tensor value
, i.e. they're all scalar value
, so currently we don't test this batch rule. I'll open a PR to add one later. (#91534)
pytorch/torch/testing/_internal/common_methods_invocations.py
Lines 4195 to 4197 in 5030929
elif fill: | |
# A weird number to catch errors | |
args.append(make_arg((1,)).item()) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you for checking. My suggestion would be to get this PR merged first (with this for-loop, and after index_fill_int_scalar_batch_rule_impl is in a good state), and then we can work out how to improve the for-loop in index_fill_int_tensor_batch_rule_impl
in a follow-up (if you're interested).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ok, sounds reasonable.
for (const auto i : c10::irange(0, batch_size)) { | ||
const auto& self_slice = self_.select(0, i); | ||
const auto& index_slice = index_.select(0, i); | ||
const auto& value_slice = value_.select(0, i); | ||
self_slice.index_fill_( | ||
dim, | ||
index_slice, | ||
value_slice | ||
); | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we can. But since value
can only be 1-element tensor to be fed into index_fill_
, this path is used only when value
is not batched.
Another thing is that the test framework doesn't include a test sample with a tensor value
, i.e. they're all scalar value
, so currently we don't test this batch rule. I'll open a PR to add one later. (#91534)
pytorch/torch/testing/_internal/common_methods_invocations.py
Lines 4195 to 4197 in 5030929
elif fill: | |
# A weird number to catch errors | |
args.append(make_arg((1,)).item()) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Happy New Year @qqaatw. My apologies for the delayed reply, I was out for the past couple of days. I left some more comments in the PR, it's looking pretty good so far! index_fill is one of the more complicated batching rules; thank you for taking this on.
// If self.dim() is 0 or 1, the batch dim is certainly 0, and we must apply batched indices to each row. | ||
index_ = reshape_dim_into(0, 0, index_); | ||
self_.unsqueeze_(-1).index_fill_(dim + 1, index_, value).squeeze_(-1); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this case just to improve performance? If so, I'd prefer to remove it so that all of our code goes through the above case: it makes the code a bit simpler to reason about if there's just a single case for out-of-place
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No it's not just for performance improvement (or maybe I'm missing something). In case of self.dim() == 0 or 1 because the only dim is the batch dim, we need to add another dim at the last in order to apply batched index to it.
For example:
If self
is [1, 2]
and a batched index
, where the batch dim is 0, is [[0], [0]]
, we should fill value
to self[0] and self[1] instead of only self[0].
for (const auto i : c10::irange(0, batch_size)) { | ||
const auto& self_slice = self_.select(0, i); | ||
const auto& index_slice = index_.select(0, i); | ||
const auto& value_slice = value_.select(0, i); | ||
self_slice.index_fill_( | ||
dim, | ||
index_slice, | ||
value_slice | ||
); | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you for checking. My suggestion would be to get this PR merged first (with this for-loop, and after index_fill_int_scalar_batch_rule_impl is in a good state), and then we can work out how to improve the for-loop in index_fill_int_tensor_batch_rule_impl
in a follow-up (if you're interested).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hello, Happy New Year @zou3519. No worries about the delay, and thank you for the comments!
for (const auto i : c10::irange(0, batch_size)) { | ||
const auto& self_slice = self_.select(0, i); | ||
const auto& index_slice = index_.select(0, i); | ||
const auto& value_slice = value_.select(0, i); | ||
self_slice.index_fill_( | ||
dim, | ||
index_slice, | ||
value_slice | ||
); | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ok, sounds reasonable.
// If self.dim() is 0 or 1, the batch dim is certainly 0, and we must apply batched indices to each row. | ||
index_ = reshape_dim_into(0, 0, index_); | ||
self_.unsqueeze_(-1).index_fill_(dim + 1, index_, value).squeeze_(-1); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No it's not just for performance improvement (or maybe I'm missing something). In case of self.dim() == 0 or 1 because the only dim is the batch dim, we need to add another dim at the last in order to apply batched index to it.
For example:
If self
is [1, 2]
and a batched index
, where the batch dim is 0, is [[0], [0]]
, we should fill value
to self[0] and self[1] instead of only self[0].
3e7a655
to
37c2bd4
Compare
Hello @zou3519, the current state should be ok, thanks for your suggestions. I'm happy to open follow-up PRs for |
Thank you! I'll take a look. I might not be very responsive this week, so my apologies in advance |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I read through the code, and I'm not completely sure that the out-of-place path where index is batched is correct on the edge cases. In particular, the OpInfo testing appears to be missing SampleInputs for those edge cases.
I've suggested some edge cases that we should manually test via vmap_opinfo_test. If your code passes the tests, then I am happy, if not, then my suggestion might be relevant.
std::tuple<Tensor,optional<int64_t>> index_fill_int_scalar_batch_rule_impl( | ||
Tensor & self, optional<int64_t> self_bdim, | ||
int64_t dim, | ||
const Tensor & index, optional<int64_t> index_bdim, | ||
const Scalar & value, | ||
const bool inplace) { | ||
const auto self_logical_rank = rankWithoutBatchDim(self, self_bdim); | ||
const auto index_logical_rank = rankWithoutBatchDim(index, index_bdim); | ||
Tensor self_ = moveBatchDimToFront(self, self_bdim); | ||
Tensor index_ = moveBatchDimToFront(index, index_bdim); | ||
dim = maybe_wrap_dim(dim, self_logical_rank); | ||
|
||
if (inplace && !self_bdim.has_value()) { | ||
vmapIncompatibleInplaceError("index_fill_"); | ||
} | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Some style guide nits: the intents should be at different levels, otherwise it gets a bit difficult to read. Ditto for the other functions you added:
std::tuple<Tensor,optional<int64_t>> index_fill_int_scalar_batch_rule_impl(
Tensor & self, optional<int64_t> self_bdim,
Tensor & self, optional<int64_t> self_bdim,
int64_t dim,
const Tensor & index, optional<int64_t> index_bdim,
const Scalar & value,
const bool inplace) {
const auto self_logical_rank = rankWithoutBatchDim(self, self_bdim);
const auto index_logical_rank = rankWithoutBatchDim(index, index_bdim);
Tensor self_ = moveBatchDimToFront(self, self_bdim);
...
}
@zou3519 Done. Thank you for all your help on this PR! @zou3519 @Chillee I have opened another PR #91534 that adds a test sample of index_fill to the OpInfo. But there are too many failures that are hard to xfail all of them. The reason of those failures, I guess, is that most implementations are not correctly taking the |
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Merge failedReason: GraphQL query fragment PRCheckSuites on CheckSuiteConnection { fragment CommitAuthors on PullRequestCommitConnection { query ($owner: String!, $name: String!, $number: Int!) { Details for Dev Infra teamRaised by workflow job |
The failing tests seem to be unrelated to the changes this PR brings. |
@zou3519 @Chillee @kshitij12345 Hello, can you help merge this? the test failures seem to be unrelated! |
@pytorchbot rebase |
@pytorchbot successfully started a rebase job. Check the current status here |
Rebase failed due to Command
Raised by https://github.com/pytorch/pytorch/actions/runs/4036185466 |
@qqaatw can you please rebase on latest Thank you and sorry it slipped through the notifications :) |
@kshitij12345 merged. didn't rebase as there are conflicts on multiple commits. Thanks! |
@qqaatw will take a look at the comment on Monday :) |
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Merge failedReason: 1 jobs have failed, first few of them are: linux-binary-libtorch-pre-cxx11 / libtorch-cpu-shared-with-deps-pre-cxx11-build / build Details for Dev Infra teamRaised by workflow job |
@pytorchbot merge -f"Unrelated failures in libtorch CI" |
Merge startedYour change will be merged immediately since you used the force (-f) flag, bypassing any CI checks (ETA: 1-5 minutes). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
…fill batch rule" A follow-up PR for #91364 (comment) [ghstack-poisoned]
…fill batch rule" A follow-up PR for #91364 (comment) [ghstack-poisoned]
…fill batch rule" A follow-up PR for #91364 (comment) cc zou3519 Chillee samdow soumith kshitij12345 janeyx99 [ghstack-poisoned]
…fill batch rule" A follow-up PR for #91364 (comment) cc zou3519 Chillee samdow soumith kshitij12345 janeyx99 [ghstack-poisoned]
…fill batch rule" A follow-up PR for #91364 (comment) cc zou3519 Chillee samdow soumith kshitij12345 janeyx99 [ghstack-poisoned]
…rule (#99229) A follow-up PR for #91364 (comment) Pull Request resolved: #99229 Approved by: https://github.com/kshitij12345
Fixes #91177