In [None]:
%matplotlib notebook
import matplotlib.pyplot as plt
import numpy as np
from tqdm.notebook import tqdm
import torch
import pyro
display(pyro.__version__)

# Data for this lecture
import torchvision
mnist_data = torchvision.datasets.MNIST('~/datasets', train=True, download=True,
                                        transform=torchvision.transforms.ToTensor())

fig, ax = plt.subplots(1, 10, figsize=(6, 1), tight_layout=True)
for i in range(10):
    image, label = mnist_data[i]
    ax[i].imshow(image.numpy()[0, :, :], cmap=plt.cm.Greys_r)
    ax[i].axis('off')
    ax[i].set_title(label)
    
from torch.utils.data import DataLoader, SubsetRandomSampler

np.random.seed(0)
idx = list(range(len(mnist_data)))
np.random.shuffle(idx)
split = int(0.8*len(idx))

train_loader = DataLoader(mnist_data, batch_size=128, drop_last=False,
                          sampler=SubsetRandomSampler(idx[:split]))

valid_loader = DataLoader(mnist_data, batch_size=128, drop_last=False,
                          sampler=SubsetRandomSampler(idx[split:]))

# Latent Variable Models (LVM)

## Intuition: Modeling with latent variables

Let's say we have a dataset $X = \{ \textbf{x}_1, \textbf{x}_2, \ldots, \textbf{x}_N\}$ with $\dim (\textbf{x}) =D$ and we want to model the generative distribution $p(\textbf{x})$

Each sample has $D$ components or attributes (e.g. the pixels of an image): These are the **observed variables** 

To model $p(\textbf{x})$ we may expand the joint between the attributes using the rules of probability

$$
p(x_1, x_2, \ldots, x_D) = p(x_D|x_{D-1}, \ldots, x_1) \cdot p(x_{D-1}| , x_{D-1}\ldots, x_1) \cdots p(x_3|x_1, x_1) \cdot p(x_2|x_1) \cdot p(x_1)
$$

which is known as a **fully observed model**. Unless we introduce independence between some of the variables the above representation is impractical for high dimensional problems (e.g. images)

One alternative is to assume that 

> what we observe is correlated due to *hidden causes*

These hidden causes are represented as **latent variables** and models with latent variables are called **Latent Variable Models** (LVMs)

Mathematically, we impose that the observed variables are conditionally independent given the latent variables $\textbf{z}$, this is

$$
p(x_1, x_2, \ldots, x_D|\textbf{z}) = p(x_D|\textbf{z}) \cdot p(x_{D-1}|\textbf{z}) \cdots p(x_3|\textbf{z}) \cdot p(x_2|\textbf{z}) \cdot p(x_1|\textbf{z})
$$

where in general $\dim(\textbf{z})\ll\dim(\textbf{x})$

The following figure shows the graphical model of a fully observed model with five observed variables and an LVM with two latent variables

<img src="images/LVM.png" width="700">

For the LVM we can write the marginal as

$$
\begin{align}
p(\textbf{x}) &= \int_z p(\textbf{x}, \textbf{z}) \,d\textbf{z} \nonumber \\
&= \int_\textbf{z} p(\textbf{x}|\textbf{z}) p(\textbf{z}) \,d\textbf{z} \nonumber
\end{align}
$$

Did we gain anything?  YES

> This strategy allows us to model a complex $p(x)$ by proposing a simple $p(z)$ (easy to sample from) and a transformation $p(x|z)$ 

The integral above is intractable for non-linear transformations (neural networks), in that case we resort to approximate inference

This lecture is focused on LVMs for continuous data. First we will review an example with a tractable posterior (PCA) and then the more modern LVM based on neural networks: The Variational Autoencoder

## A short review of PCA

Principal Component Analysis (PCA) is an algorithm to reduce the dimensionality of continous data

For a dataset $X = (x_1, x_2, \ldots, x_N) \in \mathbb{R}^{N \times D}$, in PCA we 

1. Compute covariance matrix $C = \frac{1}{N} X^T X$
1. Solve the eigenvalue problem $(C - \lambda I)W = 0$

This comes from the following objective

$$
\min_W W^T C W, \text{s.t.} ~ W^T W = I,
$$

i.e. PCA finds an **orthogonal transformation** $W$ that **minimizes the variance** of the projected data $XW$

By reducing the amount of columns of $W$ we reduce the dimensionality of $XW$


### Example: Classical PCA for MNIST using PyTorch

Implementation of PCA using [`symeig`](https://pytorch.org/docs/stable/generated/torch.symeig.html#torch.symeig)

In [None]:
class PCA:
    def __init__(self, data, K=2):
        self.data_mean = torch.mean(data, dim=0)
        data_centered = data - self.data_mean.expand_as(data)
        C = torch.matmul(data_centered.T, data_centered)
        # V is sorted in increasing order
        V, W = torch.symeig(C, eigenvectors=True)
        self.W = W[:, -K:]
    
    def encode(self, x):
        return torch.mm(x - self.data_mean.expand_as(x), self.W)

    def decode(self, z):
        return self.data_mean + torch.mm(z, self.W.T)

In this example the $28\times28$ observed dimensions of PCA are projected to two continuous latent variables 

We can then inspect the latent space and images reconstructed from it

In [None]:
images = mnist_data.data.reshape(-1, 28*28)/255.
pca = PCA(images, K=2)
Z = pca.encode(images)

fig, ax = plt.subplots(figsize=(6, 4), tight_layout=True)
for digit in range(10):
    mask = mnist_data.targets == digit
    ax.scatter(Z[mask, 0].detach().numpy(), Z[mask, 1].detach().numpy(), 
               s=5, alpha=0.5, cmap=plt.cm.tab10, label=str(digit))
plt.legend()
ax.set_xlabel('PC 1'); ax.set_ylabel('PC 2');
ax.set_title('Latent space')
fig, ax = plt.subplots(2, 10, figsize=(8, 2), tight_layout=True)
reconstructions = pca.decode(Z[:10, :]).reshape(-1, 28, 28).detach().numpy()
for i in range(10):
    ax[0, i].imshow(mnist_data.data[i, :].reshape(28, 28).detach().numpy(), cmap=plt.cm.Greys_r)
    ax[0, i].axis('off')
    ax[1, i].imshow(reconstructions[i], cmap=plt.cm.Greys_r)
    ax[1, i].axis('off')
ax[0, 0].set_ylabel('Original')
ax[1, 0].set_ylabel('Reconstructed');

The two most important principal components in this case are

In [None]:
fig, ax = plt.subplots(1, 2, figsize=(4, 1.5), tight_layout=True)
for i in range(2):
    ax[i].imshow(pca.W[:, i].reshape(28, 28).detach().numpy())
    ax[i].axis('off')
    ax[i].set_title('PC %d' %(i))

Two continous latent variables are not enough to model the digits given this linear model. Later we will see how this changes using non-linear models

## Probabilistic interpretation for PCA

We can give a probabilistic interpretation to PCA as an LVM

We start by modeling an observed sample $x_i \in \mathbb{R}^D$ as

$$
x_i = W z_i + B + \epsilon
$$


where 

- $B \in \mathbb{R}^D$ is the mean of $X$
- $W \in \mathbb{R}^{D\times K}$ is a linear transformation matrix
- $\epsilon$ is the noise
- $z_i \in  \mathbb{R}^K$ is a continuous latent variable with $K\ll D$
- $x$ (observed) is related to $z$ (latent) via a **linear mapping**

This model has the following assumptions

1. The noise is independent and Gaussian distributed with variance $\sigma^2$
1. The latent variable has a standard Gaussian prior

Using these we can write

$$
p(x_i | z_i) = \mathcal{N}(B + W z_i, I \sigma^2)
$$

and 

$$
p(z_i) = \mathcal{N}(0, I)
$$

Given that the Gaussian is conjugate to itself the marginal likelihood is

$$
\begin{align}
p(x) &= \int p(x|z) p(z) \,dz \nonumber \\
&= \mathcal{N}(x|B, W W^T + I\sigma^2 ) \nonumber
\end{align}
$$

Note that we have parameterized a Gaussian with full covariance from to Gaussians with diagonal covariance

The parameters of the marginal come from

- $\mathbb{E}[x] = W\mathbb{E}[z] + B + \mathbb{E}[\epsilon] = B$
- $\mathbb{E}[(Wz + \epsilon)(Wz + \epsilon)^T] = W \mathbb{E}[zz^T] W^T + \mathbb{E}[\epsilon \epsilon^T] = W W^T + I\sigma^2$


**The posterior**

Using this formalism we can write the posterior to go from observed to latent as

$$
p(z|x) = \mathcal{N}(z|M^{-1}W^T(x-B), M\sigma^{-2} )
$$

where

$$
M = W^T W + I\sigma^2
$$

**Training**

We find $W$, $B$ and $\sigma$ that best fit the data by maximizing the log marginal likelihood

$$
\hat W, \hat B, \hat \sigma^2 = \text{arg} \max_{W, B, \sigma^2} \sum_{i=1}^N \log p(x_i)
$$

which has a closed form analytical solution. Note that the solution for $W$ is equivalent to conventional PCA ($\sigma^2 \to 0$). The main difference is that we have $\sigma$ and we can generate data with $p(x|z)p(z)$


For more details on the probabilistic PCA see chapter 21 and Murphy, Chapter 12


## A short review of autoencoders

An autoencoder is an artificial neural networks for representation learning and dimensionality reduction

The schematic exemplifies the architecture of an autoencoder 

<img src="images/ae.png" width="800">

In general

- The input and output dimensionality are equivalent
- The code or bottleneck has a smaller dimensionality than the input/output

We call **encoder** to the neural net that maps the input to the code

$$
z = g_\phi(x)
$$

and **decoder** to the neural net that maps the code to the output

$$
\hat x = f_\theta(z)
$$

Autoencoders are trained by minimzing an error, e.g. the mean square error (MSE) or cross-entropy, between the input and the output, i.e. the data is used as target (self-supervision)

For example we may use the MSE

$$
\hat \theta, \hat \phi = \text{arg} \min_{\phi, \theta} \| x - f_\theta(g_\phi(x)) \|^2
$$

which is equivalent to the maximum likelihood (MLE) solution assuming a spherical Gaussian likelihood function (cross entropy is equivalent to the MLE given a Bernoulli likelihood)

Adding an L2 regularizer on $\theta$ and $\phi$ is equivalent to incorporating a spherical gaussian prior (MAP solution)






### Example: Autoencoder for MNIST in pytorch

In this example we define one module for the encoder and one for the decoder, each with two hidden layers.

In [None]:
class Decoder(torch.nn.Module):
    def __init__(self, latent_dim, output_dim, hidden_dim):
        super(Decoder, self).__init__()
        self.hidden1 = torch.nn.Linear(latent_dim, hidden_dim)
        self.hidden2 = torch.nn.Linear(hidden_dim, hidden_dim)
        self.output = torch.nn.Linear(hidden_dim, output_dim)
        self.activation = torch.nn.ReLU()

    def forward(self, z):
        h = self.activation(self.hidden1(z))
        h = self.activation(self.hidden2(h))
        return self.output(h)

class Encoder(torch.nn.Module):
    def __init__(self, latent_dim, input_dim, hidden_dim):
        super(Encoder, self).__init__()
        self.hidden1 = torch.nn.Linear(input_dim, hidden_dim)
        self.hidden2 = torch.nn.Linear(hidden_dim, hidden_dim)
        self.code = torch.nn.Linear(hidden_dim, latent_dim)
        self.activation = torch.nn.ReLU()

    def forward(self, x):
        h = self.activation(self.hidden1(x))
        h = self.activation(self.hidden2(h))
        return (self.code(h))
    
class AutoEncoder(torch.nn.Module):
    def __init__(self, latent_dim, input_dim=28*28, hidden_dim=128):
        super(AutoEncoder, self).__init__() 
        self.encoder = Encoder(latent_dim, input_dim, hidden_dim=hidden_dim)
        self.decoder = Decoder(latent_dim, input_dim, hidden_dim=hidden_dim)
        
    def forward(self, x):
        return self.decoder(self.encoder(x))

To train we use the binary cross entropy and the ADAM optimizer

In [None]:
model = AutoEncoder(latent_dim=2)
criterion = torch.nn.BCEWithLogitsLoss(reduction='sum')
optimizer = torch.optim.Adam(model.parameters(), lr=1e-2)

use_gpu = False
if use_gpu:
    model = model.cuda()

for nepoch in tqdm(range(50)):
    epoch_loss = 0.0
    for x, y in train_loader:
        optimizer.zero_grad()
        if use_gpu:
            x = x.cuda()
        hatx = model.forward(x.reshape(-1, 28*28))
        loss = criterion(hatx, x.reshape(-1, 28*28))
        loss.backward()
        optimizer.step()
        epoch_loss += loss.item()
    print(f"{nepoch}: {epoch_loss/(len(train_loader)*train_loader.batch_size)}")

The latent space and reconstructions in this case are:

In [None]:
Z = torch.tensor([], device='cuda') if use_gpu else torch.tensor([], device='cpu')
Y = torch.tensor([], device='cuda') if use_gpu else torch.tensor([], device='cpu')

for x, y in train_loader:
    Z = torch.cat((Z, model.encoder(x.reshape(-1, 28*28))))
    Y = torch.cat((Y, y))

Z = Z.detach().cpu().numpy()
Y = Y.detach().cpu().numpy()                  
fig, ax = plt.subplots(figsize=(6, 4), tight_layout=True)
for digit in range(10):
    mask = Y == digit
    ax.scatter(Z[mask, 0], Z[mask, 1], 
               s=5, alpha=0.5, cmap=plt.cm.tab10, label=str(digit))
plt.legend()

output_activation = torch.nn.Sigmoid()
fig, ax = plt.subplots(2, 10, figsize=(8, 2), tight_layout=True)
hatx = model.forward(x.reshape(-1, 28*28))
reconstructions = output_activation(hatx).reshape(-1, 28, 28).detach().cpu().numpy()
for i in range(10):
    ax[0, i].imshow(x.detach().cpu().numpy()[i+10, 0, :, :], cmap=plt.cm.Greys_r)
    ax[0, i].axis('off')
    ax[1, i].imshow(reconstructions[i+10], cmap=plt.cm.Greys_r)
    ax[1, i].axis('off')

## Variational inference  for LVM

The LVM is defined by the joint density between observed and latent variables

$$
p(\textbf{x}, \textbf{z}) = \prod_{i=1}^N p(\textbf{x}_i|\textbf{z}_i) p(\textbf{z}_i)
$$

If we use the **PCA recipe** (Linear mapping, Gaussian likelihood and Gaussian prior) we obtain an analytical Gaussian posterior and evidence

If we use a more complex (non-linear) mapping the posterior and evidence may not be tractable. In such case, we can use **Variational inference (VI)**, i.e. we propose an approximate posterior and maximize the ELBO

$$
\begin{align}
\log p(x) \geq \mathcal{L}(\phi) &= \mathbb{E}_{z\sim q_\phi(z|x)} \left[\log \frac{p(x, z)}{q_\phi(z|x)}\right] \nonumber \\
&= \int q_\phi(z|x) \log \frac{p(x, z)}{q_\phi(z|x)} dz \nonumber 
\end{align} 
$$

to find the best parameters $\hat \phi$. 

In what follows we review the variational autoencoder which combines amortized inference and VI

## Variational Autoencoder (VAE)

The Variational Autoencoder (VAE) is an LVM where **deep neural networks** are used to model the **conditional distributions** between latent $z$ and observed $x$ variables

It was proposed simultaneously by [(Kingma and Welling, ICLR, Dec. 2013)](https://arxiv.org/pdf/1312.6114.pdf) and [(Rezende *et al*, ICML, Jan. 2014)](https://arxiv.org/abs/1401.4082) perhaps sparking the revived interest into **Deep Learning plus Approximate Bayesian Inference** that we see today


The difference with a regular autoencoder is that the latent (code) is now a stochastic variable

- a prior distribution is placed on $z$: $p(z)$
- a neural network is used to model the parameters of the likelihood: $p_\theta(x | z)$
- a neural network is used to model the parameters of the approximate posterior: $q_\phi(z|x)$
- The weight and biases of the networks $\theta$ and $\phi$ are deterministic, *i.e.* VAE is not a "fully" bayesian neural network

In VAE, variational inference is used to obtain the posterior and point estimates of the global parameters

In what follows we will review the assumptions and the key contributions of this work to the field of Bayesian Neural Networks


### Assumptions

VAE assumes a particular prior and approximate posterior. The likelihood function is chosen depending on the data

1. The latent variable has a standard Gaussian prior (the same as in PCA)
1. The approximate posterior is a factorized (diagonal) Gaussian
1. For continuous data the likelihood is typically set as a spherical Gaussian or diagonal Gaussian. For binary data the likelihood is set as Bernoulli. 

Mathematically this is 

$$
p(z_i) = \mathcal{N}(0, I)
$$

$$
q(z_i|x_i) = \mathcal{N}(\mu_i, I \sigma_i^2)
$$

and

$$
p(x_i|z_i) = \mathcal{N}(\hat \mu_i, I \hat \sigma_i^2)
$$

for $x_i \in \mathbb{R}^D$. 

### Amortization

In the previous formulation the amount of variational parameters ($\mu_i$ and $\sigma_i$) scales linearly with $N$. This is impractical for large datasets. 

Instead of having parameters per data point we can have a function that maps the data to the parameters. This is known as **amortization**

In the particular case of VAE we have

$$
\mu_i, \sigma_i = g_\phi(x_i)
$$

where $g_\phi(\cdot)$ is the **encoder network** and

$$
\hat \mu_i, \hat \sigma_i = f_\theta(z_i)
$$

where $f_\theta(\cdot)$ is the **decoder network** (for a diagonal Gaussian likelihood)



### Details on the VAE training

The VAE is trained by maximizing the ELBO, which in this case is

$$
\begin{align}
\mathcal{L}(\theta, \phi) &= \mathbb{E}_{z\sim q_\phi(z|x)} \left [\log p_\theta(x|z) + \log p(z) - \log q_\phi(z|x) \right ] \nonumber \\
&= \mathbb{E}_{z\sim q_\phi(z|x)} \left [\log p_\theta(x|z) \right ] - D_{KL}\left[ q_\phi(z|x) || p(z) \right]\nonumber
\end{align}
$$

By maximizing the ELBO we

- Maximize the log likelihood when sampling from the approximate posterior: **Faithfull data reconstructions**
- Minimize the divergence between the approximate posterior and prior: **Regularization for the posterior**


The ELBO is typically optimized via gradient ascent updates for $\theta$ and $\phi$

$$
\theta_{t+1} = \theta_{t} + \eta \nabla_\theta \mathcal{L}(\theta_{t}, \phi_{t})
$$

$$
\phi_{t+1} = \phi_{t} + \eta \nabla_\phi \mathcal{L}(\theta_{t}, \phi_{t})
$$

In what follows we review how to obtain the derivates of the ELBO with respect to $\theta$ and $\phi$

**The derivative with respect to $\theta$** 

The only term that depends on $\theta$ is $p_\theta(x|z)$, then

$$
\begin{align}
\nabla_\theta \mathcal{L}(\theta, \phi)  &= \nabla_\theta \mathbb{E}_{z\sim q_\phi(z|x)}\left [\log p_\theta(x|z)\right ] \nonumber\\ &= \mathbb{E}_{z\sim q_\phi(z|x)} \left [\nabla_\theta \log  p_\theta(x|z)\right ] \nonumber
\end{align}
$$

where the gradient can be swapped with the expectation operator. The expectation can be estimated via monte-carlo integration as

$$
\begin{align}
\mathbb{E}_{z\sim q_\phi(z|x)} \left [\nabla_\theta \log  p_\theta(x|z)\right ] &= \int q_\phi(z|x) \nabla_\theta \log  p_\theta(x|z) \,dz \nonumber \\&\approx \frac{1}{S} \sum_{k=1}^S  \nabla_\theta \log p_\theta(x|z^{(k)}) \quad z^{(k)} \sim  q_\phi(z|x) \nonumber
\end{align}
$$

**The derivative wrt to $\phi$**

Let's write $f(z) = \log p_\theta(x|z)$, a function of $z$ that does not depend on $\phi$. In this case

$$
\begin{align}
\nabla_\phi  \mathcal{L}(\theta, \phi) &=  \nabla_\phi \mathbb{E}_{z\sim q_\phi(z|x)}\left [f(z) \right ] \nonumber \\ &= \nabla_\phi \int q_\phi(z|x) f(z) dz \nonumber \\ &=  \int f(z) \nabla_\phi q_\phi(z|x)  dz \nonumber
\end{align}
$$

but we cannot approximate via monte carlo integration anymore. 

A classical solution to this is the [REINFORCE](http://www-anw.cs.umass.edu/~barto/courses/cs687/williams92simple.pdf) algorithm, which is based on the identity 

$$
\nabla_\phi\log q_\phi(z) = \frac{1}{ q_\phi(z)} \nabla_\phi q_\phi(z) 
$$

then

$$
\begin{align}
\nabla_\phi \mathbb{E}_{z\sim q_\phi(z|x)}\left [f(z)\right ]  &= \int f(z) \nabla_\phi q_\phi(z|x)  dz \nonumber \\ &= \int q_\phi(z) f(z) \nabla_\phi\log q_\phi(z)  dz \nonumber \\ &\approx \frac{1}{S} \sum_{k=1}^S f(z^{(k)}) \nabla_\phi \log q_\phi(z^{(k)}|x) \quad z^{(k)} \sim  q_\phi(z|x) \nonumber
\end{align}
$$

but in practice training with this estimator is very hard. Although the REINFORCE estimator is unbiased it has a very high variance

For the particular case of VAE there is a low variance (and very elegant) alternative

**Key contribution**: The reparameterization trick

In VAE the latent variable is distributed as

$$
z \sim \mathcal{N}(\mu_\phi(x), \sigma_\phi^2 (x) )
$$

we can rewrite this as first sampling from a standard Gaussian 

$$
\epsilon \sim \mathcal{N}(0, I)
$$

and then applying a transformation

$$
z = g(\phi, \epsilon) = \mu_\phi (x) + \epsilon \sigma_\phi (x)  
$$

Then we can rewrite the expectation of $f(z)$ as 

$$
\mathbb{E}_{z\sim q_\phi(z|x)}\left [f(z) \right ] =  \mathbb{E}_{\epsilon\sim \mathcal{N}(0, I)}\left [  f(g(\phi, \epsilon))  \right ] 
$$

Now that the expectation does not depend on $\phi$ we can use the following estimator for its gradient

$$
\begin{align}
\nabla_\phi \mathbb{E}_{\epsilon\sim \mathcal{N}(0, I)}\left [  f(g(\phi, \epsilon))  \right ] &= \mathbb{E}_{\epsilon\sim \mathcal{N}(0, I)}\left [  f'(g(\phi, \epsilon)) \nabla_\phi g(\phi, \epsilon) \right ] \nonumber \\
&\approx \frac{1}{S} \sum_{k=1}^S f'(g(\phi, \epsilon^{(k)})) \nabla_\phi g(\phi, \epsilon^{(k)}) \quad \epsilon^{(k)} \sim  \mathcal{N}(0,I) \nonumber
\end{align}
$$

which has a [much lower variance than REINFORCE](https://nbviewer.jupyter.org/github/gokererdogan/Notebooks/blob/master/Reparameterization%20Trick.ipynb)

**More variance reduction:** Closed-form terms

We have focused on the left hand term of the ELBO

$$
\mathcal{L}(\theta, \phi) = \mathbb{E}_{z\sim q_\phi(z|x)} \left [\log p_\theta(x|z) \right ] - D_{KL}\left[ q_\phi(z|x) || p(z) \right]
$$

The right hand term is the KL divergence between two multivariate Gaussian distributions. This has a [closed analytical solution](https://en.wikipedia.org/wiki/Kullback%E2%80%93Leibler_divergence#Multivariate_normal_distributions)

$$
D_\text{KL}\left[q_\phi(z|x) || p(z) \right] = \frac{1}{2}\sum_{j=1}^K \left(\mu_j^2 + \sigma_j^2 - \log \sigma_j^2 - 1 \right)
$$

where $K$ is the dimensionality of the latent variable

The derivatives of this estimater are straighforward and the variance is low

**`pyro` related notes**

When [`TraceMeanField_ELBO`](https://docs.pyro.ai/en/stable/inference_algos.html#pyro.infer.trace_mean_field_elbo.TraceMeanField_ELBO) is used pyro assumes that the latent variables in the guide can be reparameterized and uses the analytical KL

You can read more on variance reduction in this official [pyro tutorial](https://pyro.ai/examples/svi_part_iii.html)

### Writing a VAE in Pyro

Similarly to the previous AE example we use two hidden layers, but now the encoder is "dual-headed", i.e. it outputs the parameters of the factorized gaussian associated to the latent variable

We have to make sure that the standard deviation is non-negative

In [None]:
class EncoderDual(torch.nn.Module):
    def __init__(self, latent_dim, input_dim=28*28, hidden_dim=128):
        super(EncoderDual, self).__init__()
        self.hidden1 = torch.nn.Linear(input_dim, hidden_dim)
        self.hidden2 = torch.nn.Linear(hidden_dim, hidden_dim)
        self.z_loc = torch.nn.Linear(hidden_dim, latent_dim)
        self.z_scale = torch.nn.Linear(hidden_dim, latent_dim)
        self.activation = torch.nn.Softplus()

    def forward(self, x):
        h = self.activation(self.hidden1(x))
        h = self.activation(self.hidden2(h))
        return self.z_loc(h), torch.exp(self.z_scale(h))
    
class Decoder(torch.nn.Module):
    def __init__(self, latent_dim, output_dim=28*28, hidden_dim=128):
        super(Decoder, self).__init__()
        self.hidden1 = torch.nn.Linear(latent_dim, hidden_dim)
        self.hidden2 = torch.nn.Linear(hidden_dim, hidden_dim)
        self.output = torch.nn.Linear(hidden_dim, output_dim)
        self.activation = torch.nn.Softplus()

    def forward(self, z):
        h = self.activation(self.hidden1(z))
        h = self.activation(self.hidden2(h))
        return self.output(h)

We write the model and the guide within a `torch.nn.Module` 

To register the weights and biases of `Encoder` and `Decoder` to the parameter dictionary of pyro we register them using the [`pyro.module`](https://docs.pyro.ai/en/stable/primitives.html#pyro.primitives.module) primitive, i.e. we don't need to use `pyro.param` in this case. 

We use plates to make the model conditionally independent on the batch dimension (leftmost dimension)

In [None]:
import pyro.distributions as dists

class VariationalAutoEncoder(torch.nn.Module):
    
    def __init__(self, latent_dim, data_dim=28*28, hidden_dim=128):
        super(VariationalAutoEncoder, self).__init__() 
        self.encoder = EncoderDual(latent_dim, input_dim=data_dim, hidden_dim=hidden_dim)
        self.decoder = Decoder(latent_dim, output_dim=data_dim, hidden_dim=hidden_dim)
        self.latent_dim = latent_dim
        
    def model(self, x):
        pyro.module("decoder", self.decoder)
        with pyro.plate("data", size=x.shape[0]):
            # p(z)
            z_loc = torch.zeros(x.shape[0], self.latent_dim, device=x.device)
            z_scale = torch.ones(x.shape[0], self.latent_dim, device=x.device)
            z = pyro.sample("latent", dists.Normal(z_loc, z_scale).to_event(1))
            # p(x|z)
            p_logits = self.decoder.forward(z)
            pyro.sample("observed", 
                        dists.Bernoulli(logits=p_logits, validate_args=False).to_event(1), 
                        obs=x.reshape(-1, 28*28))
    
    def guide(self, x):
        pyro.module("encoder", self.encoder)
        with pyro.plate("data", size=x.shape[0]):
            # q(z|x)
            z_loc, z_scale  = self.encoder.forward(x.reshape(-1, 28*28))
            pyro.sample("latent", 
                        dists.Normal(z_loc, z_scale).to_event(1))

The model is trained using SVI with the ClippedAdam optimizer

The guide satisfies the mean field condition, this means we can use `TraceMeanField_ELBO`

In [None]:
pyro.enable_validation(True) 
pyro.clear_param_store()

vae = VariationalAutoEncoder(latent_dim=2)

use_gpu = False
if use_gpu:
    vae = vae.cuda()
    
svi = pyro.infer.SVI(model=vae.model, 
                     guide=vae.guide, 
                     optim=pyro.optim.ClippedAdam({"lr": 1e-2}), 
                     loss=pyro.infer.TraceMeanField_ELBO(num_particles=5, 
                                                         vectorize_particles=True))

for nepoch in tqdm(range(50)):    
    epoch_loss = 0.0
    for x, y in train_loader:
        if use_gpu:
            x = x.cuda()
        epoch_loss += svi.step(x)
    print(f"{nepoch}: {epoch_loss/(len(train_loader)*train_loader.batch_size):0.4f}")

The following shows the latent space, errorbars are used to show the mean and variance of the latent variables 

In [None]:
if use_gpu:
    vae = vae.cpu()
    
Z = torch.tensor([], device='cuda') if use_gpu else torch.tensor([], device='cpu')
Y = torch.tensor([], device='cuda') if use_gpu else torch.tensor([], device='cpu')

for x, y in valid_loader:
    Z = torch.cat((Z, torch.cat(vae.encoder(x.reshape(-1, 28*28)), dim=1)), dim=0)
    Y = torch.cat((Y, y))

Z = Z.detach().cpu().numpy()
Y = Y.detach().cpu().numpy()                  
fig, ax = plt.subplots(figsize=(6, 4), tight_layout=True)
for digit in range(10):
    mask = Y == digit
    ax.errorbar(x=Z[mask, 0], y=Z[mask, 1], 
                xerr=Z[mask, 2], yerr=Z[mask, 3], 
                fmt='none', alpha=0.5, cmap=plt.cm.tab10, label=str(digit))
plt.legend()

The following plot shows reconstructions as a function of $z$

The white contour represents the approximate posterior

In [None]:
M = 30
z_plot = np.linspace(-3, 3, num=M)
big_imag = np.zeros(shape=(28*M, 28*M))

for i in range(M):
    for j in range(M):
        z = torch.tensor(np.array([z_plot[j], z_plot[M-1-i]]), dtype=torch.float32)
        xhat = output_activation(vae.decoder.forward(z)).reshape(28, 28). detach().numpy()
        big_imag[i*28:(i+1)*28, j*28:(j+1)*28] = xhat

fig, ax = plt.subplots(figsize=(9, 9), tight_layout=True)
Z_plot1, Z_plot2 = np.meshgrid(z_plot, z_plot)
ax.matshow(big_imag, vmin=0.0, vmax=1.0, cmap=plt.cm.gray, extent=[-4, 4, -4, 4])
H, xedge, yedge = np.histogram2d(Z[:, 0], Z[:, 1], bins=30, range=[[-4, 4], [-4, 4]])
ax.contour(Z_plot1, Z_plot2, H.T, linewidths=3, levels=[1], cmap=plt.cm.Reds);