In [1]:
# Original tutorial https://pyro.ai/examples/ss-vae.html

#### The Challenges of Inference
For concreteness we’re going to continue to assume that the partially-observed {yi} are discrete labels; we will also assume that the {zi} are continuous.

If we apply the general recipe for stochastic variational inference to our model (see SVI Part I) we would be sampling the discrete (and thus non-reparameterizable) variable yi whenever it’s unobserved. As discussed in SVI Part III this will generally lead to high-variance gradient estimates.
A common way to ameliorate this problem—and one that we’ll explore below—is to forego sampling and instead sum out all ten values of the class label yi when we calculate the ELBO for an unlabeled datapoint xi. This is more expensive per step, but can help us reduce the variance of our gradient estimator and thereby take fewer steps.
Recall that the role of the guide is to ‘fill in’ latent random variables. Concretely, one component of our guide will be a digit classifier qϕ(y|x) that will randomly ‘fill in’ labels {yi} given an image {xi}. Crucially, this means that the only term in the ELBO that will depend on qϕ(⋅|x) is the term that involves a sum over unlabeled datapoints. This means that our classifier qϕ(⋅|x)—which in many cases will be the primary object of interest—will not be learning from the labeled datapoints (at least not directly).
This seems like a potential problem. Luckily, various fixes are possible. Below we’ll follow the approach in reference [1], which involves introducing an additional objective function for the classifier to ensure that the classifier learns directly from the labeled data.

In [None]:
def model(self, xs, ys=None):
    # register this pytorch module and all of its sub-modules with pyro
    pyro.module("ss_vae", self)
    batch_size = xs.size(0)

    # inform Pyro that the variables in the batch of xs, ys are conditionally independent
    with pyro.plate("data"):

        # sample the handwriting style from the constant prior distribution
        prior_loc = xs.new_zeros([batch_size, self.z_dim])
        prior_scale = xs.new_ones([batch_size, self.z_dim])
        zs = pyro.sample("z", dist.Normal(prior_loc, prior_scale).to_event(1))

        # if the label y (which digit to write) is supervised, sample from the
        # constant prior, otherwise, observe the value (i.e. score it against the constant prior)
        alpha_prior = xs.new_ones([batch_size, self.output_size]) / (1.0 * self.output_size)
        ys = pyro.sample("y", dist.OneHotCategorical(alpha_prior), obs=ys)

        # finally, score the image (x) using the handwriting style (z) and
        # the class label y (which digit to write) against the
        # parametrized distribution p(x|y,z) = bernoulli(decoder(y,z))
        # where `decoder` is a neural network
        loc = self.decoder.forward([zs, ys])
        pyro.sample("x", dist.Bernoulli(loc).to_event(1), obs=xs)

def guide(self, xs, ys=None):
    with pyro.plate("data"):
        # if the class label (the digit) is not supervised, sample
        # (and score) the digit with the variational distribution
        # q(y|x) = categorical(alpha(x))
        if ys is None:
            alpha = self.encoder_y.forward(xs)
            ys = pyro.sample("y", dist.OneHotCategorical(alpha))

        # sample (and score) the latent handwriting-style with the variational
        # distribution q(z|x,y) = normal(loc(x,y),scale(x,y))
        loc, scale = self.encoder_z.forward([xs, ys])
        pyro.sample("z", dist.Normal(loc, scale).to_event(1))

In [None]:
# To sum out all the disrete variables: 
svi = SVI(model, config_enumerate(guide), optimizer, loss=TraceEnum_ELBO(max_plate_nesting=1))
