# The Semi-Supervised VAE

# Introduction [martin to get this rolling]

Most of the models we've covered in the tutorials are unsupervised:

- [the VAE](http://pyro.ai/examples/vae.html)
- [the DMM](http://pyro.ai/examples/dmm.html)
- [AIR](http://pyro.ai/examples/air.html)

We've also covered a simple supervised model:

- [Bayesian Regression](http://pyro.ai/examples/bayesian_regression.html)

The semi-supervised setting, where some of the data is labeled and some is not, represents an interesting intermediate case. It's also a problem setting that pairs well with generative models, where missing data can be accounted for quite naturally&mdash;at least conceptually.
As we will see, in restricting our attention to semi-supervised generative models, there will be no shortage of different model variants and possible inference strategies. 
Although we'll only be able to explore a few of these variants in detail, the reader is likely to come away from the tutorial with a greater appreciation for how useful the abstractions and modularity offered by probabilistic programming can be in practice.

So let's go about building a generative model. We have a dataset

$$ \mathcal{D} = \{ ({\bf x}_i, {\bf y}_i) \} $$

where the $\{ {\bf x}_i \}$ are always observed and the labels $\{ {\bf y}_i \}$ are only observed for some subset of the data. Since we want  to be able to model complex variations in the data, we're going to make this a latent variable model with a local latent variable ${\bf z}_i$ private to each pair $({\bf x}_i, {\bf y}_i)$. Even with this set of choices, a number of model variants are possible: let's consider the two depicted in Figure 1.


<figure><img src="ss_vae_zoo.png" style="width: 400px;"><figcaption> <font size="+1"><b>Figure 1</b>: two semi-supervised generative models</font></figcaption></figure>

For convenience&mdash;and since we're going to model MNIST in our experiments below&mdash;let's suppose the $\{ {\bf x}_i \}$ are images and the $\{ {\bf y}_i \}$ are digit labels. In the model on the left, the latent random variable ${\bf z}_i$ and the (partially observed) digit label jointly generate the observed image. In the model on the right, the latent random variable ${\bf z}_i$ generates both the observed image _and_ the digit label.

Let's sidestep asking which of the two models we expect to be better, since the answer to that question will depend in large part on the dataset in question. Let's instead highlight some of the ways in which each model makes inference challenging.

## The challenges of inference

For concreteness we're going to continue to assume that the $\{ {\bf y}_i \}$ are discrete labels; we will also assume that the $\{ {\bf z}_i \}$ are continuous. 

If we apply the general recipe for doing stochastic variational inference to the model on the left (e.g. see [SVI Part I](http://pyro.ai/examples/svi_part_i.html) we're going to be sampling the discrete variable ${\bf y}_i$ whenever it's unobserved. For reasons discussed in [SVI Part III](http://pyro.ai/examples/svi_part_i.html) this will generally lead to high-variance gradient estimates. In the model on the right, while we would still need to sample ${\bf y}_i$ whenever it's unobserved, the gradient estimator would be less prone to high-variance, since no other random variables depend on ${\bf y}_i$.

# First Variant: Naive model, naive estimator [rohit]

# Interlude: Summing out discretes [fritz]

# Second Variant: Same model, better estimator [rohit]

# Third Variant: Adding a loss term [rohit]

# Results [rohit]

# Final thoughts

See the full code on [Github](https://github.com/uber/pyro/blob/dev/examples/ssvae.py).

## References

[1] `Semi-supervised Learning with Deep Generative Models`,<br/>&nbsp;&nbsp;&nbsp;&nbsp;
Diederik P. Kingma, Danilo J. Rezende, Shakir Mohamed, Max Welling