Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

chore(deps): update dependency jax to >=0.4.31, <=0.4.31 #722

Closed
wants to merge 1 commit into from

Conversation

renovate[bot]
Copy link
Contributor

@renovate renovate bot commented Jun 10, 2024

Mend Renovate

This PR contains the following updates:

Package Change Age Adoption Passing Confidence
jax >=0.4.16, <=0.4.26 -> >=0.4.31, <=0.4.31 age adoption passing confidence

Release Notes

google/jax (jax)

v0.4.31

Compare Source

  • Deletion

    • xmap has been deleted. Please use {func}shard_map as the replacement.
  • Changes

    • The minimum CuDNN version is v9.1. This was true in previous releases also,
      but we now declare this version constraint formally.
    • The minimum Python version is now 3.10. 3.10 will remain the minimum
      supported version until July 2025.
    • The minimum NumPy version is now 1.24. NumPy 1.24 will remain the minimum
      supported version until December 2024.
    • The minimum SciPy version is now 1.10. SciPy 1.10 will remain the minimum
      supported version until January 2025.
    • {func}jax.numpy.ceil, {func}jax.numpy.floor and {func}jax.numpy.trunc now return the output
      of the same dtype as the input, i.e. no longer upcast integer or boolean inputs to floating point.
    • libdevice.10.bc is no longer bundled with CUDA wheels. It must be
      installed either as a part of local CUDA installation, or via NVIDIA's CUDA
      pip wheels.
    • {class}jax.experimental.pallas.BlockSpec now expects block_shape to
      be passed before index_map. The old argument order is deprecated and
      will be removed in a future release.
    • Updated the repr of gpu devices to be more consistent
      with TPUs/CPUs. For example, cuda(id=0) will now be CudaDevice(id=0).
    • Added the device property and to_device method to {class}jax.Array, as
      part of JAX's Array API support.
  • Deprecations

    • Removed a number of previously-deprecated internal APIs related to
      polymorphic shapes. From {mod}jax.core: removed canonicalize_shape,
      dimension_as_value, definitely_equal, and symbolic_equal_dim.
    • HLO lowering rules should no longer wrap singleton ir.Values in tuples.
      Instead, return singleton ir.Values unwrapped. Support for wrapped values
      will be removed in a future version of JAX.
    • {func}jax.experimental.jax2tf.convert with native_serialization=False
      or enable_xla=False is now deprecated and this support will be removed in
      a future version.
      Native serialization has been the default since JAX 0.4.16 (September 2023).
    • The previously-deprecated function jax.random.shuffle has been removed;
      instead use jax.random.permutation with independent=True.

v0.4.30

Compare Source

  • Changes

    • JAX supports ml_dtypes >= 0.2. In 0.4.29 release, the ml_dtypes version was
      bumped to 0.4.0 but this has been rolled back in this release to give users
      of both TensorFlow and JAX more time to migrate to a newer TensorFlow
      release.
    • jax.experimental.mesh_utils can now create an efficient mesh for TPU v5e.
    • jax now depends on jaxlib directly. This change was enabled by the CUDA
      plugin switch: there are no longer multiple jaxlib variants. You can install
      a CPU-only jax with pip install jax, no extras required.
    • Added an API for exporting and serializing JAX functions. This used
      to exist in jax.experimental.export (which is being deprecated),
      and will now live in jax.export.
      See the documentation.
  • Deprecations

    • Internal pretty-printing tools jax.core.pp_* are deprecated, and will be removed
      in a future release.
    • Hashing of tracers is deprecated, and will lead to a TypeError in a future JAX
      release. This previously was the case, but there was an inadvertent regression in
      the last several JAX releases.
    • jax.experimental.export is deprecated. Use {mod}jax.export instead.
      See the migration guide.
    • Passing an array in place of a dtype is now deprecated in most cases; e.g. for arrays
      x and y, x.astype(y) will raise a warning. To silence it use x.astype(y.dtype).
    • jax.xla_computation is deprecated and will be removed in a future release.
      Please use the AOT APIs to get the same functionality as jax.xla_computation.
      • jax.xla_computation(fn)(*args, **kwargs) can be replaced with
        jax.jit(fn).lower(*args, **kwargs).compiler_ir('hlo').
      • You can also use .out_info property of jax.stages.Lowered to get the
        output information (like tree structure, shape and dtype).
      • For cross-backend lowering, you can replace
        jax.xla_computation(fn, backend='tpu')(*args, **kwargs) with
        jax.jit(fn).trace(*args, **kwargs).lower(lowering_platforms=('tpu',)).compiler_ir('hlo').

v0.4.29

Compare Source

  • Changes

    • We anticipate that this will be the last release of JAX and jaxlib
      supporting a monolithic CUDA jaxlib. Future releases will use the CUDA
      plugin jaxlib (e.g. pip install jax[cuda12]).
    • JAX now requires ml_dtypes version 0.4.0 or newer.
    • Removed backwards-compatibility support for old usage of the
      jax.experimental.export API. It is not possible anymore to use
      from jax.experimental.export import export, and instead you should use
      from jax.experimental import export.
      The removed functionality has been deprecated since 0.4.24.
    • Added is_leaf argument to {func}jax.tree.all & {func}jax.tree_util.tree_all.
  • Deprecations

    • jax.sharding.XLACompatibleSharding is deprecated. Please use
      jax.sharding.Sharding.
    • jax.experimental.Exported.in_shardings has been renamed as
      jax.experimental.Exported.in_shardings_hlo. Same for out_shardings.
      The old names will be removed after 3 months.
    • Removed a number of previously-deprecated APIs:
      • from {mod}jax.core: non_negative_dim, DimSize, Shape
      • from {mod}jax.lax: tie_in
      • from {mod}jax.nn: normalize
      • from {mod}jax.interpreters.xla: backend_specific_translations,
        translations, register_translation, xla_destructure,
        TranslationRule, TranslationContext, XlaOp.
    • The tol argument of {func}jax.numpy.linalg.matrix_rank is being
      deprecated and will soon be removed. Use rtol instead.
    • The rcond argument of {func}jax.numpy.linalg.pinv is being
      deprecated and will soon be removed. Use rtol instead.
    • The deprecated jax.config submodule has been removed. To configure JAX
      use import jax and then reference the config object via jax.config.
    • {mod}jax.random APIs no longer accept batched keys, where previously
      some did unintentionally. Going forward, we recommend explicit use of
      {func}jax.vmap in such cases.
    • In {func}jax.scipy.special.beta, the x and y parameters have been
      renamed to a and b for consistency with other beta APIs.
  • New Functionality

    • Added {func}jax.experimental.Exported.in_shardings_jax to construct
      shardings that can be used with the JAX APIs from the HloShardings
      that are stored in the Exported objects.

v0.4.28

Compare Source

  • Bug fixes

    • Reverted a change to make_jaxpr that was breaking Equinox (#​21116).
  • Deprecations & removals

    • The kind argument to {func}jax.numpy.sort and {func}jax.numpy.argsort
      is now removed. Use stable=True or stable=False instead.
    • Removed get_compute_capability from the jax.experimental.pallas.gpu
      module. Use the compute_capability attribute of a GPU device, returned
      by {func}jax.devices or {func}jax.local_devices, instead.
    • The newshape argument to {func}jax.numpy.reshapeis being deprecated
      and will soon be removed. Use shape instead.
  • Changes

    • The minimum jaxlib version of this release is 0.4.27.

v0.4.27

Compare Source

  • New Functionality

    • Added {func}jax.numpy.unstack and {func}jax.numpy.cumulative_sum,
      following their addition in the array API 2023 standard, soon to be
      adopted by NumPy.
    • Added a new config option jax_cpu_collectives_implementation to select the
      implementation of cross-process collective operations used by the CPU backend.
      Choices available are 'none'(default), 'gloo' and 'mpi' (requires jaxlib 0.4.26).
      If set to 'none', cross-process collective operations are disabled.
  • Changes

    • {func}jax.pure_callback, {func}jax.experimental.io_callback
      and {func}jax.debug.callback now use {class}jax.Array instead
      of {class}np.ndarray. You can recover the old behavior by transforming
      the arguments via jax.tree.map(np.asarray, args) before passing them
      to the callback.
    • complex_arr.astype(bool) now follows the same semantics as NumPy, returning
      False where complex_arr is equal to 0 + 0j, and True otherwise.
    • core.Token now is a non-trivial class which wraps a jax.Array. It could
      be created and threaded in and out of computations to build up dependency.
      The singleton object core.token has been removed, users now should create
      and use fresh core.Token objects instead.
    • On GPU, the Threefry PRNG implementation no longer lowers to a kernel call
      by default. This choice can improve runtime memory usage at a compile-time
      cost. Prior behavior, which produces a kernel call, can be recovered with
      jax.config.update('jax_threefry_gpu_kernel_lowering', True). If the new
      default causes issues, please file a bug. Otherwise, we intend to remove
      this flag in a future release.
  • Deprecations & Removals

    • Pallas now exclusively uses XLA for compiling kernels on GPU. The old
      lowering pass via Triton Python APIs has been removed and the
      JAX_TRITON_COMPILE_VIA_XLA environment variable no longer has any effect.
    • {func}jax.numpy.clip has a new argument signature: a, a_min, and
      a_max are deprecated in favor of x (positional only), min, and
      max ({jax-issue}20550).
    • The device() method of JAX arrays has been removed, after being deprecated
      since JAX v0.4.21. Use arr.devices() instead.
    • The initial argument to {func}jax.nn.softmax and {func}jax.nn.log_softmax
      is deprecated; empty inputs to softmax are now supported without setting this.
    • In {func}jax.jit, passing invalid static_argnums or static_argnames
      now leads to an error rather than a warning.
    • The minimum jaxlib version is now 0.4.23.
    • The {func}jax.numpy.hypot function now issues a deprecation warning when
      passing complex-valued inputs to it. This will raise an error when the
      deprecation is completed.
    • Scalar arguments to {func}jax.numpy.nonzero, {func}jax.numpy.where, and
      related functions now raise an error, following a similar change in NumPy.
    • The config option jax_cpu_enable_gloo_collectives is deprecated.
      Use jax.config.update('jax_cpu_collectives_implementation', 'gloo') instead.
    • The jax.Array.device_buffer and jax.Array.device_buffers methods have
      been removed after being deprecated in JAX v0.4.22. Instead use
      {attr}jax.Array.addressable_shards and {meth}jax.Array.addressable_data.
    • The condition, x, and y parameters of jax.numpy.where are now
      positional-only, following deprecation of the keywords in JAX v0.4.21.
    • Non-array arguments to functions in {mod}jax.lax.linalg now must be
      specified by keyword. Previously, this raised a DeprecationWarning.
    • Array-like arguments are now required in several :func:jax.numpy APIs,
      including {func}~jax.numpy.apply_along_axis,
      {func}~jax.numpy.apply_over_axes, {func}~jax.numpy.inner,
      {func}~jax.numpy.outer, {func}~jax.numpy.cross,
      {func}~jax.numpy.kron, and {func}~jax.numpy.lexsort.
  • Bug fixes

    • {func}jax.numpy.astype will now always return a copy when copy=True.
      Previously, no copy would be made when the output array would have the same
      dtype as the input array. This may result in some increased memory usage.
      The default value is set to copy=False to preserve backwards compatibility.

Configuration

📅 Schedule: Branch creation - At any time (no schedule defined), Automerge - At any time (no schedule defined).

🚦 Automerge: Disabled by config. Please merge this manually once you are satisfied.

Rebasing: Whenever PR becomes conflicted, or you tick the rebase/retry checkbox.

🔕 Ignore: Close this PR and you won't be reminded about this update again.


  • If you want to rebase/retry this PR, check this box

This PR was generated by Mend Renovate. View the repository job log.

@renovate renovate bot requested a review from a team June 10, 2024 18:02
@renovate renovate bot changed the title chore(deps): update dependency jax to >=0.4.29, <=0.4.29 chore(deps): update dependency jax to >=0.4.30, <=0.4.30 Jun 18, 2024
@renovate renovate bot changed the title chore(deps): update dependency jax to >=0.4.30, <=0.4.30 chore(deps): update dependency jax to >=0.4.31, <=0.4.31 Jul 30, 2024
@anakinxc anakinxc closed this Aug 26, 2024
@github-actions github-actions bot locked and limited conversation to collaborators Aug 26, 2024
@renovate renovate bot deleted the renovate/jax-0.x branch August 26, 2024 02:15
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant