# Jaxpr = limited higher order features

jaxpr ::= { lambda Var* ; Var+.
            let Eqn*
            in  [Expr+] }
            
            
Where: 

The parameters of the jaxpr are shown as two lists of variables separated by ;. The first set of variables are the ones that have been introduced to stand for constants that have been hoisted out. These are called the constvars, and in a jax.core.ClosedJaxpr the consts field holds corresponding values. The second list of variables, called invars, correspond to the inputs of the traced Python function.

Eqn* is a list of equations, defining intermediate variables referring to intermediate expressions. Each equation defines one or more variables as the result of applying a primitive on some atomic expressions. Each equation uses only input variables and intermediate variables defined by previous equations.

Expr+: is a list of output atomic expressions (literals or variables) for the jaxpr.

In [3]:
from jax import make_jaxpr
import jax.numpy as jnp

In [4]:
def func1(first, second):
    temp = first + jnp.sin(second) * 3.
    return jnp.sum(temp)

In [5]:
make_jaxpr(func1)(jnp.zeros(8), jnp.ones(8))

{ lambda  ; a b.
  let c = sin b
      d = mul c 3.0
      e = add a d
      f = reduce_sum[ axes=(0,) ] e
  in (f,) }

Here there are no constvars, a and b are the input variables and they correspond respectively to first and second function parameters. The scalar literal 3.0 is kept inline. The reduce_sum primitive has named parameters axes and input_shape, in addition to the operand e.

Note that even though execution of a program that calls into JAX builds a jaxpr, Python-level control-flow and Python-level functions execute normally. This means that just because a Python program contains functions and control-flow, the resulting jaxpr does not have to contain control-flow or higher-order features.

In [6]:
make_jaxpr(func1)(jnp.zeros(5), jnp.ones(5))

{ lambda  ; a b.
  let c = sin b
      d = mul c 3.0
      e = add a d
      f = reduce_sum[ axes=(0,) ] e
  in (f,) }

In [7]:
def func2(inner, first, second):
    temp = first + inner(second) * 3
    return jnp.sum(temp)

In [8]:
def inner(second):
    if second.shape[0] > 4:
        return jnp.sin(second)
    else:
        assert False

In [10]:
def func3(first, second):
    return func2(inner, first, second)

In [13]:
print(make_jaxpr(func3)(jnp.zeros(8), jnp.ones(8)))

{ lambda  ; a b.
  let c = sin b
      d = mul c 3.0
      e = add a d
      f = reduce_sum[ axes=(0,) ] e
  in (f,) }


In [1]:
from jax import lax

def one_of_three(index, arg):
    return lax.switch(index, [lambda x: x+1.,
                              lambda x: x -2.,
                              lambda x: x + 3.],
                      arg)


In [4]:
print(make_jaxpr(one_of_three)(1, 5.))

{ lambda  ; a b.
  let c = convert_element_type[ new_dtype=int32
                                weak_type=False ] a
      d = clamp 0 c 2
      e = cond[ branches=( { lambda  ; a.
                             let b = add a 1.0
                             in (b,) }
                           { lambda  ; a.
                             let b = sub a 2.0
                             in (b,) }
                           { lambda  ; a.
                             let b = add a 3.0
                             in (b,) } )
                linear=(False,) ] d b
  in (e,) }


The cond primitive has a number of parameters:

branches are jaxprs that correspond to the branch functionals. In this example, those functionals each take one input variable, corresponding to x.

linear is a tuple of booleans that is used internally by the auto-differentiation machinery to encode which of the input parameters are used linearly in the conditional.

The above instance of the cond primitive takes two operands. The first one (d) is the branch index, then b is the operand (arg) to be passed to whichever jaxpr in branches is selected by the branch index.

