Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bug] Logdet #41

Open
daniel-dodd opened this issue Sep 6, 2023 · 2 comments
Open

[Bug] Logdet #41

daniel-dodd opened this issue Sep 6, 2023 · 2 comments
Labels
bug Something isn't working

Comments

@daniel-dodd
Copy link

daniel-dodd commented Sep 6, 2023

🐛 Bug

Issue with log determinant jit compilation on large matrices > 1e-6. Perhaps an issue with the iterative method, which I believe is triggered after 1e-6.

I replaced this issue by specifying the method="dense" kwarg and seem to have no issues there.

To reproduce

# Jit compiling this function and giving an input that has larger than 1e-6 x 1e-6 shape
jit(lambda: sigma cola.logdet(sigma))( input_matrix_here)
# Here Sigma is a SumLinearOperator of Dense LinOp and Diagonal array.
# This may be an issue on SumLinearOperators.

https://github.com/JaxGaussianProcesses/GPJax/actions/runs/6099328110/job/16551211143?pr=370

--> 189     + cola.logdet(sigma)
[709](https://github.com/JaxGaussianProcesses/GPJax/actions/runs/6099328110/job/16551211143?pr=370#step:8:710)
    190     + diff.T @ cola.solve(sigma, diff)
[710](https://github.com/JaxGaussianProcesses/GPJax/actions/runs/6099328110/job/16551211143?pr=370#step:8:711)
    191 )
[711](https://github.com/JaxGaussianProcesses/GPJax/actions/runs/6099328110/job/16551211143?pr=370#step:8:712)

[712](https://github.com/JaxGaussianProcesses/GPJax/actions/runs/6099328110/job/16551211143?pr=370#step:8:713)
File /usr/share/miniconda/envs/test/lib/python3.10/site-packages/cola/linalg/logdet.py:39, in logdet(A, **kwargs)
[713](https://github.com/JaxGaussianProcesses/GPJax/actions/runs/6099328110/job/16551211143?pr=370#step:8:714)
     17 @export
[714](https://github.com/JaxGaussianProcesses/GPJax/actions/runs/6099328110/job/16551211143?pr=370#step:8:715)
     18 def logdet(A: LinearOperator, **kwargs):
[715](https://github.com/JaxGaussianProcesses/GPJax/actions/runs/6099328110/job/16551211143?pr=370#step:8:716)
     19     r""" Computes logdet of a linear operator. 
[716](https://github.com/JaxGaussianProcesses/GPJax/actions/runs/6099328110/job/16551211143?pr=370#step:8:717)
     20 
[717](https://github.com/JaxGaussianProcesses/GPJax/actions/runs/6099328110/job/16551211143?pr=370#step:8:718)
     21     For large inputs (or with method='iterative'),
[718](https://github.com/JaxGaussianProcesses/GPJax/actions/runs/6099328110/job/16551211143?pr=370#step:8:719)
   (...)
[719](https://github.com/JaxGaussianProcesses/GPJax/actions/runs/6099328110/job/16551211143?pr=370#step:8:720)
     37         Array: logdet
[720](https://github.com/JaxGaussianProcesses/GPJax/actions/runs/6099328110/job/16551211143?pr=370#step:8:721)
     38     """
[721](https://github.com/JaxGaussianProcesses/GPJax/actions/runs/6099328110/job/16551211143?pr=370#step:8:722)
---> 39     _, ld = slogdet(A,**kwargs)
[722](https://github.com/JaxGaussianProcesses/GPJax/actions/runs/6099328110/job/16551211143?pr=370#step:8:723)
     40     return ld
[723](https://github.com/JaxGaussianProcesses/GPJax/actions/runs/6099328110/job/16551211143?pr=370#step:8:724)

[724](https://github.com/JaxGaussianProcesses/GPJax/actions/runs/6099328110/job/16551211143?pr=370#step:8:725)
File /usr/share/miniconda/envs/test/lib/python3.10/site-packages/plum/function.py:438, in Function.__call__(self, *args, **kw_args)
[725](https://github.com/JaxGaussianProcesses/GPJax/actions/runs/6099328110/job/16551211143?pr=370#step:8:726)
    436     method, return_type, loginfo = self.resolve_method(args, types)
[726](https://github.com/JaxGaussianProcesses/GPJax/actions/runs/6099328110/job/16551211143?pr=370#step:8:727)
    437 logging.info("%s",loginfo)
[727](https://github.com/JaxGaussianProcesses/GPJax/actions/runs/6099328110/job/16551211143?pr=370#step:8:728)
--> 438 return _convert(method(*args,**kw_args), return_type)
[728](https://github.com/JaxGaussianProcesses/GPJax/actions/runs/6099328110/job/16551211143?pr=370#step:8:729)

[729](https://github.com/JaxGaussianProcesses/GPJax/actions/runs/6099328110/job/16551211143?pr=370#step:8:730)
File /usr/share/miniconda/envs/test/lib/python3.10/site-packages/cola/linalg/logdet.py:96, in slogdet(A, **kwargs)
[730](https://github.com/JaxGaussianProcesses/GPJax/actions/runs/6099328110/job/16551211143?pr=370#step:8:731)
     93 elif 'exact' in method or not stochastic_faster:
[731](https://github.com/JaxGaussianProcesses/GPJax/actions/runs/6099328110/job/16551211143?pr=370#step:8:732)
     94     # TODO: explicit autograd rule for this case?
[732](https://github.com/JaxGaussianProcesses/GPJax/actions/runs/6099328110/job/16551211143?pr=370#step:8:733)
     95     logA = cola.linalg.log(A, tol=tol, method='iterative', **kws)
[733](https://github.com/JaxGaussianProcesses/GPJax/actions/runs/6099328110/job/16551211143?pr=370#step:8:734)
---> 96     trlogA = cola.linalg.trace(logA,method='exact',**kws)
[734](https://github.com/JaxGaussianProcesses/GPJax/actions/runs/6099328110/job/16551211143?pr=370#step:8:735)
     97 else:
[735](https://github.com/JaxGaussianProcesses/GPJax/actions/runs/6099328110/job/16551211143?pr=370#step:8:736)
     98     raise ValueError(f"Unknown method {method} or CoLA didn't fit any selection criteria")
[736](https://github.com/JaxGaussianProcesses/GPJax/actions/runs/6099328110/job/16551211143?pr=370#step:8:737)

[737](https://github.com/JaxGaussianProcesses/GPJax/actions/runs/6099328110/job/16551211143?pr=370#step:8:738)
File /usr/share/miniconda/envs/test/lib/python3.10/site-packages/plum/function.py:438, in Function.__call__(self, *args, **kw_args)
[738](https://github.com/JaxGaussianProcesses/GPJax/actions/runs/6099328110/job/16551211143?pr=370#step:8:739)
    436     method, return_type, loginfo = self.resolve_method(args, types)
[739](https://github.com/JaxGaussianProcesses/GPJax/actions/runs/6099328110/job/16551211143?pr=370#step:8:740)
    437 logging.info("%s",loginfo)
[740](https://github.com/JaxGaussianProcesses/GPJax/actions/runs/6099328110/job/16551211143?pr=370#step:8:741)
--> 438 return _convert(method(*args,**kw_args), return_type)
[741](https://github.com/JaxGaussianProcesses/GPJax/actions/runs/6099328110/job/16551211143?pr=370#step:8:742)

[742](https://github.com/JaxGaussianProcesses/GPJax/actions/runs/6099328110/job/16551211143?pr=370#step:8:743)
File /usr/share/miniconda/envs/test/lib/python3.10/site-packages/cola/linalg/diag_trace.py:137, in trace(A, **kwargs)
[743](https://github.com/JaxGaussianProcesses/GPJax/actions/runs/6099328110/job/16551211143?pr=370#step:8:744)
    117 r""" Compute the trace of a linear operator tr(A).
[744](https://github.com/JaxGaussianProcesses/GPJax/actions/runs/6099328110/job/16551211143?pr=370#step:8:745)
    118 
[745](https://github.com/JaxGaussianProcesses/GPJax/actions/runs/6099328110/job/16551211143?pr=370#step:8:746)
    119 Uses either :math:`O(\tfrac{1}{\delta^2})` time stochastic estimation (Hutchinson estimator)
[746](https://github.com/JaxGaussianProcesses/GPJax/actions/runs/6099328110/job/16551211143?pr=370#step:8:747)
   (...)
[747](https://github.com/JaxGaussianProcesses/GPJax/actions/runs/6099328110/job/16551211143?pr=370#step:8:748)
    134 Returns:
[748](https://github.com/JaxGaussianProcesses/GPJax/actions/runs/6099328110/job/16551211143?pr=370#step:8:749)
    135     Array: trace"""
[749](https://github.com/JaxGaussianProcesses/GPJax/actions/runs/6099328110/job/16551211143?pr=370#step:8:750)
    136 assert A.shape[0] == A.shape[1], "Can't trace non square matrix"
[800](https://github.com/JaxGaussianProcesses/GPJax/actions/runs/6099328110/job/16551211143?pr=370#step:8:801)
--> 723   return getattr(self.aval,f"_{name}")(self,*args)
[801](https://github.com/JaxGaussianProcesses/GPJax/actions/runs/6099328110/job/16551211143?pr=370#step:8:802)

[802](https://github.com/JaxGaussianProcesses/GPJax/actions/runs/6099328110/job/16551211143?pr=370#step:8:803)
File /usr/share/miniconda/envs/test/lib/python3.10/site-packages/jax/_src/numpy/lax_numpy.py:4153, in _rewriting_take(arr, idx, indices_are_sorted, unique_indices, mode, fill_value)
[803](https://github.com/JaxGaussianProcesses/GPJax/actions/runs/6099328110/job/16551211143?pr=370#step:8:804)
   4150       return lax.dynamic_index_in_dim(arr, idx, keepdims=False)
[804](https://github.com/JaxGaussianProcesses/GPJax/actions/runs/6099328110/job/16551211143?pr=370#step:8:805)
   4152 treedef, static_idx, dynamic_idx = _split_index_for_jit(idx, arr.shape)
[805](https://github.com/JaxGaussianProcesses/GPJax/actions/runs/6099328110/job/16551211143?pr=370#step:8:806)
-> 4153 return _gather(arr,treedef,static_idx,dynamic_idx,indices_are_sorted,
[806](https://github.com/JaxGaussianProcesses/GPJax/actions/runs/6099328110/job/16551211143?pr=370#step:8:807)
   4154 unique_indices,mode,fill_value)
[807](https://github.com/JaxGaussianProcesses/GPJax/actions/runs/6099328110/job/16551211143?pr=370#step:8:808)

[808](https://github.com/JaxGaussianProcesses/GPJax/actions/runs/6099328110/job/16551211143?pr=370#step:8:809)
File /usr/share/miniconda/envs/test/lib/python3.10/site-packages/jax/_src/numpy/lax_numpy.py:4162, in _gather(arr, treedef, static_idx, dynamic_idx, indices_are_sorted, unique_indices, mode, fill_value)
[809](https://github.com/JaxGaussianProcesses/GPJax/actions/runs/6099328110/job/16551211143?pr=370#step:8:810)
   4159 def _gather(arr, treedef, static_idx, dynamic_idx, indices_are_sorted,
[810](https://github.com/JaxGaussianProcesses/GPJax/actions/runs/6099328110/job/16551211143?pr=370#step:8:811)
   4160             unique_indices, mode, fill_value):
[811](https://github.com/JaxGaussianProcesses/GPJax/actions/runs/6099328110/job/16551211143?pr=370#step:8:812)
   4161   idx = _merge_static_and_dynamic_indices(treedef, static_idx, dynamic_idx)
[812](https://github.com/JaxGaussianProcesses/GPJax/actions/runs/6099328110/job/16551211143?pr=370#step:8:813)
-> 4162   indexer = _index_to_gather(shape(arr),idx)  # shared with _scatter_update
[813](https://github.com/JaxGaussianProcesses/GPJax/actions/runs/6099328110/job/16551211143?pr=370#step:8:814)
   4163   y = arr
[814](https://github.com/JaxGaussianProcesses/GPJax/actions/runs/6099328110/job/16551211143?pr=370#step:8:815)
   4165   if fill_value is not None:
[815](https://github.com/JaxGaussianProcesses/GPJax/actions/runs/6099328110/job/16551211143?pr=370#step:8:816)

[816](https://github.com/JaxGaussianProcesses/GPJax/actions/runs/6099328110/job/16551211143?pr=370#step:8:817)
File /usr/share/miniconda/envs/test/lib/python3.10/site-packages/jax/_src/numpy/lax_numpy.py:4414, in _index_to_gather(x_shape, idx, normalize_indices)
[817](https://github.com/JaxGaussianProcesses/GPJax/actions/runs/6099328110/job/16551211143?pr=370#step:8:818)
   4405 if not all(_is_slice_element_none_or_constant(elt)
[818](https://github.com/JaxGaussianProcesses/GPJax/actions/runs/6099328110/job/16551211143?pr=370#step:8:819)
   4406            for elt in (start, stop, step)):
[819](https://github.com/JaxGaussianProcesses/GPJax/actions/runs/6099328110/job/16551211143?pr=370#step:8:820)
   4407   msg = ("Array slice indices must have static start/stop/step to be used "
[820](https://github.com/JaxGaussianProcesses/GPJax/actions/runs/6099328110/job/16551211143?pr=370#step:8:821)
   4408          "with NumPy indexing syntax. "
[821](https://github.com/JaxGaussianProcesses/GPJax/actions/runs/6099328110/job/16551211143?pr=370#step:8:822)
   4409          f"Found slice({start}, {stop}, {step}). "
[822](https://github.com/JaxGaussianProcesses/GPJax/actions/runs/6099328110/job/16551211143?pr=370#step:8:823)
   (...)
[823](https://github.com/JaxGaussianProcesses/GPJax/actions/runs/6099328110/job/16551211143?pr=370#step:8:824)
   4412          "dynamic_update_slice (JAX does not support dynamically sized "
[824](https://github.com/JaxGaussianProcesses/GPJax/actions/runs/6099328110/job/16551211143?pr=370#step:8:825)
   4413          "arrays within JIT compiled functions).")
[825](https://github.com/JaxGaussianProcesses/GPJax/actions/runs/6099328110/job/16551211143?pr=370#step:8:826)
-> 4414   raise IndexError(msg)
[826](https://github.com/JaxGaussianProcesses/GPJax/actions/runs/6099328110/job/16551211143?pr=370#step:8:827)
   4415 if not core.is_constant_dim(x_shape[x_axis]):
[827](https://github.com/JaxGaussianProcesses/GPJax/actions/runs/6099328110/job/16551211143?pr=370#step:8:828)
   4416   msg = ("Cannot use NumPy slice indexing on an array dimension whose "
[828](https://github.com/JaxGaussianProcesses/GPJax/actions/runs/6099328110/job/16551211143?pr=370#step:8:829)
   4417          f"size is not statically known ({x_shape[x_axis]}). "
[829](https://github.com/JaxGaussianProcesses/GPJax/actions/runs/6099328110/job/16551211143?pr=370#step:8:830)
   4418          "Try using lax.dynamic_slice/dynamic_update_slice")
[830](https://github.com/JaxGaussianProcesses/GPJax/actions/runs/6099328110/job/16551211143?pr=370#step:8:831)

[831](https://github.com/JaxGaussianProcesses/GPJax/actions/runs/6099328110/job/16551211143?pr=370#step:8:832)
IndexError: Array slice indices must have static start/stop/step to be used with NumPy indexing syntax. Found slice(None, Traced<ShapedArray(int32[])>with<DynamicJaxprTrace(level=1/0)>, None). To index a statically sized array at a dynamic position, try lax.dynamic_slice/dynamic_update_slice (JAX does not support dynamically sized arrays within JIT compiled functions).
[832](https://github.com/JaxGaussianProcesses/GPJax/actions/runs/6099328110/job/16551211143?pr=370#step:8:833)

[833](https://github.com/JaxGaussianProcesses/GPJax/actions/runs/6099328110/job/16551211143?pr=370#step:8:834)
Error: Process completed with exit code 1.

System information

Please complete the following information:

  • Latest PyPI version of cola for the above traceback.
  • Also had same issue locally on my machine for the latest commits on main branch.

Additional context

Add any other context about the problem here.

@daniel-dodd daniel-dodd added the bug Something isn't working label Sep 6, 2023
@mfinzi
Copy link
Collaborator

mfinzi commented Sep 6, 2023

Ah, this seems to be about the trace estimation in the iterative-exact method for computing $\log \mathrm{det} A = \mathrm{Tr}(\log A)$. Looks to be coming from the slicing of I in https://github.com/wilson-labs/cola/blob/main/cola/algorithms/diagonal_estimation.py#L12 for the exact (deterministic) version of the trace estimator.

The issue may be when n (if A is a n x n matrix) is not a multiple of the trace evaluator batch size, and then the arrays in different iterations will be different.

The easiest fix would probably be to construct I_chunk (and the other chunks) that is explicitly zero padded, and to not use any slicing. I will investigate later this week.

Also for GP applications in which you only need unbiased estimates of the MLL gradients you might also want to consider SLQ (ie 'iterative-stochastic') which you can also access (will be selected by auto) if both the matrix is large enough and you set a large vtol (the tolerance for the standard deviation of the unbiased estimator) such as vtol=1/5.

@mfinzi
Copy link
Collaborator

mfinzi commented Sep 7, 2023

Oh actually it's not that, on closer inspection it's the slicing of the lanczos tridiagonal matrix in apply_unary: https://github.com/wilson-labs/cola/blob/main/cola/linalg/unary.py#L33

mfinzi added a commit that referenced this issue Sep 21, 2023
Refactoring Lanczos and Arnoldi to return LinearOperators as output,
`Q,T,info = lanczos(A,**kwargs)`
where Q,T are linear operators, and likewise for arnoldi.
Removed lanczos_parts and related functions.

Added decorator for parametrize to catch and pass NumpyNotImplemented
exceptions for the numpy backend

Refactor to make make logdet, unary, etc jittable and vmappable

see #41

---------

Co-authored-by: AndPotap <apotapczynski@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

2 participants