# Variational autoencoder

In [1]:
using Flux, Flux.Data.MNIST
using Flux: throttle, params
using Juno: @progress

**For numerically stable, extend distributions slightly to have a logpdf for `p` close to 1 or 0.**

In [2]:
using Distributions
import Distributions: logpdf
logpdf(b::Bernoulli, y::Bool) = y * log(b.p + eps()) + (1 - y) * log(1 - b.p + eps())



logpdf (generic function with 62 methods)

# Load data

In [3]:
X = MNIST.images();

## Binarise

In [4]:
X = float.(hcat(vec.(X)...)) .> 0.5

784×60000 BitArray{2}:
 false  false  false  false  false  …  false  false  false  false  false
 false  false  false  false  false     false  false  false  false  false
 false  false  false  false  false     false  false  false  false  false
 false  false  false  false  false     false  false  false  false  false
 false  false  false  false  false     false  false  false  false  false
 false  false  false  false  false  …  false  false  false  false  false
 false  false  false  false  false     false  false  false  false  false
 false  false  false  false  false     false  false  false  false  false
 false  false  false  false  false     false  false  false  false  false
 false  false  false  false  false     false  false  false  false  false
 false  false  false  false  false  …  false  false  false  false  false
 false  false  false  false  false     false  false  false  false  false
 false  false  false  false  false     false  false  false  false  false
     ⋮                      

## Mini-batches

In [5]:
N, M = size(X, 2), 100
data = [X[:,i] for i in Iterators.partition(1:N,M)]

600-element Array{BitArray{2},1}:
 [false false … false false; false false … false false; … ; false false … false false; false false … false false]
 [false false … false false; false false … false false; … ; false false … false false; false false … false false]
 [false false … false false; false false … false false; … ; false false … false false; false false … false false]
 [false false … false false; false false … false false; … ; false false … false false; false false … false false]
 [false false … false false; false false … false false; … ; false false … false false; false false … false false]
 [false false … false false; false false … false false; … ; false false … false false; false false … false false]
 [false false … false false; false false … false false; … ; false false … false false; false false … false false]
 [false false … false false; false false … false false; … ; false false … false false; false false … false false]
 [false false … false false; false false … false false

# Define Model

In [6]:
# Dimensions (D -> Dh -> Dz -> Dh -> D)
Dz, Dh, D = 5, 500, 28^2

# Encoder
A, μ, logσ = Dense(D, Dh, tanh), Dense(Dh, Dz), Dense(Dh, Dz)
g(X) = (h = A(X); (μ(h), logσ(h)))
z(μ, logσ) = μ + exp(logσ) * randn()

# Decoder
f = Chain(Dense(Dz, Dh, tanh), Dense(Dh, D, σ))

Chain(Dense(5, 500, tanh), Dense(500, 784, NNlib.σ))

# Loss function and ELBO

In [7]:
# KL-divergence between approximation posterior and N(0, 1) prior.
kl_q_p(μ, logσ) = 0.5 * sum(exp.(2 .* logσ) + μ.^2 .- 1 .+ logσ.^2)

# logp(x|z) - conditional probability of data given latents.
logp_x_z(x, z) = sum(logpdf.(Bernoulli.(f(z)), x))

# Monte Carlo estimator of mean ELBO using M samples.
L̄(X) = ((μ̂, logσ̂) = g(X); (logp_x_z(X, z.(μ̂, logσ̂)) - kl_q_p(μ̂, logσ̂)) / M)

loss(X) = -L̄(X) + 0.01 * sum(x->sum(x.^2), params(f))

# Sample from the learned model.
modelsample() = rand.(Bernoulli.(f(z.(zeros(Dz), zeros(Dz)))))

modelsample (generic function with 1 method)

# Learning

In [8]:
evalcb = throttle(() -> @show(-L̄(X[:, rand(1:N, M)])), 30)
opt = ADAM(params(A, μ, logσ, f))
@progress for i = 1:10
  println("Epoch $i")
  Flux.train!(loss, zip(data), opt, cb=evalcb)
end

Epoch 1
-(L̄(X[:, rand(1:N, M)])) = 545.1419619182199 (tracked)
-(L̄(X[:, rand(1:N, M)])) = 209.64264420630295 (tracked)
-(L̄(X[:, rand(1:N, M)])) = 190.1063487606253 (tracked)
-(L̄(X[:, rand(1:N, M)])) = 186.31947967277557 (tracked)
-(L̄(X[:, rand(1:N, M)])) = 183.81528556072715 (tracked)
-(L̄(X[:, rand(1:N, M)])) = 173.3504006933128 (tracked)
-(L̄(X[:, rand(1:N, M)])) = 178.7215420450704 (tracked)
-(L̄(X[:, rand(1:N, M)])) = 177.792189159248 (tracked)
-(L̄(X[:, rand(1:N, M)])) = 168.01247836758571 (tracked)
-(L̄(X[:, rand(1:N, M)])) = 181.54266353444308 (tracked)
-(L̄(X[:, rand(1:N, M)])) = 167.59280986056675 (tracked)
-(L̄(X[:, rand(1:N, M)])) = 171.32855676469953 (tracked)
-(L̄(X[:, rand(1:N, M)])) = 166.5451867455273 (tracked)
-(L̄(X[:, rand(1:N, M)])) = 173.43650763305786 (tracked)
Epoch 2
-(L̄(X[:, rand(1:N, M)])) = 172.0884107137738 (tracked)
-(L̄(X[:, rand(1:N, M)])) = 182.1772309214022 (tracked)
-(L̄(X[:, rand(1:N, M)])) = 189.42716652112756 (tracked)
-(L̄(X[:, rand(1:N, M)])

# Sample Output

In [9]:
using Images

img(x) = Gray.(reshape(x, 28, 28))

cd(@__DIR__)
sample = hcat(img.([modelsample() for i = 1:10])...)
save("vae-sample.png", sample)