Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add JAX support for SortOp #657

Merged
merged 2 commits into from
Mar 4, 2024
Merged

Conversation

HarshvirSandhu
Copy link
Contributor

Description

Implement JAX conversion for SortOp

Related Issue

Checklist

Type of change

  • Maintenance

@codecov-commenter
Copy link

codecov-commenter commented Mar 3, 2024

Codecov Report

All modified and coverable lines are covered by tests ✅

Project coverage is 80.82%. Comparing base (082081a) to head (1ef5238).
Report is 26 commits behind head on main.

Additional details and impacted files

Impacted file tree graph

@@            Coverage Diff             @@
##             main     #657      +/-   ##
==========================================
+ Coverage   80.80%   80.82%   +0.02%     
==========================================
  Files         162      162              
  Lines       46743    46820      +77     
  Branches    11419    11438      +19     
==========================================
+ Hits        37770    37844      +74     
+ Misses       6731     6725       -6     
- Partials     2242     2251       +9     
Files Coverage Δ
pytensor/link/jax/dispatch/tensor_basic.py 92.37% <100.00%> (+0.40%) ⬆️

... and 24 files with indirect coverage changes

Copy link
Member

@ricardoV94 ricardoV94 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good, I left just two small modification suggestions for readability and more extensive testing

Comment on lines 213 to 214
def sort(arr, *args):
return jnp.sort(arr, *args)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
def sort(arr, *args):
return jnp.sort(arr, *args)
def sort(arr, axis):
return jnp.sort(arr, axis=axis)

Comment on lines 221 to 226
def test_sort():
x = matrix("x")
out = pytensor.tensor.sort(x)
fgraph = FunctionGraph([x], [out])
arr = np.array([[1.0, 4.0], [5.0, 2.0]])
compare_jax_and_py(fgraph, [arr])
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
def test_sort():
x = matrix("x")
out = pytensor.tensor.sort(x)
fgraph = FunctionGraph([x], [out])
arr = np.array([[1.0, 4.0], [5.0, 2.0]])
compare_jax_and_py(fgraph, [arr])
@pytest.mark.parametrize("axis", [None, -1])
def test_sort(axis):
x = matrix("x", shape=(2, 2), dtype="float64")
out = pytensor.tensor.sort(x, axis=axis)
fgraph = FunctionGraph([x], [out])
arr = np.array([[1.0, 4.0], [5.0, 2.0]])
compare_jax_and_py(fgraph, [arr])

@ricardoV94 ricardoV94 merged commit d175203 into pymc-devs:main Mar 4, 2024
53 checks passed
@ricardoV94
Copy link
Member

Thanks @HarshvirSandhu

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

FEAT: JAX Conversion for the given SortOp
4 participants