In [1]:
import jax
import jax.numpy as jnp
import dataclasses

In [2]:
def some_numerical_function(x):
    return jnp.tanh(x) + 3.0

In [3]:
(jax.jit, jax.vmap, jax.grad)

(<function jax._src.api.jit(fun: 'Callable', *, static_argnums: 'Union[int, Iterable[int], None]' = None, static_argnames: 'Union[str, Iterable[str], None]' = None, device: 'Optional[xc.Device]' = None, backend: 'Optional[str]' = None, donate_argnums: 'Union[int, Iterable[int]]' = (), inline: 'bool' = False, keep_unused: 'bool' = False, abstracted_axes: 'Optional[Any]' = None) -> 'stages.Wrapped'>,
 <function jax._src.api.vmap(fun: 'F', in_axes: 'Union[int, Sequence[Any]]' = 0, out_axes: 'Any' = 0, axis_name: 'Optional[Hashable]' = None, axis_size: 'Optional[int]' = None, spmd_axis_name: 'Optional[Hashable]' = None) -> 'F'>,
 <function jax._src.api.grad(fun: 'Callable', argnums: 'Union[int, Sequence[int]]' = 0, has_aux: 'bool' = False, holomorphic: 'bool' = False, allow_int: 'bool' = False, reduce_axes: 'Sequence[AxisName]' = ()) -> 'Callable'>)

In [4]:
fn = jax.make_jaxpr(some_numerical_function)
fn(5.0)

{ [34m[22m[1mlambda [39m[22m[22m; a[35m:f32[][39m. [34m[22m[1mlet[39m[22m[22m b[35m:f32[][39m = tanh a; c[35m:f32[][39m = add b 3.0 [34m[22m[1min [39m[22m[22m(c,) }

In [5]:
fn = jax.jit(some_numerical_function)
fn(5.0)

Array(3.9999092, dtype=float32, weak_type=True)

In [6]:
fn = jax.jit(jax.vmap(jax.grad(some_numerical_function)))
fn(jnp.array([3.0, 5.0]))

Array([0.0098661 , 0.00018167], dtype=float32)

In [7]:
jaxpr = jax.make_jaxpr(jax.grad(some_numerical_function))(5.0)
jaxpr

{ [34m[22m[1mlambda [39m[22m[22m; a[35m:f32[][39m. [34m[22m[1mlet
    [39m[22m[22mb[35m:f32[][39m = tanh a
    c[35m:f32[][39m = sub 1.0 b
    _[35m:f32[][39m = add b 3.0
    d[35m:f32[][39m = mul 1.0 c
    e[35m:f32[][39m = mul d b
    f[35m:f32[][39m = add_any d e
  [34m[22m[1min [39m[22m[22m(f,) }

## Pytrees

In [8]:
import genjax
from typing import Any


@dataclasses.dataclass
class SomeFoo(genjax.Pytree):
    x: float
    y: Any

    def flatten(self):
        return (self.x, self.y), ()


def some_numerical_function(foo):
    def _inner():
        return foo.x

    return jnp.tanh(_inner()) + 3.0


foo = SomeFoo(5.0, SomeFoo(5.0, None))
leaves, form = jax.tree_util.tree_flatten(foo)
leaves

[5.0, 5.0]

In [9]:
jax.jit(some_numerical_function)(SomeFoo(5.0, None))

Array(3.9999092, dtype=float32, weak_type=True)

In [10]:
# Importing Jax functions useful for tracing/interpreting.
import numpy as np
from functools import wraps

from jax import core
from jax import lax
from jax._src.util import safe_map


def interp(fn):
    def eval_jaxpr(jaxpr, consts, *args):
        # Mapping from variable -> value
        env = {}

        def read(var):
            # Literals are values baked into the Jaxpr
            if type(var) is core.Literal:
                return var.val
            return env[var]

        def write(var, val):
            env[var] = val

        # Bind args and consts to environment
        safe_map(write, jaxpr.invars, args)
        safe_map(write, jaxpr.constvars, consts)

        # Loop through equations and evaluate primitives using `bind`
        for eqn in jaxpr.eqns:
            # Read inputs to equation from environment
            invals = safe_map(read, eqn.invars)
            # `bind` is how a primitive is called
            outvals = eqn.primitive.bind(*invals, **eqn.params)
            # Primitives may return multiple outputs or not
            if not eqn.primitive.multiple_results:
                outvals = [outvals]
            # Write the results of the primitive into the environment
            safe_map(write, eqn.outvars, outvals)
        # Read the final result of the Jaxpr from the environment
        return safe_map(read, jaxpr.outvars)

    def _inner(*args):
        jaxpr = jax.make_jaxpr(fn)(*args)
        v = eval_jaxpr(jaxpr.jaxpr, jaxpr.consts, *args)
        return v

    return _inner

In [11]:
jax.jit(interp(lambda x: x + 5.0))(5.0)

[Array(10., dtype=float32, weak_type=True)]

In [12]:
@genjax.gen
def model():
    x = genjax.Normal(0.0, 1.0) @ "x"
    return x


key = jax.random.PRNGKey(314159)
key, tr = model.simulate(key, ())
tr

BuiltinTrace(gen_fn=BuiltinGenerativeFunction(source=<function model at 0x16b0c42c0>), args=(), retval=Array(-0.10823099, dtype=float32), choices=Trie(inner={'x': DistributionTrace(gen_fn=_Normal(), args=(0.0, 1.0), value=Array(-0.10823099, dtype=float32), score=Array(-0.9247955, dtype=float32))}), cache=Trie(inner={}), score=Array(-0.9247955, dtype=float32))

In [18]:
def fn(index):
    x = jnp.ones(5)
    y = x.at[index].set(3.0)
    return y[index]


jax.jit(fn)(3)

Array(3., dtype=float32)

In [5]:
import jax
import jax.numpy as jnp
import genjax

console = genjax.pretty()


@genjax.gen
def add_normal_noise(x):
    noise1 = genjax.trace("noise1", genjax.Normal)(0.0, 1.0)
    noise2 = genjax.trace("noise2", genjax.Normal)(0.0, 1.0)
    return (key, x + noise1 + noise2)


@genjax.gen
def my_map():
    mapped = genjax.Map(add_normal_noise, in_axes=(0,))
    arr = jnp.ones(100)
    mapped(arr) @ "map"


key = jax.random.PRNGKey(314159)
key, tr = genjax.simulate(my_map)(key, ())
tr

In [7]:
import jax
import jax.numpy as jnp
import genjax

console = genjax.pretty()


@genjax.gen
def add_normal_noise(x):
    noise1 = genjax.trace("noise1", genjax.Normal)(0.0, 1.0)
    noise2 = genjax.trace("noise2", genjax.Normal)(0.0, 1.0)
    return (key, x + noise1 + noise2)


@genjax.gen
def my_map():
    mapped = genjax.Map(add_normal_noise, in_axes=(0,))
    arr = jnp.ones(100)
    mapped(arr) @ "map"


key = jax.random.PRNGKey(314159)
tr = genjax.simulate(my_map)(key, ())