# The Semi-Supervised VAE

## Introduction

Most of the models we've covered in the tutorials are unsupervised:

- [the Variational Autoencoder (VAE)](http://pyro.ai/examples/vae.html)
- [the DMM](http://pyro.ai/examples/dmm.html)
- [AIR](http://pyro.ai/examples/air.html)

We've also covered a simple supervised model:

- [Bayesian Regression](http://pyro.ai/examples/bayesian_regression.html)

The semi-supervised setting, where some of the data is labeled and some is not, represents an interesting intermediate case. It is also of great practical importance, since we are often faced with the situation where we have large amounts of unlabeled data and precious few labeled datapoints. Being able to leverage the unlabeled data to improve our models of the labeled data is something we'd clearly like to be able to do. 

The semi-supervised setting is also a problem setting that pairs well with generative models, where missing data can be accounted for quite naturally&mdash;at least conceptually.
As we will see, in restricting our attention to semi-supervised generative models, there will be no shortage of different model variants and possible inference strategies. 
Although we'll only be able to explore a few of these variants in detail, the reader is likely to come away from the tutorial with a greater appreciation for how useful the abstractions and modularity offered by probabilistic programming can be in practice.

So let's go about building a generative model. We have a dataset 
$\mathcal{D}$ of size $N$,

$$ \mathcal{D} = \{ ({\bf x}_i, {\bf y}_i) \} $$

where the $\{ {\bf x}_i \}$ are always observed and the labels $\{ {\bf y}_i \}$ are only observed for some subset of the data. Since we want  to be able to model complex variations in the data, we're going to make this a latent variable model with a local latent variable ${\bf z}_i$ private to each pair $({\bf x}_i, {\bf y}_i)$. Even with this set of choices, a number of model variants are possible: we're going to focus on the model variant depicted in Figure 1 (this is model M2 in reference [1]).

<figure><img src="ss_vae_m2.png" style="width: 180px;"><figcaption> <font size="+1"><b>Figure 1</b>: our semi-supervised generative model </font>(c.f. model M2 in reference [1])</figcaption></figure>

For convenience&mdash;and since we're going to model MNIST in our experiments below&mdash;let's suppose the $\{ {\bf x}_i \}$ are images and the $\{ {\bf y}_i \}$ are digit labels. In this model setup, the latent random variable ${\bf z}_i$ and the (partially observed) digit label _jointly_ generate the observed image. 
Let's sidestep asking when we expect this particular factorization of $({\bf x}_i, {\bf y}_i, {\bf z}_i)$ to be appropriate, since the answer to that question will depend in large part on the dataset in question (among other things). Let's instead highlight some of the ways that inference in this model will be challenging as well as some of the solutions that we'll be exploring in the rest of the tutorial.

## The challenges of inference

For concreteness we're going to continue to assume that the $\{ {\bf y}_i \}$ are discrete labels; we will also assume that the $\{ {\bf z}_i \}$ are continuous.

- If we apply the general recipe for stochastic variational inference to our model (e.g. see [SVI Part I](http://pyro.ai/examples/svi_part_i.html)) we're going to be sampling the discrete (and thus non-reparameterizable) variable ${\bf y}_i$ whenever it's unobserved. As discussed in [SVI Part III](http://pyro.ai/examples/svi_part_i.html) this will generally lead to high-variance gradient estimates. 
- One way to ameliorate this problem&mdash;and one we'll explore below&mdash;is to forego sampling and instead sum out all ten values of the class label ${\bf y}_i$ when we calculate the ELBO for an unlabeled datapoint ${\bf x}_i$. This is somewhat expensive but can help us reduce the variance of our gradient estimator
- Recall that the role of the guide is to 'fill in' _latent_ random variables. Concretely, one component of our guide will be of the form $q_\phi({\bf y} | {\bf x})$. Any unlabeled datapoints 
$\{ {\bf x}_i \}$ will have their corresponding labels $\{ {\bf y}_i \}$ 'filled in' by $q_\phi(\cdot | {\bf x})$. Crucially, this means that the only term in the ELBO that will depend on $q_\phi(\cdot | {\bf x})$ is the term that involves a sum over _unlabeled_ datapoints. This means that our classifier $q_\phi(\cdot | {\bf x})$&mdash;which in many cases will be the primary object of interest&mdash;will not be learning from the labeled datapoints (at least not directly)
- This seems like a potential problem. Luckily, various fixes are possible. Below we'll follow the approach in reference [1], which involves introducing an additional loss function for the classifier to ensure that the classifier learns directly from the labeled data

We have our work cut out for us so let's get started!

## First Variant: Standard objective function, naive estimator

As discussed in the introduction, we're considering the model depicted in Figure 1. In more detail, the model has the following structure:

<table>
    <col width="300">
    <col width="500">
    <tr>
        <td>$ p({\bf y}) = Cat({\bf y}~|~{\bf \pi})$</td>
        <td > multinomial (or categorical) prior for the class label 
    </tr>
    <tr>
        <td>$ p({\bf z}) = \mathcal{N}({\bf z}~|~{\bf 0,I})$ </td>
        <td> unit normal prior for the latent code $\bf z$
    </tr>
    <tr>
        <td>$ p_{\theta}({\bf x}~|~{\bf z,y}) = Bernoulli\left({\bf x}~|~\mu\left({\bf z,y}\right)\right)$ </td>
        <td> parameterized Bernoulli likelihood function; <br> $\mu\left({\bf z,y}\right)$ corresponds to `nn_mu_x` in the code
    </tr>
</table>

We structure the components of our guide $q_{\phi}(.)$ as follows:

<table>
    <col width="300">
    <col width="500">
    <tr>
        <td>$ q_{\phi}({\bf y}~|~{\bf x}) = Cat({\bf y}~|~{\bf \alpha}_{\phi}\left({\bf x}\right))$</td>
        <td > parameterized multinomial (or categorical) distribution; <br> ${\bf \alpha}_{\phi}\left({\bf x}\right)$ corresponds to `nn_alpha_y` in the code</td>
    </tr>
    <tr>
        <td>$ q_{\phi}({\bf z}~|~{\bf x, y}) = \mathcal{N}({\bf z}~|~{\bf \mu}_{\phi}\left({\bf x, y}\right), {\bf \sigma^2_{\phi}\left(x, y\right)})$</td>
        <td > parameterized normal distribution; <br> ${\bf \mu}_{\phi}\left({\bf x, y}\right)$ and ${\bf \sigma^2_{\phi}\left(x, y\right)}$ correspond to `nn_mu_sigma_z` in the code </td>
    </tr>
</table>
These choices reproduce the structure of model M2 and its corresponding inference network in reference [1].

We translate this model and guide pair into Pyro code below. Note that:
1. The labels `ys`, which are represented with a one-hot encoding, are only partially observed (`None` denotes unobserved values).
2. `model()` handles both the observed and unobserved case.
3. The code assumes that `xs` and `ys` are mini-batches of images and labels, respectively, with the size of each batch denoted by `batch_size`. 

In [8]:
def model(xs, ys=None):
    # sample the handwriting style from the prior 
    prior_mu = Variable(torch.zeros([batch_size, z_dim]))
    prior_sigma = Variable(torch.ones([batch_size, z_dim]))
    zs = pyro.sample("z", dist.normal, prior_mu, prior_sigma)

    # if the label y is observed, sample from the prior.
    # otherwise, observe the value 
    alpha_prior = Variable(torch.ones([batch_size, 10]) / (10.))
    if ys is None:
        ys = pyro.sample("y", dist.categorical, alpha_prior)
    else:
        pyro.observe("y", dist.categorical, ys, alpha_prior)

    # finally, score the image x against the
    # parameterized distribution p(x|y,z) = bernoulli(nn_mu_x(y,z))
    mu = nn_mu_x.forward([zs, ys])
    pyro.observe("x", dist.bernoulli, xs, mu)

def guide(self, xs, ys=None):
    # if the class label is not observed, sample
    # with the variational distribution
    # q(y|x) = categorical(alpha(x))
    if ys is None:
        alpha = nn_alpha_y.forward(xs)
        ys = pyro.sample("y", dist.categorical, alpha)

    # sample the latent z with the variational
    # distribution q(z|x,y) = normal(mu(x,y),sigma(x,y))
    mu, sigma = nn_mu_sigma_z.forward([xs, ys])
    zs = pyro.sample("z", dist.normal, mu, sigma)

### Network definitions

In our experiments we use the same network configurations as used in reference [1]. The encoder and decoder networks have one hidden layer with $500$ hidden units and softplus activation functions. We use softmax as the activation function for the output of `nn_alpha_y`, sigmoid as the output activation function for `nn_mu_x` and exponentiation for the sigma part of the output of `nn_mu_sigma_z`. The latent dimension is 50.


### MNIST Pre-processing

We normalize the pixel values to the range $[0.0, 1.0]$. We use the [MNIST data loader](http://pytorch.org/docs/0.2.0/_modules/torchvision/datasets/mnist.html) from the torchvision library. The testing set consists of $10000$ examples. The default training set consists of $60000$ examples. We use the first $50000$ examples for training (divided into supervised and un-supervised parts) and the remaining $10000$ images for validation. For our experiments, we use $4$ configurations of supervision in the training set, i.e. we consider $3000$, $1000$, $600$ and $100$ supervised examples selected randomly (while ensuring that each class is balanced).

### The objective function

The objective function for this model has the two terms (c.f. Eqn. 8 in reference [1]):

$$\mathcal{J} = \!\!\sum_{({\bf x,y}) \in \mathcal{D}_{supervised} } \!\!\!\!\!\!\!\!\mathcal{L}\big({\bf x,y}\big) +\!\!\! \sum_{{\bf x} \in \mathcal{D}_{unsupervised}} \!\!\!\!\!\!\!\mathcal{U}\left({\bf x}\right)
$$

To implement this in Pyro, we setup a single instance of the `SVI` class. The two different terms in the objective functions will emerge automatically depending on whether we pass the `step` method labeled or unlabeled data. We will alternate taking steps with labeled and unlabeled mini-batches, with the number of steps taken for each type of mini-batch depending on the total fraction of data that is labeled. For example, if we have 1,000 labeled images and 49,000 unlabeled ones, then we'll take 49 steps with unlabeled mini-batches for each labeled mini-batch. The code for this setup is given below:

In [11]:
from pyro.infer import SVI
from pyro.optim import Adam

# hyper-parameters
learning_rate, beta_1, beta_2 = 0.0003, 0.9, 0.999

# setup the optimizer
adam_params = {"lr": learning_rate, "betas": (beta_1, beta_2)}
optimizer = Adam(adam_params)

# setup the inference algorithm
svi = SVI(model, guide, optimizer, loss="ELBO")

When we run this inference in Pyro, the performance during testing is degraded by the noise inherent in the sampling of the categorical variables, as can be seen in Table 1. 

## Interlude: Summing out discrete latents

As highlighted in the introduction, when the discrete latent labels ${\bf y}$ are not observed, the ELBO gradient estimates rely on sampling from $q_\phi({\bf y}|{\bf x})$. These gradient estimates can be very high-variance, especially early in the learning process when the guessed labels are often incorrect. A common approach to reduce variance in this case is to sum out discrete latent variables, replacing the Monte Carlo expectation 

$$\mathbb E_{{\bf y}\sim q_\phi(\cdot|{\bf x})}\nabla\operatorname{ELBO}$$

with an explicit sum 

$$\sum_{\bf y} q_\phi({\bf y}|{\bf x})\nabla\operatorname{ELBO}$$

This sum is usually implemented by hand, as in [1], but Pyro can automate this in many cases. To automatically sum out all discrete latent variables (here only ${\bf y}$), we simply pass the `enum_discrete=True` argument to `SVI()`:
```python
svi = SVI(model, guide, optim, loss="ELBO", enum_discrete=True)
```
In this mode of operation, each `svi.step(...)` computes a gradient term for each of the ten latent states of $y$. Although each step is thus $10\times$ more expensive, we'll see that the lower-variance gradient estimate outweighs the additional cost.

Beyond the scope of the model in this tutorial, Pyro supports summing over arbitrarily many discrete latent variables. Beware that the cost of summing is exponential in the number of discrete variables, but is cheap(er) if multiple independent discrete variables are packed into a single tensor (as in this tutorial, where the discrete labels for the entire mini-batch are packed into the single tensor ${\bf y}$).

## Second Variant: Standard objective function, better estimator [rohit]

Merge with above

comment about the results

## Third Variant: Adding a term to the objective

alpha * log q(y|x)

(1) scaling/annealing:
log_pdf_mask (dmm reference)

(2) how to add this loss - extra model and guide + new SVI object

TODO: decide if we want to refer to the other way in the code 


## Results

3 tables 

Quick comment about comparison -- reproduce the results

For the best - loss / accuracy plots

STRETCH: Plot accuracy mean vs frction of supervised data (need ~8 dsta points)

## Final thoughts

We've seen that generative models offer a natural approach to semi-supervised machine learning. One of the most attractive features of generative models is that we can explore a large variety of models in a single unified setting. In this tutorial we've only been able to explore a small fraction of the possible model and inference setups that are possible. There is no reason to expect that one variant is best; depending on the dataset and application, there will be reason to prefer one over another. And there a lot of variants (see Figure 2)!

<figure><img src="ss_vae_zoo.png" style="width: 300px;"><figcaption> <font size="+1"><b>Figure 2</b>: A zoo of semi-supervised generative models </font></figcaption></figure>

Some of these variants clearly make more sense than others, but a priori it's difficult to know which ones are worth trying out. This is especially true once we open the door to more complicated setups, like the two models at the bottom of the figure, which include an always latent random variable ${\bf \tilde{y}}$ in addition to the partially observed label ${\bf y}$. (Incidentally, this class of models&mdash;see reference [2] for similar variants&mdash;offers another potential solution to the 'no training' problem that we identified above.)

The reader probably doesn't need any convincing that a systematic exploration of even a fraction of these options would be incredibly time-consuming and error-prone if each model and each inference procedure were coded up by scratch. It's only with the modularity and abstraction made possible by a probabilistic programming system that we can hope to explore the landscape of generative models with any kind of nimbleness&mdash;and reap any awaiting rewards.

See the full code on [Github](https://github.com/uber/pyro/blob/dev/examples/ssvae.py).

## References

[1] `Semi-supervised Learning with Deep Generative Models`,<br/>&nbsp;&nbsp;&nbsp;&nbsp;
Diederik P. Kingma, Danilo J. Rezende, Shakir Mohamed, Max Welling

[2] `Learning Disentangled Representations with Semi-Supervised Deep Generative Models`,
<br/>&nbsp;&nbsp;&nbsp;&nbsp;
N. Siddharth, Brooks Paige, Jan-Willem Van de Meent, Alban Desmaison, Frank Wood, <br/>&nbsp;&nbsp;&nbsp;&nbsp;
Noah D. Goodman, Pushmeet Kohli, Philip H.S. Torr