# Variational autoencoder

In [None]:
using Flux, Flux.Data.MNIST
using Flux: @epochs, throttle, params
using Distributions
import Distributions: logpdf

**For numerically stable, extend distributions slightly to have a logpdf for `p` close to 1 or 0.**

In [None]:
logpdf(b::Bernoulli, y::Bool) = y * log(b.p + eps()) + (1 - y) * log(1 - b.p + eps())

# Load data

In [None]:
X = MNIST.images();

## Binarise

In [None]:
X = float.(hcat(vec.(X)...)) .> 0.5

## Mini-batches

In [None]:
N, M = size(X, 2), 100
data = [X[:,i] for i in Iterators.partition(1:N,M)]

# Define Model

In [None]:
# Dimensions (D -> Dh -> Dz -> Dh -> D)
Dz, Dh, D = 5, 500, 28^2

# Encoder
A, μ, logσ = Dense(D, Dh, tanh), Dense(Dh, Dz), Dense(Dh, Dz)
g(X) = (h = A(X); (μ(h), logσ(h)))
z(μ, logσ) = μ + exp(logσ) * randn()

# Decoder
f = Chain(Dense(Dz, Dh, tanh), Dense(Dh, D, σ))

# Loss function and ELBO

In [None]:
# KL-divergence between approximation posterior and N(0, 1) prior.
kl_q_p(μ, logσ) = 0.5 * sum(exp.(2 .* logσ) + μ.^2 .- 1 .+ logσ.^2)

# logp(x|z) - conditional probability of data given latents.
logp_x_z(x, z) = sum(logpdf.(Bernoulli.(f(z)), x))

# Monte Carlo estimator of mean ELBO using M samples.
L̄(X) = ((μ̂, logσ̂) = g(X); (logp_x_z(X, z.(μ̂, logσ̂)) - kl_q_p(μ̂, logσ̂)) / M)

loss(X) = -L̄(X) + 0.01 * sum(x->sum(x.^2), params(f))

# Sample from the learned model.
modelsample() = rand.(Bernoulli.(f(z.(zeros(Dz), zeros(Dz)))))

# Learning

In [None]:
evalcb = throttle(() -> @show(-L̄(X[:, rand(1:N, M)])), 30)
opt = ADAM(params(A, μ, logσ, f))
@epochs 10 Flux.train!(loss, zip(data), opt, cb=evalcb)

# Sample Output

In [None]:
using Images

img(x) = Gray.(reshape(x, 28, 28))

cd(@__DIR__)
sample = hcat(img.([modelsample() for i = 1:10])...)
save("vae-sample.png", sample)