# Variational autoencoder

In [1]:
using Flux, Flux.Data.MNIST
using Flux: @epochs, throttle, params
using Distributions
import Distributions: logpdf



**For numerically stable, extend distributions slightly to have a logpdf for `p` close to 1 or 0.**

In [2]:
logpdf(b::Bernoulli, y::Bool) = y * log(b.p + eps()) + (1 - y) * log(1 - b.p + eps())

logpdf (generic function with 64 methods)

# Load data

In [3]:
X = MNIST.images();

## Binarise

In [4]:
X = float.(hcat(vec.(X)...)) .> 0.5

784×60000 BitArray{2}:
 0  0  0  0  0  0  0  0  0  0  0  0  0  …  0  0  0  0  0  0  0  0  0  0  0  0
 0  0  0  0  0  0  0  0  0  0  0  0  0     0  0  0  0  0  0  0  0  0  0  0  0
 0  0  0  0  0  0  0  0  0  0  0  0  0     0  0  0  0  0  0  0  0  0  0  0  0
 0  0  0  0  0  0  0  0  0  0  0  0  0     0  0  0  0  0  0  0  0  0  0  0  0
 0  0  0  0  0  0  0  0  0  0  0  0  0     0  0  0  0  0  0  0  0  0  0  0  0
 0  0  0  0  0  0  0  0  0  0  0  0  0  …  0  0  0  0  0  0  0  0  0  0  0  0
 0  0  0  0  0  0  0  0  0  0  0  0  0     0  0  0  0  0  0  0  0  0  0  0  0
 0  0  0  0  0  0  0  0  0  0  0  0  0     0  0  0  0  0  0  0  0  0  0  0  0
 0  0  0  0  0  0  0  0  0  0  0  0  0     0  0  0  0  0  0  0  0  0  0  0  0
 0  0  0  0  0  0  0  0  0  0  0  0  0     0  0  0  0  0  0  0  0  0  0  0  0
 0  0  0  0  0  0  0  0  0  0  0  0  0  …  0  0  0  0  0  0  0  0  0  0  0  0
 0  0  0  0  0  0  0  0  0  0  0  0  0     0  0  0  0  0  0  0  0  0  0  0  0
 0  0  0  0  0  0  0  0  0  0  0  0  0   

## Mini-batches

In [5]:
N, M = size(X, 2), 100
data = [X[:,i] for i in Iterators.partition(1:N,M)]

600-element Array{BitArray{2},1}:
 [0 0 … 0 0; 0 0 … 0 0; … ; 0 0 … 0 0; 0 0 … 0 0]
 [0 0 … 0 0; 0 0 … 0 0; … ; 0 0 … 0 0; 0 0 … 0 0]
 [0 0 … 0 0; 0 0 … 0 0; … ; 0 0 … 0 0; 0 0 … 0 0]
 [0 0 … 0 0; 0 0 … 0 0; … ; 0 0 … 0 0; 0 0 … 0 0]
 [0 0 … 0 0; 0 0 … 0 0; … ; 0 0 … 0 0; 0 0 … 0 0]
 [0 0 … 0 0; 0 0 … 0 0; … ; 0 0 … 0 0; 0 0 … 0 0]
 [0 0 … 0 0; 0 0 … 0 0; … ; 0 0 … 0 0; 0 0 … 0 0]
 [0 0 … 0 0; 0 0 … 0 0; … ; 0 0 … 0 0; 0 0 … 0 0]
 [0 0 … 0 0; 0 0 … 0 0; … ; 0 0 … 0 0; 0 0 … 0 0]
 [0 0 … 0 0; 0 0 … 0 0; … ; 0 0 … 0 0; 0 0 … 0 0]
 [0 0 … 0 0; 0 0 … 0 0; … ; 0 0 … 0 0; 0 0 … 0 0]
 [0 0 … 0 0; 0 0 … 0 0; … ; 0 0 … 0 0; 0 0 … 0 0]
 [0 0 … 0 0; 0 0 … 0 0; … ; 0 0 … 0 0; 0 0 … 0 0]
 ⋮
 [0 0 … 0 0; 0 0 … 0 0; … ; 0 0 … 0 0; 0 0 … 0 0]
 [0 0 … 0 0; 0 0 … 0 0; … ; 0 0 … 0 0; 0 0 … 0 0]
 [0 0 … 0 0; 0 0 … 0 0; … ; 0 0 … 0 0; 0 0 … 0 0]
 [0 0 … 0 0; 0 0 … 0 0; … ; 0 0 … 0 0; 0 0 … 0 0]
 [0 0 … 0 0; 0 0 … 0 0; … ; 0 0 … 0 0; 0 0 … 0 0]
 [0 0 … 0 0; 0 0 … 0 0; … ; 0 0 … 0 0; 0 0 … 0 0]
 [0 0 … 0 0; 

# Define Model

In [6]:
# Dimensions (D -> Dh -> Dz -> Dh -> D)
Dz, Dh, D = 5, 500, 28^2

# Encoder
A, μ, logσ = Dense(D, Dh, tanh), Dense(Dh, Dz), Dense(Dh, Dz)
g(X) = (h = A(X); (μ(h), logσ(h)))
z(μ, logσ) = μ + exp(logσ) * randn()

# Decoder
f = Chain(Dense(Dz, Dh, tanh), Dense(Dh, D, σ))

Chain(Dense(5, 500, tanh), Dense(500, 784, σ))

# Loss function and ELBO

In [7]:
# KL-divergence between approximation posterior and N(0, 1) prior.
kl_q_p(μ, logσ) = 0.5 * sum(exp.(2 .* logσ) + μ.^2 .- 1 .+ logσ.^2)

# logp(x|z) - conditional probability of data given latents.
logp_x_z(x, z) = sum(logpdf.(Bernoulli.(f(z)), x))

# Monte Carlo estimator of mean ELBO using M samples.
L̄(X) = ((μ̂, logσ̂) = g(X); (logp_x_z(X, z.(μ̂, logσ̂)) - kl_q_p(μ̂, logσ̂)) / M)

loss(X) = -L̄(X) + 0.01 * sum(x->sum(x.^2), params(f))

# Sample from the learned model.
modelsample() = rand.(Bernoulli.(f(z.(zeros(Dz), zeros(Dz)))))

modelsample (generic function with 1 method)

# Learning

In [8]:
evalcb = throttle(() -> @show(-L̄(X[:, rand(1:N, M)])), 30)
opt = ADAM()
@epochs 10 Flux.train!(loss, params(A, μ, logσ, f), zip(data), opt, cb=evalcb)

┌ Info: Epoch 1
└ @ Main /home/yuehhua/.julia/packages/Flux/Fj3bt/src/optimise/train.jl:121


-(L̄(X[:, rand(1:N, M)])) = 543.2253876344163
-(L̄(X[:, rand(1:N, M)])) = 210.90770182909597
-(L̄(X[:, rand(1:N, M)])) = 182.37667903594695
-(L̄(X[:, rand(1:N, M)])) = 176.8284544304646
-(L̄(X[:, rand(1:N, M)])) = 166.23625509264545
-(L̄(X[:, rand(1:N, M)])) = 164.13378323882003
-(L̄(X[:, rand(1:N, M)])) = 166.2227478943171
-(L̄(X[:, rand(1:N, M)])) = 167.46112752433044
-(L̄(X[:, rand(1:N, M)])) = 165.69806285103962
-(L̄(X[:, rand(1:N, M)])) = 162.4959224058759
-(L̄(X[:, rand(1:N, M)])) = 159.14507198446748
-(L̄(X[:, rand(1:N, M)])) = 165.51285004401856
-(L̄(X[:, rand(1:N, M)])) = 162.4391337805895
-(L̄(X[:, rand(1:N, M)])) = 158.4599822592562


┌ Info: Epoch 2
└ @ Main /home/yuehhua/.julia/packages/Flux/Fj3bt/src/optimise/train.jl:121


-(L̄(X[:, rand(1:N, M)])) = 161.03412204020972
-(L̄(X[:, rand(1:N, M)])) = 156.7389618437368
-(L̄(X[:, rand(1:N, M)])) = 158.1771271355005
-(L̄(X[:, rand(1:N, M)])) = 160.88672806869374
-(L̄(X[:, rand(1:N, M)])) = 155.79862867113846
-(L̄(X[:, rand(1:N, M)])) = 160.02524540488443
-(L̄(X[:, rand(1:N, M)])) = 150.3229243267442
-(L̄(X[:, rand(1:N, M)])) = 155.1546929795795
-(L̄(X[:, rand(1:N, M)])) = 158.97106259231666
-(L̄(X[:, rand(1:N, M)])) = 151.70704697131768
-(L̄(X[:, rand(1:N, M)])) = 155.14125917124053
-(L̄(X[:, rand(1:N, M)])) = 154.15639887942905
-(L̄(X[:, rand(1:N, M)])) = 145.10338762876364
-(L̄(X[:, rand(1:N, M)])) = 164.4805176921039


┌ Info: Epoch 3
└ @ Main /home/yuehhua/.julia/packages/Flux/Fj3bt/src/optimise/train.jl:121


-(L̄(X[:, rand(1:N, M)])) = 148.51659831133435
-(L̄(X[:, rand(1:N, M)])) = 153.5461828286335
-(L̄(X[:, rand(1:N, M)])) = 147.2820018498238
-(L̄(X[:, rand(1:N, M)])) = 148.86039214863516
-(L̄(X[:, rand(1:N, M)])) = 151.0876332741134
-(L̄(X[:, rand(1:N, M)])) = 150.71910344321128
-(L̄(X[:, rand(1:N, M)])) = 149.4393017566366
-(L̄(X[:, rand(1:N, M)])) = 151.50084433905374
-(L̄(X[:, rand(1:N, M)])) = 148.03157701201147
-(L̄(X[:, rand(1:N, M)])) = 150.7750336383714
-(L̄(X[:, rand(1:N, M)])) = 143.9872031495771
-(L̄(X[:, rand(1:N, M)])) = 146.48314006619918
-(L̄(X[:, rand(1:N, M)])) = 152.95312741773247


┌ Info: Epoch 4
└ @ Main /home/yuehhua/.julia/packages/Flux/Fj3bt/src/optimise/train.jl:121


-(L̄(X[:, rand(1:N, M)])) = 148.17424985604464
-(L̄(X[:, rand(1:N, M)])) = 151.7466289691474
-(L̄(X[:, rand(1:N, M)])) = 146.4476323311858
-(L̄(X[:, rand(1:N, M)])) = 145.74483143940458
-(L̄(X[:, rand(1:N, M)])) = 144.53607755980426
-(L̄(X[:, rand(1:N, M)])) = 150.01887649783197
-(L̄(X[:, rand(1:N, M)])) = 144.7111089045797
-(L̄(X[:, rand(1:N, M)])) = 137.4631344712804
-(L̄(X[:, rand(1:N, M)])) = 144.50302983320714
-(L̄(X[:, rand(1:N, M)])) = 146.95350558937076
-(L̄(X[:, rand(1:N, M)])) = 139.26244098124513
-(L̄(X[:, rand(1:N, M)])) = 142.86931438422022
-(L̄(X[:, rand(1:N, M)])) = 145.04857109110947


┌ Info: Epoch 5
└ @ Main /home/yuehhua/.julia/packages/Flux/Fj3bt/src/optimise/train.jl:121


-(L̄(X[:, rand(1:N, M)])) = 146.78166291085452
-(L̄(X[:, rand(1:N, M)])) = 149.55042317698576
-(L̄(X[:, rand(1:N, M)])) = 143.42558181554992
-(L̄(X[:, rand(1:N, M)])) = 149.40420360091608
-(L̄(X[:, rand(1:N, M)])) = 144.96669214455568
-(L̄(X[:, rand(1:N, M)])) = 143.76979748001227
-(L̄(X[:, rand(1:N, M)])) = 145.42393278857972
-(L̄(X[:, rand(1:N, M)])) = 140.62927204255107
-(L̄(X[:, rand(1:N, M)])) = 143.26140974259715
-(L̄(X[:, rand(1:N, M)])) = 135.53553238240622
-(L̄(X[:, rand(1:N, M)])) = 140.0573741000583
-(L̄(X[:, rand(1:N, M)])) = 142.98340302800239
-(L̄(X[:, rand(1:N, M)])) = 139.90787737949992
-(L̄(X[:, rand(1:N, M)])) = 147.9919342269344


┌ Info: Epoch 6
└ @ Main /home/yuehhua/.julia/packages/Flux/Fj3bt/src/optimise/train.jl:121


-(L̄(X[:, rand(1:N, M)])) = 156.53304805578634
-(L̄(X[:, rand(1:N, M)])) = 140.47259311682149
-(L̄(X[:, rand(1:N, M)])) = 138.26591734950435
-(L̄(X[:, rand(1:N, M)])) = 144.86717807820867
-(L̄(X[:, rand(1:N, M)])) = 142.68051808334687
-(L̄(X[:, rand(1:N, M)])) = 142.32425259448482
-(L̄(X[:, rand(1:N, M)])) = 147.95957363093913
-(L̄(X[:, rand(1:N, M)])) = 148.02098878033357
-(L̄(X[:, rand(1:N, M)])) = 148.4827426125895
-(L̄(X[:, rand(1:N, M)])) = 141.2648393738112
-(L̄(X[:, rand(1:N, M)])) = 139.77244143318464
-(L̄(X[:, rand(1:N, M)])) = 147.81476616221343
-(L̄(X[:, rand(1:N, M)])) = 138.44730365311068


┌ Info: Epoch 7
└ @ Main /home/yuehhua/.julia/packages/Flux/Fj3bt/src/optimise/train.jl:121


-(L̄(X[:, rand(1:N, M)])) = 139.30169455894173
-(L̄(X[:, rand(1:N, M)])) = 146.2446533589009
-(L̄(X[:, rand(1:N, M)])) = 140.2470440994723
-(L̄(X[:, rand(1:N, M)])) = 142.60571132366624
-(L̄(X[:, rand(1:N, M)])) = 147.51699421115313
-(L̄(X[:, rand(1:N, M)])) = 138.03915898617404
-(L̄(X[:, rand(1:N, M)])) = 139.78039568775898
-(L̄(X[:, rand(1:N, M)])) = 140.31444011523843
-(L̄(X[:, rand(1:N, M)])) = 142.8195495008625
-(L̄(X[:, rand(1:N, M)])) = 147.44421428530913
-(L̄(X[:, rand(1:N, M)])) = 144.54885158765296
-(L̄(X[:, rand(1:N, M)])) = 143.573021558816
-(L̄(X[:, rand(1:N, M)])) = 137.66572906063965
-(L̄(X[:, rand(1:N, M)])) = 136.61115505613074


┌ Info: Epoch 8
└ @ Main /home/yuehhua/.julia/packages/Flux/Fj3bt/src/optimise/train.jl:121


-(L̄(X[:, rand(1:N, M)])) = 139.57253036017062
-(L̄(X[:, rand(1:N, M)])) = 148.89092090742378
-(L̄(X[:, rand(1:N, M)])) = 149.26619897792014
-(L̄(X[:, rand(1:N, M)])) = 143.243834831569
-(L̄(X[:, rand(1:N, M)])) = 147.86138192998067
-(L̄(X[:, rand(1:N, M)])) = 148.0655915586063
-(L̄(X[:, rand(1:N, M)])) = 147.8141323817697
-(L̄(X[:, rand(1:N, M)])) = 145.4305581754671
-(L̄(X[:, rand(1:N, M)])) = 138.2097762034181
-(L̄(X[:, rand(1:N, M)])) = 134.9522158469559
-(L̄(X[:, rand(1:N, M)])) = 134.5969250614072
-(L̄(X[:, rand(1:N, M)])) = 139.542342818872
-(L̄(X[:, rand(1:N, M)])) = 136.93176063659178
-(L̄(X[:, rand(1:N, M)])) = 137.45656662160923


┌ Info: Epoch 9
└ @ Main /home/yuehhua/.julia/packages/Flux/Fj3bt/src/optimise/train.jl:121


-(L̄(X[:, rand(1:N, M)])) = 148.29414316052817
-(L̄(X[:, rand(1:N, M)])) = 139.59843323808767
-(L̄(X[:, rand(1:N, M)])) = 136.7469974385336
-(L̄(X[:, rand(1:N, M)])) = 140.03094821240018
-(L̄(X[:, rand(1:N, M)])) = 141.66299899141615
-(L̄(X[:, rand(1:N, M)])) = 139.63165820842949
-(L̄(X[:, rand(1:N, M)])) = 136.05798960674645
-(L̄(X[:, rand(1:N, M)])) = 133.51882811158828
-(L̄(X[:, rand(1:N, M)])) = 137.002564251954
-(L̄(X[:, rand(1:N, M)])) = 144.94032442533734
-(L̄(X[:, rand(1:N, M)])) = 139.91165662733604
-(L̄(X[:, rand(1:N, M)])) = 149.01580826369454
-(L̄(X[:, rand(1:N, M)])) = 139.45007465023917


┌ Info: Epoch 10
└ @ Main /home/yuehhua/.julia/packages/Flux/Fj3bt/src/optimise/train.jl:121


-(L̄(X[:, rand(1:N, M)])) = 147.49770785587486
-(L̄(X[:, rand(1:N, M)])) = 141.5444633022392
-(L̄(X[:, rand(1:N, M)])) = 139.3365246722247
-(L̄(X[:, rand(1:N, M)])) = 143.209801681998
-(L̄(X[:, rand(1:N, M)])) = 140.38463613252443
-(L̄(X[:, rand(1:N, M)])) = 143.2812841633853
-(L̄(X[:, rand(1:N, M)])) = 137.4326181092248
-(L̄(X[:, rand(1:N, M)])) = 148.74905374253748
-(L̄(X[:, rand(1:N, M)])) = 135.08367360663794
-(L̄(X[:, rand(1:N, M)])) = 140.7546174413095
-(L̄(X[:, rand(1:N, M)])) = 130.7631354278858
-(L̄(X[:, rand(1:N, M)])) = 144.48058290052003
-(L̄(X[:, rand(1:N, M)])) = 140.47611935270174
-(L̄(X[:, rand(1:N, M)])) = 143.75693510937379


# Sample Output

In [9]:
using Images

img(x) = Gray.(reshape(x, 28, 28))

cd(@__DIR__)
sample = hcat(img.([modelsample() for i = 1:10])...)
save("vae-sample.png", sample)

┌ Info: Precompiling Images [916415d5-f1e6-5110-898d-aaa5f9f070e0]
└ @ Base loading.jl:1260
┌ Info: Precompiling ImageMagick [6218d12a-5da1-5696-b52f-db25d2ecc6d1]
└ @ Base loading.jl:1260
