# The Semi-Supervised VAE

The semi-supervised setting represents an interesting intermediate case where some of the data is labeled and some is not. It is also of great practical importance, since we often have very little labeled data and much more unlabeled data. We’d clearly like to leverage labeled data to improve our models of the unlabeled data.

The semi-supervised setting is also well suited to generative models, where missing data can be accounted for quite naturally—at least conceptually. As we will see, in restricting our attention to semi-supervised generative models, there will be no shortage of different model variants and possible inference strategies. Although we’ll only be able to explore a few of these variants in detail, hopefully you will come away from the tutorial with a greater appreciation for the abstractions and modularity offered by probabilistic programming.

So let’s go about building a generative model. We have a dataset $D$ with $N$ datapoints,

$ D = $ {(xi, yi)}$  $

where the  {xi} are always observed and the labels {yi} are only observed for some subset of the data. Since we want to be able to model complex variations in the data, we’re going to make this a latent variable model with a local latent variable {zi} private to each pair (xi, yi). Even with this set of choices, a number of model variants are possible: we’re going to focus on the model variant depicted in Figure 1.

<img src="assets/ss_vae_m2.png"  width="180" height="200">

For convenience—and since we’re going to model MNIST in our experiments below—let’s suppose the {xi} are images and the {yi} are digit labels. In this model setup, the latent random variable {zi} and the (partially observed) digit label jointly generate the observed image. The {zi} represents everything but the digit label, possibly handwriting style or position. Let’s sidestep asking when we expect this particular factorization of (xi, yi, zi) to be appropriate, since the answer to that question will depend in large part on the dataset in question (among other things). Let’s instead highlight some of the ways that inference in this model will be challenging as well as some of the solutions that we’ll be exploring in the rest of the tutorial.

# The Challenge of Inference

For concreteness we’re going to continue to assume that the partially-observed {yi} are discrete labels; we will also assume that the {zi} are continuous.

- If we apply the general recipe for stochastic variational inference to our model (see SVI Part I) we would be sampling the discrete (and thus non-reparameterizable) variable {yi} whenever it’s unobserved. As discussed in SVI Part III this will generally lead to high-variance gradient estimates.

- A common way to ameliorate this problem—and one that we’ll explore below—is to forego sampling and instead sum out all ten values of the class label {yi} when we calculate the ELBO for an unlabeled datapoint {xi} . This is more expensive per step, but can help us reduce the variance of our gradient estimator and thereby take fewer steps.

- Recall that the role of the guide is to ‘fill in’ latent random variables. Concretely, one component of our guide will be a digit classifier $q\phi(y|x)$ that will randomly ‘fill in’ labels {yi} given an image {xi}. Crucially, this means that the only term in the ELBO that will depend on $q\phi(.|x)$ is the term that involves a sum over unlabeled datapoints. This means that our classifier $q\phi(.|x)$—which in many cases will be the primary object of interest—will not be learning from the labeled datapoints (at least not directly).

- This seems like a potential problem. Luckily, various fixes are possible. Below we’ll follow the approach in reference, which involves introducing an additional objective function for the classifier to ensure that the classifier learns directly from the labeled data.


In [1]:
import os

import numpy as np
import torch
from pyro.contrib.examples.util import MNIST
import torch.nn as nnp
import torchvision.transforms as transforms

import pyro
import pyro.distributions as dist
import pyro.contrib.examples.util  # patches torchvision
from pyro.infer import SVI, Trace_ELBO
from pyro.optim import Adam

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
assert pyro.__version__.startswith('1.8.4')
pyro.distributions.enable_validation(False)
pyro.set_rng_seed(0)

smoke_test = 'CI' in os.environ

In [3]:
def model(self, xs, ys=None):
    # register this pytorch module and all of its sub-modules with pyro
    pyro.module("ss_vae", self)
    batch_size = xs.size(0)

    # inform Pyro that variables in the batch of xs, ys are conditionally independent
    with pyro.plate("data"):
        # sample the handwritting style from the constant prior distribution
        prior_loc = xs.new_zeros([batch_size, self.z_dim])
        prior_scale = xs.new_ones([batch_size, self.z_dim])
        zs = pyro.sample("z", dist.Normal(prior_loc, prior_scale).to_event(1))

        # if the label y (which digit to write) is supervised, sample from the
        # constant prior, otherwise, observe the value (i.e. score it against the constant prior)
        alpha_prior = xs.new_ones([batch_size, self.output_size]) / (1.0 * self.output_size)
        ys = pyro.sample("y", dist.OneHotCategorical(alpha_prior), obs=ys)

        # finally, score the image (x) using the handwritting style (z) and
        # the class label y (which digit to write) against the
        # parametrized distribution p(x|y, z) = bernouli(decoder(y,z))
        # where 'decoder' is neural network
        loc = self.decoder([zs, ys])
        pyro.sample("x", dist.Bernouli(loc).to_event(1), obs=xs)

def guide(self, xs, ys=None):
    with pyro.plate("data"):
        # if the class label (the digit) is not supervised, sample
        # (and score) the digit with the variational distribution
        # q(y|x) = categorical(alpha(x))
        if ys is None:
            alpha = self.encoder_y(xs)
            ys = pyro.sample("y", dist.OneHotCategorical(alpha))

        # sample (and score) the latent handwritting-style with the variational
        # distribution q(z|x,y) = normal(loc(x,y),scale(x,y))
        loc, scale = self.encoder_z([xs, ys])
        pyro.sample("z", dist.Normal(loc, scale).to_event(1))

## Network Definitions
In our experiments we use the same network configurations as used in reference [1]. The encoder and decoder networks have one hidden layer with 
 hidden units and softplus activation functions. We use softmax as the activation function for the output of encoder_y, sigmoid as the output activation function for decoder and exponentiation for the scale part of the output of encoder_z. The latent dimension is 50.

## MNIST Pre-Processing
We normalize the pixel values to the range [0.0, 1.0]. We use the MNIST data loader from the torchvision library. The testing set consists of 10000 examples. The default training set consists of 600000 examples. We use the first 50000 examples for training (divided into supervised and un-supervised parts) and the remaining 10000 images for validation. For our experiments, we use 4 configurations of supervision in the training set, i.e. we consider 3000, 1000, 600, and 100 supervised examples selected randomly (while ensuring that each class is balanced).

## The Objective Function
The objective function for this model has the two terms:

  <img src="assets/1.png"  width="350" height="80">
 
To implement this in Pyro, we setup a single instance of the SVI class. The two different terms in the objective functions will emerge automatically depending on whether we pass the step method labeled or unlabeled data. We will alternate taking steps with labeled and unlabeled mini-batches, with the number of steps taken for each type of mini-batch depending on the total fraction of data that is labeled. For example, if we have 1,000 labeled images and 49,000 unlabeled ones, then we’ll take 49 steps with unlabeled mini-batches for each labeled mini-batch. (Note that there are different ways we could do this, but for simplicity we only consider this variant.) The code for this setup is given below:

In [4]:
from pyro.infer import SVI, Trace_ELBO, TraceEnum_ELBO, config_enumerate
from pyro.optim import Adam

#setup the optimzer
adam_params = {"lr": 0.0003}
optimizer = Adam(adam_params)

#setup the inference algorithm
svi = SVI(model, guide, optimizer, loss=Trace_ELBO())

## Interlude: Summing Out Discrete Latents

As highlighted in the introduction, when the discrete latent labels $y$ are not observed, the ELBO gradient estimates rely on sampling from $q(y|x)$. These gradient estimates can be very high-variance, especially early in the learning process when the guessed labels are often incorrect. A common approach to reduce variance in this case is to sum out discrete latent variables, replacing the Monte Carlo expectation

  <img src="assets/2.png"  width="350" height="80">

with an explicit sum

  <img src="assets/3.png"  width="350" height="80">

This sum is usually implemented by hand, as in [1], but Pyro can automate this in many cases. To automatically sum out all discrete latent variables (here only $y$), we simply wrap the guide in config_enumerate():

In [5]:
svi = SVI(model, config_enumerate(guide), optimizer, loss = TraceEnum_ELBO(max_plate_nesting=1))

In this mode of operation, each svi.step(...) computes a gradient term for each of the ten latent states of $y$. Although each step is thus 10x more expensive, we’ll see that the lower-variance gradient estimate outweighs the additional cost.

Going beyond the particular model in this tutorial, Pyro supports summing over arbitrarily many discrete latent variables. Beware that the cost of summing is exponential in the number of discrete variables, but is cheap(er) if multiple independent discrete variables are packed into a single tensor (as in this tutorial, where the discrete labels for the entire mini-batch are packed into the single tensor $y$). To use this parallel form of config_enumerate(), we must inform Pyro that the items in a minibatch are indeed independent by wrapping our vectorized code in a with pyro.plate("name") block.

## Second Variant: Standard Objective Function, Better Estimator

Now that we have the tools to sum out discrete latents, we can see if doing so helps our performance. First, as we can see from Figure 3, the test and validation accuracies now evolve much more smoothly over the course of training. More importantly, this single modification improved test accuracy from around 20% to about 90% for the case of 3000 labeled examples. See Table 1 for the full results. This is great, but can we do better?

## Third Variant: Adding a Term to the Objective

For the two variants we’ve explored so far, the classifier $q(y|x)$ doesn’t learn directly from labeled data. As we discussed in the introduction, this seems like a potential problem. One approach to addressing this problem is to add an extra term to the objective so that the classifier learns directly from labeled data. The modified objective function is given by:

  <img src="assets/4.png"  width="350" height="80">

To learn using this modified objective in Pyro we do the following:

- We use a new model and guide pair (see the code snippet below) that corresponds to scoring the observed label $y$ for a given image $x$ against the predictive distribution $q(y|x)$

- We specify the scaling factor $a'$ (aux_loss_multiplier in the code) in the pyro.sample call by making use of poutine.scale. Note that poutine.scale was used to similar effect in the Deep Markov Model to implement KL annealing.

- We create a new SVI object and use it to take gradient steps on the new objective term



In [7]:
def model_classify(self, xs, ys=None):
    pyro.module("ss_vae", self)
    with pyro.plate("data"):
        #this here is the extra term to yield an auxiliary loss
        #that we do gradient descent on
        if ys is not None:
            alpha = self.encoder_y(xs)
            with pyro.poutine.scale(scale=self.aux_loss_multiplier):
                pyro.sample("y_aux", dist.OneHotCategorical(alpha), obs=ys)

def guide_classify(xs, ys):
    # the guide is trivial, since there are no latent random variables
    pass

svi_aux = SVI(model_classify, guide_classify, optimizer, loss=Trace_ELBO())

When we run inference in Pyro with the additional term in the objective, we outperform both previous inference setups. For example, the test accuracy for the case with 3000 labeled examples improves from 90% to 96% (see Figure 4 below and Table 1 in the next section). Note that we used validation accuracy to select the hyperparameter $a'$.

 <img src="assets/5.png"  width="350" height="250">
 <img src="assets/6.png"  width="350" height="250">