Skip to content

Add forward AD for linalg.det and simplify its backward #79487

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 14 commits into from

Conversation

lezcano
Copy link
Collaborator

@lezcano lezcano commented Jun 14, 2022

Stack from ghstack:

This PR is in preparation for implementing logdet and slogdet as
structured kernels + implementing them with more efficient derivatives

We implement forward AD for det. We also simplify the implementation of
the backward, and leave a note on how to implement it properly for
singular matrices. We leave thad for future work.

Note (by looking at the OpInfo) that the current implementation passes
the same tests as the one before. We skip the forward-over-backward in
the singular case, as that one was not working in the gradgrad case
either.

This PR is in preparation for implementing `logdet` and `slogdet` as
structured kernels + implementing them with more efficient derivatives

We implement forward AD for det. We also simplify the implementation of
the backward, and leave a note on how to implement it properly for
singular matrices. We leave thad for future work.

Note (by looking at the OpInfo) that the current implementation passes
the same tests as the one before. We skip the forward-over-backward in
the singular case, as that one was not working in the gradgrad case
either.

[ghstack-poisoned]
@facebook-github-bot
Copy link
Contributor

facebook-github-bot commented Jun 14, 2022

🔗 Helpful links

❌ 2 New Failures

As of commit 833bd72 (more details on the Dr. CI page):

Expand to see more
  • 2/2 failures introduced in this PR

🕵️ 2 new failures recognized by patterns

The following CI failures do not appear to be due to upstream breakages

See GitHub Actions build periodic / linux-xenial-cuda10.2-py3-gcc7-slow-gradcheck / test (default, 1, 2, linux.4xlarge.nvidia.gpu) (1/2)

Step: "Test" (full log | diagnosis details | 🔁 rerun)

2022-06-24T13:39:59.6722078Z RuntimeError: test_nn failed!
2022-06-24T13:39:58.6179622Z Generated XML report: test-reports/python-unittest/test_nn/TEST-TestModuleGlobalHooks-20220624120625.xml
2022-06-24T13:39:58.8304336Z Generated XML report: test-reports/python-unittest/test_nn/TEST-TestNN-20220624120625.xml
2022-06-24T13:39:59.0589040Z Generated XML report: test-reports/python-unittest/test_nn/TEST-TestNNDeviceTypeCUDA-20220624120625.xml
2022-06-24T13:39:59.0622240Z Generated XML report: test-reports/python-unittest/test_nn/TEST-TestNNInit-20220624120625.xml
2022-06-24T13:39:59.0629551Z Generated XML report: test-reports/python-unittest/test_nn/TEST-TestStateDictHooks-20220624120625.xml
2022-06-24T13:39:59.6713097Z Traceback (most recent call last):
2022-06-24T13:39:59.6713997Z   File "test/run_test.py", line 945, in <module>
2022-06-24T13:39:59.6717533Z     main()
2022-06-24T13:39:59.6718039Z   File "test/run_test.py", line 923, in main
2022-06-24T13:39:59.6721485Z     raise RuntimeError(err_message)
2022-06-24T13:39:59.6722078Z RuntimeError: test_nn failed!
2022-06-24T13:40:00.2238527Z ##[error]Process completed with exit code 1.
2022-06-24T13:40:00.2278880Z Prepare all required actions
2022-06-24T13:40:00.2279333Z Getting action download info
2022-06-24T13:40:00.4830548Z ##[group]Run ./.github/actions/get-workflow-job-id
2022-06-24T13:40:00.4830858Z with:
2022-06-24T13:40:00.4831302Z   github-token: ***
2022-06-24T13:40:00.4831522Z env:
2022-06-24T13:40:00.4831762Z   GIT_DEFAULT_BRANCH: master
2022-06-24T13:40:00.4832030Z   GPU_FLAG: --gpus all
2022-06-24T13:40:00.4832258Z ##[endgroup]

See GitHub Actions build periodic / linux-xenial-cuda10.2-py3-gcc7-slow-gradcheck / test (default, 2, 2, linux.4xlarge.nvidia.gpu) (2/2)

Step: "Test" (full log | diagnosis details | 🔁 rerun)

2022-06-24T12:55:05.9006884Z RuntimeError: CUDA driver error: out of memory
2022-06-24T12:55:05.9002282Z   File "/opt/conda/lib/python3.7/site-packages/torch/autograd/gradcheck.py", line 1578, in gradgradcheck
2022-06-24T12:55:05.9002739Z     check_forward_ad=check_fwd_over_rev, check_backward_ad=check_rev_over_rev)
2022-06-24T12:55:05.9003273Z   File "/opt/conda/lib/python3.7/site-packages/torch/autograd/gradcheck.py", line 1418, in gradcheck
2022-06-24T12:55:05.9003656Z     return _gradcheck_helper(**args)
2022-06-24T12:55:05.9004164Z   File "/opt/conda/lib/python3.7/site-packages/torch/autograd/gradcheck.py", line 1427, in _gradcheck_helper
2022-06-24T12:55:05.9004538Z     func_out = func(*tupled_inputs)
2022-06-24T12:55:05.9005019Z   File "/opt/conda/lib/python3.7/site-packages/torch/autograd/gradcheck.py", line 1570, in new_func
2022-06-24T12:55:05.9005380Z     allow_unused=True)
2022-06-24T12:55:05.9006035Z   File "/opt/conda/lib/python3.7/site-packages/torch/autograd/__init__.py", line 296, in grad
2022-06-24T12:55:05.9006494Z     allow_unused, accumulate_grad=False)  # Calls into the C++ engine to run the backward pass
2022-06-24T12:55:05.9006884Z RuntimeError: CUDA driver error: out of memory
2022-06-24T12:55:05.9007088Z 
2022-06-24T12:55:06.1132460Z   test_fn_fwgrad_bwgrad_float_power_cuda_complex128 (__main__.TestGradientsCUDA) ... ERROR (0.214s)
2022-06-24T12:55:06.1133505Z     test_fn_fwgrad_bwgrad_float_power_cuda_complex128 errored - num_retries_left: 0
2022-06-24T12:55:06.6627709Z   test_fn_fwgrad_bwgrad_float_power_cuda_float64 (__main__.TestGradientsCUDA) ... ERROR (0.549s)
2022-06-24T12:55:06.6665984Z     test_fn_fwgrad_bwgrad_float_power_cuda_float64 errored - num_retries_left: 3
2022-06-24T12:55:06.6668355Z Traceback (most recent call last):
2022-06-24T12:55:06.6669485Z   File "/opt/conda/lib/python3.7/site-packages/torch/testing/_internal/common_utils.py", line 1806, in wrapper
2022-06-24T12:55:06.6670253Z     method(*args, **kwargs)
2022-06-24T12:55:06.6671410Z   File "/opt/conda/lib/python3.7/site-packages/torch/testing/_internal/common_utils.py", line 1806, in wrapper
2022-06-24T12:55:06.6672234Z     method(*args, **kwargs)

This comment was automatically generated by Dr. CI (expand for details).

Please report bugs/suggestions to the (internal) Dr. CI Users group.

Click here to manually regenerate this comment.

lezcano added a commit that referenced this pull request Jun 14, 2022
This PR is in preparation for implementing `logdet` and `slogdet` as
structured kernels + implementing them with more efficient derivatives

We implement forward AD for det. We also simplify the implementation of
the backward, and leave a note on how to implement it properly for
singular matrices. We leave thad for future work.

Note (by looking at the OpInfo) that the current implementation passes
the same tests as the one before. We skip the forward-over-backward in
the singular case, as that one was not working in the gradgrad case
either.

ghstack-source-id: 9f4033c
Pull Request resolved: #79487
@lezcano lezcano requested review from soulitzer, nikitaved and albanD and removed request for nikitaved, albanD, soulitzer, ngimel, bdhirsh, IvanYashchuk and mruberry June 14, 2022 00:21
@lezcano lezcano added module: derivatives Related to derivatives of operators topic: improvements topic category module: forward ad module: autograd Related to torch.autograd, and the autograd engine in general release notes: autograd release notes category and removed module: autograd Related to torch.autograd, and the autograd engine in general labels Jun 14, 2022
This PR is in preparation for implementing `logdet` and `slogdet` as
structured kernels + implementing them with more efficient derivatives

We implement forward AD for det. We also simplify the implementation of
the backward, and leave a note on how to implement it properly for
singular matrices. We leave thad for future work.

Note (by looking at the OpInfo) that the current implementation passes
the same tests as the one before. We skip the forward-over-backward in
the singular case, as that one was not working in the gradgrad case
either.

[ghstack-poisoned]
lezcano added 2 commits June 19, 2022 15:38
This PR is in preparation for implementing `logdet` and `slogdet` as
structured kernels + implementing them with more efficient derivatives

We implement forward AD for det. We also simplify the implementation of
the backward, and leave a note on how to implement it properly for
singular matrices. We leave thad for future work.

Note (by looking at the OpInfo) that the current implementation passes
the same tests as the one before. We skip the forward-over-backward in
the singular case, as that one was not working in the gradgrad case
either.

[ghstack-poisoned]
This PR is in preparation for implementing `logdet` and `slogdet` as
structured kernels + implementing them with more efficient derivatives

We implement forward AD for det. We also simplify the implementation of
the backward, and leave a note on how to implement it properly for
singular matrices. We leave thad for future work.

Note (by looking at the OpInfo) that the current implementation passes
the same tests as the one before. We skip the forward-over-backward in
the singular case, as that one was not working in the gradgrad case
either.

[ghstack-poisoned]
Comment on lines 4056 to 4058
// The proper way of doing this is doing `auto mask = det == 0.;` and then
// if any determinant is zero, use an SVD decomposition to compute the
// derivative in those inputs (not all inputs). The derivative may be then
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But this is what is done in the current code. What is it being removed now?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not quite. The point here is that:

  • If there is just one singular value that is zero, then the derivative is not zero
  • If there is more than one singular value that is zero, the derivative is zero
  • The SVD returns the singular values in decreasing order, so if the det is zero, we can just remove the smallest singular value and compute the derivative as (det(U) * det(V).conj() * s_1 * ... * s_{n-1})u^Tv

The last point is the relevant one to perform this optimisation.

Copy link
Collaborator

@nikitaved nikitaved Jun 20, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK, makes sense. Why not implementing it in this PR though?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I realised it's not that easy. It's easy to compute the derivative, but the second derivative will not be corrct.

I am not implementing it in this PR as I need to move on to do other things. I've left a detailed comment on what's the issue.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fwiw, the second derivative of a vector with more than one zero is not correct in torch.prod afaik, so it's fine if it doesn't work for torch.det either.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, forgot to push the changes. Now they're in.

lezcano added 3 commits June 20, 2022 14:47
This PR is in preparation for implementing `logdet` and `slogdet` as
structured kernels + implementing them with more efficient derivatives

We implement forward AD for det. We also simplify the implementation of
the backward, and leave a note on how to implement it properly for
singular matrices. We leave thad for future work.

Note (by looking at the OpInfo) that the current implementation passes
the same tests as the one before. We skip the forward-over-backward in
the singular case, as that one was not working in the gradgrad case
either.

[ghstack-poisoned]
This PR is in preparation for implementing `logdet` and `slogdet` as
structured kernels + implementing them with more efficient derivatives

We implement forward AD for det. We also simplify the implementation of
the backward, and leave a note on how to implement it properly for
singular matrices. We leave thad for future work.

Note (by looking at the OpInfo) that the current implementation passes
the same tests as the one before. We skip the forward-over-backward in
the singular case, as that one was not working in the gradgrad case
either.

[ghstack-poisoned]
This PR is in preparation for implementing `logdet` and `slogdet` as
structured kernels + implementing them with more efficient derivatives

We implement forward AD for det. We also simplify the implementation of
the backward, and leave a note on how to implement it properly for
singular matrices. We leave thad for future work.

Note (by looking at the OpInfo) that the current implementation passes
the same tests as the one before. We skip the forward-over-backward in
the singular case, as that one was not working in the gradgrad case
either.

[ghstack-poisoned]
lezcano added 2 commits June 23, 2022 00:56
This PR is in preparation for implementing `logdet` and `slogdet` as
structured kernels + implementing them with more efficient derivatives

We implement forward AD for det. We also simplify the implementation of
the backward, and leave a note on how to implement it properly for
singular matrices. We leave thad for future work.

Note (by looking at the OpInfo) that the current implementation passes
the same tests as the one before. We skip the forward-over-backward in
the singular case, as that one was not working in the gradgrad case
either.

[ghstack-poisoned]
This PR is in preparation for implementing `logdet` and `slogdet` as
structured kernels + implementing them with more efficient derivatives

We implement forward AD for det. We also simplify the implementation of
the backward, and leave a note on how to implement it properly for
singular matrices. We leave thad for future work.

Note (by looking at the OpInfo) that the current implementation passes
the same tests as the one before. We skip the forward-over-backward in
the singular case, as that one was not working in the gradgrad case
either.

[ghstack-poisoned]
Comment on lines +4099 to 4101
// compute higher derivatives of `x.prod()` and `x` has more than one zero.
return at::linalg_solve(A.mH(), d);
}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you know which samples gradgradcheck fails on?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I filled up a follow-up PR #80217 to unlock this one (as this one is already working and these modifications may introduce some weird errors in some platforms)

This PR is in preparation for implementing `logdet` and `slogdet` as
structured kernels + implementing them with more efficient derivatives

We implement forward AD for det. We also simplify the implementation of
the backward, and leave a note on how to implement it properly for
singular matrices. We leave thad for future work.

Note (by looking at the OpInfo) that the current implementation passes
the same tests as the one before. We skip the forward-over-backward in
the singular case, as that one was not working in the gradgrad case
either.

[ghstack-poisoned]
Comment on lines +4095 to +4098
// The issue with this code is that the derivative given by autograd of
// prod_safe_zeros_backward is not the second derivative of the product.
// It is not clear to me how to implement the second derivative of the
// product efficently. Note that this is also currently a problem when we
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is the issue with the prod?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's a non-issue, sorry. This comment is removed in #80217

Copy link
Collaborator

@nikitaved nikitaved left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you, Mario. LGTM! Let's see the output, and that of the next one in the stack.

Copy link
Collaborator

@albanD albanD left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

SGTM

@lezcano
Copy link
Collaborator Author

lezcano commented Jun 24, 2022

@pytorchbot merge -g

@pytorchmergebot
Copy link
Collaborator

@pytorchbot successfully started a merge job. Check the current status here

@pytorchmergebot
Copy link
Collaborator

@lezcano
Copy link
Collaborator Author

lezcano commented Jun 24, 2022

The error, while related to the backward of conv, seems completely unrelated and it looks like a flaky test cc @janeyx99 @jbschlosser (note that the previous CI run was green and there were no changes in this PR between that run and this one).

I'm going YOLO and trying a merge anyway. Wish me the best.

@lezcano
Copy link
Collaborator Author

lezcano commented Jun 24, 2022

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

@pytorchbot successfully started a merge job. Check the current status here

facebook-github-bot pushed a commit that referenced this pull request Jun 27, 2022
)

Summary:
This PR is in preparation for implementing `logdet` and `slogdet` as
structured kernels + implementing them with more efficient derivatives

We implement forward AD for det. We also simplify the implementation of
the backward, and leave a note on how to implement it properly for
singular matrices. We leave thad for future work.

Note (by looking at the OpInfo) that the current implementation passes
the same tests as the one before. We skip the forward-over-backward in
the singular case, as that one was not working in the gradgrad case
either.

Pull Request resolved: #79487
Approved by: https://github.com/nikitaved, https://github.com/albanD

Test Plan: contbuild & OSS CI, see https://hud.pytorch.org/commit/pytorch/pytorch/42a2359612743f791fd2c0684b75f2ff450bc0b3

Reviewed By: seemethere

Differential Revision: D37423853

Pulled By: seemethere

fbshipit-source-id: 627829937609ae977c7c539a5aa6cf009512e162
@facebook-github-bot facebook-github-bot deleted the gh/Lezcano/92/head branch June 28, 2022 14:16
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ciflow/periodic Trigger jobs ran periodically on master (periodic.yml) on the PR ciflow/trunk Trigger trunk jobs on your pull request cla signed Merged module: derivatives Related to derivatives of operators module: forward ad open source release notes: autograd release notes category topic: improvements topic category
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants