From de03c99b52723e0e86062258e7144f8a49494c6f Mon Sep 17 00:00:00 2001 From: Roy Frostig Date: Mon, 11 May 2020 16:48:30 -0700 Subject: [PATCH] update jaxpr doc and tests with single-operand cond --- docs/jaxpr.rst | 62 +++++++++++++++++++++++------------------------ tests/api_test.py | 20 +++++++-------- 2 files changed, 39 insertions(+), 43 deletions(-) diff --git a/docs/jaxpr.rst b/docs/jaxpr.rst index cec0bae27a4e..b033b5455c75 100644 --- a/docs/jaxpr.rst +++ b/docs/jaxpr.rst @@ -202,26 +202,24 @@ Higher-order primitives jaxpr includes several higher-order primitives. They are more complicated because they include sub-jaxprs. -Cond -^^^^ - -JAX traces through normal Python conditionals. To capture a conditional expression -for dynamic execution, one must use the :py:func:`jax.lax.cond` constructor -with the following signature:: +Conditionals +^^^^^^^^^^^^ - lax.cond(pred : bool, true_op: A, true_body: A -> B, false_op: C, false_body: C -> B) -> B +JAX traces through normal Python conditionals. To capture a +conditional expression for dynamic execution, one must use the +:py:func:`jax.lax.cond` constructor, which has the signature:: -For example + lax.cond(pred: bool, true_body: A -> B, false_body: A -> B, operand: A) -> B +For example: >>> from jax import lax >>> >>> def func7(arg): ... return lax.cond(arg >= 0., -... arg, ... lambda xtrue: xtrue + 3., -... arg, -... lambda xfalse: xfalse - 3.) +... lambda xfalse: xfalse - 3., +... arg) ... >>> print(make_jaxpr(func7)(5.)) { lambda ; a. @@ -229,10 +227,10 @@ For example c = cond[ false_jaxpr={ lambda ; a. let b = sub a 3.0 in (b,) } - linear=(False, False) + linear=(False,) true_jaxpr={ lambda ; a. let b = add a 3.0 - in (b,) } ] b a a + in (b,) } ] b a in (c,) } @@ -245,10 +243,9 @@ The cond primitive has a number of parameters: machinery to encode which of the input parameters are used linearly in the conditional. -The above instance of the cond primitive takes 3 operands. -The first one (``b``) is the predicate, then ``a` is the ``true_op`` (``arg``, to be -passed to ``true_jaxpr``) and also ``a`` is the ``false_op`` -(``arg``, to be passed to ``false_jaxpr``). +The above instance of the cond primitive takes 2 operands. +The first one (``b``) is the predicate, then ``a` is the operand (``arg``) +to be passed to ``true_jaxpr`` and ``false_jaxpr``. The following example shows a more complicated situation when the input to the branch functionals is a tuple, and the `false` branch functional @@ -256,10 +253,9 @@ contains a constant ``jnp.ones(1)`` that is hoisted as a `constvar` >>> def func8(arg1, arg2): # arg2 is a pair ... return lax.cond(arg1 >= 0., -... arg2, ... lambda xtrue: xtrue[0], -... arg2, -... lambda xfalse: jnp.ones(1) + xfalse[1]) +... lambda xfalse: jnp.ones(1) + xfalse[1], +... arg2) ... >>> print(make_jaxpr(func8)(5., (jnp.zeros(1), 2.))) { lambda e ; a b c. @@ -267,29 +263,31 @@ contains a constant ``jnp.ones(1)`` that is hoisted as a `constvar` f = cond[ false_jaxpr={ lambda ; c a b. let d = add c b in (d,) } - linear=(False, False, False, False, False) - true_jaxpr={ lambda ; a b. - let - in (a,) } ] d b c e b c + linear=(False, False, False) + true_jaxpr={ lambda ; a_ a b. + let + in (a,) } ] d e b c in (f,) } The top-level jaxpr has one `constvar` ``e`` (corresponding to ``jnp.ones(1)`` from the body of the ``false_jaxpr``) and three input variables ``a b c`` (corresponding to ``arg1`` and the two elements of ``arg2``; note that ``arg2`` has been flattened). -The ``true_jaxpr`` has two input variables (corresponding to the two elements of ``arg2`` -that is passed to ``true_jaxpr``). The ``false_jaxpr`` has three input variables (``c`` corresponding to the constant for ``jnp.ones(1)``, and ``a b`` for the two elements of ``arg2`` that are passed to ``false_jaxpr``). +The ``true_jaxpr`` has three input variables. The first (``a_``) is an +unused argument matching the constant first argument ``c`` of +``false_jaxpr`` (required for the jaxpr signatures to match). The +subsequent two correspond to the two elements of ``arg2`` that is +passed to ``true_jaxpr``. -The actual operands to the cond primitive are: ``d b c e b c``, which correspond in order to: +The actual operands to the cond primitive are: ``d e b c`` ``d b c e b c``, which correspond in order to: * 1 operand for the predicate, - * 2 operands for ``true_jaxpr``, i.e., ``b`` and ``c``, which are input vars, - corresponding to ``arg2`` for the top-level jaxpr, - * 1 constant for ``false_jaxpr``, i.e., ``e``, which is a consvar for the top-level jaxpr, - * 2 operands for ``true_jaxpr``, i.e., ``b`` and ``c``, which are the input vars - corresponding to ``arg2`` for the top-level jaxpr. + * 1 constant (only used by ``false_jaxpr``, but passed to both), + i.e., ``e``, which is a constvar for the top-level jaxpr + * 2 operands passed to both jaxprs, i.e., ``b`` and ``c``, which are + input vars, corresponding to ``arg2`` for the top-level jaxpr. While ^^^^^ diff --git a/tests/api_test.py b/tests/api_test.py index 6fba84d9a51f..e41554c1677d 100644 --- a/tests/api_test.py +++ b/tests/api_test.py @@ -1740,10 +1740,9 @@ def func6(first): def func7(arg): return lax.cond(arg >= 0., - arg, lambda xtrue: xtrue + 3., - arg, - lambda xfalse: xfalse - 3.) + lambda xfalse: xfalse - 3., + arg) jaxpr = api.make_jaxpr(func7)(5.) self.assertMultiLineStrippedEqual(""" @@ -1752,19 +1751,18 @@ def func7(arg): c = cond[ false_jaxpr={ lambda ; a. let b = sub a 3.0 in (b,) } - linear=(False, False) + linear=(False,) true_jaxpr={ lambda ; a. let b = add a 3.0 - in (b,) } ] b a a + in (b,) } ] b a in (c,) } """, str(jaxpr)) def func8(arg1, arg2): # arg2 is a pair return lax.cond(arg1 >= 0., - arg2, lambda xtrue: xtrue[0], - arg2, - lambda xfalse: jnp.ones(1) + xfalse[1]) + lambda xfalse: jnp.ones(1) + xfalse[1], + arg2) jaxpr = api.make_jaxpr(func8)(5., (jnp.zeros(1), 2.)) self.assertMultiLineStrippedEqual(""" @@ -1773,10 +1771,10 @@ def func8(arg1, arg2): # arg2 is a pair f = cond[ false_jaxpr={ lambda ; c a b. let d = add c b in (d,) } - linear=(False, False, False, False, False) - true_jaxpr={ lambda ; a b. + linear=(False, False, False) + true_jaxpr={ lambda ; a_ a b. let - in (a,) } ] d b c e b c + in (a,) } ] d e b c in (f,) } """, str(jaxpr))