Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement destructive in-place rewrites for Blockwise Ops #577

Open
wants to merge 5 commits into
base: main
Choose a base branch
from

Conversation

jessegrabowski
Copy link
Member

@jessegrabowski jessegrabowski commented Jan 6, 2024

Description

All scipy.linalg functions offer an overwrite_a and/or overwrite_b argument that can enhance performance by re-using the input memory for the outputs. This PR implements re-writes that will set these flags to True at compile time.

These rewrites are also a nice example for the in-place docs here, so I'll update them with an example as a later commit to this PR.

Related Issue

Checklist

Type of change

  • New feature / enhancement
  • Bug fix
  • Documentation
  • Maintenance
  • Other (please specify):

pytensor/tensor/slinalg.py Outdated Show resolved Hide resolved
pytensor/tensor/slinalg.py Outdated Show resolved Hide resolved
Copy link
Member

@ricardoV94 ricardoV94 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great! I think we made our lives easier going forward

@ricardoV94
Copy link
Member

Can the Numba Cholesky make use of overwrite?

@jessegrabowski
Copy link
Member Author

jessegrabowski commented Jan 6, 2024

I think so? I added a check to copy the input matrix or not. I need to test it more carefully to make sure it does what I want it to do.

@ricardoV94
Copy link
Member

Numba seems to be computing in integer? Anyway I guess this PR now depends on #578?

@jessegrabowski
Copy link
Member Author

jessegrabowski commented Jan 7, 2024

I had removed a .astype(input.dtype) to the output of Cholesky().perform, so it made the integer test fail. Everything should pass now. #578 is just a general numba speedup, the two shouldn't clash (I hope).

@codecov-commenter
Copy link

codecov-commenter commented Jan 7, 2024

Codecov Report

All modified and coverable lines are covered by tests ✅

Comparison is base (e180927) 80.92% compared to head (470ea60) 80.93%.
Report is 7 commits behind head on main.

Additional details and impacted files

Impacted file tree graph

@@            Coverage Diff             @@
##             main     #577      +/-   ##
==========================================
+ Coverage   80.92%   80.93%   +0.01%     
==========================================
  Files         162      162              
  Lines       46524    46661     +137     
  Branches    11375    11403      +28     
==========================================
+ Hits        37648    37766     +118     
- Misses       6653     6667      +14     
- Partials     2223     2228       +5     
Files Coverage Δ
pytensor/link/numba/dispatch/basic.py 86.18% <ø> (ø)
pytensor/tensor/rewriting/basic.py 94.13% <100.00%> (+0.08%) ⬆️
pytensor/tensor/rewriting/linalg.py 87.66% <100.00%> (+3.87%) ⬆️
pytensor/tensor/slinalg.py 93.79% <100.00%> (-0.85%) ⬇️

... and 13 files with indirect coverage changes

@ricardoV94
Copy link
Member

If you are a pro: each commit corresponds to a relevant logical change

You marked that, but the commits seem "dirty". Do you mind if I squash merge?

@jessegrabowski
Copy link
Member Author

I don't mind at all, but I'd rather you explained it to me so I can do it right next time.

I read the link, but it suggested I write War and Peace in each commit message.

@ricardoV94
Copy link
Member

ricardoV94 commented Jan 8, 2024

I read the link, but it suggested I write War and Peace in each commit message.

Fair enough. The biggest point is each commit should be a self-contained logical change, so that you could in theory revert or checkout any of them and the codebase would still make sense as is. Usually stuff like fix typo, run-pre commit, add suggestions from review, intermediate changes that were ultimately not needed, disappear after "cleaning"

@jessegrabowski
Copy link
Member Author

OK let me take a stab at cleaning it up

@ricardoV94
Copy link
Member

ricardoV94 commented Jan 8, 2024

OK let me take a stab at cleaning it up

Advice: backup the branch before you try :)

@jessegrabowski
Copy link
Member Author

I just did an interactive rebase and squashed things together. Is that all I needed to do?

Refactor cholesky destructive re-write to use `make_inplace` helper
@bwengals
Copy link

bwengals commented Jan 9, 2024

sorry for the delay, ya'll are fast! Taking a look now

@@ -135,6 +135,18 @@ def alloc_like(
return rval


def make_inplace(node, inplace_prop="inplace"):
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Having a hard time understanding this function, so what it looks like it does is:

  • checks if operator is wrapped in Blockwise, pulls it out
  • if props["overwrite_a"] is true return false, otherwise set it to true (mostly this step is the confusing one to me)
  • reconstruct the op with "overwrite_a" true
  • return Apply node for the reconstructed op

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You got it. I think what's confusing (for me) is that the False that's returned if props["overwrite_a"] is already True is for the rewriter. I guess rewrites either return a new symbolic variable (if the rewrite is deemed necessary) or False if it's deemed unnecessary.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yup. Return a list/dict of replacement of variables or None/False to indicate rewrite doesn't apply

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah okay makes sense now

@bwengals
Copy link

bwengals commented Jan 9, 2024

Works great for me (running the test you wrote)! Are you planning on doing all of them in this PR? the solves, eig's, etc?

try:
z[0] = scipy.linalg.cholesky(x, lower=self.lower).astype(x.dtype)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it worth handling F_CONTIGUOUS too? Or is C_CONTIGUOUS way more common? Sorry if it's a dumb Q, I haven't fiddled with these flags in numpy before. Like, if its not c_contiguous would it be a bad idea to force it to be here?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

C_CONTIGUOUS is the default for numpy. We have to set this flag because the actual cholesky routine is written in fortran, and it's really picky about the inputs. If they're not exactly right (including being in column-major order), then the inputs get copied and re-formatted before they're passed along to the routine and the in-place operation doesn't happen.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That may explain a comment that I saw recently lying around some other scipy Op saying it didn't seem to be respecting the flag

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does Pytensor make sure that all the arrays that get passed around always remain contiguous?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No. There's an explicit Op for that. It's not guaranteed otherwise

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No. There's an explicit Op for that. It's not guaranteed otherwise

An op that does the equivalent of np.asfortranarray?

Found the same thing with fortran / c-contiguous arrays for sp.linalg.solve btw. overwrite_a only overwrites if A is f-contiguous. Looks like this happens at the lapack call to getrf that does an LU decomposition, which must be fortran. When overwrite_b=True, b is always overwritten whether A is f or c contiguous.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

An op that does the equivalent of np.asfortranarray?

No, I meant we have an Op to make an array c-contiguous

@@ -135,6 +135,18 @@ def alloc_like(
return rval


def make_inplace(node, inplace_prop="inplace"):
op = getattr(node.op, "core_op", node.op)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I missed this. Making the core_op inplace isn't enough because that falls out of scope of the pytensor safeguards.

In these cases the Blockwise itself needs to be inplace as well. I forgot Cholesky was Blockwised by default now

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I need to do both, or just the outer Blockwise?

Copy link
Member

@ricardoV94 ricardoV94 Jan 9, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Need to just think a bit about the API for Blockwise inplace. I would restrict the rewrite to the core Op case for now.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When a Blockwise is not needed it will be removed so the inplace is still useful, just not for the batched case

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So I need to add a check to the rewrite that the Op is not a Blockwise?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why wouldn't a blockwise solve(A, b, overwrite_a=True) be safe?

I didn't read carefully enough, I understand now you are enumerating the cases. I don't think the case logic is too bad though. Here's a table:

a_batched b_batched overwrite_a overwrite_b safe?
FALSE FALSE FALSE FALSE Yes
FALSE FALSE TRUE FALSE Yes
FALSE FALSE FALSE TRUE Yes
FALSE FALSE TRUE TRUE Yes
TRUE FALSE FALSE FALSE Yes
TRUE FALSE TRUE FALSE Yes
TRUE FALSE FALSE TRUE No
TRUE FALSE TRUE TRUE No
FALSE TRUE FALSE FALSE Yes
FALSE TRUE TRUE FALSE No
FALSE TRUE FALSE TRUE Yes
FALSE TRUE TRUE TRUE Yes
TRUE TRUE FALSE FALSE Yes
TRUE TRUE TRUE FALSE Yes
TRUE TRUE FALSE TRUE Yes
TRUE TRUE TRUE TRUE Yes

I think it boils down to "overwrite things that are already batched, refuse to overwrite if it's being broadcast"

We might also be able to be intelligent about which flags are set based on the output shapes, if they are available.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah what matters is not the batching but the broadcasting (which is caused by the batching of other inputs)

Since Blockwise is more general than linalg Ops it may be worth to think about a couple more Ops to see if there's a strategy that works reasonably well for most cases.

Also may be worth thinking a bit about the numba backend.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

think it boils down to "overwrite things that are already batched, refuse to overwrite if it's being broadcast"

Yeah and for that we need a minimal API so Blockwise can talk with the core Op in a sense.

I think it's enough for Blockwise to say what inputs are safe to destruct and the Core Op can decide (but is not required) to destroy those, but none else.

I am hopeful that may not require too much code.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So basically what needs to happen is for blockwise to get some kind of "safe_to_destroy_map" property that can be referenced by rewrites?

I do think it fits in this PR but I'm not sure I'm the one to implement it. I can try of course.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do these issues only come up when Blockwise isn't compiled out? Would it be possible to punt on it till later, or make it a separate issue, and go with Jesses current approach of applying the rewrite on the core_op? If inplacing only worked on non-batched matrices I think that would satisfy the pareto principle pretty well.

@ricardoV94 ricardoV94 changed the title Implement destructive in-place rewrites for linear algebra functions Implement destructive in-place rewrites for Blockwise Ops Jan 19, 2024
Comment on lines +365 to +370
for candidate_input in candidate_inputs:
# Note: It may still be working in place and not be detectable by this check
assert not np.allclose(
batch_inputs_test_val[candidate_input],
batch_inputs_test_val_copy[candidate_input],
)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should test those other inputs haven't changed

@@ -131,6 +146,12 @@ def conjugate_solve_triangular(outer, inner):
else:
return [grad]

def try_inplace_inputs(self, candidate_inputs: list[int]) -> "Op":
if candidate_inputs == [0]:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As per the signature this should not return None. Maybe that should be the default? Or return self (so it's still an Op)?

In the rewrite we're not expecting None to be returned, but only handling the NotImplementedError. I think there would be an error if None is returned as we next try to access .destroy_map property

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

ENH: Add destructive rewrites for Cholesky Op
4 participants