Skip to content

Commit

Permalink
update jaxpr doc and tests with single-operand cond
Browse files Browse the repository at this point in the history
  • Loading branch information
froystig committed May 14, 2020
1 parent f90bd4f commit de03c99
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 43 deletions.
62 changes: 30 additions & 32 deletions docs/jaxpr.rst
Expand Up @@ -202,37 +202,35 @@ Higher-order primitives
jaxpr includes several higher-order primitives. They are more complicated because
they include sub-jaxprs.

Cond
^^^^

JAX traces through normal Python conditionals. To capture a conditional expression
for dynamic execution, one must use the :py:func:`jax.lax.cond` constructor
with the following signature::
Conditionals
^^^^^^^^^^^^

lax.cond(pred : bool, true_op: A, true_body: A -> B, false_op: C, false_body: C -> B) -> B
JAX traces through normal Python conditionals. To capture a
conditional expression for dynamic execution, one must use the
:py:func:`jax.lax.cond` constructor, which has the signature::

For example
lax.cond(pred: bool, true_body: A -> B, false_body: A -> B, operand: A) -> B

For example:

>>> from jax import lax
>>>
>>> def func7(arg):
... return lax.cond(arg >= 0.,
... arg,
... lambda xtrue: xtrue + 3.,
... arg,
... lambda xfalse: xfalse - 3.)
... lambda xfalse: xfalse - 3.,
... arg)
...
>>> print(make_jaxpr(func7)(5.))
{ lambda ; a.
let b = ge a 0.0
c = cond[ false_jaxpr={ lambda ; a.
let b = sub a 3.0
in (b,) }
linear=(False, False)
linear=(False,)
true_jaxpr={ lambda ; a.
let b = add a 3.0
in (b,) } ] b a a
in (b,) } ] b a
in (c,) }


Expand All @@ -245,51 +243,51 @@ The cond primitive has a number of parameters:
machinery to encode which of the input parameters are used linearly in the
conditional.

The above instance of the cond primitive takes 3 operands.
The first one (``b``) is the predicate, then ``a` is the ``true_op`` (``arg``, to be
passed to ``true_jaxpr``) and also ``a`` is the ``false_op``
(``arg``, to be passed to ``false_jaxpr``).
The above instance of the cond primitive takes 2 operands.
The first one (``b``) is the predicate, then ``a` is the operand (``arg``)
to be passed to ``true_jaxpr`` and ``false_jaxpr``.

The following example shows a more complicated situation when the input
to the branch functionals is a tuple, and the `false` branch functional
contains a constant ``jnp.ones(1)`` that is hoisted as a `constvar`

>>> def func8(arg1, arg2): # arg2 is a pair
... return lax.cond(arg1 >= 0.,
... arg2,
... lambda xtrue: xtrue[0],
... arg2,
... lambda xfalse: jnp.ones(1) + xfalse[1])
... lambda xfalse: jnp.ones(1) + xfalse[1],
... arg2)
...
>>> print(make_jaxpr(func8)(5., (jnp.zeros(1), 2.)))
{ lambda e ; a b c.
let d = ge a 0.0
f = cond[ false_jaxpr={ lambda ; c a b.
let d = add c b
in (d,) }
linear=(False, False, False, False, False)
true_jaxpr={ lambda ; a b.
let
in (a,) } ] d b c e b c
linear=(False, False, False)
true_jaxpr={ lambda ; a_ a b.
let
in (a,) } ] d e b c
in (f,) }

The top-level jaxpr has one `constvar` ``e`` (corresponding to ``jnp.ones(1)`` from the
body of the ``false_jaxpr``) and three input variables ``a b c`` (corresponding to ``arg1``
and the two elements of ``arg2``; note that ``arg2`` has been flattened).
The ``true_jaxpr`` has two input variables (corresponding to the two elements of ``arg2``
that is passed to ``true_jaxpr``).
The ``false_jaxpr`` has three input variables (``c`` corresponding to the constant for
``jnp.ones(1)``, and ``a b`` for the two elements of ``arg2`` that are passed
to ``false_jaxpr``).
The ``true_jaxpr`` has three input variables. The first (``a_``) is an
unused argument matching the constant first argument ``c`` of
``false_jaxpr`` (required for the jaxpr signatures to match). The
subsequent two correspond to the two elements of ``arg2`` that is
passed to ``true_jaxpr``.

The actual operands to the cond primitive are: ``d b c e b c``, which correspond in order to:
The actual operands to the cond primitive are: ``d e b c`` ``d b c e b c``, which correspond in order to:

* 1 operand for the predicate,
* 2 operands for ``true_jaxpr``, i.e., ``b`` and ``c``, which are input vars,
corresponding to ``arg2`` for the top-level jaxpr,
* 1 constant for ``false_jaxpr``, i.e., ``e``, which is a consvar for the top-level jaxpr,
* 2 operands for ``true_jaxpr``, i.e., ``b`` and ``c``, which are the input vars
corresponding to ``arg2`` for the top-level jaxpr.
* 1 constant (only used by ``false_jaxpr``, but passed to both),
i.e., ``e``, which is a constvar for the top-level jaxpr
* 2 operands passed to both jaxprs, i.e., ``b`` and ``c``, which are
input vars, corresponding to ``arg2`` for the top-level jaxpr.

While
^^^^^
Expand Down
20 changes: 9 additions & 11 deletions tests/api_test.py
Expand Up @@ -1740,10 +1740,9 @@ def func6(first):

def func7(arg):
return lax.cond(arg >= 0.,
arg,
lambda xtrue: xtrue + 3.,
arg,
lambda xfalse: xfalse - 3.)
lambda xfalse: xfalse - 3.,
arg)

jaxpr = api.make_jaxpr(func7)(5.)
self.assertMultiLineStrippedEqual("""
Expand All @@ -1752,19 +1751,18 @@ def func7(arg):
c = cond[ false_jaxpr={ lambda ; a.
let b = sub a 3.0
in (b,) }
linear=(False, False)
linear=(False,)
true_jaxpr={ lambda ; a.
let b = add a 3.0
in (b,) } ] b a a
in (b,) } ] b a
in (c,) }
""", str(jaxpr))

def func8(arg1, arg2): # arg2 is a pair
return lax.cond(arg1 >= 0.,
arg2,
lambda xtrue: xtrue[0],
arg2,
lambda xfalse: jnp.ones(1) + xfalse[1])
lambda xfalse: jnp.ones(1) + xfalse[1],
arg2)

jaxpr = api.make_jaxpr(func8)(5., (jnp.zeros(1), 2.))
self.assertMultiLineStrippedEqual("""
Expand All @@ -1773,10 +1771,10 @@ def func8(arg1, arg2): # arg2 is a pair
f = cond[ false_jaxpr={ lambda ; c a b.
let d = add c b
in (d,) }
linear=(False, False, False, False, False)
true_jaxpr={ lambda ; a b.
linear=(False, False, False)
true_jaxpr={ lambda ; a_ a b.
let
in (a,) } ] d b c e b c
in (a,) } ] d e b c
in (f,) }
""", str(jaxpr))

Expand Down

0 comments on commit de03c99

Please sign in to comment.