Skip to content

Commit

Permalink
Add linalg.block_diag and sparse equivalent (pymc-devs#576)
Browse files Browse the repository at this point in the history
* Copy `block_diag` and support functions from `pymc.math`

* Evaluate output in sphinx code example

Co-authored-by: Ricardo Vieira <28983449+ricardoV94@users.noreply.github.com>

* Test type equivalence with `isinstance` instead of `==`

Co-authored-by: Ricardo Vieira <28983449+ricardoV94@users.noreply.github.com>

* Typo in test function

* Split `block_diag` into sparse and dense version

Closely follow scipy function signature for `block_diag`

* Use `as_sparse_or_tensor_variable` in `SparseBlockDiagonalMatrix` to allow sparse matrix inputs to `pytensor.sparse.block_diag`

* Test sparse and dense inputs to `pytensor.sparse.block_diag`

* Add numba overload for `pytensor.tensor.slinalg.block_diag`

* add jax overload for `pytensor.tensor.slinalg.block_diag`

* Move stand-alone `block_diag_grad` function into `grad` method

* Add `format` prop to `SparseBlockDiagonalMatrix`

* Use `compare_numba_and_py` in `numba\test_slinalg.py::test_block_diag`

* Add support for Blockwise to `slinalg.block_diag`

* Add gradient test

Remove `Matrix` from `BlockDiagonal` and `SparseBlockDiagonal` `Op` names

Correct errors in docstrings

Move input validation to a shared class method

* Remove `gufunc_signature` from `__props__`

Co-authored-by: Ricardo Vieira <28983449+ricardoV94@users.noreply.github.com>

* Implement correct `__props__` for subclasses of `BaseBlockMatrix`

---------

Co-authored-by: Ricardo Vieira <28983449+ricardoV94@users.noreply.github.com>
  • Loading branch information
jessegrabowski and ricardoV94 committed Jan 7, 2024
1 parent 96f753b commit c4ae6e3
Show file tree
Hide file tree
Showing 9 changed files with 348 additions and 13 deletions.
10 changes: 9 additions & 1 deletion pytensor/link/jax/dispatch/slinalg.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import jax

from pytensor.link.jax.dispatch.basic import jax_funcify
from pytensor.tensor.slinalg import Cholesky, Solve, SolveTriangular
from pytensor.tensor.slinalg import BlockDiagonal, Cholesky, Solve, SolveTriangular


@jax_funcify.register(Cholesky)
Expand Down Expand Up @@ -45,3 +45,11 @@ def solve_triangular(A, b):
)

return solve_triangular


@jax_funcify.register(BlockDiagonal)
def jax_funcify_BlockDiagonalMatrix(op, **kwargs):
def block_diag(*inputs):
return jax.scipy.linalg.block_diag(*inputs)

return block_diag
24 changes: 23 additions & 1 deletion pytensor/link/numba/dispatch/slinalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

from pytensor.link.numba.dispatch import basic as numba_basic
from pytensor.link.numba.dispatch.basic import numba_funcify
from pytensor.tensor.slinalg import SolveTriangular
from pytensor.tensor.slinalg import BlockDiagonal, SolveTriangular


_PTR = ctypes.POINTER
Expand Down Expand Up @@ -273,3 +273,25 @@ def solve_triangular(a, b):
return res

return solve_triangular


@numba_funcify.register(BlockDiagonal)
def numba_funcify_BlockDiagonal(op, node, **kwargs):
dtype = node.outputs[0].dtype

# TODO: Why do we always inline all functions? It doesn't work with starred args, so can't use it in this case.
@numba_basic.numba_njit(inline="never")
def block_diag(*arrs):
shapes = np.array([a.shape for a in arrs], dtype="int")
out_shape = [int(s) for s in np.sum(shapes, axis=0)]
out = np.zeros((out_shape[0], out_shape[1]), dtype=dtype)

r, c = 0, 0
for arr, shape in zip(arrs, shapes):
rr, cc = shape
out[r : r + rr, c : c + cc] = arr
r += rr
c += cc
return out

return block_diag
96 changes: 87 additions & 9 deletions pytensor/sparse/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
TODO: Automatic methods for determining best sparse format?
"""
from typing import Literal
from warnings import warn

import numpy as np
Expand Down Expand Up @@ -47,6 +48,7 @@
trunc,
)
from pytensor.tensor.shape import shape, specify_broadcastable
from pytensor.tensor.slinalg import BaseBlockDiagonal, _largest_common_dtype
from pytensor.tensor.type import TensorType
from pytensor.tensor.type import continuous_dtypes as tensor_continuous_dtypes
from pytensor.tensor.type import discrete_dtypes as tensor_discrete_dtypes
Expand All @@ -60,7 +62,6 @@

sparse_formats = ["csc", "csr"]


"""
Types of sparse matrices to use for testing.
Expand Down Expand Up @@ -183,7 +184,6 @@ def as_sparse_variable(x, name=None, ndim=None, **kwargs):

as_sparse = as_sparse_variable


as_sparse_or_tensor_variable = as_symbolic


Expand Down Expand Up @@ -1800,7 +1800,7 @@ def infer_shape(self, fgraph, node, shapes):
return r

def __str__(self):
return f"{self.__class__.__name__ }{{axis={self.axis}}}"
return f"{self.__class__.__name__}{{axis={self.axis}}}"


def sp_sum(x, axis=None, sparse_grad=False):
Expand Down Expand Up @@ -2775,19 +2775,14 @@ def comparison(self, x, y):

greater_equal_s_d = GreaterEqualSD()


eq = __ComparisonSwitch(equal_s_s, equal_s_d, equal_s_d)


neq = __ComparisonSwitch(not_equal_s_s, not_equal_s_d, not_equal_s_d)


lt = __ComparisonSwitch(less_than_s_s, less_than_s_d, greater_than_s_d)


gt = __ComparisonSwitch(greater_than_s_s, greater_than_s_d, less_than_s_d)


le = __ComparisonSwitch(less_equal_s_s, less_equal_s_d, greater_equal_s_d)

ge = __ComparisonSwitch(greater_equal_s_s, greater_equal_s_d, less_equal_s_d)
Expand Down Expand Up @@ -2992,7 +2987,7 @@ def __str__(self):
l = []
if self.inplace:
l.append("inplace")
return f"{self.__class__.__name__ }{{{', '.join(l)}}}"
return f"{self.__class__.__name__}{{{', '.join(l)}}}"

def make_node(self, x):
"""
Expand Down Expand Up @@ -3291,6 +3286,7 @@ class TrueDot(Op):
# Simplify code by splitting into DotSS and DotSD.

__props__ = ()

# The grad_preserves_dense attribute doesn't change the
# execution behavior. To let the optimizer merge nodes with
# different values of this attribute we shouldn't compare it
Expand Down Expand Up @@ -4260,3 +4256,85 @@ def grad(self, inputs, grads):


construct_sparse_from_list = ConstructSparseFromList()


class SparseBlockDiagonal(BaseBlockDiagonal):
__props__ = (
"n_inputs",
"format",
)

def __init__(self, n_inputs: int, format: Literal["csc", "csr"] = "csc"):
super().__init__(n_inputs)
self.format = format

def make_node(self, *matrices):
matrices = self._validate_and_prepare_inputs(
matrices, as_sparse_or_tensor_variable
)
dtype = _largest_common_dtype(matrices)
out_type = matrix(format=self.format, dtype=dtype)

return Apply(self, matrices, [out_type])

def perform(self, node, inputs, output_storage, params=None):
dtype = node.outputs[0].type.dtype
output_storage[0][0] = scipy.sparse.block_diag(
inputs, format=self.format
).astype(dtype)


def block_diag(*matrices: TensorVariable, format: Literal["csc", "csr"] = "csc"):
r"""
Construct a block diagonal matrix from a sequence of input matrices.
Given the inputs `A`, `B` and `C`, the output will have these arrays arranged on the diagonal:
[[A, 0, 0],
[0, B, 0],
[0, 0, C]]
Parameters
----------
A, B, C ... : tensors
Input tensors to form the block diagonal matrix. last two dimensions of the inputs will be used, and all
inputs should have at least 2 dimensins.
Note that the input matrices need not be sparse themselves, and will be automatically converted to the
requested format if they are not.
format: str, optional
The format of the output sparse matrix. One of 'csr' or 'csc'. Default is 'csr'. Ignored if sparse=False.
Returns
-------
out: sparse matrix tensor
Symbolic sparse matrix in the specified format.
Examples
--------
Create a sparse block diagonal matrix from two sparse 2x2 matrices:
..code-block:: python
import numpy as np
from pytensor.sparse import block_diag
from scipy.sparse import csr_matrix
A = csr_matrix([[1, 2], [3, 4]])
B = csr_matrix([[5, 6], [7, 8]])
result_sparse = block_diag(A, B, format='csr', name='X')
print(result_sparse)
>>> SparseVariable{csr,int32}
print(result_sparse.toarray().eval())
>>> array([[1, 2, 0, 0],
>>> [3, 4, 0, 0],
>>> [0, 0, 5, 6],
>>> [0, 0, 7, 8]])
"""
if len(matrices) == 1:
return matrices

_sparse_block_diagonal = SparseBlockDiagonal(n_inputs=len(matrices), format=format)
return _sparse_block_diagonal(*matrices)
19 changes: 19 additions & 0 deletions pytensor/tensor/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -4279,6 +4279,25 @@ def take_along_axis(arr, indices, axis=0):
return arr[_make_along_axis_idx(arr.shape, indices, axis)]


def ix_(*args):
"""
PyTensor np.ix_ analog
See numpy.lib.index_tricks.ix_ for reference
"""
out = []
nd = len(args)
for k, new in enumerate(args):
if new is None:
out.append(slice(None))
new = as_tensor(new)
if new.ndim != 1:
raise ValueError("Cross index must be 1 dimensional")
new = new.reshape((1,) * k + (new.size,) + (1,) * (nd - k - 1))
out.append(new)
return tuple(out)


__all__ = [
"take_along_axis",
"expand_dims",
Expand Down
104 changes: 103 additions & 1 deletion pytensor/tensor/slinalg.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import logging
import typing
import warnings
from functools import reduce
from typing import TYPE_CHECKING, Literal, Optional, Union

import numpy as np
Expand All @@ -23,7 +24,6 @@
if TYPE_CHECKING:
from pytensor.tensor import TensorLike


logger = logging.getLogger(__name__)


Expand Down Expand Up @@ -908,6 +908,107 @@ def solve_discrete_are(A, B, Q, R, enforce_Q_symmetric=False) -> TensorVariable:
)


def _largest_common_dtype(tensors: typing.Sequence[TensorVariable]) -> np.dtype:
return reduce(lambda l, r: np.promote_types(l, r), [x.dtype for x in tensors])


class BaseBlockDiagonal(Op):
__props__ = ("n_inputs",)

def __init__(self, n_inputs):
input_sig = ",".join([f"(m{i},n{i})" for i in range(n_inputs)])
self.gufunc_signature = f"{input_sig}->(m,n)"

if n_inputs == 0:
raise ValueError("n_inputs must be greater than 0")
self.n_inputs = n_inputs

def grad(self, inputs, gout):
shapes = pt.stack([i.shape for i in inputs])
index_end = shapes.cumsum(0)
index_begin = index_end - shapes
slices = [
ptb.ix_(
pt.arange(index_begin[i, 0], index_end[i, 0]),
pt.arange(index_begin[i, 1], index_end[i, 1]),
)
for i in range(len(inputs))
]
return [gout[0][slc] for slc in slices]

def infer_shape(self, fgraph, nodes, shapes):
first, second = zip(*shapes)
return [(pt.add(*first), pt.add(*second))]

def _validate_and_prepare_inputs(self, matrices, as_tensor_func):
if len(matrices) != self.n_inputs:
raise ValueError(
f"Expected {self.n_inputs} matri{'ces' if self.n_inputs > 1 else 'x'}, got {len(matrices)}"
)
matrices = list(map(as_tensor_func, matrices))
if any(mat.type.ndim != 2 for mat in matrices):
raise TypeError("All inputs must have dimension 2")
return matrices


class BlockDiagonal(BaseBlockDiagonal):
__props__ = ("n_inputs",)

def make_node(self, *matrices):
matrices = self._validate_and_prepare_inputs(matrices, pt.as_tensor)
dtype = _largest_common_dtype(matrices)
out_type = pytensor.tensor.matrix(dtype=dtype)
return Apply(self, matrices, [out_type])

def perform(self, node, inputs, output_storage, params=None):
dtype = node.outputs[0].type.dtype
output_storage[0][0] = scipy.linalg.block_diag(*inputs).astype(dtype)


def block_diag(*matrices: TensorVariable):
"""
Construct a block diagonal matrix from a sequence of input tensors.
Given the inputs `A`, `B` and `C`, the output will have these arrays arranged on the diagonal:
[[A, 0, 0],
[0, B, 0],
[0, 0, C]]
Parameters
----------
A, B, C ... : tensors
Input tensors to form the block diagonal matrix. last two dimensions of the inputs will be used, and all
inputs should have at least 2 dimensins.
Returns
-------
out: tensor
The block diagonal matrix formed from the input matrices.
Examples
--------
Create a block diagonal matrix from two 2x2 matrices:
..code-block:: python
import numpy as np
from pytensor.tensor.linalg import block_diag
A = pt.as_tensor_variable(np.array([[1, 2], [3, 4]]))
B = pt.as_tensor_variable(np.array([[5, 6], [7, 8]]))
result = block_diagonal(A, B, name='X')
print(result.eval())
>>> Out: array([[1, 2, 0, 0],
>>> [3, 4, 0, 0],
>>> [0, 0, 5, 6],
>>> [0, 0, 7, 8]])
"""
_block_diagonal_matrix = Blockwise(BlockDiagonal(n_inputs=len(matrices)))
return _block_diagonal_matrix(*matrices)


__all__ = [
"cholesky",
"solve",
Expand All @@ -918,4 +1019,5 @@ def solve_discrete_are(A, B, Q, R, enforce_Q_symmetric=False) -> TensorVariable:
"solve_continuous_lyapunov",
"solve_discrete_are",
"solve_triangular",
"block_diag",
]
Loading

0 comments on commit c4ae6e3

Please sign in to comment.