-
Notifications
You must be signed in to change notification settings - Fork 83
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Implement destructive in-place rewrites for Blockwise Ops #577
base: main
Are you sure you want to change the base?
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great! I think we made our lives easier going forward
Can the Numba Cholesky make use of overwrite? |
I think so? I added a check to copy the input matrix or not. I need to test it more carefully to make sure it does what I want it to do. |
Numba seems to be computing in integer? Anyway I guess this PR now depends on #578? |
I had removed a |
Codecov ReportAll modified and coverable lines are covered by tests ✅
Additional details and impacted files@@ Coverage Diff @@
## main #577 +/- ##
==========================================
+ Coverage 80.92% 80.93% +0.01%
==========================================
Files 162 162
Lines 46524 46661 +137
Branches 11375 11403 +28
==========================================
+ Hits 37648 37766 +118
- Misses 6653 6667 +14
- Partials 2223 2228 +5
|
You marked that, but the commits seem "dirty". Do you mind if I squash merge? |
I don't mind at all, but I'd rather you explained it to me so I can do it right next time. I read the link, but it suggested I write War and Peace in each commit message. |
Fair enough. The biggest point is each commit should be a self-contained logical change, so that you could in theory revert or checkout any of them and the codebase would still make sense as is. Usually stuff like fix typo, run-pre commit, add suggestions from review, intermediate changes that were ultimately not needed, disappear after "cleaning" |
OK let me take a stab at cleaning it up |
Advice: backup the branch before you try :) |
6d93813
to
ff364f1
Compare
I just did an interactive rebase and squashed things together. Is that all I needed to do? |
Refactor cholesky destructive re-write to use `make_inplace` helper
ff364f1
to
470ea60
Compare
sorry for the delay, ya'll are fast! Taking a look now |
pytensor/tensor/rewriting/basic.py
Outdated
@@ -135,6 +135,18 @@ def alloc_like( | |||
return rval | |||
|
|||
|
|||
def make_inplace(node, inplace_prop="inplace"): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Having a hard time understanding this function, so what it looks like it does is:
- checks if operator is wrapped in Blockwise, pulls it out
- if props["overwrite_a"] is true return false, otherwise set it to true (mostly this step is the confusing one to me)
- reconstruct the op with "overwrite_a" true
- return Apply node for the reconstructed op
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You got it. I think what's confusing (for me) is that the False
that's returned if props["overwrite_a"]
is already True
is for the rewriter. I guess rewrites either return a new symbolic variable (if the rewrite is deemed necessary) or False if it's deemed unnecessary.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yup. Return a list/dict of replacement of variables or None/False to indicate rewrite doesn't apply
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ah okay makes sense now
Works great for me (running the test you wrote)! Are you planning on doing all of them in this PR? the solves, eig's, etc? |
pytensor/tensor/slinalg.py
Outdated
try: | ||
z[0] = scipy.linalg.cholesky(x, lower=self.lower).astype(x.dtype) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is it worth handling F_CONTIGUOUS
too? Or is C_CONTIGUOUS
way more common? Sorry if it's a dumb Q, I haven't fiddled with these flags in numpy before. Like, if its not c_contiguous would it be a bad idea to force it to be here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
C_CONTIGUOUS
is the default for numpy. We have to set this flag because the actual cholesky routine is written in fortran, and it's really picky about the inputs. If they're not exactly right (including being in column-major order), then the inputs get copied and re-formatted before they're passed along to the routine and the in-place operation doesn't happen.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That may explain a comment that I saw recently lying around some other scipy Op saying it didn't seem to be respecting the flag
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does Pytensor make sure that all the arrays that get passed around always remain contiguous?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No. There's an explicit Op for that. It's not guaranteed otherwise
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No. There's an explicit Op for that. It's not guaranteed otherwise
An op that does the equivalent of np.asfortranarray
?
Found the same thing with fortran / c-contiguous arrays for sp.linalg.solve
btw. overwrite_a
only overwrites if A
is f-contiguous. Looks like this happens at the lapack call to getrf that does an LU decomposition, which must be fortran. When overwrite_b=True
, b
is always overwritten whether A
is f or c contiguous.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
An op that does the equivalent of np.asfortranarray?
No, I meant we have an Op to make an array c-contiguous
pytensor/tensor/rewriting/basic.py
Outdated
@@ -135,6 +135,18 @@ def alloc_like( | |||
return rval | |||
|
|||
|
|||
def make_inplace(node, inplace_prop="inplace"): | |||
op = getattr(node.op, "core_op", node.op) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I missed this. Making the core_op inplace isn't enough because that falls out of scope of the pytensor safeguards.
In these cases the Blockwise itself needs to be inplace as well. I forgot Cholesky was Blockwised by default now
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I need to do both, or just the outer Blockwise
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Need to just think a bit about the API for Blockwise inplace. I would restrict the rewrite to the core Op case for now.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
When a Blockwise is not needed it will be removed so the inplace is still useful, just not for the batched case
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So I need to add a check to the rewrite that the Op is not a Blockwise?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why wouldn't a blockwise solve(A, b, overwrite_a=True) be safe?
I didn't read carefully enough, I understand now you are enumerating the cases. I don't think the case logic is too bad though. Here's a table:
a_batched | b_batched | overwrite_a | overwrite_b | safe? |
---|---|---|---|---|
FALSE | FALSE | FALSE | FALSE | Yes |
FALSE | FALSE | TRUE | FALSE | Yes |
FALSE | FALSE | FALSE | TRUE | Yes |
FALSE | FALSE | TRUE | TRUE | Yes |
TRUE | FALSE | FALSE | FALSE | Yes |
TRUE | FALSE | TRUE | FALSE | Yes |
TRUE | FALSE | FALSE | TRUE | No |
TRUE | FALSE | TRUE | TRUE | No |
FALSE | TRUE | FALSE | FALSE | Yes |
FALSE | TRUE | TRUE | FALSE | No |
FALSE | TRUE | FALSE | TRUE | Yes |
FALSE | TRUE | TRUE | TRUE | Yes |
TRUE | TRUE | FALSE | FALSE | Yes |
TRUE | TRUE | TRUE | FALSE | Yes |
TRUE | TRUE | FALSE | TRUE | Yes |
TRUE | TRUE | TRUE | TRUE | Yes |
I think it boils down to "overwrite things that are already batched, refuse to overwrite if it's being broadcast"
We might also be able to be intelligent about which flags are set based on the output shapes, if they are available.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah what matters is not the batching but the broadcasting (which is caused by the batching of other inputs)
Since Blockwise is more general than linalg Ops it may be worth to think about a couple more Ops to see if there's a strategy that works reasonably well for most cases.
Also may be worth thinking a bit about the numba backend.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
think it boils down to "overwrite things that are already batched, refuse to overwrite if it's being broadcast"
Yeah and for that we need a minimal API so Blockwise can talk with the core Op in a sense.
I think it's enough for Blockwise to say what inputs are safe to destruct and the Core Op can decide (but is not required) to destroy those, but none else.
I am hopeful that may not require too much code.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So basically what needs to happen is for blockwise to get some kind of "safe_to_destroy_map" property that can be referenced by rewrites?
I do think it fits in this PR but I'm not sure I'm the one to implement it. I can try of course.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do these issues only come up when Blockwise
isn't compiled out? Would it be possible to punt on it till later, or make it a separate issue, and go with Jesses current approach of applying the rewrite on the core_op
? If inplacing only worked on non-batched matrices I think that would satisfy the pareto principle pretty well.
eed127f
to
bb90822
Compare
for candidate_input in candidate_inputs: | ||
# Note: It may still be working in place and not be detectable by this check | ||
assert not np.allclose( | ||
batch_inputs_test_val[candidate_input], | ||
batch_inputs_test_val_copy[candidate_input], | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should test those other inputs haven't changed
cdb78f6
to
8dc2a26
Compare
@@ -131,6 +146,12 @@ def conjugate_solve_triangular(outer, inner): | |||
else: | |||
return [grad] | |||
|
|||
def try_inplace_inputs(self, candidate_inputs: list[int]) -> "Op": | |||
if candidate_inputs == [0]: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As per the signature this should not return None. Maybe that should be the default? Or return self (so it's still an Op)?
In the rewrite we're not expecting None to be returned, but only handling the NotImplementedError. I think there would be an error if None is returned as we next try to access .destroy_map
property
Description
All scipy.linalg functions offer an
overwrite_a
and/oroverwrite_b
argument that can enhance performance by re-using the input memory for the outputs. This PR implements re-writes that will set these flags to True at compile time.These rewrites are also a nice example for the in-place docs here, so I'll update them with an example as a later commit to this PR.
Related Issue
Op
#572Checklist
Type of change