-
Notifications
You must be signed in to change notification settings - Fork 135
Use lapack func instead of scipy.linalg.cholesky
#1487
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
* Now skips 2D checks in perform * Updated the default arguments for `check_finite` to false to match documentation * Add benchmark test case
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice start, left some comments
Thank you for the feedback. I moved it all to |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looking really great! I see a few more places we can improve the performance (avoiding copy + cleaning up the code) then I think it'll be ready to merge
eye = np.eye(1, dtype=x.dtype) | ||
(potrf,) = scipy_linalg.get_lapack_funcs(("potrf",), (eye,)) | ||
c, _ = potrf(eye, lower=False, overwrite_a=False, clean=True) | ||
out[0] = np.empty_like(x, dtype=c.dtype) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
eye = np.eye(1, dtype=x.dtype) | |
(potrf,) = scipy_linalg.get_lapack_funcs(("potrf",), (eye,)) | |
c, _ = potrf(eye, lower=False, overwrite_a=False, clean=True) | |
out[0] = np.empty_like(x, dtype=c.dtype) | |
(potrf,) = scipy_linalg.get_lapack_funcs(("potrf",), (x,)) | |
out[0] = np.empty_like(x, dtype=potrf.dtype) |
Since you just want the dtype, you can check potrf.dtype
, you don't need to actually call the routine. It's also fine to pass in the empty x
in this case, since it's only used for its dtype attribute (the data or lack thereof doesn't matter)
out[0] = np.empty_like(x, dtype=c.dtype) | ||
return | ||
|
||
x1 = np.asarray_chkfinite(x) if self.check_finite else x |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't like the introduction of a new variable. We know for sure x
is a numpy array inside this function, so the only relevant code in np.asarray_chkfinite(x)
is the if not np.isfinite(x).all():
part. We should just directly do that
# Quick return for square empty array | ||
if x.size == 0: | ||
eye = np.eye(1, dtype=x.dtype) | ||
(potrf,) = scipy_linalg.get_lapack_funcs(("potrf",), (eye,)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Get potrf
once before this size check and re-use it in all branches
|
||
# Scipy cholesky only makes use of overwrite_a when it is F_CONTIGUOUS | ||
# If we have a `C_CONTIGUOUS` array we transpose to benefit from it | ||
if self.overwrite_a and x.flags["C_CONTIGUOUS"]: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
save a variable c_contiguous_input
and use that for checks later (instead of maintaining two arrays)
check_finite
to false to match documentationDescription
Adds
_cholesky
method toslinalg.Cholesky
to replace thescipy.linalg.cholesky
wrapper. It is almost identical to the corresponding scipy function but it skips the 2d check and the batching wrapper.Previous performance (with
check_finite=False
):After changes:
Related Issue
Checklist
Type of change
📚 Documentation preview 📚: https://pytensor--1487.org.readthedocs.build/en/1487/