Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor nlinalg.norm to match np.linalg.norm signature and functionaly #588

Merged
merged 2 commits into from
Feb 10, 2024

Conversation

jessegrabowski
Copy link
Member

@jessegrabowski jessegrabowski commented Jan 14, 2024

Description

Currently, pt.linalg.norm does not match the signature of np.linalg.norm -- it is missing the axis and keepdims arguments, and has no logic for handling complex-valued matrices. This PR fixes this by copy-pasting the np.linalg.norm code to pytensor.

In addition, some extra logic is needed to allow for tensor-valued inputs to norm. This part is still a WIP. The tests need to be refactored to test all the new combinations of keyword arguments as well.

This is a preliminary step to addressing pymc-devs/pymc#7101, because the correct transformation between unconstrained lower-triangular matrices to cholesky factorized matrices involves row-normalizing the unconstrained lower-triangular matrix; see here.

Related Issue

Checklist

Type of change

  • New feature / enhancement
  • Bug fix
  • Documentation
  • Maintenance
  • Other (please specify):

@codecov-commenter
Copy link

codecov-commenter commented Jan 14, 2024

Codecov Report

Attention: 16 lines in your changes are missing coverage. Please review.

Comparison is base (ee7c946) 80.83% compared to head (2812aca) 80.83%.

Additional details and impacted files

Impacted file tree graph

@@            Coverage Diff             @@
##             main     #588      +/-   ##
==========================================
- Coverage   80.83%   80.83%   -0.01%     
==========================================
  Files         162      162              
  Lines       46743    46797      +54     
  Branches    11417    11434      +17     
==========================================
+ Hits        37786    37827      +41     
- Misses       6714     6722       +8     
- Partials     2243     2248       +5     
Files Coverage Δ
pytensor/tensor/nlinalg.py 94.51% <81.17%> (-3.58%) ⬇️

... and 2 files with indirect coverage changes

@jessegrabowski jessegrabowski marked this pull request as ready for review January 14, 2024 17:05
@jessegrabowski
Copy link
Member Author

I ended up just manually handing the broadcasting cases. Curious if my original approach could work somehow, because the solution I converged on duplicates a lot of the logic handled automatically by Blockwise.

I also had to make SVD Blockwise en passant, so that's a nice bonus from this PR.

Copy link
Member

@ricardoV94 ricardoV94 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These changes look amazing!

Left some comments, but I don't see anything that should be done structurally differently

pytensor/tensor/nlinalg.py Outdated Show resolved Hide resolved
pytensor/tensor/nlinalg.py Outdated Show resolved Hide resolved
pytensor/tensor/nlinalg.py Outdated Show resolved Hide resolved
pytensor/tensor/nlinalg.py Outdated Show resolved Hide resolved
pytensor/tensor/nlinalg.py Outdated Show resolved Hide resolved
pytensor/tensor/nlinalg.py Outdated Show resolved Hide resolved
pytensor/tensor/nlinalg.py Outdated Show resolved Hide resolved
tests/tensor/test_nlinalg.py Outdated Show resolved Hide resolved
tests/tensor/test_nlinalg.py Outdated Show resolved Hide resolved
tests/tensor/test_nlinalg.py Outdated Show resolved Hide resolved
@jessegrabowski
Copy link
Member Author

I reorganized this PR into two separate commits, one for SVD and one for norm. It should be good to go now.

@ricardoV94
Copy link
Member

Did we have any rewrites targeting SVD Ops? Now they we Blockwise they would need to be tweaked, but since tests pass I assume we don't have any

@jessegrabowski
Copy link
Member Author

No, it was kind of an orphan Op, with no rewrites or gradients.

@jessegrabowski jessegrabowski merged commit 28f2648 into pymc-devs:main Feb 10, 2024
53 checks passed
@jessegrabowski jessegrabowski deleted the norm-refactor branch February 11, 2024 01:12
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants