Skip to content

Commit

Permalink
Provide JAX Ops from Optional tfp dependency
Browse files Browse the repository at this point in the history
  • Loading branch information
ricardoV94 committed Jul 27, 2023
1 parent 8ac8342 commit 9ada945
Show file tree
Hide file tree
Showing 3 changed files with 65 additions and 2 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ jobs:
# PyTensor next, pip installs a lower version of numpy via the PyPI.
if [[ $INSTALL_NUMBA == "1" ]] && [[ $PYTHON_VERSION == "3.9" ]]; then mamba install --yes -q -c conda-forge "python~=${PYTHON_VERSION}=*_cpython" "numpy<1.23" "numba>=0.57" numba-scipy; fi
if [[ $INSTALL_NUMBA == "1" ]] && [[ $PYTHON_VERSION != "3.9" ]]; then mamba install --yes -q -c conda-forge "python~=${PYTHON_VERSION}=*_cpython" "numba>=0.57" numba-scipy; fi
if [[ $INSTALL_JAX == "1" ]]; then mamba install --yes -q -c conda-forge "python~=${PYTHON_VERSION}=*_cpython" jax jaxlib numpyro; fi
if [[ $INSTALL_JAX == "1" ]]; then mamba install --yes -q -c conda-forge "python~=${PYTHON_VERSION}=*_cpython" jax jaxlib numpyro && pip install tensorflow-probability; fi
pip install -e ./
mamba list && pip freeze
python -c 'import pytensor; print(pytensor.config.__str__(print_doc=False))'
Expand Down
36 changes: 35 additions & 1 deletion pytensor/link/jax/dispatch/scalar.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import functools
import typing
from typing import Callable, Optional

import jax
import jax.numpy as jnp
Expand All @@ -18,7 +20,21 @@
Second,
Sub,
)
from pytensor.scalar.math import Erf, Erfc, Erfinv, Log1mexp, Psi
from pytensor.scalar.math import Erf, Erfc, Erfcinv, Erfcx, Erfinv, Iv, Log1mexp, Psi


def try_import_tfp_jax_op(op: ScalarOp, jax_op_name: Optional[str] = None) -> Callable:
try:
import tensorflow_probability.substrates.jax.math as tfp_jax_math
except ModuleNotFoundError:
raise NotImplementedError(
f"No JAX implementation for Op {op.name}. "
"Implementation is available if TensorFlow Probability is installed"
)

if jax_op_name is None:
jax_op_name = op.name
return typing.cast(Callable, getattr(tfp_jax_math, jax_op_name))


def check_if_inputs_scalars(node):
Expand Down Expand Up @@ -211,6 +227,24 @@ def erfinv(x):
return erfinv


@jax_funcify.register(Erfcx)
@jax_funcify.register(Erfcinv)
def jax_funcify_from_tfp(op, **kwargs):
tfp_jax_op = try_import_tfp_jax_op(op)

return tfp_jax_op


@jax_funcify.register(Iv)
def jax_funcify_Iv(op, **kwargs):
ive = try_import_tfp_jax_op(op, jax_op_name="bessel_ive")

def iv(v, x):
return ive(v, x) / jnp.exp(-jnp.abs(jnp.real(x)))

return iv


@jax_funcify.register(Log1mexp)
def jax_funcify_Log1mexp(op, node, **kwargs):
def log1mexp(x):
Expand Down
29 changes: 29 additions & 0 deletions tests/link/jax/test_scalar.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,17 @@
from pytensor.graph.fg import FunctionGraph
from pytensor.graph.op import get_test_value
from pytensor.scalar.basic import Composite
from pytensor.tensor import as_tensor
from pytensor.tensor.elemwise import Elemwise
from pytensor.tensor.math import all as at_all
from pytensor.tensor.math import (
cosh,
erf,
erfc,
erfcinv,
erfcx,
erfinv,
iv,
log,
log1mexp,
psi,
Expand All @@ -28,6 +32,14 @@
from pytensor.link.jax.dispatch import jax_funcify


try:
pass

TFP_INSTALLED = True
except ModuleNotFoundError:
TFP_INSTALLED = False


def test_second():
a0 = scalar("a0")
b = scalar("b")
Expand Down Expand Up @@ -134,6 +146,23 @@ def test_erfinv():
compare_jax_and_py(fg, [0.95])


@pytest.mark.parametrize(
"op, test_values",
[
(erfcx, (0.7,)),
(erfcinv, (0.7,)),
(iv, (0.3, 0.7)),
],
)
@pytest.mark.skipif(not TFP_INSTALLED, reason="Test requires tensorflow-probability")
def test_tfp_ops(op, test_values):
inputs = [as_tensor(test_value).type() for test_value in test_values]
output = op(*inputs)

fg = FunctionGraph(inputs, [output])
compare_jax_and_py(fg, test_values)


def test_psi():
x = scalar("x")
out = psi(x)
Expand Down

0 comments on commit 9ada945

Please sign in to comment.