In [1]:
# Tutorial found here: https://pyro.ai/examples/svi_part_iii.html
"""
 In particular as long as we use a TraceGraph_ELBO loss,
 Pyro will keep track of the dependency structure within the execution traces
 of the model and guide and construct a surrogate objective that has all the unnecessary terms removed:
 
 svi = SVI(model, guide, optimizer, TraceGraph_ELBO())

 Note that leveraging this dependency information takes extra computations,
 so TraceGraph_ELBO should only be used in the case where your model
 has non-reparameterizable random variables; in most applications Trace_ELBO suffices.
"""

'\n In particular as long as we use a TraceGraph_ELBO loss,\n Pyro will keep track of the dependency structure within the execution traces\n of the model and guide and construct a surrogate objective that has all the unnecessary terms removed:\n \n svi = SVI(model, guide, optimizer, TraceGraph_ELBO())\n\n Note that leveraging this dependency information takes extra computations,\n so TraceGraph_ELBO should only be used in the case where your model\n has non-reparameterizable random variables; in most applications Trace_ELBO suffices.\n'

In [2]:
"""
To take advantage of Rao-Blackwellization
the user needs to explicitly mark the conditional independence.

# mark conditional independence
# (assumed to be along the rightmost tensor dimension)
with pyro.plate("foo", data.size(-1)):
    ks = pyro.sample("k", dist.Categorical(probs))
    pyro.sample("obs", dist.Normal(locs[ks], scale),
                obs=data)

"""

'\nTo take advantage of Rao-Blackwellization\nthe user needs to explicitly mark the conditional independence.\n\n# mark conditional independence\n# (assumed to be along the rightmost tensor dimension)\nwith pyro.plate("foo", data.size(-1)):\n    ks = pyro.sample("k", dist.Categorical(probs))\n    pyro.sample("obs", dist.Normal(locs[ks], scale),\n                obs=data)\n\n'

In [3]:
"""
Reducing Variance with Data-Dependent Baselines

There are several ways the user can instruct Pyro to use baselines in
the context of stochastic variational inference. Since baselines can
be attached to any non-reparameterizable random variable, the current
baseline interface is at the level of the pyro.sample statement.
In particular the baseline interface makes use of an argument baseline
, which is a dictionary that specifies baseline options. 
Note that it only makes sense to specify baselines for sample
statements within the guide (and not in the model).

z = pyro.sample("z", dist.Bernoulli(...),
                infer=dict(baseline={'use_decaying_avg_baseline': True,
                                     'baseline_beta': 0.95}))
                                     
We can also use a neural network to compute the baseline e.g.: 

class BaselineNN(nn.Module):
    def __init__(self, dim_input, dim_hidden):
        super(BaselineNN, self).__init__()
        self.linear = nn.Linear(dim_input, dim_hidden)
        # ... finish initialization ...

    def forward(self, x):
        hidden = self.linear(x)
        # ... do more computations ...
        return baseline
        
        
Then, assuming the BaselineNN object baseline_module
has been initialized somewhere else, in the guide we’ll have something like

def guide(x):  # here x is the current mini-batch of data
    pyro.module("my_baseline", baseline_module)
    # ... other computations ...
    z = pyro.sample("z", dist.Bernoulli(...),
                    infer=dict(baseline={'nn_baseline': baseline_module,
                                         'nn_baseline_input': x}))
                                         
Note that the baseline module needs to be registered with Pyro with a pyro.module
call so that Pyro is aware of the trainable parameters within the module.
"""

'\nReducing Variance with Data-Dependent Baselines\n\nThere are several ways the user can instruct Pyro to use baselines in\nthe context of stochastic variational inference. Since baselines can\nbe attached to any non-reparameterizable random variable, the current\nbaseline interface is at the level of the pyro.sample statement.\nIn particular the baseline interface makes use of an argument baseline\n, which is a dictionary that specifies baseline options. \nNote that it only makes sense to specify baselines for sample\nstatements within the guide (and not in the model).\n\nz = pyro.sample("z", dist.Bernoulli(...),\n                infer=dict(baseline={\'use_decaying_avg_baseline\': True,\n                                     \'baseline_beta\': 0.95}))\n                                     \nWe can also use a neural network to compute the baseline e.g.: \n\nclass BaselineNN(nn.Module):\n    def __init__(self, dim_input, dim_hidden):\n        super(BaselineNN, self).__init__()\n        

In [4]:
import os
import torch
import torch.distributions.constraints as constraints
import pyro
import pyro.distributions as dist
# Pyro also has a reparameterized Beta distribution so we import
# the non-reparameterized version to make our point
from pyro.distributions.testing.fakes import NonreparameterizedBeta
import pyro.optim as optim
from pyro.infer import SVI, TraceGraph_ELBO
import sys

# enable validation (e.g. validate parameters of distributions)
assert pyro.__version__.startswith('1.0.0')
pyro.enable_validation(True)

In [5]:
def param_abs_error(name, target):
    return torch.sum(torch.abs(target - pyro.param(name))).item()


In [6]:
class BernoulliBetaExample(object):
    def __init__(self, max_steps):
        # the maximum number of inference steps we do
        self.max_steps = max_steps
        # the two hyperparameters for the beta prior
        self.alpha0 = 10.0
        self.beta0 = 10.0
        # the dataset consists of six 1s and four 0s
        self.data = torch.zeros(10)
        self.data[0:6] = torch.ones(6)
        self.n_data = self.data.size(0)
        # compute the alpha parameter of the exact beta posterior
        self.alpha_n = self.data.sum() + self.alpha0
        # compute the beta parameter of the exact beta posterior
        self.beta_n = - self.data.sum() + torch.tensor(self.beta0 + self.n_data)
        # initial values of the two variational parameters
        self.alpha_q_0 = 15.0
        self.beta_q_0 = 15.0

    def model(self, use_decaying_avg_baseline):
        # sample `latent_fairness` from the beta prior
        f = pyro.sample("latent_fairness", dist.Beta(self.alpha0, self.beta0))
        # use plate to indicate that the observations are
        # conditionally independent given f and get vectorization
        with pyro.plate("data_plate"):
            # observe all ten datapoints using the bernoulli likelihood
            pyro.sample("obs", dist.Bernoulli(f), obs=self.data)

    def guide(self, use_decaying_avg_baseline):
        # register the two variational parameters with pyro
        alpha_q = pyro.param("alpha_q", torch.tensor(self.alpha_q_0),
                             constraint=constraints.positive)
        beta_q = pyro.param("beta_q", torch.tensor(self.beta_q_0),
                            constraint=constraints.positive)
        # sample f from the beta variational distribution
        baseline_dict = {'use_decaying_avg_baseline': use_decaying_avg_baseline,
                         'baseline_beta': 0.90}
        # note that the baseline_dict specifies whether we're using
        # decaying average baselines or not
        pyro.sample("latent_fairness", NonreparameterizedBeta(alpha_q, beta_q),
                    infer=dict(baseline=baseline_dict))

    def do_inference(self, use_decaying_avg_baseline, tolerance=0.80):
        # clear the param store in case we're in a REPL
        pyro.clear_param_store()
        # setup the optimizer and the inference algorithm
        optimizer = optim.Adam({"lr": .0005, "betas": (0.93, 0.999)})
        svi = SVI(self.model, self.guide, optimizer, loss=TraceGraph_ELBO())
        print("Doing inference with use_decaying_avg_baseline=%s" % use_decaying_avg_baseline)

        # do up to this many steps of inference
        for k in range(self.max_steps):
            svi.step(use_decaying_avg_baseline)
            if k % 100 == 0:
                print('.', end='')
                sys.stdout.flush()

            # compute the distance to the parameters of the true posterior
            alpha_error = param_abs_error("alpha_q", self.alpha_n)
            beta_error = param_abs_error("beta_q", self.beta_n)

            # stop inference early if we're close to the true posterior
            if alpha_error < tolerance and beta_error < tolerance:
                break

        print("\nDid %d steps of inference." % k)
        print(("Final absolute errors for the two variational parameters " +
               "were %.4f & %.4f") % (alpha_error, beta_error))

In [7]:
bbe = BernoulliBetaExample(max_steps=max_steps)
bbe.do_inference(use_decaying_avg_baseline=True)
bbe.do_inference(use_decaying_avg_baseline=False)

Doing inference with use_decaying_avg_baseline=True
..
Did 106 steps of inference.
Final absolute errors for the two variational parameters were 0.7987 & 0.7732
Doing inference with use_decaying_avg_baseline=False
........
Did 724 steps of inference.
Final absolute errors for the two variational parameters were 0.7974 & 0.7822
