# An Introductory Tutorial On the Wasserstein Auto-encoder

-------

## Authors
Joel Dapello<br>
Michael Sedelmeyer<br>
Wenjun Yan

-------

<a id="top"></a>
## Contents

Table of contents with markdown hyperlinks to each section of the notebook

1. [Motivation and background](#intro)

1. [Conceptual foundations](#concepts)

1. [Mathematics and algorithms](#details)

1. [Comparing results on MNIST](#mnist)

1. [Comparing results on FashionMNIST](#fmnist)

1. [Conclusions and further analysis](#conclusion)

1. [References and further reading](#sources)


- [Appendices: PyTorch Implementation](#appendix)
    - [Appendix A: Auto-encoder](#ae)
    - [Appendix B: Variational auto-encoder](#vae)
    - [Appendix C: Wasserstein auto-encoder](#wae)
    - [Appendix C: Plotting functions](#plots)

----------

<a id="intro"></a>
## Motivation and background
[return to top](#top)

Explicitly cite the primary paper we are referencing and provide a high-level motivation for WAE

**images to include:**
1. side-by-side plot similar to the one found in the WAE paper showing AE vs VAE vs WAE reconstruction (this chart may be better suited for next section, but that may depend on how we summarize the motivation for WAE in this section)

<a id="concepts"></a>
## Conceptual foundations
[return to top](#top)

Provide and illustrate the conceptual foundations and intuition for AE vs VAE vs WAE

This should probably be written linearly, using the AE as our base example, and then adding elements of VAE, and then WAE, demonstrating the evolution of our chosen method.

**images to include:**
1. simplified CNN diagram demonstrating the generic encoding, latent space bottleneck, and decoding networks as a left-to-right process-flow (similary to either the full-connected or CNN-representative diagram)
2. A further simplified version of the CNN diagram, with emphasis on the mechanics of the latent space of each method (i.e. similar to [this sort of image](http://kvfrans.com/content/images/2016/08/vae.jpg), but with MNIST digit images at either end) 

<a id="details"></a>
## Mathematics and algorithms
[return to top](#top)

In this section we provide the mathematical detail and algorithmic differences between each method, paying extra attention to WAE and how it varies from VAE.

**latex to include:**
1. notational algorithms
1. loss function detail
1. mathematical representation of the reparameterization trick

**images to include:**
1. A small graphical representation of the reparameterization trick (small and simple node/edge plot)


<a id="mnist"></a>
## Comparing results on MNIST
[return to top](#top)

In this section we specify the parameters used in our model and provide plots and metrics and written interpretation describing the training results and latent space representations of our algorithms on MNIST

**images/tables to include:**
1. Sample of 5 original MNIST images and corresponding decoded images for AE, VAE, and WAE on separate rows
1. Latent space linear interpolation results of each model, pixel space vs AE vs VAE vs WAE on separate rows
1. tSNE or PCA representation of pixel space vs latent space for each model to demonstrate differences
1. table summarizing comparative loss (and if possible FID results)


<a id="fmnist"></a>
## Comparing results on FashionMNIST
[return to top](#top)

Same as above for MNIST

**images/tables to include:**
1. same as above for MNIST, but probably smaller and with fewer examples if results demonstrate similar characteristics

<a id="conclusion"></a>
## Conclusions and further analysis
[return to top](#top)

Here we summarize our conclusions given MNIST and FMNIST, but also describe other dataset we may want to run as comparison (e.g. celeb faces for representation on a low manifold surface such a faces, RNA expression data for investigation of a novel application of WAE)

<a id="conclusion"></a>
## References and Further Reading
[return to top](#top)

Cite the papers, repos, datasets, and blogs we used in our analysis, as well as any other resources we want to direct our readers toward

1. VAE paper
1. WAE paper
1. PyTorch/resources implementation of VAE
1. AE paper?
1. MNIST
1. FashionMNIST

<a id="appendix"></a>
## Appendices: PyTorch Implementation
[return to top](#top)

- The Appendix is where we lay out and run our PyTorch code, each model is separated among sub-appendices
- We should output our most important plots to png (saved on GitHub) so we can display them via markdown img link at the appropriate locations in our paper

In [None]:
# Import libraries
# Set parameter args
# load data train and test sets

<a id="ae"></a>
### Appendix A: Auto-encoder 
[return to top](#top)

<a id="vae"></a>
### Appendix B: Variational auto-encoder
[return to top](#top)

<a id="wae"></a>
### Appendix C: Wasserstein auto-encoder
[return to top](#top)

<a id="plots"></a>
### Appendix D: Plotting functions
[return to top](#top)

**Require:** Regularization coefficient $\lambda > 0$.

> Initialize the parameters fo the encoder $Q_{\phi}$, decoder $G_{\theta}$, and latent discriminator $D_{\gamma}$.

> **while** $(\phi, \theta)$ not converged **do**

>> Sample $\{x_1, \dotsc , x_n\}$ from the training set

>> Sample $\{z_1, \dotsc , z_n\}$ from the prior $P_z$

>> Sample $\tilde{z}_i$ from $Q_{\phi}(Z\vert x_i)$ for $i=1, \dotsc , n$

>> Update $D_{\gamma}$ by ascending:
$$\frac{\lambda}{n}\sum_{i=1}^n log \; D_{\gamma}(z_i) + log (1-D_{\gamma}(\tilde{z}_i))$$

>> Update $Q_{\phi}$ and $G_{\theta}$ by descending:
$$\frac{1}{n}\sum_{i=1}^n c(x_i, G_{\theta}(\tilde{z}_i)) - \lambda \cdot log\;D_{\gamma}(\tilde{z}_i)$$
> **end while**

**Require:** Regularization coefficient $\lambda > 0$.

> Initialize the parameters for the encoder $Q_{\phi}$ and decoder $G_{\theta}$

> **while** $(\phi, \theta)$ not converged **do**

>> Sample $\{x_1, \dotsc , x_n\}$ from the training set

>> Sample $\{\epsilon_1, \dotsc , \epsilon_n\}$ from the prior $P_z$

>> Sample $\tilde{z}_i$ from $Q_{\phi}(Z\vert x_i)$ for $i=1, \dotsc , n$

>> Update $Q_{\phi}$ and $G_{\theta}$ by descending:
$$\frac{1}{n}\sum_{i=1}^n c(x_i, G_{\theta}(\tilde{z}_i))$$
> **end while**