-
Notifications
You must be signed in to change notification settings - Fork 102
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Extend ExtractDiag
gradient and numba implementation to higher dimensional inputs
#389
Conversation
ricardoV94
commented
Jul 18, 2023
•
edited
Loading
edited
- Extend ExtractDiag gradient to all cases
- Fixes wrong gradient for negative offsets
- Fix/extend ExtractDiag Numba implementation to all cases
- Deprecate AllocDiag Op and use symbolic graph instead
- Gets rid of wrong Numba and JAX implementations
- Extra: do not raise in JAX dispatch of Arange, when a non-constant input comes from a shape
9aa2137
to
9669a7d
Compare
9669a7d
to
c357794
Compare
There seems to be a pre-existing bug with the gradient and offset. I'll try and push a fix in this PR as well |
c357794
to
7bafc53
Compare
Ah hold your horses, the numba implementation of |
7bafc53
to
49f7c24
Compare
I deprecated The old implementation should be slower (no C impl) and was wrong for the Numba/JAX backends, so I am happy to merge as is. |
750e967
to
5c39aca
Compare
5c39aca
to
a2d34df
Compare
Codecov Report
Additional details and impacted files@@ Coverage Diff @@
## main #389 +/- ##
==========================================
- Coverage 80.42% 80.36% -0.06%
==========================================
Files 156 156
Lines 45483 45541 +58
Branches 11139 11151 +12
==========================================
+ Hits 36579 36599 +20
- Misses 6703 6741 +38
Partials 2201 2201
|
a72fb9e
to
004c8e7
Compare
@@ -3605,12 +3623,49 @@ def __setstate__(self, state): | |||
self.axis2 = 1 | |||
|
|||
|
|||
def alloc_diag(diag, offset=0, axis1=0, axis2=1): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why the high level alloc diag is removed? Could it be Op from graph instead? Can it cause less efficient rewrites and more complicated graphs?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The reason was twofold:
- The JAX/Numba implementations were incorrect
- It is a relatively simple operation (set_subtensor and transpose) that should be efficient enough (and if anything benefit from rewrites we have in PyTensor, not be hurt by them). In the C-backend this should certainly be faster because it avoids the python switch overhead
Making it a symbolic graph solves both problems. We can consider making it an OpFromGraph subclass if we see signs that this is slower. The dprint
graph will be more complex for sure, but otherwise you are just hiding the computational graph inside an Op implementation anyway.
We might also want to dispatch more specific implementations in other backends, but this shouldn't count as a regression because the pre-existing ones were simply wrong.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can optimizers see inside OpFromGraph? I'd say cleaner dprints
are desireable in general, all else equal.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
They can see inside. They can also be inlined with inline=True
so that rewrites can work across the boundary.
The dprint will move that complexity to inner graphs, whether that's easier or not I don't know.
There are however some issues with gradients of OpFromGraph #1, but we can definitely check if it works for this case. I am more inclined not to use it though. the graph isn't really that bad.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think I'm qualified to review this
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Tests seem quite exhaustive, so I don't think there will be any more bugs :D
I left some small comments. I'd say the early warning on not having gradients would be the closest thing to a real complaint from me, curious what you think.
Also, why was the change to jaxified arange also in this PR? I think you said in the last meeting but I forgot.
@@ -3605,12 +3623,49 @@ def __setstate__(self, state): | |||
self.axis2 = 1 | |||
|
|||
|
|||
def alloc_diag(diag, offset=0, axis1=0, axis2=1): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can optimizers see inside OpFromGraph? I'd say cleaner dprints
are desireable in general, all else equal.
Check the last two commits. When we got rid of the blackbox |
4e527d4
to
80a2212
Compare
80a2212
to
0507287
Compare
0507287
to
9c5a2df
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good to go
(1, 2), | ||
(2, 1), | ||
(0, 2), | ||
(2, 0), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(0, 2)
and (2, 0)
are the cases that would previously have raised, correct?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Before this PR anything above 2d input would have raised. The cases you listed weren't supported in the previous iteration of this PR
ExtractDiag
gradient and numba implementation to higher dimensional inputs