-
Notifications
You must be signed in to change notification settings - Fork 5
/
linear_algebra.py
137 lines (94 loc) · 2.79 KB
/
linear_algebra.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
import logging
from typing import Union, Optional
import jax.numpy as jnp
import jax.scipy.linalg as jsla
from . import dispatch, B, Numeric
from .custom import jax_register
from ..custom import (
toeplitz_solve,
i_toeplitz_solve,
s_toeplitz_solve,
i_s_toeplitz_solve,
expm,
i_expm,
s_expm,
i_s_expm,
logm,
i_logm,
s_logm,
i_s_logm,
)
from ..linear_algebra import _default_perm
from ..types import Int
from ..util import batch_computation
__all__ = []
log = logging.getLogger(__name__)
@dispatch
def matmul(a: Numeric, b: Numeric, tr_a: bool = False, tr_b: bool = False):
a = transpose(a) if tr_a else a
b = transpose(b) if tr_b else b
return jnp.matmul(a, b)
@dispatch
def einsum(equation: str, *elements: Numeric):
return jnp.einsum(equation, *elements)
@dispatch
def transpose(a: Numeric, perm: Optional[Union[tuple, list]] = None):
# Correctly handle special cases.
rank_a = B.rank(a)
if rank_a == 0:
return a
elif rank_a == 1 and perm is None:
return a[None, :]
if perm is None:
perm = _default_perm(a)
return jnp.transpose(a, axes=perm)
@dispatch
def trace(a: Numeric, axis1: Int = -2, axis2: Int = -1):
return jnp.trace(a, axis1=axis1, axis2=axis2)
@dispatch
def svd(a: Numeric, compute_uv: bool = True):
res = jnp.linalg.svd(a, full_matrices=False, compute_uv=compute_uv)
return (res[0], res[1], jnp.conj(transpose(res[2]))) if compute_uv else res
@dispatch
def eig(a: Numeric, compute_eigvecs: bool = True):
vals, vecs = jnp.linalg.eig(a)
return (vals, vecs) if compute_eigvecs else vals
@dispatch
def solve(a: Numeric, b: Numeric):
return jnp.linalg.solve(a, b)
@dispatch
def inv(a: Numeric):
return jnp.linalg.inv(a)
@dispatch
def det(a: Numeric):
return jnp.linalg.det(a)
@dispatch
def logdet(a: Numeric):
return jnp.linalg.slogdet(a)[1]
_expm = jax_register(expm, i_expm, s_expm, i_s_expm)
@dispatch
def expm(a: Numeric):
return _expm(a)
_logm = jax_register(logm, i_logm, s_logm, i_s_logm)
@dispatch
def logm(a: Numeric):
return _logm(a)
@dispatch
def _cholesky(a: Numeric):
return jnp.linalg.cholesky(a)
@dispatch
def cholesky_solve(a: Numeric, b: Numeric):
return triangular_solve(transpose(a), triangular_solve(a, b), lower_a=False)
@dispatch
def triangular_solve(a: Numeric, b: Numeric, lower_a: bool = True):
def _triangular_solve(a_, b_):
return jsla.solve_triangular(
a_, b_, trans="N", lower=lower_a, check_finite=False
)
return batch_computation(_triangular_solve, (a, b), (2, 2))
_toeplitz_solve = jax_register(
toeplitz_solve, i_toeplitz_solve, s_toeplitz_solve, i_s_toeplitz_solve
)
@dispatch
def toeplitz_solve(a: Numeric, b: Numeric, c: Numeric):
return _toeplitz_solve(a, b, c)