-
Notifications
You must be signed in to change notification settings - Fork 149
Determinant of factorized matrices #1785
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Draft
ricardoV94
wants to merge
2
commits into
pymc-devs:main
Choose a base branch
from
ricardoV94:det_rewrites
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
+162
−33
Draft
Changes from all commits
Commits
Show all changes
2 commits
Select commit
Hold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -14,7 +14,7 @@ | |
| node_rewriter, | ||
| ) | ||
| from pytensor.graph.rewriting.unify import OpPattern | ||
| from pytensor.scalar.basic import Abs, Log, Mul, Sign | ||
| from pytensor.scalar.basic import Abs, Exp, Log, Mul, Sign, Sqr | ||
| from pytensor.tensor.basic import ( | ||
| AllocDiag, | ||
| ExtractDiag, | ||
|
|
@@ -23,6 +23,7 @@ | |
| concatenate, | ||
| diag, | ||
| diagonal, | ||
| ones, | ||
| ) | ||
| from pytensor.tensor.blockwise import Blockwise | ||
| from pytensor.tensor.elemwise import DimShuffle, Elemwise | ||
|
|
@@ -46,9 +47,12 @@ | |
| ) | ||
| from pytensor.tensor.rewriting.blockwise import blockwise_of | ||
| from pytensor.tensor.slinalg import ( | ||
| LU, | ||
| QR, | ||
| BlockDiagonal, | ||
| Cholesky, | ||
| CholeskySolve, | ||
| LUFactor, | ||
| Solve, | ||
| SolveBase, | ||
| SolveTriangular, | ||
|
|
@@ -65,6 +69,10 @@ | |
| MATRIX_INVERSE_OPS = (MatrixInverse, MatrixPinv) | ||
|
|
||
|
|
||
| def matrix_diagonal_product(x): | ||
| return pt.prod(diagonal(x, axis1=-2, axis2=-1), axis=-1) | ||
|
|
||
|
|
||
| def is_matrix_transpose(x: TensorVariable) -> bool: | ||
| """Check if a variable corresponds to a transpose of the last two axes""" | ||
| node = x.owner | ||
|
|
@@ -281,41 +289,39 @@ def cholesky_ldotlt(fgraph, node): | |
|
|
||
| @register_stabilize | ||
| @register_specialize | ||
| @node_rewriter([det]) | ||
| def local_det_chol(fgraph, node): | ||
| """ | ||
| If we have det(X) and there is already an L=cholesky(X) | ||
| floating around, then we can use prod(diag(L)) to get the determinant. | ||
| @node_rewriter([log]) | ||
| def local_log_prod_to_sum_log(fgraph, node): | ||
| """Rewrite log(prod(x)) as sum(log(x)), when x is known to be positive.""" | ||
| [p] = node.inputs | ||
| p_node = p.owner | ||
|
|
||
| """ | ||
| (x,) = node.inputs | ||
| for cl, xpos in fgraph.clients[x]: | ||
| if isinstance(cl.op, Blockwise) and isinstance(cl.op.core_op, Cholesky): | ||
| L = cl.outputs[0] | ||
| return [prod(diagonal(L, axis1=-2, axis2=-1) ** 2, axis=-1)] | ||
| if p_node is None: | ||
| return None | ||
|
|
||
| p_op = p_node.op | ||
|
|
||
| @register_canonicalize | ||
| @register_stabilize | ||
| @register_specialize | ||
| @node_rewriter([log]) | ||
| def local_log_prod_sqr(fgraph, node): | ||
| """ | ||
| This utilizes a boolean `positive` tag on matrices. | ||
| """ | ||
| (x,) = node.inputs | ||
| if x.owner and isinstance(x.owner.op, Prod): | ||
| # we cannot always make this substitution because | ||
| # the prod might include negative terms | ||
| p = x.owner.inputs[0] | ||
| if isinstance(p_op, Prod): | ||
| x = p_node.inputs[0] | ||
|
|
||
| # p is the matrix we're reducing with prod | ||
| if getattr(p.tag, "positive", None) is True: | ||
| return [log(p).sum(axis=x.owner.op.axis)] | ||
| # TODO: The product of diagonals of a Cholesky(A) are also strictly positive | ||
| if ( | ||
| x.owner is not None | ||
| and isinstance(x.owner.op, Elemwise) | ||
| and isinstance(x.owner.op.scalar_op, Abs | Sqr | Exp) | ||
| ) or getattr(x.tag, "positive", False): | ||
| return [log(x).sum(axis=p_node.op.axis)] | ||
|
|
||
| # TODO: have a reduction like prod and sum that simply | ||
| # returns the sign of the prod multiplication. | ||
|
|
||
| # Special case for log(abs(prod(x))) -> sum(log(abs(x))) that shows up in slogdet | ||
| elif isinstance(p_op, Elemwise) and isinstance(p_op.scalar_op, Abs): | ||
| [p] = p_node.inputs | ||
| p_node = p.owner | ||
| if p_node is not None and isinstance(p_node.op, Prod): | ||
| [x] = p.owner.inputs | ||
| return [log(abs(x)).sum(axis=p_node.op.axis)] | ||
|
|
||
|
|
||
| @register_specialize | ||
| @node_rewriter([blockwise_of(MatrixInverse | Cholesky | MatrixPinv)]) | ||
|
|
@@ -442,6 +448,127 @@ def _find_diag_from_eye_mul(potential_mul_input): | |
| return eye_input, non_eye_inputs | ||
|
|
||
|
|
||
| @register_stabilize | ||
| @register_specialize | ||
| @node_rewriter([det]) | ||
| def det_of_matrix_factorized_elsewhere(fgraph, node): | ||
| """ | ||
| If we have det(X) or abs(det(X)) and there is already a nice decomposition(X) floating around, | ||
| use it to compute it more cheaply | ||
|
|
||
| """ | ||
| [det] = node.outputs | ||
| [x] = node.inputs | ||
|
|
||
| only_used_by_abs = all( | ||
| isinstance(client.op, Elemwise) and isinstance(client.op.scalar_op, Abs) | ||
| for client, _ in fgraph.clients[det] | ||
| ) | ||
|
|
||
| new_det = None | ||
| for client, _ in fgraph.clients[x]: | ||
| core_op = client.op.core_op if isinstance(client.op, Blockwise) else client.op | ||
| match core_op: | ||
| case Cholesky(): | ||
| L = client.outputs[0] | ||
| new_det = matrix_diagonal_product(L) ** 2 | ||
|
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. TODO: Add the positive tag here. Possibly also rewrite for log(x ** 2) -> log(x) * 2, when we know x is positive |
||
| case LU(): | ||
| U = client.outputs[-1] | ||
| new_det = matrix_diagonal_product(U) | ||
| case LUFactor(): | ||
| LU_packed = client.outputs[0] | ||
| new_det = matrix_diagonal_product(LU_packed) | ||
| case _: | ||
| if not only_used_by_abs: | ||
| continue | ||
| match core_op: | ||
| case SVD(): | ||
| lmbda = ( | ||
| client.outputs[1] | ||
| if core_op.compute_uv | ||
| else client.outputs[0] | ||
| ) | ||
| new_det = prod(lmbda, axis=-1) | ||
| case QR(): | ||
| R = client.outputs[-1] | ||
| # if mode == "economic", R may not be square and this rewrite could hide a shape error | ||
| # That's why it's tagged as `shape_unsafe` | ||
| new_det = matrix_diagonal_product(R) | ||
|
|
||
| if new_det is not None: | ||
| # found a match | ||
| break | ||
| else: # no-break (i.e., no-match) | ||
| return None | ||
|
|
||
| [det] = node.outputs | ||
| copy_stack_trace(det, new_det) | ||
| return [new_det] | ||
|
|
||
|
|
||
| @register_stabilize("shape_unsafe") | ||
| @register_specialize("shape_unsafe") | ||
| @node_rewriter(tracks=[det]) | ||
| def det_of_factorized_matrix(fgraph, node): | ||
| """Introduce special forms for det(decomposition(X)). | ||
|
|
||
| Some cases are only known up to a sign change such as det(QR(X)), | ||
| and are only introduced if the determinant is only ever used inside an abs | ||
| """ | ||
| [det] = node.outputs | ||
| [x] = node.inputs | ||
|
|
||
| only_used_by_abs = all( | ||
| isinstance(client.op, Elemwise) and isinstance(client.op.scalar_op, Abs) | ||
| for client, _ in fgraph.clients[det] | ||
| ) | ||
|
|
||
| x_node = x.owner | ||
| if x_node is None: | ||
| return None | ||
|
|
||
| x_op = x_node.op | ||
| core_op = x_op.core_op if isinstance(x_op, Blockwise) else x_op | ||
|
|
||
| new_det = None | ||
| match core_op: | ||
| case Cholesky(): | ||
| new_det = matrix_diagonal_product(x) | ||
| case LU(): | ||
| if x is x_node.outputs[-2]: | ||
| # x is L | ||
| new_det = ones(x.shape[:-2], dtype=det.dtype) | ||
| elif x is x_node.outputs[-1]: | ||
| # x is U | ||
| new_det = matrix_diagonal_product(x) | ||
| case SVD(): | ||
| if not core_op.compute_uv or x is x_node.outputs[1]: | ||
| # x is lambda | ||
| new_det = prod(x, axis=-1) | ||
| elif only_used_by_abs: | ||
| # x is either U or Vt and only ever used inside an abs | ||
| new_det = ones(x.shape[:-2], dtype=det.dtype) | ||
| case QR(): | ||
| # if mode == "economic", Q/R may not be square and this rewrite could hide a shape error | ||
| # That's why it's tagged as `shape_unsafe` | ||
| if x is x_node.outputs[-1]: | ||
| # x is R | ||
| new_det = matrix_diagonal_product(x) | ||
| elif ( | ||
| only_used_by_abs | ||
| and core_op.mode in ("economic", "full") | ||
| and x is x_node.outputs[0] | ||
| ): | ||
| # x is Q and it's only ever used inside an abs | ||
| new_det = ones(x.shape[:-2], dtype=det.dtype) | ||
|
|
||
| if new_det is None: | ||
| return None | ||
|
|
||
| copy_stack_trace(det, new_det) | ||
| return [new_det] | ||
|
|
||
|
|
||
| @register_canonicalize("shape_unsafe") | ||
| @register_stabilize("shape_unsafe") | ||
| @node_rewriter([det]) | ||
|
|
||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Any Op that that maps (-1, 1) to the same value is actually fine, At the very least should include square as well