Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[pull] master from deepmind:master #320

Open
wants to merge 31 commits into
base: master
Choose a base branch
from

Conversation

pull[bot]
Copy link

@pull pull bot commented Mar 17, 2023

See Commits and Changes for more details.


Created by pull[bot]

Can you help keep this open source service alive? 💖 Please sponsor : )

PiperOrigin-RevId: 517425521
Change-Id: Iae8bc6bce806afbeccf87f35b790fc08a0762428
@pull pull bot added the ⤵️ pull label Mar 17, 2023
hawkinsp and others added 28 commits March 17, 2023 10:37
jax.xla was an accidental export from the jax namespace. jax.Device is the public name for JAX devices, as of JAX 0.4.3.

This is a trivial and safe change: jax.xla.Device and jax.Device are aliases.

PiperOrigin-RevId: 517449335
Change-Id: I48334cea855f8819d246f3f5ba75954b73781924
PiperOrigin-RevId: 518825517
Change-Id: I4bec530d53e4d93106da32dbe8f0026d5699ef85
PiperOrigin-RevId: 519995074
Change-Id: Id5956a9d1e9e40d906d1b3bc320f92408072cd77
PiperOrigin-RevId: 527367078
Change-Id: Ida4ba1153e872d0f6643f75da3040a08452502c8
NumPy 1.25 deprecates a number of function aliases (https://github.com/numpy/numpy/releases/tag/v1.25.0rc1)

This change replaces uses of the deprecated names with their recommended replacements:
* `np.round_` -> `np.round`
* `np.product` -> `np.prod`
* `np.cumproduct` -> `np.cumprod`
* `np.sometrue` -> `np.any`
* `np.alltrue` -> `np.all`

The deprecated functions will issue a `DeprecationWarning` under NumPy 1.25, and will be removed in NumPy 2.0.

PiperOrigin-RevId: 539973324
Change-Id: I68df241a9e3a04997787cd3b47f08acccb02e19f
…ubbed out).

With few exceptions (classes with callbacks mostly), all TensorFlow Probability (tfp) distributions and TensorFLow linear operators are already composite tensors, so this conversion function is no longer necessary. For out-of-tfp distributions, previous changes have already made the manual changes needed to make them be composite tensors.

PiperOrigin-RevId: 545736125
Change-Id: Iecb2d02a58db91a9800153b493a4d41eb2189e7b
… default value.

PiperOrigin-RevId: 547546481
Change-Id: I21ea27e026908bfd0390e5c772b3473fc46bb23d
PiperOrigin-RevId: 553134800
Change-Id: I9880dbd0ee617d37bb884cedabe591f44120e58b
`jnp.array` is a function, not a type:
https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.array.html
so it never makes sense to use `jnp.array` in a type annotation.

Presumably the intent was to write `jnp.ndarray` aka `jax.Array`. Change uses of `jnp.array` to `jnp.ndarray`.

PiperOrigin-RevId: 555454626
Change-Id: I089d1c5a0988f2b608fce41cc345c86f17b8957c
An upcoming change to JAX will teach pytype more accurate types for functions in the jax.numpy module. This reveals a number of type errors in downstream users of JAX. In particular, pytype is able to infer `jax.Array` accurately as a type in many more cases.

PiperOrigin-RevId: 557814902
Change-Id: Id3975ca545f6051ecc9526134251a87ca59b483a
… test.

The version pinning is a temporary fix; a longer term fix will require more
properly figuring out the different versions that need to be installed.

PiperOrigin-RevId: 558805522
Change-Id: Ib3ab357e2ee4712880c2345a8f6c707c8fb6af27
… not actually applied to constructing the default optimizer (in case one wasn't passed as an argument), which used hardcoded default values.

With this change it is now used for the default optimizer, and a new parameter dual_learning_rate is available to provide for the dual_optimizer.

PiperOrigin-RevId: 560774672
Change-Id: I2da2d8d241857fa3b411ed5669d92bd6a40d9311
… default value.

PiperOrigin-RevId: 561458449
Change-Id: I2c91d266ec6de36025c9cffaf6173f111a635a07
PiperOrigin-RevId: 568212531
Change-Id: Id5bfe34404ac2e0fab928219976131cf07bce91d
PiperOrigin-RevId: 568244870
Change-Id: I66cac69883a9857a3177a8710249e66987683941
PiperOrigin-RevId: 568398999
Change-Id: I4d41588ecb44c8804b9738aeb53322cab0be0c86
PiperOrigin-RevId: 568520617
Change-Id: I372051fa57315aa4710e842a0fc4582c685a78c6
Fixes Issue 292.

PiperOrigin-RevId: 571906732
Change-Id: I7ee0c4952fab2f3eec353e787caeffab799617a9
…to jax.Array

This change replaces uses of jax.random.KeyArray and jax.random.PRNGKeyArray in the context of type annotations with jax.Array, which is the correct annotation for JAX PRNG keys moving forward.

The purpose of this change is to remove references to KeyArray and PRNGKeyArray, which are deprecated (google/jax#17594) and will soon be removed from JAX. The design and thought process behind this is described in https://jax.readthedocs.io/en/latest/jep/9263-typed-keys.html.

Note that KeyArray and PRNGKeyArray have always been aliased to Any, so the new type annotation is far more specific than the old one.

PiperOrigin-RevId: 574195739
Change-Id: I703dbac5824d7497fa6228312a5722b96f0df665
PiperOrigin-RevId: 575132921
Change-Id: I3c535c9d5f766acd59ed19ff79885deb5d705e42
…nd="cpu").

An upcoming change to JAX will include non-local (addressable) CPU devices in jax.devices() when JAX is used multicontroller-style, where there are multiple Python processes.

This change preserves the current behavior by replacing uses of jax.devices("cpu"), which previously only returned local devices, with jax.local_devices("cpu"), which will return local devices both now and in the future.

This change is always be safe (i.e., it should always preserve the previous behavior) but it may sometimes be unnecessary if code is never used in a multicontroller setting.

PiperOrigin-RevId: 583072300
Change-Id: I128aa11d48a8dd89103dbec6f273e19963820b77
PiperOrigin-RevId: 595410959
Change-Id: Idf1c7d46fb3dfaccbc0e78c51cd91267f340a1e9
PiperOrigin-RevId: 595412439
Change-Id: I1367163d0c2cd08c22446bc9e439b946628d4e4c
PiperOrigin-RevId: 614766270
Change-Id: I8dd7b264f77c07583983663064cfec369d78c218
tensorflow_probability.substrates.jax and to use the JAX specific
BUILD target.

PiperOrigin-RevId: 618931679
Change-Id: Ic47d7ed13e46336fd4e593725be14b6ebe86b1f7
PiperOrigin-RevId: 625006545
Change-Id: Ib189d2bdd39687d8aaf6acb124ef72932613a474
…inference servers

According to the documentation [1]

when the thread pool size is smaller than the batch size, it is possible to hang when the batched handler is waiting to collect the next
example but all the threads are busy synchronously waiting for the results.

PiperOrigin-RevId: 628306415
Change-Id: I9c48a689d0e667577f361495524c8fd2b980653e
PiperOrigin-RevId: 630425642
Change-Id: I205ace49ab167b826267e92445b0077a4e465079
Jake VanderPlas and others added 2 commits May 20, 2024 04:57
The top-level `jax.tree_*` aliases have long been deprecated, and will soon be removed. Alternate APIs are in `jax.tree_util`, with shorter aliases in the `jax.tree` submodule, added in JAX version 0.4.25.

PiperOrigin-RevId: 635420186
Change-Id: Ie71a2deb905622b947a9b075ce55bcb1bff46462
PiperOrigin-RevId: 652787597
Change-Id: I7aff382d61475c35cc24b6bbc42d62b10ccebe76
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet