Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add vmap support for torch.index_fill #91364

Closed
wants to merge 15 commits into from

Conversation

qqaatw
Copy link
Collaborator

@qqaatw qqaatw commented Dec 23, 2022

Fixes #91177

@qqaatw qqaatw requested a review from zou3519 as a code owner December 23, 2022 17:18
@pytorch-bot
Copy link

pytorch-bot bot commented Dec 23, 2022

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/91364

Note: Links to docs will display an error until the docs builds have been completed.

❌ 2 Failures

As of commit c05b326:

NEW FAILURES - The following jobs have failed:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

Copy link
Contributor

@zou3519 zou3519 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the Pull Request, @qqaatw. torch.index_fill is a tricky one to handle.

Your code looks correct to me. For some of the cases -- vmap is supposed to "eliminate" the for-loop. It looks like we're still doing a for-loop here. It should be possible to avoid the for-loop by modifying the index and then doing a single call to index_fill (see my inline comments), please let me know your thoughts.

aten/src/ATen/functorch/BatchRulesScatterOps.cpp Outdated Show resolved Hide resolved
aten/src/ATen/functorch/BatchRulesScatterOps.cpp Outdated Show resolved Hide resolved
aten/src/ATen/functorch/BatchRulesScatterOps.cpp Outdated Show resolved Hide resolved
Comment on lines 1150 to 1164
std::tuple<Tensor,optional<int64_t>> index_fill__int_scalar_batch_rule(
Tensor & self, optional<int64_t> self_bdim,
int dim,
const Tensor & index, optional<int64_t> index_bdim,
const Scalar & value) {
return index_fill_int_scalar_batch_rule_impl(self, self_bdim, dim, index, index_bdim, value, true);
}

std::tuple<Tensor,optional<int64_t>> index_fill__int_tensor_batch_rule(
Tensor & self, optional<int64_t> self_bdim,
int dim,
const Tensor & index, optional<int64_t> index_bdim,
const Tensor & value, optional<int64_t> value_bdim) {
return index_fill_int_tensor_batch_rule_impl(self, self_bdim, dim, index, index_bdim, value, value_bdim, true);
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The signature for inplace batching rules is actually

void index_fill__int_scalar_batch_rule(...) {
  ...
}

Since the operation is in-place, we just return self and self_bdim. This happens somewhere in the codegen for vmap (I can link it if you're interested). The codegen ends up ignoring the return value of this function. To make it clearer, we should change the signature to return void.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This makes sense. I'm indeed interested in the codegen part, can you please point out the link?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I should probably toss this into a guide somewhere, but:

  • here is the codegen
  • it generates a {operator}_generated_plumbing for each PyTorch ATen operator. You can see the output in the local build/aten/src/ATen/VmapGeneratedPlumbing.h file after you build PyTorch. example
  • the VMAP_SUPPORT(operator, batch_rule) macro is just {operator}_generated_plumbing<decltype(batch_rule), batch_rule>.
  • you'll notice that the plumbing doesn't actually use the return value of batch_rule, so it can be whatever. To avoid confusion, we should return nothing from index_fill__int_tensor_batch_rule

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the info! So based on what I found, since the in-place plumbing doesn't take any return value from the batch_rule, the batch dims that seem to me are expected to be no change after the batch_rule. As a result, should we move the batch dims back before returning from the batch_rule?

From the implementation of _index_put_impl__batch_rule, for example, it seems not moving the batch dims back.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When we move the batch dims to the front, usually we use at::movedim (or equivalent). This produces a new Tensor that is a view of the original one. In this case, there's no need to move the batch dims back -- the original tensor still has the batch dim in the correct position.

@bdhirsh bdhirsh added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Dec 27, 2022
@zou3519 zou3519 self-requested a review December 28, 2022 19:56
Comment on lines 1155 to 1164
for (const auto i : c10::irange(0, batch_size)) {
const auto& self_slice = self_.select(0, i);
const auto& index_slice = index_.select(0, i);
const auto& value_slice = value_.select(0, i);
self_slice.index_fill_(
dim,
index_slice,
value_slice
);
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I haven't thought about this case as much, but, can we do something (the arange + single index_fill_ ) in the out-of-place case here?

Copy link
Collaborator Author

@qqaatw qqaatw Dec 29, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we can. But since value can only be 1-element tensor to be fed into index_fill_, this path is used only when value is not batched.

Another thing is that the test framework doesn't include a test sample with a tensor value, i.e. they're all scalar value, so currently we don't test this batch rule. I'll open a PR to add one later. (#91534)

elif fill:
# A weird number to catch errors
args.append(make_arg((1,)).item())

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for checking. My suggestion would be to get this PR merged first (with this for-loop, and after index_fill_int_scalar_batch_rule_impl is in a good state), and then we can work out how to improve the for-loop in index_fill_int_tensor_batch_rule_impl in a follow-up (if you're interested).

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok, sounds reasonable.

aten/src/ATen/functorch/BatchRulesScatterOps.cpp Outdated Show resolved Hide resolved
Comment on lines 1155 to 1164
for (const auto i : c10::irange(0, batch_size)) {
const auto& self_slice = self_.select(0, i);
const auto& index_slice = index_.select(0, i);
const auto& value_slice = value_.select(0, i);
self_slice.index_fill_(
dim,
index_slice,
value_slice
);
}
Copy link
Collaborator Author

@qqaatw qqaatw Dec 29, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we can. But since value can only be 1-element tensor to be fed into index_fill_, this path is used only when value is not batched.

Another thing is that the test framework doesn't include a test sample with a tensor value, i.e. they're all scalar value, so currently we don't test this batch rule. I'll open a PR to add one later. (#91534)

elif fill:
# A weird number to catch errors
args.append(make_arg((1,)).item())

aten/src/ATen/functorch/BatchRulesScatterOps.cpp Outdated Show resolved Hide resolved
Copy link
Contributor

@zou3519 zou3519 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Happy New Year @qqaatw. My apologies for the delayed reply, I was out for the past couple of days. I left some more comments in the PR, it's looking pretty good so far! index_fill is one of the more complicated batching rules; thank you for taking this on.

aten/src/ATen/functorch/BatchRulesScatterOps.cpp Outdated Show resolved Hide resolved
aten/src/ATen/functorch/BatchRulesScatterOps.cpp Outdated Show resolved Hide resolved
aten/src/ATen/functorch/BatchRulesScatterOps.cpp Outdated Show resolved Hide resolved
Comment on lines 1115 to 1117
// If self.dim() is 0 or 1, the batch dim is certainly 0, and we must apply batched indices to each row.
index_ = reshape_dim_into(0, 0, index_);
self_.unsqueeze_(-1).index_fill_(dim + 1, index_, value).squeeze_(-1);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this case just to improve performance? If so, I'd prefer to remove it so that all of our code goes through the above case: it makes the code a bit simpler to reason about if there's just a single case for out-of-place

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No it's not just for performance improvement (or maybe I'm missing something). In case of self.dim() == 0 or 1 because the only dim is the batch dim, we need to add another dim at the last in order to apply batched index to it.

For example:

If self is [1, 2] and a batched index, where the batch dim is 0, is [[0], [0]], we should fill value to self[0] and self[1] instead of only self[0].

Comment on lines 1155 to 1164
for (const auto i : c10::irange(0, batch_size)) {
const auto& self_slice = self_.select(0, i);
const auto& index_slice = index_.select(0, i);
const auto& value_slice = value_.select(0, i);
self_slice.index_fill_(
dim,
index_slice,
value_slice
);
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for checking. My suggestion would be to get this PR merged first (with this for-loop, and after index_fill_int_scalar_batch_rule_impl is in a good state), and then we can work out how to improve the for-loop in index_fill_int_tensor_batch_rule_impl in a follow-up (if you're interested).

Copy link
Collaborator Author

@qqaatw qqaatw left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hello, Happy New Year @zou3519. No worries about the delay, and thank you for the comments!

Comment on lines 1155 to 1164
for (const auto i : c10::irange(0, batch_size)) {
const auto& self_slice = self_.select(0, i);
const auto& index_slice = index_.select(0, i);
const auto& value_slice = value_.select(0, i);
self_slice.index_fill_(
dim,
index_slice,
value_slice
);
}
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok, sounds reasonable.

aten/src/ATen/functorch/BatchRulesScatterOps.cpp Outdated Show resolved Hide resolved
Comment on lines 1115 to 1117
// If self.dim() is 0 or 1, the batch dim is certainly 0, and we must apply batched indices to each row.
index_ = reshape_dim_into(0, 0, index_);
self_.unsqueeze_(-1).index_fill_(dim + 1, index_, value).squeeze_(-1);
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No it's not just for performance improvement (or maybe I'm missing something). In case of self.dim() == 0 or 1 because the only dim is the batch dim, we need to add another dim at the last in order to apply batched index to it.

For example:

If self is [1, 2] and a batched index, where the batch dim is 0, is [[0], [0]], we should fill value to self[0] and self[1] instead of only self[0].

@qqaatw
Copy link
Collaborator Author

qqaatw commented Jan 9, 2023

Hello @zou3519, the current state should be ok, thanks for your suggestions. I'm happy to open follow-up PRs for index_fill_int_tensor_batch_rule_impl improvements.

@zou3519
Copy link
Contributor

zou3519 commented Jan 10, 2023

Thank you! I'll take a look. I might not be very responsive this week, so my apologies in advance

@zou3519 zou3519 self-requested a review January 10, 2023 15:14
Copy link
Contributor

@zou3519 zou3519 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I read through the code, and I'm not completely sure that the out-of-place path where index is batched is correct on the edge cases. In particular, the OpInfo testing appears to be missing SampleInputs for those edge cases.

I've suggested some edge cases that we should manually test via vmap_opinfo_test. If your code passes the tests, then I am happy, if not, then my suggestion might be relevant.

test/functorch/test_vmap.py Outdated Show resolved Hide resolved
aten/src/ATen/functorch/BatchRulesScatterOps.cpp Outdated Show resolved Hide resolved
@qqaatw qqaatw requested a review from zou3519 January 15, 2023 16:21
Comment on lines 1059 to 1074
std::tuple<Tensor,optional<int64_t>> index_fill_int_scalar_batch_rule_impl(
Tensor & self, optional<int64_t> self_bdim,
int64_t dim,
const Tensor & index, optional<int64_t> index_bdim,
const Scalar & value,
const bool inplace) {
const auto self_logical_rank = rankWithoutBatchDim(self, self_bdim);
const auto index_logical_rank = rankWithoutBatchDim(index, index_bdim);
Tensor self_ = moveBatchDimToFront(self, self_bdim);
Tensor index_ = moveBatchDimToFront(index, index_bdim);
dim = maybe_wrap_dim(dim, self_logical_rank);

if (inplace && !self_bdim.has_value()) {
vmapIncompatibleInplaceError("index_fill_");
}

Copy link
Contributor

@zou3519 zou3519 Jan 18, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some style guide nits: the intents should be at different levels, otherwise it gets a bit difficult to read. Ditto for the other functions you added:

std::tuple<Tensor,optional<int64_t>> index_fill_int_scalar_batch_rule_impl(
    Tensor & self, optional<int64_t> self_bdim,
    Tensor & self, optional<int64_t> self_bdim,
    int64_t dim,
    const Tensor & index, optional<int64_t> index_bdim,
    const Scalar & value,
    const bool inplace) {
  const auto self_logical_rank = rankWithoutBatchDim(self, self_bdim);
  const auto index_logical_rank = rankWithoutBatchDim(index, index_bdim);
  Tensor self_ = moveBatchDimToFront(self, self_bdim);
  ...
}

@qqaatw qqaatw requested a review from Chillee as a code owner January 19, 2023 08:19
@qqaatw
Copy link
Collaborator Author

qqaatw commented Jan 19, 2023

@zou3519 Done. Thank you for all your help on this PR!

@zou3519 @Chillee I have opened another PR #91534 that adds a test sample of index_fill to the OpInfo. But there are too many failures that are hard to xfail all of them. The reason of those failures, I guess, is that most implementations are not correctly taking the value gradient calculation into account. Could you point out some suggestions there? If there are other parts of functorch that I could contribute, please let me know. I'm happy and interested in contributing more. I'm on PyTorch's Slack channel also if you prefer to talk there to prevent mail flooding. Thanks.

@qqaatw
Copy link
Collaborator Author

qqaatw commented Jan 19, 2023

@pytorchbot merge

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label Jan 19, 2023
@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@pytorchmergebot
Copy link
Collaborator

Merge failed

Reason: GraphQL query
fragment PRReviews on PullRequestReviewConnection {
nodes {
author {
login
}
state
}
pageInfo {
startCursor
hasPreviousPage
}
}

fragment PRCheckSuites on CheckSuiteConnection {
edges {
node {
app {
name
databaseId
}
workflowRun {
workflow {
name
}
url
}
checkRuns(first: 50) {
nodes {
name
conclusion
detailsUrl
}
pageInfo {
endCursor
hasNextPage
}
}
conclusion
}
cursor
}
pageInfo {
hasNextPage
}
}

fragment CommitAuthors on PullRequestCommitConnection {
nodes {
commit {
author {
user {
login
}
email
name
}
oid
}
}
pageInfo {
endCursor
hasNextPage
}
}

query ($owner: String!, $name: String!, $number: Int!) {
repository(owner: $owner, name: $name) {
pullRequest(number: $number) {
closed
isCrossRepository
author {
login
}
title
body
headRefName
headRepository {
nameWithOwner
}
baseRefName
baseRepository {
nameWithOwner
isPrivate
defaultBranchRef {
name
}
}
mergeCommit {
oid
}
commits_with_authors: commits(first: 100) {
...CommitAuthors
totalCount
}
commits(last: 1) {
nodes {
commit {
checkSuites(first: 10) {
...PRCheckSuites
}
status {
contexts {
context
state
targetUrl
}
}
pushedDate
oid
}
}
}
changedFiles
files(first: 100) {
nodes {
path
}
pageInfo {
endCursor
hasNextPage
}
}
reviews(last: 100) {
...PRReviews
}
comments(last: 5) {
nodes {
bodyText
createdAt
author {
login
}
authorAssociation
editor {
login
}
databaseId
}
pageInfo {
startCursor
hasPreviousPage
}
}
labels(first: 100) {
edges {
node {
name
}
}
}
}
}
}
, args {'name': 'pytorch', 'owner': 'pytorch', 'number': 91364} failed: [{'message': 'Something went wrong while executing your query. Please include 0401:674C:11ECF17:24D7532:63C91988 when reporting this issue.'}]

Details for Dev Infra team Raised by workflow job

@qqaatw
Copy link
Collaborator Author

qqaatw commented Jan 20, 2023

The failing tests seem to be unrelated to the changes this PR brings.

cc @Chillee @kshitij12345

@qqaatw
Copy link
Collaborator Author

qqaatw commented Jan 29, 2023

@zou3519 @Chillee @kshitij12345 Hello, can you help merge this? the test failures seem to be unrelated!

@kshitij12345
Copy link
Collaborator

@pytorchbot rebase

@pytorchmergebot
Copy link
Collaborator

@pytorchbot successfully started a rebase job. Check the current status here

@pytorchmergebot
Copy link
Collaborator

Rebase failed due to Command git -C /home/runner/work/pytorch/pytorch rebase refs/remotes/origin/viable/strict pull/91364/head returned non-zero exit code 1

Rebasing (1/12)
Auto-merging test/functorch/test_vmap.py
CONFLICT (content): Merge conflict in test/functorch/test_vmap.py
error: could not apply 94d71204b9... Add vmap support for torch.index_fill
hint: Resolve all conflicts manually, mark them as resolved with
hint: "git add/rm <conflicted_files>", then run "git rebase --continue".
hint: You can instead skip this commit: run "git rebase --skip".
hint: To abort and get back to the state before "git rebase", run "git rebase --abort".
Could not apply 94d71204b9... Add vmap support for torch.index_fill

Raised by https://github.com/pytorch/pytorch/actions/runs/4036185466

@kshitij12345
Copy link
Collaborator

@qqaatw can you please rebase on latest viable/strict branch and push the code. Post that we can merge it.

Thank you and sorry it slipped through the notifications :)

@qqaatw
Copy link
Collaborator Author

qqaatw commented Jan 29, 2023

@kshitij12345 merged. didn't rebase as there are conflicts on multiple commits.
Can you please also provide some suggestions on this? #91364 (comment)

Thanks!

@kshitij12345
Copy link
Collaborator

@qqaatw will take a look at the comment on Monday :)

@kshitij12345
Copy link
Collaborator

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@pytorchmergebot
Copy link
Collaborator

Merge failed

Reason: 1 jobs have failed, first few of them are: linux-binary-libtorch-pre-cxx11 / libtorch-cpu-shared-with-deps-pre-cxx11-build / build

Details for Dev Infra team Raised by workflow job

@kshitij12345
Copy link
Collaborator

@pytorchbot merge -f"Unrelated failures in libtorch CI"

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged immediately since you used the force (-f) flag, bypassing any CI checks (ETA: 1-5 minutes).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

qqaatw added a commit that referenced this pull request Apr 17, 2023
…fill batch rule"


A follow-up PR for #91364 (comment)


[ghstack-poisoned]
qqaatw added a commit that referenced this pull request Apr 17, 2023
…fill batch rule"


A follow-up PR for #91364 (comment)


[ghstack-poisoned]
qqaatw added a commit that referenced this pull request Apr 18, 2023
…fill batch rule"


A follow-up PR for #91364 (comment)


cc zou3519 Chillee samdow soumith kshitij12345 janeyx99

[ghstack-poisoned]
qqaatw added a commit that referenced this pull request Apr 18, 2023
…fill batch rule"


A follow-up PR for #91364 (comment)


cc zou3519 Chillee samdow soumith kshitij12345 janeyx99

[ghstack-poisoned]
qqaatw added a commit that referenced this pull request Apr 19, 2023
…fill batch rule"


A follow-up PR for #91364 (comment)


cc zou3519 Chillee samdow soumith kshitij12345 janeyx99

[ghstack-poisoned]
pytorchmergebot pushed a commit that referenced this pull request Apr 20, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ciflow/trunk Trigger trunk jobs on your pull request Merged open source triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

Successfully merging this pull request may close these issues.

vmap support for torch.index_fill
6 participants