Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[mta] Backward of unary foreach functions #89591

Closed
wants to merge 9 commits into from

Conversation

crcrpar
Copy link
Collaborator

@crcrpar crcrpar commented Nov 23, 2022

as per title, this PR defines backward of those.

This doesn't implement forward-mode automatic differentiation as the current codegen doesn't seem to handle ArrayRef<Tensor>.

Rel:

cc @mcarilli @ngimel

@pytorch-bot
Copy link

pytorch-bot bot commented Nov 23, 2022

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/89591

Note: Links to docs will display an error until the docs builds have been completed.

⏳ No Failures, 1 Pending

As of commit 01d6f9e:
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@pytorch-bot pytorch-bot bot added the release notes: foreach_frontend release notes category label Nov 23, 2022
@crcrpar crcrpar marked this pull request as ready for review November 30, 2022 00:37
@crcrpar
Copy link
Collaborator Author

crcrpar commented Nov 30, 2022

regarding inplace ops, I locally confirmed the version seems to be bumped appropriately

In [8]: tensors = [torch.randn(3).requires_grad_() for _ in range(2)]
In [9]: print([t._version for t in tensors])
[0, 0]
In [10]: torch._foreach_exp_(tensors)
In [11]: print([t._version for t in tensors])
[1, 1]

@drisspg drisspg added module: mta Issues related to multi-tensor apply kernels and foreach functions triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module labels Dec 1, 2022
@ngimel
Copy link
Collaborator

ngimel commented Dec 5, 2022

So the backward of foreach will always fall back to slow version, and codegen would make sure that we properly error out if versions don't match? That looks fine, but I'll let @albanD take a look.

@albanD
Copy link
Collaborator

albanD commented Dec 28, 2022

regarding inplace ops, I locally confirmed the version seems to be bumped appropriately

I think this is only because you're on CPU: the slow code calls the single-Tensor op which properly bumps the version counter.
Running the same thing on CUDA won't work I think.

Copy link
Collaborator

@albanD albanD left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a lot of code duplication.
Should we just have the codegen take care of it? The same way we re-use the foo formula for foo_ when no explicit formula for foo_ is provided. We could codegen a slow _foreach_foo formula when it is not provided. WDYT?

torch/_torch_docs.py Outdated Show resolved Hide resolved
Only CPU and CUDA are supported. Forward-mode AD is not supported.

Args:
self (list of Tensors): Input list of Tensors. Each Tensor can have an arbitrary shape, dtype, device, and strides.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It works for difference dtype and device? Why do we collect grads to zero out by device/dtype here then:

per_device_and_dtype_grads = defaultdict(lambda: defaultdict(list))
?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To reduce the number of CUDA kernels launched as possible, I guess.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure to follow? How does the two have different number of kernel launched?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we pass a list of tensors of different dtypes and/or devices to a foreach function, it'll call the corresponding aten native function the number of Tensors times, i.e. len(self) cuda kernels. With the grouping, it could reduce the number to len(per_device_and_dtype_grads), assuming all the tensors are a CUDA tensor

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@albanD _for_each ops support args on the different device and of the different dtypes, but they'll fall back to slow implementation. If, instead of just calling for_each we pre-sort the args, we can call for_each with a subset of tensors that'll go to the fast path. Really this sorting should be done in for_each itself though, and not in the optimizer

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok!
Should we really advertise this then as it will not do what you expect these functions to do.

torch/_torch_docs.py Outdated Show resolved Hide resolved
~~~~~~~~~~~~~~~~~~

.. warning::
This API is in beta and subject to future changes.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: beta or prototype?

tools/autograd/derivatives.yaml Outdated Show resolved Hide resolved
@crcrpar crcrpar force-pushed the foreach/unary-bwd branch 3 times, most recently from 246cad6 to a0df3ef Compare January 16, 2023 09:03
@crcrpar
Copy link
Collaborator Author

crcrpar commented Jan 16, 2023

Added a codegen and a version bump

@crcrpar crcrpar requested review from albanD and removed request for soulitzer, ngimel and mruberry January 18, 2023 00:23
Copy link
Collaborator

@albanD albanD left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the update!
All the codegen part looks good! Only questions about the InplaceOrView key side.

aten/src/ATen/native/cuda/ForeachFunctors.cuh Outdated Show resolved Hide resolved
@@ -73,6 +73,7 @@ template <typename scalar_t, template<class> class Op> void foreach_unary_op_(Te
/* r_args_depth */ 1,
/* res_arg_index */ 0>(),
Op<opmath_t>());
maybe_increment_version(tensors);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do I misremember our discussion where you showed me an example where the version was bumped properly? Why does this need to be added?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That happened when inputs tensors are a CPU tensor not when fast path is chosen.
Even the functions registered to native_functions.yaml with CUDA key could go into the slow path which just calls aten native function. Therefore I decided to put this manual version bump here

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ho that's right!
I'm still not sure why the InplaceOrView kernel doesn't get generated automatically to do that already? But we can look into that later.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I thought with

pytorch/torchgen/model.py

Lines 1388 to 1405 in b33d9e2

if self.name.name.inplace:
self_a = self.arguments.self_arg
assert (
self_a
and self_a.argument.annotation
and self_a.argument.annotation.is_write
)
if self_a.argument.type == BaseType(BaseTy.Tensor):
# All inplace ops with an ordinary `Tensor self` argument should return self,
# to allow for method chaining.
assert (
len(self.returns) == 1
and self.returns[0].annotation == self_a.argument.annotation
)
else:
# You can't method chain on non-tensor self arguments though (like a List[Tensor])
# so in all other cases we expect the return type to be none.
assert len(self.returns) == 0
, inplace functions that modify (and return) multiple Tensors don't get caught by
for r in cpp.return_names(f):
inplace_view_body.append(f"increment_version({r});")

crcrpar added a commit to crcrpar/pytorch that referenced this pull request Jan 19, 2023
following pytorch#89591 (comment)

Signed-off-by: Masaki Kozuki <mkozuki@nvidia.com>
Copy link
Collaborator

@albanD albanD left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good!
Sorry for all the back and forth, this is definitely a challenging one...

@crcrpar
Copy link
Collaborator Author

crcrpar commented Jan 20, 2023

@pytorchbot merge

(knowing this won’t work this time but want to trigger more jobs)

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label Jan 20, 2023
@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@pytorchmergebot
Copy link
Collaborator

Merge failed

Reason: 2 mandatory check(s) failed (Rule superuser). The first few are:

Dig deeper by viewing the failures on hud

Details for Dev Infra team Raised by workflow job

crcrpar added a commit to crcrpar/pytorch that referenced this pull request Jan 20, 2023
following pytorch#89591 (comment)

Signed-off-by: Masaki Kozuki <mkozuki@nvidia.com>
crcrpar added a commit to crcrpar/pytorch that referenced this pull request Jan 20, 2023
following pytorch#89591 (comment)

Signed-off-by: Masaki Kozuki <mkozuki@nvidia.com>
@crcrpar crcrpar force-pushed the foreach/unary-bwd branch 2 times, most recently from b0c89f2 to 42ee873 Compare January 21, 2023 02:25
crcrpar added a commit to crcrpar/pytorch that referenced this pull request Jan 21, 2023
following pytorch#89591 (comment)

Signed-off-by: Masaki Kozuki <mkozuki@nvidia.com>
@albanD
Copy link
Collaborator

albanD commented Jan 21, 2023

You have a @pytorchbot label if you need to add label (like the ciflow/trunk one) but you are not allowed.

with a new CodeTemplate of DERIVATIVE_SINGLE_FOREACH

Signed-off-by: Masaki Kozuki <mkozuki@nvidia.com>
Signed-off-by: Masaki Kozuki <mkozuki@nvidia.com>
Signed-off-by: Masaki Kozuki <mkozuki@nvidia.com>
Signed-off-by: Masaki Kozuki <mkozuki@nvidia.com>
Pushing the heavy lifting to torchgen

Signed-off-by: Masaki Kozuki <mkozuki@nvidia.com>
Signed-off-by: Masaki Kozuki <mkozuki@nvidia.com>
following pytorch#89591 (comment)

Signed-off-by: Masaki Kozuki <mkozuki@nvidia.com>
Signed-off-by: Masaki Kozuki <mkozuki@nvidia.com>
Signed-off-by: Masaki Kozuki <mkozuki@nvidia.com>
@crcrpar
Copy link
Collaborator Author

crcrpar commented Jan 23, 2023

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

pytorchmergebot pushed a commit that referenced this pull request Feb 20, 2023
…possible) (#93901)

## summary
- increment tensor versions in inplace foreach functions
- add a logic to take care of `ArrayRef<Scalar>`

rel: #58833, #89591

Pull Request resolved: #93901
Approved by: https://github.com/albanD
jhavukainen pushed a commit to kulinseth/pytorch that referenced this pull request Mar 15, 2024
…possible) (pytorch#93901)

## summary
- increment tensor versions in inplace foreach functions
- add a logic to take care of `ArrayRef<Scalar>`

rel: pytorch#58833, pytorch#89591

Pull Request resolved: pytorch#93901
Approved by: https://github.com/albanD
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ciflow/trunk Trigger trunk jobs on your pull request Merged module: mta Issues related to multi-tensor apply kernels and foreach functions open source release notes: foreach_frontend release notes category triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

6 participants