chore(deps): update dependency jax to >=0.4.31, <=0.4.31 #722
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
This PR contains the following updates:
>=0.4.16, <=0.4.26
->>=0.4.31, <=0.4.31
Release Notes
google/jax (jax)
v0.4.31
Compare Source
Deletion
shard_map
as the replacement.Changes
but we now declare this version constraint formally.
supported version until July 2025.
supported version until December 2024.
supported version until January 2025.
jax.numpy.ceil
, {func}jax.numpy.floor
and {func}jax.numpy.trunc
now return the outputof the same dtype as the input, i.e. no longer upcast integer or boolean inputs to floating point.
libdevice.10.bc
is no longer bundled with CUDA wheels. It must beinstalled either as a part of local CUDA installation, or via NVIDIA's CUDA
pip wheels.
jax.experimental.pallas.BlockSpec
now expectsblock_shape
tobe passed before
index_map
. The old argument order is deprecated andwill be removed in a future release.
with TPUs/CPUs. For example,
cuda(id=0)
will now beCudaDevice(id=0)
.device
property andto_device
method to {class}jax.Array
, aspart of JAX's Array API support.
Deprecations
polymorphic shapes. From {mod}
jax.core
: removedcanonicalize_shape
,dimension_as_value
,definitely_equal
, andsymbolic_equal_dim
.Instead, return singleton ir.Values unwrapped. Support for wrapped values
will be removed in a future version of JAX.
jax.experimental.jax2tf.convert
withnative_serialization=False
or
enable_xla=False
is now deprecated and this support will be removed ina future version.
Native serialization has been the default since JAX 0.4.16 (September 2023).
jax.random.shuffle
has been removed;instead use
jax.random.permutation
withindependent=True
.v0.4.30
Compare Source
Changes
bumped to 0.4.0 but this has been rolled back in this release to give users
of both TensorFlow and JAX more time to migrate to a newer TensorFlow
release.
jax.experimental.mesh_utils
can now create an efficient mesh for TPU v5e.plugin switch: there are no longer multiple jaxlib variants. You can install
a CPU-only jax with
pip install jax
, no extras required.to exist in
jax.experimental.export
(which is being deprecated),and will now live in
jax.export
.See the documentation.
Deprecations
jax.core.pp_*
are deprecated, and will be removedin a future release.
TypeError
in a future JAXrelease. This previously was the case, but there was an inadvertent regression in
the last several JAX releases.
jax.experimental.export
is deprecated. Use {mod}jax.export
instead.See the migration guide.
x
andy
,x.astype(y)
will raise a warning. To silence it usex.astype(y.dtype)
.jax.xla_computation
is deprecated and will be removed in a future release.Please use the AOT APIs to get the same functionality as
jax.xla_computation
.jax.xla_computation(fn)(*args, **kwargs)
can be replaced withjax.jit(fn).lower(*args, **kwargs).compiler_ir('hlo')
..out_info
property ofjax.stages.Lowered
to get theoutput information (like tree structure, shape and dtype).
jax.xla_computation(fn, backend='tpu')(*args, **kwargs)
withjax.jit(fn).trace(*args, **kwargs).lower(lowering_platforms=('tpu',)).compiler_ir('hlo')
.v0.4.29
Compare Source
Changes
supporting a monolithic CUDA jaxlib. Future releases will use the CUDA
plugin jaxlib (e.g.
pip install jax[cuda12]
).jax.experimental.export
API. It is not possible anymore to usefrom jax.experimental.export import export
, and instead you should usefrom jax.experimental import export
.The removed functionality has been deprecated since 0.4.24.
is_leaf
argument to {func}jax.tree.all
& {func}jax.tree_util.tree_all
.Deprecations
jax.sharding.XLACompatibleSharding
is deprecated. Please usejax.sharding.Sharding
.jax.experimental.Exported.in_shardings
has been renamed asjax.experimental.Exported.in_shardings_hlo
. Same forout_shardings
.The old names will be removed after 3 months.
jax.core
:non_negative_dim
,DimSize
,Shape
jax.lax
:tie_in
jax.nn
:normalize
jax.interpreters.xla
:backend_specific_translations
,translations
,register_translation
,xla_destructure
,TranslationRule
,TranslationContext
,XlaOp
.tol
argument of {func}jax.numpy.linalg.matrix_rank
is beingdeprecated and will soon be removed. Use
rtol
instead.rcond
argument of {func}jax.numpy.linalg.pinv
is beingdeprecated and will soon be removed. Use
rtol
instead.jax.config
submodule has been removed. To configure JAXuse
import jax
and then reference the config object viajax.config
.jax.random
APIs no longer accept batched keys, where previouslysome did unintentionally. Going forward, we recommend explicit use of
{func}
jax.vmap
in such cases.jax.scipy.special.beta
, thex
andy
parameters have beenrenamed to
a
andb
for consistency with otherbeta
APIs.New Functionality
jax.experimental.Exported.in_shardings_jax
to constructshardings that can be used with the JAX APIs from the HloShardings
that are stored in the
Exported
objects.v0.4.28
Compare Source
Bug fixes
make_jaxpr
that was breaking Equinox (#21116).Deprecations & removals
kind
argument to {func}jax.numpy.sort
and {func}jax.numpy.argsort
is now removed. Use
stable=True
orstable=False
instead.get_compute_capability
from thejax.experimental.pallas.gpu
module. Use the
compute_capability
attribute of a GPU device, returnedby {func}
jax.devices
or {func}jax.local_devices
, instead.newshape
argument to {func}jax.numpy.reshape
is being deprecatedand will soon be removed. Use
shape
instead.Changes
v0.4.27
Compare Source
New Functionality
jax.numpy.unstack
and {func}jax.numpy.cumulative_sum
,following their addition in the array API 2023 standard, soon to be
adopted by NumPy.
jax_cpu_collectives_implementation
to select theimplementation of cross-process collective operations used by the CPU backend.
Choices available are
'none'
(default),'gloo'
and'mpi'
(requires jaxlib 0.4.26).If set to
'none'
, cross-process collective operations are disabled.Changes
jax.pure_callback
, {func}jax.experimental.io_callback
and {func}
jax.debug.callback
now use {class}jax.Array
insteadof {class}
np.ndarray
. You can recover the old behavior by transformingthe arguments via
jax.tree.map(np.asarray, args)
before passing themto the callback.
complex_arr.astype(bool)
now follows the same semantics as NumPy, returningFalse where
complex_arr
is equal to0 + 0j
, and True otherwise.core.Token
now is a non-trivial class which wraps ajax.Array
. It couldbe created and threaded in and out of computations to build up dependency.
The singleton object
core.token
has been removed, users now should createand use fresh
core.Token
objects instead.by default. This choice can improve runtime memory usage at a compile-time
cost. Prior behavior, which produces a kernel call, can be recovered with
jax.config.update('jax_threefry_gpu_kernel_lowering', True)
. If the newdefault causes issues, please file a bug. Otherwise, we intend to remove
this flag in a future release.
Deprecations & Removals
lowering pass via Triton Python APIs has been removed and the
JAX_TRITON_COMPILE_VIA_XLA
environment variable no longer has any effect.jax.numpy.clip
has a new argument signature:a
,a_min
, anda_max
are deprecated in favor ofx
(positional only),min
, andmax
({jax-issue}20550
).device()
method of JAX arrays has been removed, after being deprecatedsince JAX v0.4.21. Use
arr.devices()
instead.initial
argument to {func}jax.nn.softmax
and {func}jax.nn.log_softmax
is deprecated; empty inputs to softmax are now supported without setting this.
jax.jit
, passing invalidstatic_argnums
orstatic_argnames
now leads to an error rather than a warning.
jax.numpy.hypot
function now issues a deprecation warning whenpassing complex-valued inputs to it. This will raise an error when the
deprecation is completed.
jax.numpy.nonzero
, {func}jax.numpy.where
, andrelated functions now raise an error, following a similar change in NumPy.
jax_cpu_enable_gloo_collectives
is deprecated.Use
jax.config.update('jax_cpu_collectives_implementation', 'gloo')
instead.jax.Array.device_buffer
andjax.Array.device_buffers
methods havebeen removed after being deprecated in JAX v0.4.22. Instead use
{attr}
jax.Array.addressable_shards
and {meth}jax.Array.addressable_data
.condition
,x
, andy
parameters ofjax.numpy.where
are nowpositional-only, following deprecation of the keywords in JAX v0.4.21.
jax.lax.linalg
now must bespecified by keyword. Previously, this raised a DeprecationWarning.
jax.numpy
APIs,including {func}
~jax.numpy.apply_along_axis
,{func}
~jax.numpy.apply_over_axes
, {func}~jax.numpy.inner
,{func}
~jax.numpy.outer
, {func}~jax.numpy.cross
,{func}
~jax.numpy.kron
, and {func}~jax.numpy.lexsort
.Bug fixes
jax.numpy.astype
will now always return a copy whencopy=True
.Previously, no copy would be made when the output array would have the same
dtype as the input array. This may result in some increased memory usage.
The default value is set to
copy=False
to preserve backwards compatibility.Configuration
📅 Schedule: Branch creation - At any time (no schedule defined), Automerge - At any time (no schedule defined).
🚦 Automerge: Disabled by config. Please merge this manually once you are satisfied.
♻ Rebasing: Whenever PR becomes conflicted, or you tick the rebase/retry checkbox.
🔕 Ignore: Close this PR and you won't be reminded about this update again.
This PR was generated by Mend Renovate. View the repository job log.