Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement numba overload for POTRF, LAPACK cholesky routine #578

Merged
merged 9 commits into from
Jan 15, 2024
37 changes: 1 addition & 36 deletions pytensor/link/numba/dispatch/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
from pytensor.tensor.blas import BatchedDot
from pytensor.tensor.math import Dot
from pytensor.tensor.shape import Reshape, Shape, Shape_i, SpecifyShape
from pytensor.tensor.slinalg import Cholesky, Solve
from pytensor.tensor.slinalg import Solve
from pytensor.tensor.subtensor import (
AdvancedIncSubtensor,
AdvancedIncSubtensor1,
Expand Down Expand Up @@ -809,41 +809,6 @@ def softplus(x):
return softplus


@numba_funcify.register(Cholesky)
def numba_funcify_Cholesky(op, node, **kwargs):
lower = op.lower

out_dtype = node.outputs[0].type.numpy_dtype

if lower:
inputs_cast = int_to_float_fn(node.inputs, out_dtype)

@numba_njit
def cholesky(a):
return np.linalg.cholesky(inputs_cast(a)).astype(out_dtype)

else:
# TODO: Use SciPy's BLAS/LAPACK Cython wrappers.

warnings.warn(
(
"Numba will use object mode to allow the "
"`lower` argument to `scipy.linalg.cholesky`."
),
UserWarning,
)

ret_sig = get_numba_type(node.outputs[0].type)

@numba_njit
def cholesky(a):
with numba.objmode(ret=ret_sig):
ret = scipy.linalg.cholesky(a, lower=lower).astype(out_dtype)
return ret

return cholesky


@numba_funcify.register(Solve)
def numba_funcify_Solve(op, node, **kwargs):
assume_a = op.assume_a
Expand Down
148 changes: 135 additions & 13 deletions pytensor/link/numba/dispatch/slinalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

from pytensor.link.numba.dispatch import basic as numba_basic
from pytensor.link.numba.dispatch.basic import numba_funcify
from pytensor.tensor.slinalg import BlockDiagonal, SolveTriangular
from pytensor.tensor.slinalg import BlockDiagonal, Cholesky, SolveTriangular


_PTR = ctypes.POINTER
Expand All @@ -25,6 +25,15 @@
_ptr_int = _PTR(_int)


@numba.core.extending.register_jitable
def _check_finite_matrix(a, func_name):
for v in np.nditer(a):
if not np.isfinite(v.item()):
raise np.linalg.LinAlgError(
"Non-numeric values (nan or inf) in input to " + func_name
)


@intrinsic
def val_to_dptr(typingctx, data):
def impl(context, builder, signature, args):
Expand Down Expand Up @@ -177,6 +186,22 @@ def numba_xtrtrs(cls, dtype):

return functype(lapack_ptr)

@classmethod
def numba_xpotrf(cls, dtype):
"""
Called by scipy.linalg.cholesky
"""
lapack_ptr, float_pointer = _get_lapack_ptr_and_ptr_type(dtype, "potrf")
functype = ctypes.CFUNCTYPE(
None,
_ptr_int, # UPLO,
_ptr_int, # N
float_pointer, # A
_ptr_int, # LDA
_ptr_int, # INFO
)
return functype(lapack_ptr)


def _solve_triangular(A, B, trans=0, lower=False, unit_diagonal=False):
return linalg.solve_triangular(
Expand All @@ -190,13 +215,7 @@ def solve_triangular_impl(A, B, trans=0, lower=False, unit_diagonal=False):

_check_scipy_linalg_matrix(A, "solve_triangular")
_check_scipy_linalg_matrix(B, "solve_triangular")

dtype = A.dtype
if str(dtype).startswith("complex"):
raise ValueError(
"Complex inputs not currently supported by solve_triangular in Numba mode"
)

w_type = _get_underlying_float(dtype)
numba_trtrs = _LAPACK().numba_xtrtrs(dtype)

Expand Down Expand Up @@ -249,8 +268,8 @@ def impl(A, B, trans=0, lower=False, unit_diagonal=False):
)

if B_is_1d:
return B_copy[..., 0]
return B_copy
return B_copy[..., 0], int_ptr_to_val(INFO)
return B_copy, int_ptr_to_val(INFO)

return impl

Expand All @@ -262,19 +281,122 @@ def numba_funcify_SolveTriangular(op, node, **kwargs):
unit_diagonal = op.unit_diagonal
check_finite = op.check_finite

dtype = node.inputs[0].dtype
if str(dtype).startswith("complex"):
raise NotImplementedError(
"Complex inputs not currently supported by solve_triangular in Numba mode"
)

@numba_basic.numba_njit(inline="always")
def solve_triangular(a, b):
res = _solve_triangular(a, b, trans, lower, unit_diagonal)
if check_finite:
if np.any(np.bitwise_or(np.isinf(res), np.isnan(res))):
raise ValueError(
"Non-numeric values (nan or inf) returned by solve_triangular"
if np.any(np.bitwise_or(np.isinf(a), np.isnan(a))):
raise np.linalg.LinAlgError(
"Non-numeric values (nan or inf) in input A to solve_triangular"
)
if np.any(np.bitwise_or(np.isinf(b), np.isnan(b))):
raise np.linalg.LinAlgError(
"Non-numeric values (nan or inf) in input b to solve_triangular"
)

res, info = _solve_triangular(a, b, trans, lower, unit_diagonal)
if info != 0:
raise np.linalg.LinAlgError(
"Singular matrix in input A to solve_triangular"
)
return res

return solve_triangular


def _cholesky(a, lower=False, overwrite_a=False, check_finite=True):
return linalg.cholesky(
a, lower=lower, overwrite_a=overwrite_a, check_finite=check_finite
)


@overload(_cholesky)
def cholesky_impl(A, lower=0, overwrite_a=False, check_finite=True):
ensure_lapack()
_check_scipy_linalg_matrix(A, "cholesky")
dtype = A.dtype
w_type = _get_underlying_float(dtype)
numba_potrf = _LAPACK().numba_xpotrf(dtype)

def impl(A, lower=0, overwrite_a=False, check_finite=True):
_N = np.int32(A.shape[-1])
if A.shape[-2] != _N:
raise linalg.LinAlgError("Last 2 dimensions of A must be square")

UPLO = val_to_int_ptr(ord("L") if lower else ord("U"))
N = val_to_int_ptr(_N)
LDA = val_to_int_ptr(_N)
INFO = val_to_int_ptr(0)

if not overwrite_a:
A_copy = _copy_to_fortran_order(A)
else:
A_copy = A

numba_potrf(
UPLO,
N,
A_copy.view(w_type).ctypes,
LDA,
INFO,
)

return A_copy, int_ptr_to_val(INFO)

return impl


@numba_funcify.register(Cholesky)
def numba_funcify_Cholesky(op, node, **kwargs):
jessegrabowski marked this conversation as resolved.
Show resolved Hide resolved
"""
Overload scipy.linalg.cholesky with a numba function.

Note that np.linalg.cholesky is already implemented in numba, but it does not support additional keyword arguments.
In particular, the `inplace` argument is not supported, which is why we choose to implement our own version.
"""
lower = op.lower
overwrite_a = False
check_finite = op.check_finite
on_error = op.on_error

dtype = node.inputs[0].dtype
if str(dtype).startswith("complex"):
raise NotImplementedError(
"Complex inputs not currently supported by cholesky in Numba mode"
)

@numba_basic.numba_njit(inline="always")
def nb_cholesky(a):
if check_finite:
if np.any(np.bitwise_or(np.isinf(a), np.isnan(a))):
raise np.linalg.LinAlgError(
"Non-numeric values (nan or inf) found in input to cholesky"
)
res, info = _cholesky(a, lower, overwrite_a, check_finite)

if on_error == "raise":
if info > 0:
raise np.linalg.LinAlgError(
"Input to cholesky is not positive definite"
)
if info < 0:
raise ValueError(
'LAPACK reported an illegal value in input on entry to "POTRF."'
)
else:
if info != 0:
res = np.full_like(res, np.nan)

return res

return nb_cholesky


@numba_funcify.register(BlockDiagonal)
def numba_funcify_BlockDiagonal(op, node, **kwargs):
dtype = node.outputs[0].dtype
Expand Down
13 changes: 9 additions & 4 deletions pytensor/tensor/slinalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,9 +51,10 @@ class Cholesky(Op):
__props__ = ("lower", "destructive", "on_error")
gufunc_signature = "(m,m)->(m,m)"

def __init__(self, *, lower=True, on_error="raise"):
def __init__(self, *, lower=True, check_finite=True, on_error="raise"):
self.lower = lower
self.destructive = False
self.check_finite = check_finite
if on_error not in ("raise", "nan"):
raise ValueError('on_error must be one of "raise" or ""nan"')
self.on_error = on_error
Expand All @@ -70,7 +71,9 @@ def perform(self, node, inputs, outputs):
x = inputs[0]
z = outputs[0]
try:
z[0] = scipy.linalg.cholesky(x, lower=self.lower).astype(x.dtype)
z[0] = scipy.linalg.cholesky(
x, lower=self.lower, check_finite=self.check_finite
).astype(x.dtype)
except scipy.linalg.LinAlgError:
if self.on_error == "raise":
raise
Expand Down Expand Up @@ -129,8 +132,10 @@ def conjugate_solve_triangular(outer, inner):
return [grad]


def cholesky(x, lower=True, on_error="raise"):
return Blockwise(Cholesky(lower=lower, on_error=on_error))(x)
def cholesky(x, lower=True, on_error="raise", check_finite=False):
return Blockwise(
Cholesky(lower=lower, on_error=on_error, check_finite=check_finite)
)(x)


class SolveBase(Op):
Expand Down
51 changes: 0 additions & 51 deletions tests/link/numba/test_nlinalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,57 +14,6 @@
rng = np.random.default_rng(42849)


@pytest.mark.parametrize(
"x, lower, exc",
[
(
set_test_value(
pt.dmatrix(),
(lambda x: x.T.dot(x))(rng.random(size=(3, 3)).astype("float64")),
),
True,
None,
),
(
set_test_value(
pt.lmatrix(),
(lambda x: x.T.dot(x))(
rng.integers(1, 10, size=(3, 3)).astype("int64")
),
),
True,
None,
),
(
set_test_value(
pt.dmatrix(),
(lambda x: x.T.dot(x))(rng.random(size=(3, 3)).astype("float64")),
),
False,
UserWarning,
),
],
)
def test_Cholesky(x, lower, exc):
g = slinalg.Cholesky(lower=lower)(x)

if isinstance(g, list):
g_fg = FunctionGraph(outputs=g)
else:
g_fg = FunctionGraph(outputs=[g])

cm = contextlib.suppress() if exc is None else pytest.warns(exc)
with cm:
compare_numba_and_py(
g_fg,
[
i.tag.test_value
for i in g_fg.inputs
if not isinstance(i, (SharedVariable, Constant))
],
)


@pytest.mark.parametrize(
"A, x, lower, exc",
[
Expand Down