# The Semi-Supervised VAE

# Introduction

Most of the models we've covered in the tutorials are unsupervised:

- [the VAE](http://pyro.ai/examples/vae.html)
- [the DMM](http://pyro.ai/examples/dmm.html)
- [AIR](http://pyro.ai/examples/air.html)

We've also covered a simple supervised model:

- [Bayesian Regression](http://pyro.ai/examples/bayesian_regression.html)

The semi-supervised setting, where some of the data is labeled and some is not, represents an interesting intermediate case. It is also of great practical importance, since we are often faced with the situation where we have large amounts of unlabeled data and precious few labeled datapoints. Being able to leverage the unlabeled data to improve our models of the labeled data is something we'd clearly like to be able to do. 

The semi-supervised setting is also a problem setting that pairs well with generative models, where missing data can be accounted for quite naturally&mdash;at least conceptually.
As we will see, in restricting our attention to semi-supervised generative models, there will be no shortage of different model variants and possible inference strategies. 
Although we'll only be able to explore a few of these variants in detail, the reader is likely to come away from the tutorial with a greater appreciation for how useful the abstractions and modularity offered by probabilistic programming can be in practice.

So let's go about building a generative model. We have a dataset 
$\mathcal{D}$ of size $N$,

$$ \mathcal{D} = \{ ({\bf x}_i, {\bf y}_i) \} $$

where the $\{ {\bf x}_i \}$ are always observed and the labels $\{ {\bf y}_i \}$ are only observed for some subset of the data. Since we want  to be able to model complex variations in the data, we're going to make this a latent variable model with a local latent variable ${\bf z}_i$ private to each pair $({\bf x}_i, {\bf y}_i)$. Even with this set of choices, a number of model variants are possible: we're primarily going to focus on the model variant depicted in Figure 1 (this is model M2 in reference [1]).

<figure><img src="ss_vae_m2.png" style="width: 180px;"><figcaption> <font size="+1"><b>Figure 1</b>: our semi-supervised generative model </font>(c.f. model M2 in reference [1])</figcaption></figure>

For convenience&mdash;and since we're going to model MNIST in our experiments below&mdash;let's suppose the $\{ {\bf x}_i \}$ are images and the $\{ {\bf y}_i \}$ are digit labels. In this model setup, the latent random variable ${\bf z}_i$ and the (partially observed) digit label _jointly_ generate the observed image. 
Let's sidestep asking when we expect this particular factorization of $({\bf x}_i, {\bf y}_i, {\bf z}_i)$ to be appropriate, since the answer to that question will depend in large part on the dataset in question (among other things). Let's instead highlight some of the ways that inference in this model will be challenging as well as some of the solutions that we'll be exploring in the rest of the tutorial.

## The challenges of inference

For concreteness we're going to continue to assume that the $\{ {\bf y}_i \}$ are discrete labels; we will also assume that the $\{ {\bf z}_i \}$ are continuous.

- If we apply the general recipe for stochastic variational inference to our model (e.g. see [SVI Part I](http://pyro.ai/examples/svi_part_i.html)) we're going to be sampling the discrete (and thus non-reparameterizable) variable ${\bf y}_i$ whenever it's unobserved. As discussed in [SVI Part III](http://pyro.ai/examples/svi_part_i.html) this will generally lead to high-variance gradient estimates. 
- One way to ameliorate this problem&mdash;and one we'll explore below&mdash;is to forego sampling and instead sum out all ten values of the class label ${\bf y}_i$ when we calculate the ELBO for an unlabeled datapoint ${\bf x}_i$. This is somewhat expensive but can help us reduce the variance of our gradient estimator
- Recall that the role of the guide is to 'fill in' _latent_ random variables. Concretely, one component of our guide will be of the form $q_\phi({\bf y} | {\bf x})$. Any unlabeled datapoints 
$\{ {\bf x}_i \}$ will have their corresponding labels $\{ {\bf y}_i \}$ 'filled in' by $q_\phi(\cdot | {\bf x})$. Crucially, this means that the only term in the ELBO that will depend on $q_\phi(\cdot | {\bf x})$ is the term that involves a sum over _unlabeled_ datapoints. This means that our classifier $q_\phi(\cdot | {\bf x})$&mdash;which in many cases will be the primary object of interest&mdash;will not be learning from the labeled datapoints (at least not directly)
- This seems like a potential problem. Luckily, various fixes are possible. Below we'll follow the approach in reference [1], which involves introducing an additional loss function for the classifier to ensure that the classifier learns directly from the labeled data

We have our work cut out for us so let's get to work!

# First Variant: Naive model, naive estimator [rohit]

# Interlude: Summing out discrete latents

When the discrete latent labels ${\bf y}$ are not observed, the gradient of the ELBO is estimated by sampling from $q_\phi({\bf y}|{\bf x})$. These sampled gradients can be very high-variance, especially early in the learning process when the guessed labels are often incorrect. A common approach to reduce variance in this case is to sum out these discrete latent variables, replacing the Monte Carlo expectation $\mathbb E_{{\bf y}\sim q_\phi(\cdot|{\bf x})}\nabla\operatorname{ELBO}$ with an explicit sum $\sum_{\bf y} q_\phi({\bf y}|{\bf x})\nabla\operatorname{ELBO}$. This sum is often implemented by hand, as in [1], but Pyro can automate this in many cases. To automatically sum out all discrete latent variables (here only ${\bf y}$), we pass the `enum_discrete=True` argument to `SVI()`:
```python
inference = SVI(model, guide, optim, loss="ELBO", enum_discrete=True)
```
In this mode of operation, each `inference.step(...)` computes a gradient term for each of the ten latent states of $y$. Although each step is thus $10\times$ more expensive, the lower-variance gradient estimates usually outweigh the per-step cost. In this tutorial we found that `enum_discrete=True` alone improved accuracy from a very poor `10-20%` to a much better `80-90%` (both after convergence).

# Second Variant: Same model, better estimator [rohit]

# Third Variant: Adding a loss term [rohit]

# Results [rohit]

# Final thoughts

We've seen that generative models offer a natural approach to semi-supervised machine learning. One of the most attractive features of generative models is that we can explore a large variety of models in a single unified setting. In this tutorial we've only been able to explore a small fraction of the possible model and inference setups that are possible. There is no reason to expect that one variant is best; depending on the dataset and application, there will be reason to prefer one over another. And there a lot of variants (see Figure 2)!

<figure><img src="ss_vae_zoo.png" style="width: 300px;"><figcaption> <font size="+1"><b>Figure 2</b>: A zoo of semi-supervised generative models </font></figcaption></figure>

Some of these variants clearly make more sense than others, but a priori it's difficult to know which ones are worth trying out. This is especially true once we open the door to more complicated setups, like the two models at the bottom of the figure, which include an always latent random variable ${\bf \tilde{y}}$ in addition to the partially observed label ${\bf y}$. (Incidentally, this class of models&mdash;see reference [2] for similar variants&mdash;offers another potential solution to the 'no training' problem that we identified above.)

The reader probably doesn't need any convincing that a systematic exploration of even a fraction of these options would be incredibly time-consuming and error-prone if each model and each inference procedure were coded up by scratch. It's only with the modularity and abstraction made possible by a probabilistic programming system that we can hope to explore the landscape of generative models with any kind of nimbleness&mdash;and reap any awaiting rewards.

See the full code on [Github](https://github.com/uber/pyro/blob/dev/examples/ssvae.py).

## References

[1] `Semi-supervised Learning with Deep Generative Models`,<br/>&nbsp;&nbsp;&nbsp;&nbsp;
Diederik P. Kingma, Danilo J. Rezende, Shakir Mohamed, Max Welling

[2] `Learning Disentangled Representations with Semi-Supervised Deep Generative Models`,
<br/>&nbsp;&nbsp;&nbsp;&nbsp;
N. Siddharth, Brooks Paige, Jan-Willem Van de Meent, Alban Desmaison, Frank Wood, <br/>&nbsp;&nbsp;&nbsp;&nbsp;
Noah D. Goodman, Pushmeet Kohli, Philip H.S. Torr