Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

vmap support for torch.index_fill #91177

Closed
1 task
Tracked by #1088
zou3519 opened this issue Dec 20, 2022 · 1 comment
Closed
1 task
Tracked by #1088

vmap support for torch.index_fill #91177

zou3519 opened this issue Dec 20, 2022 · 1 comment
Labels
good first issue module: functorch Pertaining to torch.func or pytorch/functorch triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@zou3519
Copy link
Contributor

zou3519 commented Dec 20, 2022

Tasks

Add vmap support for the following PyTorch operations. That is, each one needs a batching rule.

  • torch.index_fill

Expected behavior

Currently, when one does vmap over them, they raise a warning suggesting the batching rule is not implemented:

import torch
from functorch import vmap
x = torch.randn(4, 3, 3)
index = torch.tensor([0, 2])
z = vmap(torch.index_fill, (0, None, None, None))(x, 1, index, -1)
#  UserWarning: There is a performance drop because we have not yet implemented the batching rule

We expect to not see a warning

Read this!

See this note for more context https://github.com/pytorch/pytorch/blob/master/functorch/writing_batching_rules.md

If you're new to developing PyTorch and/or function transforms, we would recommend reading through https://github.com/pytorch/pytorch/wiki/Core-Frontend-Onboarding#unit-8-function-transforms-optional

Should likely be similar to something in https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/functorch/BatchRulesScatterOps.cpp

Testing

pytest test/functorch/test_vmap.py -v -k "op_has_batch_rule and index_fill" should pass (currently, it is an expected failure).

cc @Chillee @samdow @soumith

@zou3519 zou3519 added triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module module: functorch Pertaining to torch.func or pytorch/functorch labels Dec 20, 2022
@qqaatw
Copy link
Collaborator

qqaatw commented Dec 23, 2022

I'm trying to implement this.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
good first issue module: functorch Pertaining to torch.func or pytorch/functorch triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants