# Using and extending `genjax.vi`

This notebook shows how to use our library to solve a new  inference task, beyond those considered in our experiments. It is intended to illustrate the usage of the library, but assumes some knowledge of variational inference. 

We use a very simple toy example: inferring the likely bias of a coin from multiple observed coin flips. (This same example is used in the tutorials for Pyro, another PPL that we compare to in our experiments.) The inference problem comes from [Pyro's SVI Part I tutorial](https://pyro.ai/examples/svi_part_i.html#A-simple-example).

Variational inference is overkill for solving this problem—in fact, the exact Bayesian posterior is analytically tractable. But it will serve to demonstrate the basic features of the library. The more complex examples in our experiments follow the same basic structure, and so understanding the code in this notebook should make it possible to understand (and modify) those more complex examples.

In [1]:
import jax
import genjax
import jax.numpy as jnp
import optax
from genjax import vi
from extras import beta_implicit, sim, density

key = jax.random.PRNGKey(314159)

## Implementing new models and guides

We begin by defining a model, a joint distribution over the unknown bias of the coin, and the observed sequence of flips.

Models and variational families (guides) in our system are probabilistic programs. We write these using a modeling language which can be accessed via the `genjax.gen` decorator. 

In the code, random choices can be made using the syntax `dist(args) @ "choice_name"`, where `"choice_name"` is a unique name for the random variable being sampled. In the code below, our model defines a distribution over two random variables, and the variational family, or guide, defines a distribution over only one random variable.

Although we don't show it here, deterministic (JAX traceable) code can be freely interwoven between random variable statements.

In [2]:
@genjax.gen
def model():
    f = genjax.tfp_beta(10.0, 10.0) @ "latent_fairness"
    _ = genjax.tfp_flip(f) @ "obs"
    
model

BuiltinGenerativeFunction(source=<function model at 0x15f587490>)

We then define a variational family, a parametric representation of the posterior distribution. Our goal will be to infer which parameters alpha and beta make this variational family as close as possible to the true posterior.


In [3]:
@genjax.gen
def guide(alpha, beta):
    beta_implicit(alpha, beta) @ "latent_fairness"

guide

BuiltinGenerativeFunction(source=<function guide at 0x15f587e20>)

To find the optimal parameters, we need to define a loss function that measures how well the guide matches the posterior, for any given alpha and beta. The standard choice of loss in variational inference is the ELBO, or evidence lower bound. 

Now, we can construct this loss function in two ways: (a) using a standard library loss function (like `genjax.vi.elbo`) or (b) by writing our own version.

### Writing our own loss

Let's write our own version to get a feel for what that looks like in our system.

In [4]:
import adevjax
from genjax.typing import Tuple

def elbo(
    p: genjax.GenerativeFunction,
    q: genjax.GenerativeFunction,
    data: genjax.ChoiceMap,
):

    @adevjax.adev
    def elbo_loss(p_args: Tuple, q_args: Tuple):
        x, log_q = sim(q, q_args)
        x_y = x.safe_merge(data)
        log_p = density(p, x_y, p_args)
        return log_p - log_q

    return adevjax.E(elbo_loss)

A loss in GenJAX will generally be the expected value (`adevjax.E`) of a probabilistic process -- in this case, the process of simulating from Q and computing log P/Q.

To write the loss, we use the language of [ADEV, a new type of AD algorithm](https://arxiv.org/pdf/2212.06386.pdf). ADEV handles automating the construction of unbiased gradient estimators for us.

Now, given a `p`, a `q`, and some `data`, this function will return an `Expectation`, an expected value objective function which we wish to acquire unbiased gradient estimators for. 

Let's construct some `data` to use:

In [5]:
#####################
# Data Generation
#####################

data = []
for _ in range(6):
    data.append(True)
for _ in range(4):
    data.append(False)

data = jnp.array(data)

print(data)

[ True  True  True  True  True  True False False False False]


Now, we can build our objective:

In [6]:
objective = elbo(model, guide, genjax.choice_map({"obs": data}))
objective

Expectation(prog=ADEVProgram(source=<function elbo.<locals>.elbo_loss at 0x15f5cca60>))

And we can even construct and sample from a gradient estimator for the objective:

In [7]:
key, sub_key = jax.random.split(key)
_, q_grads = objective.grad_estimate(sub_key, ((), (1.0, 1.0)))
q_grads

(Array(-13.78883, dtype=float32, weak_type=True),
 Array(25.818943, dtype=float32, weak_type=True))

The `objective.grad_estimate` method takes arguments `(key: PRNGKey, loss_args: Tuple)` and returns an unbiased estimate of the gradient of our objective. 

We can use these gradient estimates for stochastic optimization of the guide's parameters (see below).

### Using a standard library loss

Of course, we can also use the standard library version.

In [8]:
objective = genjax.vi.elbo(model, guide, genjax.choice_map({"obs": data}))
objective

Expectation(prog=ADEVProgram(source=<function elbo.<locals>.elbo_loss at 0x15ffbf520>))

In [9]:
key, sub_key = jax.random.split(key)
_, q_grads = objective.grad_estimate(sub_key, ((), (1.0, 1.0)))
q_grads

(Array(7.713711, dtype=float32, weak_type=True),
 Array(-7.2344327, dtype=float32, weak_type=True))

Now, we'll use the loss as part of a training loop, where we compute gradient estimates and perform stochastic gradient ascent.

In [10]:
#####################
# Parameter updater
#####################

def svi_update(model, guide, optimizer):
    def _inner(key, data, params):
        data_chm = genjax.choice_map({"obs": data})
        objective = vi.elbo(model, guide, data_chm) # Here's our objective
        (loss, (_, params_grad)) = objective.value_and_grad_estimate(key, ((), params))
        params_grad = jax.tree_util.tree_map(lambda v: v * -1.0, params_grad)
        return params_grad, loss

    @jax.jit
    def updater(key, data, params, opt_state):
        params_grad, loss = _inner(key, data, params)
        updates, opt_state = optimizer.update(params_grad, opt_state, params)
        params = optax.apply_updates(params, updates)
        return params, loss, opt_state

    return updater

Below, we setup our parameters, and the update process, to prepare for training.

In [11]:
# setup the optimizer
adam = optax.adam(5e-4)
svi_updater = svi_update(model, guide, adam)

# initialize parameters
alpha = jnp.array(2.0)
beta = jnp.array(2.0)
params = (alpha, beta)
opt_state = adam.init(params)

# warm up JIT compiler
key = jax.random.PRNGKey(2)
_ = svi_updater(key, data, params, opt_state)

We run our update process for 2000 steps.

In [12]:
#####################
# Run gradient steps
#####################

for step in range(5000):
    key, sub_key = jax.random.split(key)
    params, loss, opt_state = svi_updater(key, data, params, opt_state)

Now, we can look at the trained parameters from our variational guide.

In [13]:
#####################
# Inferred parameters
#####################

alpha, beta = params

# here we use some facts about the Beta distribution
# compute the inferred mean of the coin's fairness
inferred_mean = alpha / (alpha + beta)
# compute inferred standard deviation
factor = beta / (alpha * (1.0 + alpha + beta))
inferred_std = inferred_mean * jnp.sqrt(factor)
print(
    "\nBased on the data and our prior belief, the fairness "
    + "of the coin is %.3f +- %.3f" % (inferred_mean, inferred_std)
)


Based on the data and our prior belief, the fairness of the coin is 0.525 +- 0.207


We can bundle this all up into a single function.

In [14]:
def run_experiment(key, data, loss_fn):
    def svi_update(model, guide, optimizer):
        def _inner(key, data, params):
            data_chm = genjax.choice_map({"obs": data})
            objective = loss_fn(model, guide, data_chm) # Here's our objective
            (loss, (_, params_grad)) = objective.value_and_grad_estimate(key, ((), params))
            params_grad = jax.tree_util.tree_map(lambda v: v * -1.0, params_grad)
            return params_grad, loss
    
        @jax.jit
        def updater(key, data, params, opt_state):
            params_grad, loss = _inner(key, data, params)
            updates, opt_state = optimizer.update(params_grad, opt_state, params)
            params = optax.apply_updates(params, updates)
            return params, loss, opt_state
    
        return updater

    # setup the optimizer
    adam = optax.adam(5e-4)
    svi_updater = svi_update(model, guide, adam)
    
    # initialize parameters
    alpha = jnp.array(15.0)
    beta = jnp.array(15.0)
    
    params = (alpha, beta)
    opt_state = adam.init(params)
    
    # warm up JIT compiler
    key = jax.random.PRNGKey(0)
    _ = svi_updater(key, data, params, opt_state)

    losses = []
    for step in range(5000):
        key, sub_key = jax.random.split(key)
        params, loss, opt_state = svi_updater(key, data, params, opt_state)
        losses.append(loss)

    alpha, beta = params
    
    # here we use some facts about the Beta distribution
    # compute the inferred mean of the coin's fairness
    inferred_mean = alpha / (alpha + beta)
    # compute inferred standard deviation
    factor = beta / (alpha * (1.0 + alpha + beta))
    inferred_std = inferred_mean * jnp.sqrt(factor)
    print(
        "\nBased on the data and our prior belief, the fairness "
        + "of the coin is %.3f +- %.3f" % (inferred_mean, inferred_std)
    )

Play around with this cell, and see if the inferences make sense! Note that the model beta prior `Beta(10.0, 10.0)` is symmetric, and peaked around `p = 0.5`. 

The prior is going to prevent the variational family from having inferred means far away from `p = 0.5`, while the data will try and pull the family away (depending on the number of `True` of `False` values observed).

In [15]:
data = []
for _ in range(8):
    data.append(True)
for _ in range(2):
    data.append(False)

data = jnp.array(data)

run_experiment(key, data, genjax.vi.elbo)


Based on the data and our prior belief, the fairness of the coin is 0.549 +- 0.090


## Using an alternative objective

Although the ELBO is a standard choice for variational inference, other objectives are possible, and our library makes it possible to define your own. This model is simple enough that many variational objectives will yield comparable inference results, but as an example, here is [the 2-particle IWELBO objective](https://arxiv.org/abs/1509.00519) 

In [16]:
import jax.tree_util as jtu
import jax.numpy as jnp
from jax.scipy.special import logsumexp

def iwelbo(
    p: genjax.GenerativeFunction,
    q: genjax.GenerativeFunction,
    data: genjax.ChoiceMap,
):
    @adevjax.adev
    def loss(p_args: Tuple, q_args: Tuple):
        s1 = sim(q, q_args)
        s2 = sim(q, q_args)
        latents, q_scores = jtu.tree_map(lambda v1, v2: jnp.hstack([v1, v2]), s1, s2)
        
        def score_against_model(proposal):
            observed = proposal.safe_merge(data)
            return density(p, observed, p_args)
            
        p_scores = jax.vmap(score_against_model)(latents)
        return logsumexp(p_scores - q_scores) - jnp.log(2)

    return adevjax.E(loss)

Note that after redefining the objective, we can retrain, without other changes to the code. This is in contrast to systems like Pyro, where defining new objective functions can require involved changes to library internals.

In [17]:
run_experiment(key, data, iwelbo)


Based on the data and our prior belief, the fairness of the coin is 0.543 +- 0.090


Pretty convenient! In this case, the change to the loss doesn't affect the result significantly, but it's better to have the option to explore different objective functions than not!

## Implementing new distributions with gradient strategies

At the top of this file, we had to import the Beta distribution (as `beta_implicit`) to use in our guide:

In [18]:
from extras import beta_implicit

Each primitive distribution in our library comes equipped with some default strategy for estimating gradients of expected values under the distribution in question. 

New primitives, with new gradient estimation strategies, can be added modularly; see the code in `./extras/beta_implicit.py` for an example of how this is done.