Skip to content

Commit

Permalink
added missing diag rules and fix to unary rule
Browse files Browse the repository at this point in the history
  • Loading branch information
mfinzi committed Aug 16, 2023
1 parent ed01d51 commit 884cc3f
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 3 deletions.
6 changes: 6 additions & 0 deletions cola/linalg/diag_trace.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,12 @@ def diag(A: Identity, k=0, **kwargs):
else:
return A.xnp.zeros(A.shape[0] - k, A.dtype)

@dispatch
def diag(A: Diagonal, k=0, **kwargs):
if k == 0:
return A.v
else:
return A.xnp.zeros(A.shape[0] - k, A.dtype)

@dispatch
def diag(A: Sum, k=0, **kwargs):
Expand Down
3 changes: 2 additions & 1 deletion cola/linalg/logdet.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,8 @@ def slogdet(A: LinearOperator, **kwargs) -> Array:
return slogdet(cola.decompositions.cholesky_decomposed(A), **kws)
elif method in ('iterative', 'approx') or (method == 'auto' and
(np.prod(A.shape) > 1e6 and kws['tol'] >= 3e-2)):
return 1., stochastic_lanczos_quad(A, A.xnp.log, **kws)
one = A.xnp.array(1., dtype=A.dtype, device=A.device)
return one, stochastic_lanczos_quad(A, A.xnp.log, **kws)
else:
raise ValueError(f"Unknown method {method} or CoLA didn't fit any selection criteria")

Expand Down
5 changes: 3 additions & 2 deletions cola/linalg/unary.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,12 +115,13 @@ def apply_unary(f: Callable, A: Diagonal, **kwargs):

@dispatch
def apply_unary(f: Callable, A: BlockDiag, **kwargs):
fAs = [apply_unary(f, a, **kwargs) for a in A.blocks]
fAs = [apply_unary(f, a, **kwargs) for a in A.Ms]
return BlockDiag(*fAs, multiplicities=A.multiplicities)

@dispatch
def apply_unary(f: Callable, A: Identity, **kwargs):
return f(1.)*A
one = A.xnp.array(1., dtype=A.dtype, device=A.device)
return f(one)*A

@dispatch
def apply_unary(f: Callable, A: ScalarMul, **kwargs):
Expand Down

0 comments on commit 884cc3f

Please sign in to comment.