# The Semi-Supervised VAE

# Introduction

Most of the models we've covered in the tutorials are unsupervised:

- [the VAE](http://pyro.ai/examples/vae.html)
- [the DMM](http://pyro.ai/examples/dmm.html)
- [AIR](http://pyro.ai/examples/air.html)

We've also covered a simple supervised model:

- [Bayesian Regression](http://pyro.ai/examples/bayesian_regression.html)

The semi-supervised setting, where some of the data is labeled and some is not, represents an interesting intermediate case. It is also of great practical importance, since we are often faced with the situation where we have large amounts of unlabeled data and precious few labeled datapoints. Being able to leverage the unlabeled data to improve our models of the labeled data is something we'd clearly like to be able to do. 

The semi-supervised setting is also a problem setting that pairs well with generative models, where missing data can be accounted for quite naturally&mdash;at least conceptually.
As we will see, in restricting our attention to semi-supervised generative models, there will be no shortage of different model variants and possible inference strategies. 
Although we'll only be able to explore a few of these variants in detail, the reader is likely to come away from the tutorial with a greater appreciation for how useful the abstractions and modularity offered by probabilistic programming can be in practice.

So let's go about building a generative model. We have a dataset 
$\mathcal{D}$ of size $N$,

$$ \mathcal{D} = \{ ({\bf x}_i, {\bf y}_i) \} $$

where the $\{ {\bf x}_i \}$ are always observed and the labels $\{ {\bf y}_i \}$ are only observed for some subset of the data. Since we want  to be able to model complex variations in the data, we're going to make this a latent variable model with a local latent variable ${\bf z}_i$ private to each pair $({\bf x}_i, {\bf y}_i)$. Even with this set of choices, a number of model variants are possible: we're primarily going to focus on the model variant depicted in Figure 1 (this is model M2 in reference [1]).

<figure><img src="ss_vae_m2.png" style="width: 180px;"><figcaption> <font size="+1"><b>Figure 1</b>: our semi-supervised generative model </font>(c.f. model M2 in reference [1])</figcaption></figure>

For convenience&mdash;and since we're going to model MNIST in our experiments below&mdash;let's suppose the $\{ {\bf x}_i \}$ are images and the $\{ {\bf y}_i \}$ are digit labels. In this model setup, the latent random variable ${\bf z}_i$ and the (partially observed) digit label _jointly_ generate the observed image. 
Let's sidestep asking when we expect this particular factorization of $({\bf x}_i, {\bf y}_i, {\bf z}_i)$ to be appropriate, since the answer to that question will depend in large part on the dataset in question (among other things). Let's instead highlight some of the ways that inference in this model will be challenging as well as some of the solutions that we'll be exploring in the rest of the tutorial.

## The challenges of inference

For concreteness we're going to continue to assume that the $\{ {\bf y}_i \}$ are discrete labels; we will also assume that the $\{ {\bf z}_i \}$ are continuous.

- If we apply the general recipe for stochastic variational inference to our model (e.g. see [SVI Part I](http://pyro.ai/examples/svi_part_i.html)) we're going to be sampling the discrete (and thus non-reparameterizable) variable ${\bf y}_i$ whenever it's unobserved. As discussed in [SVI Part III](http://pyro.ai/examples/svi_part_i.html) this will generally lead to high-variance gradient estimates. 
- One way to ameliorate this problem&mdash;and one we'll explore below&mdash;is to forego sampling and instead sum out all ten values of the class label ${\bf y}_i$ when we calculate the ELBO for an unlabeled datapoint ${\bf x}_i$. This is somewhat expensive but can help us reduce the variance of our gradient estimator
- Recall that the role of the guide is to 'fill in' _latent_ random variables. Concretely, one component of our guide will be of the form $q_\phi({\bf y} | {\bf x})$. Any unlabeled datapoints 
$\{ {\bf x}_i \}$ will have their corresponding labels $\{ {\bf y}_i \}$ 'filled in' by $q_\phi(\cdot | {\bf x})$. Crucially, this means that the only term in the ELBO that will depend on $q_\phi(\cdot | {\bf x})$ is the term that involves a sum over _unlabeled_ datapoints. This means that our classifier $q_\phi(\cdot | {\bf x})$&mdash;which in many cases will be the primary object of interest&mdash;will not be learning from the labeled datapoints (at least not directly)
- This seems like a potential problem. Luckily, various fixes are possible. Below we'll follow the approach in reference [1], which involves introducing an additional loss function for the classifier to ensure that the classifier learns directly from the labeled data

We have our work cut out for us so let's get to work!

# First Variant: Naive model, naive estimator [rohit]

We consider the probabilistic generative model M2 (from reference [1]) that corresponds to:


<table>
    <col width="300">
    <col width="500">
    <tr>
        <td>$ p({\bf y}) = Cat({\bf y}~|~{\bf \pi})$</td>
        <td > multinomial (or categorical) distribution for the class label <br> which digit does an image ${\bf x}$ correspond to in the context of the MNIST dataset</td>
    </tr>
    <tr>
        <td>$ p({\bf z}) = \mathcal{N}({\bf z}~|~{\bf 0,I})$ </td>
        <td> normal distribution <br> this is the latent handwriting style in the context of the MNIST dataset </td>
    </tr>
    <tr>
        <td>$ p_{\theta}({\bf x}~|~{\bf z,y}) = Bernoulli\left({\bf x}~|~\mu\left({\bf z,y}\right)\right)$ </td>
        <td> parametrized Bernoulli likelihood function; <br> $\mu\left({\bf z,y}\right)$ is given by a  neural network `nn_mu_x` in the code below; <br> this is the flattened pixels of the image in the context of the MNIST dataset</td>
    </tr>
</table>

We construct the guide $q_{\phi}(.)$ as the following recognition model (again from  reference [1]):

<table>
    <col width="300">
    <col width="500">
    <tr>
        <td>$ q_{\phi}({\bf y}~|~{\bf x}) = Cat({\bf y}~|~{\bf \alpha}_{\phi}\left({\bf x}\right))$</td>
        <td > parametrized multinomial (or categorical) distribution; <br> ${\bf \alpha}_{\phi}\left({\bf x}\right)$ is given by a neural network `nn_alpha_y` in the code below </td>
    </tr>
    <tr>
        <td>$ q_{\phi}({\bf z}~|~{\bf x, y}) = \mathcal{N}({\bf z}~|~{\bf \mu}_{\phi}\left({\bf x, y}\right), diag\left({\bf \sigma^2_{\phi}\left(x, y\right)}\right))$</td>
        <td > parametrized normal distribution; <br> ${\bf \mu}_{\phi}\left({\bf x, y}\right)$ and ${\bf \sigma^2_{\phi}\left(x, y\right)}$ are given by a neural network `nn_mu_sigma_z` in the code below </td>
    </tr>
</table>


We present these model and guide written in Pyro below. Note that:
1. The labels `ys` may not always be observed (`None` denotes an un-observed value). The labels are represented as one-hots i.e. each label is represented as a one-hot with $10$ bits and only one of those bits being set to $1$.
2. We denote the number of dimensions used for the latent variable ${\bf z}$ in the code by `latent_layer` 
3. The code below works when `xs` and `ys` are batches of images and labels respectively. The size of each batch is denoted by `batch_size` in the code.  

In [3]:
    import torch
    import pyro
    from torch.autograd import Variable
    import pyro.distributions as dist

    def model(xs, ys=None):
        # sample the handwriting style from the constant prior distribution
        const_mu = Variable(torch.zeros([batch_size, latent_layer]))
        const_sigma = Variable(torch.ones([batch_size, latent_layer]))
        zs = pyro.sample("z", dist.normal, const_mu, const_sigma)

        # if the label y (which digit to write) is observed, sample from the
        # constant prior, otherwise, observe the value (i.e. score it against the constant prior)
        alpha_prior = Variable(torch.ones([batch_size, 10]) / (10.))
        if ys is None:
            ys = pyro.sample("y", dist.categorical, alpha_prior)
        else:
            pyro.observe("y", dist.categorical, ys, alpha_prior)

        # finally, score the image (x) using the handwriting style (z) and
        # the class label y (which digit to write) against the
        # parametrized distribution p(x|y,z) = bernoulli(nn_mu_x(y,z))
        # where nn_mu_x is a neural network
        mu = nn_mu_x.forward([zs, ys])
        pyro.observe("x", dist.bernoulli, xs, mu)

    def guide(self, xs, ys=None):
        # if the class label (the digit) is not observed, sample
        # (and score) the digit with the variational distribution
        # q(y|x) = categorical(alpha(x))
        if ys is None:
            alpha = nn_alpha_y.forward(xs)
            ys = pyro.sample("y", dist.categorical, alpha)

        # sample (and score) the latent handwriting-style with the variational
        # distribution q(z|x,y) = normal(mu(x,y),sigma(x,y))
        mu, sigma = nn_mu_sigma_z.forward([xs, ys])
        zs = pyro.sample("z", dist.normal, mu, sigma)


## Network definitions

In our experiments we use the same network configurations as used in reference [1] i.e. we use a $50$-dimensional latent variable ${\bf z}$ and MLPs (multi-layer perceptrons or feed-forward networks) for parameters in the model and the guide. Each MLP is constructed with one hidden layer with $500$ hidden units and softplus activation functions. We use softmax as the activation function for the output of `nn_alpha_y`, sigmoid as the output activation function for `nn_mu_x` and exponentiation for the sigma part of the output of `nn_mu_sigma_z`. 


## Data processing

For MNIST data, we flatten the pixels of all images ($28\times28$) to a tensor with $784$ pixel values. Then, we normalize the pixel values by dividing each value by $\frac{1}{255}$ so that each value is between $0$ and $1$. We transform each class label (originally an integer between $0$ and $9$) to a $10$-dimensional one-hot tensor.

We use the <a href="http://pytorch.org/docs/0.2.0/_modules/torchvision/datasets/mnist.html"> MNIST dataset </a> from torchvision library. The testing set consists of $10,000$ examples. The default training set consists of $60000$ examples. We use first $50000$ examples for training (divided into supervised and un-supervised parts) and the rest $10000$ images for validation. For our experiments, we use $4$ configurations of supervision in the training set i.e. we consider $3000$, $1000$, $600$ and $100$ supervised examples selected randomly while ensuring that they have equal number of examples from each class. 


## ELBO objective to be optimized

The loss expression (-ELBO) for this model has the following structure (equation (8) from reference [1]):

$$
    \mathcal{J} = \sum_{({\bf x,y}) \in \mathcal{D}_{supervised} } \mathcal{L}\big({\bf x,y}\big) + \sum_{{\bf x} \in \mathcal{D}_{unsupervised}} \mathcal{U}\left({\bf x}\right)
$$

To run this optimization in Pyro, we setup two losses (one for the supervised batches and one for the un-supervised ones) using the stochastic variational inference (SVI) class from Pyro. Depending on the franction of supervised data, we take steps on the either loss prediodically e.g. if we have 1,000 labeled images and 49,000 unlabeled ones and each batch contains 100 images then every 50th batch would be a supervised batch and we'll take a step on the supervised loss for this batch. The code for this setup is given below:

In [4]:
from pyro.infer import SVI
from pyro.optim import Adam

# hyper-parameters
learning_rate, beta_1, beta_2 = 0.0003, 0.9, 0.999

# setting up the optimizer
adam_params = {"lr": learning_rate, "betas": (beta_1, beta_2)}
optimizer = Adam(adam_params)

# set up the losses for inference
loss_observed = SVI(model, guide, optimizer, loss="ELBO")
loss_latent = SVI(model, guide, optimizer, loss="ELBO")


When we run this inference in Pyro, the performance during testing is negatively impacted by the noise in the sampling of the categorical variables (see the graphs below and Table 1 at the end of this tutorial). To handle this issue, we introduce a new feature in Pyro that does explicit enumeration of the support of discrete latent variables during inference. 


<figure>
    <table>
        <tr>
            <td> 
                <img src="exp_1_losses_24_3000.png?1"  style="width: 450px;">
            </td>
            <td> 
                <img src="exp_1_acc_24_3000.png?1" style="width: 450px;">
            </td>
        </tr>
    </table> 
    <figcaption> 
        <font size="+1"><center><b>Figure 2</b>: loss optimization and variation of accuracies (naive variant) for 3000/50000 supervised examples </center></font>
    </figcaption>
</figure>


# Second Variant: Same model,  summing out discrete latents

We investigate the loss expression for this model (equations (7) and (8) from reference [1]) to understand the root cause of this noise:

$$
    \mathcal{J} = \sum_{({\bf x,y}) \in \mathcal{D}_{supervised} } \mathcal{L}\big({\bf x,y}\big) + \sum_{{\bf x} \in \mathcal{D}_{unsupervised}} \bigg[ \mathop{\mathbb{E}}_{q_{\phi}({\bf y}~|~{\bf x})} \big[\mathcal{L}\left({\bf x,y}\right)  - \log\left( q_{\phi}({\bf y}~|~{\bf x}) \right)\big]\bigg]
$$


When the discrete latent labels ${\bf y}$ are not observed, the gradient of the loss $\mathcal{J}$ is estimated by sampling from $q_\phi({\bf y}|{\bf x})$. These sampled gradients can be very high-variance, especially early in the learning process when the guessed labels are often incorrect. A common approach to reduce variance in this case is to sum out these discrete latent variables, replacing the Monte Carlo expectation $\mathbb E_{q_\phi({\bf y}|{\bf x})}\nabla\operatorname{F({\bf x,y})}$ with an explicit sum $\sum_{\bf y} q_\phi({\bf y}|{\bf x})\nabla\operatorname{F({\bf x,y})}$. This sum is often implemented by hand, as in [1], but Pyro can automate this in many cases. To automatically sum out all discrete latent variables (here only ${\bf y}$), we pass the `enum_discrete=True` argument to `SVI()`:
```python
loss_latent = SVI(model, guide, optimizer, loss="ELBO", enum_discrete=True)
```
In this mode of operation, each step on this loss computes a gradient term for each of the ten latent states of ${\bf y}$. Although each step is thus $10\times$ more expensive, the lower-variance gradient estimates usually outweigh the per-step cost. In this tutorial we found that `enum_discrete=True` alone improved accuracy from around `20%` to a much better `90%` (when using $3000$ labeled examples out of $50000$ examples in the training set). See the graphs below and Table 1 for further results.

<figure>
    <table>
        <tr>
            <td> 
                <img src="exp_2_losses_56_3000.png?1"  style="width: 450px;">
            </td>
            <td> 
                <img src="exp_2_acc_56_3000.png?1" style="width: 450px;">
            </td>
        </tr>
    </table> 
    <figcaption> 
        <font size="+1"><center><b>Figure 3</b>: loss optimization and variation of accuracies (second variant) for 3000/50000 supervised examples </center></font>
    </figcaption>
</figure>


Beyond this tutorial, Pyro supports summing over an arbitrarily many discrete latent variables. Beware that the cost of summing is exponential in the number of discrete variables, but is cheap if multiple independent discrete variables are packed into a single tensor (as in this tutorial, where the discrete labels for the entire batch are packed into the single tensor ${\bf y}$).

# Third Variant: Adding an extra loss term

As done in reference [1] (see equation (9) in [1]), we also add an extra loss term to remedy the issue of the classifier $q_{\phi}({\bf y}~|~ {\bf x})$ not contributing to the supervised loss. The new loss function now becomes:

$$
    \mathcal{J}^{\alpha} = \mathcal{J} + \alpha \sum_{({\bf x,y}) \in \mathcal{D}_{supervised} } \big[-\log\big(q_{\phi}({\bf y}~|~ {\bf x})\big)\big]
$$

To add this loss in Pyro: 
1. We use a new model and guide pair (as shown below in the code) that corresponds to scoring the observed label ${\bf y}$ for a given image ${\bf x}$ against the predictive distribution $q_{\phi}({\bf y}~|~ {\bf x})$ 
2. We specify the sclaing factor $\alpha$ (represented using the variable `aux_loss_multiplier`) by passing it as an argument in the `pyro.observe` statement (`log_pdf_mask` parameter) 
3. We create a new auxiliary loss object of the class `SVI()` and take steps on this loss along with the supervised loss `loss_observed`.

We try different values of the hyper-parameter $\alpha$ to identify one that leads to the best validation accuracy. The 

### TODO: decide if we want to refer to the other way in the code (add to the same model instead of separately)

In [7]:
def model_classify(xs, ys):
    # this here is the extra Term to yield an auxiliary loss 
    alpha = nn_alpha_y.forward(xs)
    pyro.observe("y_aux", dist.categorical, ys, alpha, log_pdf_mask=aux_loss_multiplier)

def guide_classify(xs, ys):
    pass
    

loss_aux = SVI(model_classify, guide_classify, optimizer, loss="ELBO")

When we run inference in Pyro with this added loss term, the performance during testing is better than either of the two variants before e.g. the accuracy for $3000$ labeled examples out of $50000$ examples is improved from `90%` to `96%` (see the graphs below and Table 1 in the next section)


<figure>
    <table>
        <tr>
            <td> 
                <img src="exp_3_losses_112_3000.png"  style="width: 450px;">
            </td>
            <td> 
                <img src="exp_3_acc_112_3000.png" style="width: 450px;">
            </td>
        </tr>
    </table> 
    <figcaption> 
        <font size="+1"><center><b>Figure 4</b>: loss optimization and variation of accuracies (third variant) for 3000/50000 supervised examples </center></font>
    </figcaption>
</figure>

# Results

### Best w.r.t validation set

Supervised data size  | First variant   | Second variant | Third variant | Baseline classifier 
----------------------|-----------------|----------------|---------------|--------------------
100                   |  0.201(0.03)    |   0.225(0.02)  |  0.926(0.004) |  0.763(TODO)
600                   |  0.179(0.02)    |   0.694(0.03)  |  0.943(0.005) |  0.869(TODO)
1000                  |  0.201(0.02)    |   0.756(0.02)  |  0.948(0.004) |  0.875(TODO)
3000                  |  0.198(0.04)    |   0.887(0.02)  |  0.958(0.001) |  0.907(TODO)

### Accuracy after the last epoch

Supervised data size  | First variant   | Second variant | Third variant | Baseline classifier 
----------------------|-----------------|----------------|---------------|--------------------
100                   |  0.099(0.008)   |   0.225(0.02)  |  0.926(0.004) |  0.763(TODO)
600                   |  0.101(0.02)    |   0.694(0.03)  |  0.943(0.005) |  0.869(TODO)
1000                  |  0.201(0.02)    |   0.756(0.02)  |  0.948(0.004) |  0.875(TODO)
3000                  |  0.198(0.04)    |   0.887(0.02)  |  0.958(0.001) |  0.907(TODO)

sup_num	baseline	enum	enum+aux
100	0.0995(0.007836836096282735)	0.10102(0.002393658288060348)	0.9253399999999999(0.004079509774470452)
600	0.09376(0.014344141661319441)	0.26584(0.02265264664448726)	0.94228(0.003717741249737535)
1000	0.08434(0.026763901060944013)	0.47456(0.017633899171765718)	0.94734(0.0028450659043333126)
3000	0.10211999999999999(0.006216558533465282)	0.8841800000000001(0.011855699051511022)	0.9582599999999999(0.0014974645237867737)


Table 1 shows the mean accuracy numbers (with standard deviations in parenthesis) across $5$ random selections of supervised data. We compare the three variants introduced in this tutorial and naively training the classifier on the supervised data in this table. We note that we are able to reproduce the results similar to the ones from reference [1] (column M2 in Table one from [1]) using the abstractions in Pyro.  


### TODO: T-SNE

### TODO: Conditional generation 

### TODO: STRETCH: Plot accuracy mean vs fraction of supervised data (need ~8 dsta points)

# Final thoughts

We've seen that generative models offer a natural approach to semi-supervised machine learning. One of the most attractive features of generative models is that we can explore a large variety of models in a single unified setting. In this tutorial we've only been able to explore a small fraction of the possible model and inference setups that are possible. There is no reason to expect that one variant is best; depending on the dataset and application, there will be reason to prefer one over another. And there a lot of variants (see Figure 2)!

<figure><img src="ss_vae_zoo.png" style="width: 300px;"><figcaption><center> <font size="+1"><b>Figure 5</b>: A zoo of semi-supervised generative models </font></center></figcaption></figure>

Some of these variants clearly make more sense than others, but a priori it's difficult to know which ones are worth trying out. This is especially true once we open the door to more complicated setups, like the two models at the bottom of the figure, which include an always latent random variable ${\bf \tilde{y}}$ in addition to the partially observed label ${\bf y}$. (Incidentally, this class of models&mdash;see reference [2] for similar variants&mdash;offers another potential solution to the 'no training' problem that we identified above.)

The reader probably doesn't need any convincing that a systematic exploration of even a fraction of these options would be incredibly time-consuming and error-prone if each model and each inference procedure were coded up by scratch. It's only with the modularity and abstraction made possible by a probabilistic programming system that we can hope to explore the landscape of generative models with any kind of nimbleness&mdash;and reap any awaiting rewards.

See the full code on [Github](https://github.com/uber/pyro/blob/dev/examples/ssvae.py).

## References

[1] `Semi-supervised Learning with Deep Generative Models`,<br/>&nbsp;&nbsp;&nbsp;&nbsp;
Diederik P. Kingma, Danilo J. Rezende, Shakir Mohamed, Max Welling

[2] `Learning Disentangled Representations with Semi-Supervised Deep Generative Models`,
<br/>&nbsp;&nbsp;&nbsp;&nbsp;
N. Siddharth, Brooks Paige, Jan-Willem Van de Meent, Alban Desmaison, Frank Wood, <br/>&nbsp;&nbsp;&nbsp;&nbsp;
Noah D. Goodman, Pushmeet Kohli, Philip H.S. Torr