In [None]:
%config InlineBackend.figure_format = 'svg'

In [None]:
import jax
import jax.numpy as jnp
import jax.tree_util as jtu
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import genjax
from genjax import GenerativeFunction, ChoiceMap, Selection, trace

sns.set_theme(style="white")

# Pretty printing.
console = genjax.pretty(width=80)

# Reproducibility.
key = jax.random.PRNGKey(314159)

In [None]:
@genjax.gen
def g(x):
    m0 = genjax.trace("m0", genjax.bernoulli)(x)  # unsweetened
    return m0


@genjax.gen
def h(x):
    m0 = g(x) @ "sub"  # sweetened
    return m0

In [None]:
key, sub_key = jax.random.split(key)
tr = genjax.simulate(h)(sub_key, (0.3,))
print(tr.retval)
np.exp(tr.score)


In [None]:
jaxpr = jax.make_jaxpr(genjax.simulate(h))(key, (0.3,))
jaxpr

In [None]:
@genjax.gen
def h(x):
    m1 = genjax.bernoulli(x) @ "m0"
    m2 = genjax.bernoulli(x) @ "m1"
    return m1 + m2


key, sub_key = jax.random.split(key)
tr = genjax.simulate(h)(sub_key, (0.3,))
selection = genjax.select("m1")
selected = tr.get_choices().filter(selection)
selection

In [7]:
# Two branches for a branching submodel.
@genjax.gen
def model_y(x, coefficients):
    basis_value = jnp.array([1.0, x, x**2])
    polynomial_value = jnp.sum(basis_value * coefficients)
    y = genjax.tfp_normal(polynomial_value, 0.3) @ "value"
    return y


@genjax.gen
def outlier_model(x, coefficients):
    basis_value = jnp.array([1.0, x, x**2])
    polynomial_value = jnp.sum(basis_value * coefficients)
    y = genjax.tfp_normal(polynomial_value, 30.0) @ "value"
    return y


# The branching submodel.
switch = genjax.Switch(model_y, outlier_model)

# A mapped kernel function which calls the branching submodel.
@genjax.gen(genjax.Map, in_axes=(0, None))
def kernel(x, coefficients):
    is_outlier = genjax.bernoulli(0.1) @ "outlier"
    is_outlier = jnp.asarray(is_outlier, dtype=int)
    y = switch(is_outlier, x, coefficients) @ "y"
    return y


@genjax.gen
def model(xs):
    coefficients = genjax.mv_normal(np.zeros(3), 2.0 * np.identity(3)) @ "alpha"
    ys = kernel(xs, coefficients) @ "ys"
    return ys


In [8]:
data = jnp.arange(0, 10, 0.5)
key, sub_key = jax.random.split(key)
tr = jax.jit(model.simulate)(sub_key, (data,))
tr.strip()




├── [1m:ys[0m
│   └── [1m(Vector)[0m
│       ├── [1m:outlier[0m
│       │   └──  bool[20]
│       └── [1m:y[0m
│           └── [1m(Switch, i32[20])[0m
│               ├── 
│               │   └── [1m:value[0m
│               │       └──  f32[20]
│               └── 
│                   └── [1m:value[0m
│                       └──  f32[20]
└── [1m:alpha[0m
    └──  f32[3]

In [9]:
tr.get_retval()


[1;35mArray[0m[1m([0m[1m[[0m [1;36m-2.9941075[0m ,  [1;36m-2.7835548[0m ,  [1;36m-1.1057514[0m ,   [1;36m0.32127208[0m,
       [1;36m-50.07349[0m   ,   [1;36m5.217543[0m  ,   [1;36m8.73091[0m   ,  [1;36m11.922469[0m  ,
        [1;36m16.383581[0m  ,  [1;36m21.863268[0m  ,  [1;36m27.054882[0m  ,  [1;36m32.550755[0m  ,
        [1;36m38.68741[0m   ,  [1;36m45.95885[0m   ,  [1;36m52.816547[0m  ,  [1;36m60.52383[0m   ,
        [1;36m69.170555[0m  ,  [1;36m78.65796[0m   ,  [1;36m87.955605[0m  ,  [1;36m97.73689[0m   [1m][0m,      [33mdtype[0m=[35mfloat32[0m[1m)[0m

In [12]:
chm = tr.get_choices()
# values = chm["ys", "y", "value"]
values = chm["ys", "y"]
# values = chm["ys"]

values




└── [1m(Switch, i32[20])[0m
    ├── 
    │   └── [1mBuiltinTrace[0m
    │       ├── gen_fn
    │       │   └── [1mBuiltinGenerativeFunction[0m
    │       │       └── source
    │       │           └── <function model_y>
    │       ├── args
    │       │   └── [1mtuple[0m
    │       │       ├──  f32[20]
    │       │       └──  f32[20,3]
    │       ├── retval
    │       │   └──  f32[20]
    │       ├── choices
    │       │   └── [1mTrie[0m
    │       │       └── [1m:value[0m
    │       │           └── [1mDistributionTrace[0m
    │       │               ├── gen_fn
    │       │               │   └── [1mTFPDistribution[0m
    │       │               │       └── distribution
    │       │               │           └── (const) <class 
    │       │               │               'tensorflow_probability.substrates.jax.distributions.normal.Normal'>
    │       │               ├── args
    │       │               │   └── [1mtuple[0m
    │       │               │    

In [13]:
x = np.array([0.3, 0.7, 1.1, 1.4, 2.3, 2.5, 3.0, 4.0, 5.0])
y = 2.0 * x + 1.5 + x**2
y[2] = 50.0

observations = genjax.choice_map(
    {"ys": genjax.vector_choice_map(genjax.choice_map({("y", "value"): y}))}
)
key, sub_key = jax.random.split(key)
(w, tr) = model.importance(sub_key, observations, (x,))