Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Extend ExtractDiag gradient and numba implementation to higher dimensional inputs #389

Merged
merged 7 commits into from
Aug 24, 2023

Conversation

ricardoV94
Copy link
Member

@ricardoV94 ricardoV94 commented Jul 18, 2023

  • Extend ExtractDiag gradient to all cases
    • Fixes wrong gradient for negative offsets
  • Fix/extend ExtractDiag Numba implementation to all cases
  • Deprecate AllocDiag Op and use symbolic graph instead
    • Gets rid of wrong Numba and JAX implementations
    • Extra: do not raise in JAX dispatch of Arange, when a non-constant input comes from a shape

@ricardoV94 ricardoV94 force-pushed the extract_diag_numba branch 3 times, most recently from 9aa2137 to 9669a7d Compare July 19, 2023 07:38
@ricardoV94 ricardoV94 marked this pull request as draft August 9, 2023 09:41
@ricardoV94
Copy link
Member Author

There seems to be a pre-existing bug with the gradient and offset. I'll try and push a fix in this PR as well

@ricardoV94 ricardoV94 added bug Something isn't working gradients labels Aug 9, 2023
@ricardoV94 ricardoV94 changed the title Fully support ExtractDiag in numba Extend ExtractDiag gradient and numba implementation to higher dimensional inputs Aug 9, 2023
@ricardoV94
Copy link
Member Author

Ah hold your horses, the numba implementation of AllocDiag is also wrong

@ricardoV94
Copy link
Member Author

ricardoV94 commented Aug 9, 2023

I deprecated AllocDiag in favor of a helper that returns the symbolic graph. The use of alloc_diag in ExtractDiag.grad can probably be improved to merge the two distinct set_subtensor operations that now exist.

The old implementation should be slower (no C impl) and was wrong for the Numba/JAX backends, so I am happy to merge as is.

@ricardoV94 ricardoV94 force-pushed the extract_diag_numba branch 6 times, most recently from 750e967 to 5c39aca Compare August 10, 2023 10:17
@ricardoV94 ricardoV94 added the jax label Aug 10, 2023
@ricardoV94 ricardoV94 marked this pull request as ready for review August 10, 2023 13:27
@codecov-commenter
Copy link

Codecov Report

Merging #389 (a2d34df) into main (c6b0858) will decrease coverage by 0.06%.
The diff coverage is 75.00%.

Additional details and impacted files

Impacted file tree graph

@@            Coverage Diff             @@
##             main     #389      +/-   ##
==========================================
- Coverage   80.42%   80.36%   -0.06%     
==========================================
  Files         156      156              
  Lines       45483    45541      +58     
  Branches    11139    11151      +12     
==========================================
+ Hits        36579    36599      +20     
- Misses       6703     6741      +38     
  Partials     2201     2201              
Files Changed Coverage Δ
pytensor/link/numba/dispatch/basic.py 87.23% <ø> (+0.18%) ⬆️
pytensor/link/numba/dispatch/tensor_basic.py 89.60% <61.76%> (-10.41%) ⬇️
pytensor/tensor/basic.py 88.61% <79.54%> (-2.23%) ⬇️
pytensor/link/jax/dispatch/tensor_basic.py 91.89% <100.00%> (+4.50%) ⬆️

... and 2 files with indirect coverage changes

@ricardoV94 ricardoV94 force-pushed the extract_diag_numba branch 2 times, most recently from a72fb9e to 004c8e7 Compare August 16, 2023 07:23
@@ -3605,12 +3623,49 @@ def __setstate__(self, state):
self.axis2 = 1


def alloc_diag(diag, offset=0, axis1=0, axis2=1):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why the high level alloc diag is removed? Could it be Op from graph instead? Can it cause less efficient rewrites and more complicated graphs?

Copy link
Member Author

@ricardoV94 ricardoV94 Aug 16, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The reason was twofold:

  1. The JAX/Numba implementations were incorrect
  2. It is a relatively simple operation (set_subtensor and transpose) that should be efficient enough (and if anything benefit from rewrites we have in PyTensor, not be hurt by them). In the C-backend this should certainly be faster because it avoids the python switch overhead

Making it a symbolic graph solves both problems. We can consider making it an OpFromGraph subclass if we see signs that this is slower. The dprint graph will be more complex for sure, but otherwise you are just hiding the computational graph inside an Op implementation anyway.

We might also want to dispatch more specific implementations in other backends, but this shouldn't count as a regression because the pre-existing ones were simply wrong.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can optimizers see inside OpFromGraph? I'd say cleaner dprints are desireable in general, all else equal.

Copy link
Member Author

@ricardoV94 ricardoV94 Aug 23, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

They can see inside. They can also be inlined with inline=True so that rewrites can work across the boundary.

The dprint will move that complexity to inner graphs, whether that's easier or not I don't know.

There are however some issues with gradients of OpFromGraph #1, but we can definitely check if it works for this case. I am more inclined not to use it though. the graph isn't really that bad.

@ricardoV94 ricardoV94 requested review from Armavica and michaelosthege and removed request for michaelosthege August 18, 2023 15:37
Copy link
Member

@michaelosthege michaelosthege left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think I'm qualified to review this

Copy link
Member

@jessegrabowski jessegrabowski left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Tests seem quite exhaustive, so I don't think there will be any more bugs :D

I left some small comments. I'd say the early warning on not having gradients would be the closest thing to a real complaint from me, curious what you think.

Also, why was the change to jaxified arange also in this PR? I think you said in the last meeting but I forgot.

pytensor/link/numba/dispatch/tensor_basic.py Show resolved Hide resolved
pytensor/tensor/basic.py Show resolved Hide resolved
pytensor/tensor/basic.py Show resolved Hide resolved
@@ -3605,12 +3623,49 @@ def __setstate__(self, state):
self.axis2 = 1


def alloc_diag(diag, offset=0, axis1=0, axis2=1):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can optimizers see inside OpFromGraph? I'd say cleaner dprints are desireable in general, all else equal.

pytensor/tensor/basic.py Outdated Show resolved Hide resolved
pytensor/tensor/basic.py Outdated Show resolved Hide resolved
@ricardoV94
Copy link
Member Author

Also, why was the change to jaxified arange also in this PR? I think you said in the last meeting but I forgot.

Check the last two commits. When we got rid of the blackbox Alloc Op one of the JAX tests now needed a constant shape input (or a specify_shape). After the last commit, that constraint could again be lifted.

@ricardoV94 ricardoV94 force-pushed the extract_diag_numba branch 2 times, most recently from 4e527d4 to 80a2212 Compare August 23, 2023 12:30
Copy link
Member

@jessegrabowski jessegrabowski left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good to go

(1, 2),
(2, 1),
(0, 2),
(2, 0),
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(0, 2) and (2, 0) are the cases that would previously have raised, correct?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Before this PR anything above 2d input would have raised. The cases you listed weren't supported in the previous iteration of this PR

@ricardoV94 ricardoV94 merged commit 28fc9ac into pymc-devs:main Aug 24, 2023
52 checks passed
@ricardoV94 ricardoV94 deleted the extract_diag_numba branch August 24, 2023 21:19
@ricardoV94 ricardoV94 changed the title Extend ExtractDiag gradient and numba implementation to higher dimensional inputs Extend ExtractDiag gradient and numba implementation to higher dimensional inputs Aug 25, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants