Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove incorrect solve usage in psd_solve_with_chol rewrite #575

Merged
merged 4 commits into from
Jan 7, 2024

Conversation

jessegrabowski
Copy link
Member

@jessegrabowski jessegrabowski commented Jan 5, 2024

Description

To invert a positive semi-definite matrix, it is faster to first cholesky factorize the matrix then use two triangular solves rather than directly using solve. Here's a lazy benchmark:

import numpy as np
from scipy import linalg

N = 10_000
Z = np.random.normal(size=(N, N))
X = Z @ Z.T

def direct_solve(X, N):
   return linalg.solve(X, np.eye(N), assume_a='pos')

def tri_solve(X, N):
   L = linalg.cholesky(X, lower=True)
   b = np.eye(N)
   Li_b = linalg.solve_triangular(L, b, lower=True)
   x = linalg.solve_triangular(L.T, Li_b, lower=False)
   return x

direct_time = %timeit -o direct_solve(X, N)
tri_time = %timeit -o tri_solve(X, N)

12.7 s ± 369 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
11.5 s ± 165 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

The triangular method is about 1s faster. I've read it also more numerically stable too, but I don't know how to demonstrate that.

Anyway, there's a rewrite to change pt.linalg.solve to something like the tri_solve function above if the A matrix has a "psd" tag. This rewrite is actually useless -- as far as I can tell, the psd is never used in the code base. But it is potentially useful in light of #573. It's also currently wrong: as pointed out in #382, solve(tri_A, b, assume_a='sym', lower=True) does NOT use the correct algorithm to solve a triangular matrix, and results in an incorrect computation.

This PR corrects the rewrite to use the correct computations, but does nothing else to make it used anywhere in the code base.

Related Issue

Checklist

Type of change

  • New feature / enhancement
  • Bug fix
  • Documentation
  • Maintenance
  • Other (please specify):

@jessegrabowski jessegrabowski added bug Something isn't working graph rewriting linalg Linear algebra labels Jan 5, 2024
@jessegrabowski
Copy link
Member Author

The rewrite also doesn't catch graphs that use pt.linalg.inv(X_psd), I guess it's a problem with the ordering of rewrites? Not sure how to adjust that.

@codecov-commenter
Copy link

Codecov Report

All modified and coverable lines are covered by tests ✅

Comparison is base (e180927) 80.92% compared to head (2b823ee) 80.93%.

Additional details and impacted files

Impacted file tree graph

@@            Coverage Diff             @@
##             main     #575      +/-   ##
==========================================
+ Coverage   80.92%   80.93%   +0.01%     
==========================================
  Files         162      162              
  Lines       46524    46524              
  Branches    11375    11375              
==========================================
+ Hits        37648    37653       +5     
+ Misses       6653     6649       -4     
+ Partials     2223     2222       -1     
Files Coverage Δ
pytensor/tensor/rewriting/linalg.py 87.16% <100.00%> (+3.37%) ⬆️

@jessegrabowski jessegrabowski merged commit 96f753b into pymc-devs:main Jan 7, 2024
53 checks passed
@jessegrabowski jessegrabowski deleted the fix-psd-solve branch January 7, 2024 00:25
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working graph rewriting linalg Linear algebra
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants