New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Avoid cloning gradient tensor in embedding backward pass #2526
Conversation
This pull request was exported from Phabricator. Differential Revision: D56420646 |
✅ Deploy Preview for pytorch-fbgemm-docs ready!
To edit notification comments on pull requests, go to your Netlify site configuration. |
This will break the signature for const, will it not? The reason for that if I recall it was that we moved this memory alignment into the operator, inside the cuda kernel, the signatures ask that to be const. If we don’t do it inside the operator we need to figure out a way to adjust the meta kernel which don’t have access to the data pointers to check the alignment. |
No, I don't think so. You're right that |
Cool, I am only on my phone, and still on leave, can you check with someone else to approve it? |
@pytorchbot merge |
This PR needs to be approved by an authorized maintainer before merge. |
Summary: I found memory spike during embedding kernel backward `split_embedding_backward_codegen_rowwise_adagrad_unweghted_exact_cuda`, which was traced into the below code making a clone of the gradient tensor. This logic didn't seem to be there in the original code: https://github.com/pytorch/FBGEMM/pull/2347/files#diff-944ab49dcbcf54826cc3e1eab5e3c0c787b5a195f602c2d3052adae14c506d78. Reviewed By: ezyang Differential Revision: D56420646
b93d4fb
to
3f16122
Compare
This pull request was exported from Phabricator. Differential Revision: D56420646 |
This pull request has been merged in a75037b. |
Summary: I found memory spike during embedding kernel backward
split_embedding_backward_codegen_rowwise_adagrad_unweghted_exact_cuda
, which was traced into the below code making a clone of the gradient tensor. This logic didn't seem to be there in the original code: https://github.com/pytorch/FBGEMM/pull/2347/files#diff-944ab49dcbcf54826cc3e1eab5e3c0c787b5a195f602c2d3052adae14c506d78.Differential Revision: D56420646