-
Notifications
You must be signed in to change notification settings - Fork 29
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Bug] Logdet #41
Comments
Ah, this seems to be about the trace estimation in the The issue may be when n (if A is a n x n matrix) is not a multiple of the trace evaluator batch size, and then the arrays in different iterations will be different. The easiest fix would probably be to construct Also for GP applications in which you only need unbiased estimates of the MLL gradients you might also want to consider SLQ (ie 'iterative-stochastic') which you can also access (will be selected by auto) if both the matrix is large enough and you set a large |
Oh actually it's not that, on closer inspection it's the slicing of the lanczos tridiagonal matrix in apply_unary: https://github.com/wilson-labs/cola/blob/main/cola/linalg/unary.py#L33 |
Refactoring Lanczos and Arnoldi to return LinearOperators as output, `Q,T,info = lanczos(A,**kwargs)` where Q,T are linear operators, and likewise for arnoldi. Removed lanczos_parts and related functions. Added decorator for parametrize to catch and pass NumpyNotImplemented exceptions for the numpy backend Refactor to make make logdet, unary, etc jittable and vmappable see #41 --------- Co-authored-by: AndPotap <apotapczynski@gmail.com>
🐛 Bug
Issue with log determinant jit compilation on large matrices > 1e-6. Perhaps an issue with the
iterative
method, which I believe is triggered after 1e-6.I replaced this issue by specifying the
method="dense"
kwarg and seem to have no issues there.To reproduce
https://github.com/JaxGaussianProcesses/GPJax/actions/runs/6099328110/job/16551211143?pr=370
System information
Please complete the following information:
main
branch.Additional context
Add any other context about the problem here.
The text was updated successfully, but these errors were encountered: