# Using and extending `genjax.vi`

This notebook is intended as a tutorial: a gentle guide to the usage of our system on new problems, which illustrates how several parts of the system work together.

In [1]:
import jax
import genjax
from genjax import vi
from extras import beta_implicit

key = jax.random.PRNGKey(314159)

## Implementing new models and guides

In [2]:
import jax.numpy as jnp
import optax

#####################
# Model & Guide
#####################

@genjax.gen
def model():
    f = genjax.beta(0.5, 0.5) @ "latent_fairness"
    _ = genjax.tfp_bernoulli(f) @ "obs"


@genjax.gen
def guide(log_alpha, log_beta):
    beta_implicit(jnp.exp(log_alpha), jnp.exp(log_beta)) @ "latent_fairness"

With a model and guide program defined, we can construct a loss function in two ways: (a) using a standard library loss function (like `genjax.vi.elbo`) or (b) by writing our own.

### Writing our own loss

Let's write our own to get a feel for what that looks like.

First, we'll define two utility interfaces which wrap functionality exposed by `GenerativeFunction` - the model type of Gen. These are just for our own convenience, to make the code a bit nicer to look at.

In [3]:
# Sample a trace.
def sim(g: genjax.GenerativeFunction, args):
    key = adevjax.reap_key() # gain access to a fresh PRNG key
    tr = g.simulate(key, args)
    return tr

# Score constraints.
def density(g, chm, args):
    _, score = g.assess(jax.random.PRNGKey(0), chm, args)
    return score

Now, the loss function.

In [4]:
import adevjax
from genjax.typing import Tuple

def elbo(
    p: genjax.GenerativeFunction,
    q: genjax.GenerativeFunction,
    data: genjax.ChoiceMap,
):
    @adevjax.adev
    def elbo_loss(p_args: Tuple, q_args: Tuple):
        tr = sim(q, q_args)
        q_score = tr.get_score()
        observed = tr.get_choices().safe_merge(data)
        p_score = density(p, observed, p_args)
        return p_score - q_score

    return adevjax.E(elbo_loss)

Now, given some data, this object is something that we can construct an unbiased gradient estimator for.

In [5]:
#####################
# Data Generation
#####################

data = []
for _ in range(1):
    data.append(True)
for _ in range(9):
    data.append(False)

data = jnp.array(data)

print(data)

[ True False False False False False False False False False]


In [6]:
objective = elbo(model, guide, genjax.choice_map({"obs": data}))
objective

Expectation(prog=ADEVProgram(source=<function elbo.<locals>.elbo_loss at 0x168b30f70>))

In [7]:
_, q_grads = objective.grad_estimate(jax.random.PRNGKey(1), ((), (1.0, 1.0)))
q_grads

(Array(-0.17768328, dtype=float32, weak_type=True),
 Array(-0.21889801, dtype=float32, weak_type=True))

That all works, just like you'd expect it to.

### Using a standard library loss

Of course, we can also use the standard library version.

In [8]:
objective = genjax.vi.elbo(model, guide, genjax.choice_map({"obs": data}))
objective

Expectation(prog=ADEVProgram(source=<function elbo.<locals>.elbo_loss at 0x1692eba30>))

In [9]:
_, q_grads = objective.grad_estimate(jax.random.PRNGKey(1), ((), (1.0, 1.0)))
q_grads

(Array(-1.8723825, dtype=float32, weak_type=True),
 Array(1.2360288, dtype=float32, weak_type=True))

The values line up given the same `PRNGKey`.

Now, we'll use the loss as part of a training loop, where we compute gradient estimates and perform stochastic gradient ascent.

In [10]:
#####################
# SVI Setup
#####################

def svi_update(model, guide, optimizer):
    def _inner(key, data, params):
        data_chm = genjax.choice_map({"obs": data})
        objective = vi.elbo(model, guide, data_chm) # Here's our objective
        (loss, (_, params_grad)) = objective.value_and_grad_estimate(key, ((), params))
        params_grad = jax.tree_util.tree_map(lambda v: v * -1.0, params_grad)
        return params_grad, loss

    @jax.jit
    def updater(key, data, params, opt_state):
        params_grad, loss = _inner(key, data, params)
        updates, opt_state = optimizer.update(params_grad, opt_state, params)
        params = optax.apply_updates(params, updates)
        return params, loss, opt_state

    return updater

Below, we setup our parameters, and update process, to prepare for training.

In [11]:
# setup the optimizer
adam = optax.adam(5e-3)
svi_updater = svi_update(model, guide, adam)

# initialize parameters
log_alpha = jnp.log(jnp.array(10.0))
log_beta = jnp.log(jnp.array(10.0))

params = (log_alpha, log_beta)
opt_state = adam.init(params)

# warm up JIT compiler
key = jax.random.PRNGKey(2)
_ = svi_updater(key, data, params, opt_state)

We run our update process for 2000 steps.

In [12]:
#####################
# Gradient Steps
#####################
losses = []
for step in range(2000):
    key, sub_key = jax.random.split(key)
    params, loss, opt_state = svi_updater(key, data, params, opt_state)
    losses.append(loss)

Now, we can look at the trained parameters from our variational guide.

In [13]:
#####################
# Inferred parameters
#####################

log_alpha, log_beta = params
alpha, beta = jnp.exp(log_alpha), jnp.exp(log_beta)

# here we use some facts about the Beta distribution
# compute the inferred mean of the coin's fairness
inferred_mean = alpha / (alpha + beta)
# compute inferred standard deviation
factor = beta / (alpha * (1.0 + alpha + beta))
inferred_std = inferred_mean * jnp.sqrt(factor)
print(
    "\nBased on the data and our prior belief, the fairness "
    + "of the coin is %.3f +- %.3f" % (inferred_mean, inferred_std)
)


Based on the data and our prior belief, the fairness of the coin is 0.107 +- 0.130


We can bundle this all up into a single function.

In [14]:
def run_experiment(key, data):
    def svi_update(model, guide, optimizer):
        def _inner(key, data, params):
            data_chm = genjax.choice_map({"obs": data})
            objective = vi.elbo(model, guide, data_chm) # Here's our objective
            (loss, (_, params_grad)) = objective.value_and_grad_estimate(key, ((), params))
            params_grad = jax.tree_util.tree_map(lambda v: v * -1.0, params_grad)
            return params_grad, loss
    
        @jax.jit
        def updater(key, data, params, opt_state):
            params_grad, loss = _inner(key, data, params)
            updates, opt_state = optimizer.update(params_grad, opt_state, params)
            params = optax.apply_updates(params, updates)
            return params, loss, opt_state
    
        return updater

    # setup the optimizer
    adam = optax.adam(5e-3)
    svi_updater = svi_update(model, guide, adam)
    
    # initialize parameters
    log_alpha = jnp.log(jnp.array(10.0))
    log_beta = jnp.log(jnp.array(10.0))
    
    params = (log_alpha, log_beta)
    opt_state = adam.init(params)
    
    # warm up JIT compiler
    key = jax.random.PRNGKey(0)
    _ = svi_updater(key, data, params, opt_state)

    losses = []
    for step in range(2000):
        key, sub_key = jax.random.split(key)
        params, loss, opt_state = svi_updater(key, data, params, opt_state)
        losses.append(loss)

    log_alpha, log_beta = params
    alpha, beta = jnp.exp(log_alpha), jnp.exp(log_beta)
    
    # here we use some facts about the Beta distribution
    # compute the inferred mean of the coin's fairness
    inferred_mean = alpha / (alpha + beta)
    # compute inferred standard deviation
    factor = beta / (alpha * (1.0 + alpha + beta))
    inferred_std = inferred_mean * jnp.sqrt(factor)
    print(
        "\nBased on the data and our prior belief, the fairness "
        + "of the coin is %.3f +- %.3f" % (inferred_mean, inferred_std)
    )

Play around with this cell, and see if the inferences make sense!

In [15]:
data = []
for _ in range(10):
    data.append(True)
for _ in range(0):
    data.append(False)

data = jnp.array(data)

run_experiment(key, data)


Based on the data and our prior belief, the fairness of the coin is 0.835 +- 0.184
