New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add batched version of trtrs #18025
Add batched version of trtrs #18025
Conversation
156e143
to
e3421eb
Compare
- Remove single batch TH/THC implementations
e3421eb
to
645c07c
Compare
Seems like you don't need a magma_queue_t object
c9baeb2
to
33ed8fa
Compare
result = at::_cholesky_solve_helper(self, A, upper); | ||
Tensor result_tmp; | ||
result_tmp = at::_cholesky_solve_helper(self, A, upper); | ||
result.resize_as_(result_tmp).copy_(result_tmp); | ||
return result; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice catch!
@pytorchbot retest this please |
@ifedan Is this good to go? |
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@ifedan is landing this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@ifedan is landing this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
Summary: - Remove single batch TH/THC implementations - Remove `_batch_trtrs_lower` from `multivariate_normal` - Add tests for batched behavior - Modify trtrs_backward to accommodate for batched case - Modify docs In a future PR, this will be renamed to `triangular_solve`. Pull Request resolved: pytorch/pytorch#18025 Differential Revision: D14523004 Pulled By: ifedan fbshipit-source-id: 11c6a967d107f969b60e5a5c73ce6bb8099ebbe1
_batch_trtrs_lower
frommultivariate_normal
In a future PR, this will be renamed to
triangular_solve
.