Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add numba overload for solve_triangular #423

Merged
merged 18 commits into from
Sep 24, 2023

Conversation

jessegrabowski
Copy link
Member

Motivation for these changes

The pytensor.tensor.slinalg module is not currently compatible with mode = "NUMBA". This PR is a first step in an effort to fix that. It's marked as a draft because it's 1) not done, and 2) needs discussion/work.

Functions in slinalg don't have overloads in numba.np.linalg, so to implement these functions there needs to be an overload that calls the relevant C LAPACK functions. This involves some acrobatics with C pointers and typing, which I am absolutely not an expert at. Currently, I use dynamic pointers from ctypes, essentially just following numba/numba#5301. This works, but it means the resulting functions can't be cached, which will be a huge slowdown on complex graphs (I think).

A more complete approach would try to directly extend numba/numba/_lapack.c with some new pointers to the relevant scipy code. I'm not sure if it would be possible to have our own e..g _lapack_extensions.c that could have #include _lapack.con top? The pattern in that file looks straightforward enough to copy, but it's been a long time since I did anything in C, and I'm not sure how importing across modules would work.

Also, to answer "why solve_triangular? Because it's a function that we don't have now, that only depends on a single LAPACK call. Once the pattern is ironed out, I'll do these for all the functions we currently have in slinalg, most importantly solve (yes, we have the np.linalg.solve overload, but it doesn't allow access to the specialized solvers for e.g. symmetric positive definite matrices, which matters a lot for PyMC).

Implementation details

I followed the implementation of LAPACK overloads established in numba/numpy/linalg/linalg.py. There's a class called _LAPACK that holds signatures for all the LAPACK functions that will be implemented, then an overload function.

Checklist

Major / Breaking Changes

None

New features

Solve triangular matrices with numba!

Bugfixes

None

Documentation

Not yet

Maintenance

None

@codecov-commenter
Copy link

codecov-commenter commented Aug 28, 2023

Codecov Report

Merging #423 (42b5b5f) into main (071eadd) will decrease coverage by 0.12%.
The diff coverage is 49.10%.

Additional details and impacted files

Impacted file tree graph

@@            Coverage Diff             @@
##             main     #423      +/-   ##
==========================================
- Coverage   80.75%   80.64%   -0.12%     
==========================================
  Files         159      160       +1     
  Lines       45849    46016     +167     
  Branches    11234    11263      +29     
==========================================
+ Hits        37026    37108      +82     
- Misses       6595     6671      +76     
- Partials     2228     2237       +9     
Files Changed Coverage Δ
pytensor/link/numba/dispatch/slinalg.py 48.79% <48.79%> (ø)
pytensor/link/numba/dispatch/__init__.py 100.00% <100.00%> (ø)

@ricardoV94 ricardoV94 added numba backend compatibility linalg Linear algebra enhancement New feature or request labels Aug 29, 2023
Add informative message to error raised by check_finite=True
@ricardoV94
Copy link
Member

Is this ready for review or something important still missing?

@jessegrabowski
Copy link
Member Author

jessegrabowski commented Aug 29, 2023

I'm still not 100% sold on how it's all implemented. I wanted someone to take a closer look at _lapack.c in the numba library and decide if we can do it more like that, or if this hackish way is acceptable.

@jessegrabowski jessegrabowski marked this pull request as ready for review August 29, 2023 14:47
Copy link
Member

@aseyboldt aseyboldt left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think there are maybe a few ways to make this a bit faster, but it looks good to me as it is. I'm not really sure why it would feel hackish? The only downside I can think of compared to compiling a separate extension module is that numba can't cache this due to the dynamic pointer.

pytensor/link/numba/dispatch/slinalg.py Outdated Show resolved Hide resolved

# Need to expand B here; I tried everywhere else and it doesn't work
if B_is_1d:
B_copy = _copy_to_fortran_order(np.expand_dims(B, -1))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If the original B was 1d, I don't think we need the copy?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In my testing, trtrs expects at least 2d everything. The docs say LDB >= 0, but when I was giving it 1d arrays I was getting back numerically incorrect results.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

After testing, you're right. I wasn't able to avoid the copy in the 2d case though. If I don't copy 2d B, numba flags this line:
B_NDIM = 1 if B_is_1d else int(B.shape[1])

Saying that it's considering a case where B is 3d. Not sure why it thinks that is possible. Does numba evaluate all if-else branches on all possible inputs?

if A.shape[0] != B.shape[0]:
raise linalg.LinAlgError("Dimensions of A and B do not conform")

A_copy = _copy_to_fortran_order(A)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we avoid the copy if it is c-order by flipping transval? I think we could also have a special overload for when trans, lower and unit_diag are literals, and we statically know that A and B are C or Fortran continuous.
I think that would really be only an optimization of the current code though, this here should be fine as well.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does setting an array to fortran contiguous actually transpose the matrix, or does it just re-order the pointers to the internal flat representation?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

After testing, we can avoid copying A in all cases.

Re: the other point, do you mean checking the values of trans, lower, and unit_diag inside the wrapper function, then returning a specialized impl function based on their values? Similar to how I'm doing dispatching to real/complex versions here?

@jessegrabowski
Copy link
Member Author

It feels hackish because 1) we can't cache the functions (relevant for compile times, which you've pointed out are extremely long with numba), 2) we can't support complex inputs due to a weird technical reason, not due to some principled/fundamental reason, 3) It's nowhere close to working within the "official" numba API, so I have no idea how future proof it is. Complex inputs definitely worked last year, so something was changed in the numba codebase to break that A_copy.view(w_type).ctype work-around. No idea what else might change and break this code down the road.

Don't copy B matrix when B is array in overload func
@ricardoV94
Copy link
Member

Some conflicts have cropped up

@jessegrabowski
Copy link
Member Author

What do you mean?

@ricardoV94
Copy link
Member

image

@jessegrabowski
Copy link
Member Author

Still says no conflicts for me. I'll update my fork and double check everything.

@ricardoV94
Copy link
Member

ricardoV94 commented Sep 1, 2023

Still says no conflicts for me. I'll update my fork and double check everything.

Swap Squash and merge button to rebase and merge, and you should see it

@ricardoV94 ricardoV94 marked this pull request as draft September 1, 2023 16:31
@ricardoV94
Copy link
Member

Marked as a draft, given the suggestion to upper pin numba. Feel free to convert back to ready for review if you chage your mind about it or do it!

@ricardoV94
Copy link
Member

ricardoV94 commented Sep 1, 2023

@maresb is there a way to set an upper version limit on an optional dependency (that will also be respected by conda)?

@maresb
Copy link
Contributor

maresb commented Sep 1, 2023

Sure, just add it under run_constrained in the feedstock. https://docs.conda.io/projects/conda-build/en/stable/resources/define-metadata.html#run-constrained

@jessegrabowski jessegrabowski marked this pull request as ready for review September 24, 2023 12:35
@jessegrabowski
Copy link
Member Author

Following conversation with @ricardoV94 I'm merging this. We'll cross the bridge of code breaking when/if we get there.

@jessegrabowski jessegrabowski merged commit 2d94407 into pymc-devs:main Sep 24, 2023
53 checks passed
@jessegrabowski jessegrabowski deleted the numba_slinalg branch September 24, 2023 12:39
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
backend compatibility enhancement New feature or request linalg Linear algebra numba
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants