Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
183 changes: 155 additions & 28 deletions pytensor/tensor/rewriting/linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
node_rewriter,
)
from pytensor.graph.rewriting.unify import OpPattern
from pytensor.scalar.basic import Abs, Log, Mul, Sign
from pytensor.scalar.basic import Abs, Exp, Log, Mul, Sign, Sqr
from pytensor.tensor.basic import (
AllocDiag,
ExtractDiag,
Expand All @@ -23,6 +23,7 @@
concatenate,
diag,
diagonal,
ones,
)
from pytensor.tensor.blockwise import Blockwise
from pytensor.tensor.elemwise import DimShuffle, Elemwise
Expand All @@ -46,9 +47,12 @@
)
from pytensor.tensor.rewriting.blockwise import blockwise_of
from pytensor.tensor.slinalg import (
LU,
QR,
BlockDiagonal,
Cholesky,
CholeskySolve,
LUFactor,
Solve,
SolveBase,
SolveTriangular,
Expand All @@ -65,6 +69,10 @@
MATRIX_INVERSE_OPS = (MatrixInverse, MatrixPinv)


def matrix_diagonal_product(x):
return pt.prod(diagonal(x, axis1=-2, axis2=-1), axis=-1)


def is_matrix_transpose(x: TensorVariable) -> bool:
"""Check if a variable corresponds to a transpose of the last two axes"""
node = x.owner
Expand Down Expand Up @@ -281,41 +289,39 @@ def cholesky_ldotlt(fgraph, node):

@register_stabilize
@register_specialize
@node_rewriter([det])
def local_det_chol(fgraph, node):
"""
If we have det(X) and there is already an L=cholesky(X)
floating around, then we can use prod(diag(L)) to get the determinant.
@node_rewriter([log])
def local_log_prod_to_sum_log(fgraph, node):
"""Rewrite log(prod(x)) as sum(log(x)), when x is known to be positive."""
[p] = node.inputs
p_node = p.owner

"""
(x,) = node.inputs
for cl, xpos in fgraph.clients[x]:
if isinstance(cl.op, Blockwise) and isinstance(cl.op.core_op, Cholesky):
L = cl.outputs[0]
return [prod(diagonal(L, axis1=-2, axis2=-1) ** 2, axis=-1)]
if p_node is None:
return None

p_op = p_node.op

@register_canonicalize
@register_stabilize
@register_specialize
@node_rewriter([log])
def local_log_prod_sqr(fgraph, node):
"""
This utilizes a boolean `positive` tag on matrices.
"""
(x,) = node.inputs
if x.owner and isinstance(x.owner.op, Prod):
# we cannot always make this substitution because
# the prod might include negative terms
p = x.owner.inputs[0]
if isinstance(p_op, Prod):
x = p_node.inputs[0]

# p is the matrix we're reducing with prod
if getattr(p.tag, "positive", None) is True:
return [log(p).sum(axis=x.owner.op.axis)]
# TODO: The product of diagonals of a Cholesky(A) are also strictly positive
if (
x.owner is not None
and isinstance(x.owner.op, Elemwise)
and isinstance(x.owner.op.scalar_op, Abs | Sqr | Exp)
) or getattr(x.tag, "positive", False):
return [log(x).sum(axis=p_node.op.axis)]

# TODO: have a reduction like prod and sum that simply
# returns the sign of the prod multiplication.

# Special case for log(abs(prod(x))) -> sum(log(abs(x))) that shows up in slogdet
elif isinstance(p_op, Elemwise) and isinstance(p_op.scalar_op, Abs):
[p] = p_node.inputs
p_node = p.owner
if p_node is not None and isinstance(p_node.op, Prod):
[x] = p.owner.inputs
return [log(abs(x)).sum(axis=p_node.op.axis)]


@register_specialize
@node_rewriter([blockwise_of(MatrixInverse | Cholesky | MatrixPinv)])
Expand Down Expand Up @@ -442,6 +448,127 @@ def _find_diag_from_eye_mul(potential_mul_input):
return eye_input, non_eye_inputs


@register_stabilize
@register_specialize
@node_rewriter([det])
def det_of_matrix_factorized_elsewhere(fgraph, node):
"""
If we have det(X) or abs(det(X)) and there is already a nice decomposition(X) floating around,
use it to compute it more cheaply

"""
[det] = node.outputs
[x] = node.inputs

only_used_by_abs = all(
Copy link
Member Author

@ricardoV94 ricardoV94 Dec 10, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Any Op that that maps (-1, 1) to the same value is actually fine, At the very least should include square as well

isinstance(client.op, Elemwise) and isinstance(client.op.scalar_op, Abs)
for client, _ in fgraph.clients[det]
)

new_det = None
for client, _ in fgraph.clients[x]:
core_op = client.op.core_op if isinstance(client.op, Blockwise) else client.op
match core_op:
case Cholesky():
L = client.outputs[0]
new_det = matrix_diagonal_product(L) ** 2
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TODO: Add the positive tag here.

Possibly also rewrite for log(x ** 2) -> log(x) * 2, when we know x is positive

case LU():
U = client.outputs[-1]
new_det = matrix_diagonal_product(U)
case LUFactor():
LU_packed = client.outputs[0]
new_det = matrix_diagonal_product(LU_packed)
case _:
if not only_used_by_abs:
continue
match core_op:
case SVD():
lmbda = (
client.outputs[1]
if core_op.compute_uv
else client.outputs[0]
)
new_det = prod(lmbda, axis=-1)
case QR():
R = client.outputs[-1]
# if mode == "economic", R may not be square and this rewrite could hide a shape error
# That's why it's tagged as `shape_unsafe`
new_det = matrix_diagonal_product(R)

if new_det is not None:
# found a match
break
else: # no-break (i.e., no-match)
return None

[det] = node.outputs
copy_stack_trace(det, new_det)
return [new_det]


@register_stabilize("shape_unsafe")
@register_specialize("shape_unsafe")
@node_rewriter(tracks=[det])
def det_of_factorized_matrix(fgraph, node):
"""Introduce special forms for det(decomposition(X)).

Some cases are only known up to a sign change such as det(QR(X)),
and are only introduced if the determinant is only ever used inside an abs
"""
[det] = node.outputs
[x] = node.inputs

only_used_by_abs = all(
isinstance(client.op, Elemwise) and isinstance(client.op.scalar_op, Abs)
for client, _ in fgraph.clients[det]
)

x_node = x.owner
if x_node is None:
return None

x_op = x_node.op
core_op = x_op.core_op if isinstance(x_op, Blockwise) else x_op

new_det = None
match core_op:
case Cholesky():
new_det = matrix_diagonal_product(x)
case LU():
if x is x_node.outputs[-2]:
# x is L
new_det = ones(x.shape[:-2], dtype=det.dtype)
elif x is x_node.outputs[-1]:
# x is U
new_det = matrix_diagonal_product(x)
case SVD():
if not core_op.compute_uv or x is x_node.outputs[1]:
# x is lambda
new_det = prod(x, axis=-1)
elif only_used_by_abs:
# x is either U or Vt and only ever used inside an abs
new_det = ones(x.shape[:-2], dtype=det.dtype)
case QR():
# if mode == "economic", Q/R may not be square and this rewrite could hide a shape error
# That's why it's tagged as `shape_unsafe`
if x is x_node.outputs[-1]:
# x is R
new_det = matrix_diagonal_product(x)
elif (
only_used_by_abs
and core_op.mode in ("economic", "full")
and x is x_node.outputs[0]
):
# x is Q and it's only ever used inside an abs
new_det = ones(x.shape[:-2], dtype=det.dtype)

if new_det is None:
return None

copy_stack_trace(det, new_det)
return [new_det]


@register_canonicalize("shape_unsafe")
@register_stabilize("shape_unsafe")
@node_rewriter([det])
Expand Down
12 changes: 7 additions & 5 deletions tests/tensor/rewriting/test_linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,14 +243,16 @@ def test_local_det_chol():
det_X = pt.linalg.det(X)

f = function([X], [L, det_X])

nodes = f.maker.fgraph.toposort()
assert not any(isinstance(node, Det) for node in nodes)
assert not any(isinstance(node, Det) for node in f.maker.fgraph.apply_nodes)

# This previously raised an error (issue #392)
f = function([X], [L, det_X, X])
nodes = f.maker.fgraph.toposort()
assert not any(isinstance(node, Det) for node in nodes)
assert not any(isinstance(node, Det) for node in f.maker.fgraph.apply_nodes)

# Test graph that only has det_X
f = function([X], [det_X])
f.dprint()
assert not any(isinstance(node, Det) for node in f.maker.fgraph.apply_nodes)


def test_psd_solve_with_chol():
Expand Down
Loading