Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

The equivalent of LinearOperator in JAX #1733

Closed
pawel-czyz opened this issue Jul 11, 2023 · 2 comments
Closed

The equivalent of LinearOperator in JAX #1733

pawel-czyz opened this issue Jul 11, 2023 · 2 comments

Comments

@pawel-czyz
Copy link

This issue is related to TensorFlow Probability on JAX.

The tutorials of multivariate Student and multivariate normal use tf.linalg.LinearOperatorLowerTriangular.

However, this object does not exist in JAX. I wonder what is the correct way to build these distributions?
Passing a JAX array leads to an error:

TypeError: Expected argument `scale` to be instance of `LinearOperator`. Found: `[[ 0.6  0.   0. ]
 [ 0.2  0.5  0. ]
 [ 0.1 -0.3  0.4]]`.

Any help would be much appreciated!

Additional notes

@brianwa84
Copy link
Contributor

If you've done import tensorflow_probability.substrates.jax as tfp, then you can find linear operators in: jtf = tfp.tf2jax; dir(jtf.linalg)

@pawel-czyz
Copy link
Author

It resolved the issue, thank you very much! 🙂

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants