Skip to content

Commit

Permalink
Replace use of deprecated kind keyword in jax.numpy.sort
Browse files Browse the repository at this point in the history
``kind`` is being replaced by ``stable`` in NumPy 2.0, and jax.numpy is in the process of deprecating the old argument.

PiperOrigin-RevId: 631326978
  • Loading branch information
vanderplas authored and tensorflower-gardener committed May 7, 2024
1 parent 94f592a commit 3e65280
Showing 1 changed file with 11 additions and 3 deletions.
14 changes: 11 additions & 3 deletions tensorflow_probability/python/internal/backend/numpy/misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,8 +60,12 @@ def _argsort(values, axis=-1, direction='ASCENDING', stable=False, name=None):
values = np.negative(values)
else:
raise ValueError('Unrecognized direction: {}.'.format(direction))
return np.argsort(
values, axis, kind='stable' if stable else 'quicksort').astype(np.int32)
try:
# stable keyword introduced in NumPy 2.0.
return np.argsort(values, axis, stable=stable).astype(np.int32)
except TypeError:
return np.argsort(
values, axis, kind='stable' if stable else 'quicksort').astype(np.int32)


def _histogram_fixed_width(values, value_range, nbins=100, dtype=np.int32,
Expand Down Expand Up @@ -103,7 +107,11 @@ def _sort(values, axis=-1, direction='ASCENDING', name=None): # pylint: disable
values = np.negative(values)
else:
raise ValueError('Unrecognized direction: {}.'.format(direction))
result = np.sort(values, axis, kind='stable')
try:
# NumPy 2.0
result = np.sort(values, axis, stable=True)
except TypeError:
result = np.sort(values, axis, kind='stable')
if direction == 'DESCENDING':
return np.negative(result)
return result
Expand Down

0 comments on commit 3e65280

Please sign in to comment.