# Variational autoencoder

In [1]:
using Flux, Flux.Data.MNIST
using Flux: @epochs, throttle, params
using Distributions
import Distributions: logpdf
using CuArrays



**For numerically stable, extend distributions slightly to have a logpdf for `p` close to 1 or 0.**

In [2]:
logpdf(b::Bernoulli, y::Bool) = y * log(b.p + eps()) + (1 - y) * log(1 - b.p + eps())

logpdf (generic function with 64 methods)

# Load data

In [3]:
X = MNIST.images();

## Binarise

In [4]:
X = float.(hcat(vec.(X)...)) .> 0.5

784×60000 BitArray{2}:
 0  0  0  0  0  0  0  0  0  0  0  0  0  …  0  0  0  0  0  0  0  0  0  0  0  0
 0  0  0  0  0  0  0  0  0  0  0  0  0     0  0  0  0  0  0  0  0  0  0  0  0
 0  0  0  0  0  0  0  0  0  0  0  0  0     0  0  0  0  0  0  0  0  0  0  0  0
 0  0  0  0  0  0  0  0  0  0  0  0  0     0  0  0  0  0  0  0  0  0  0  0  0
 0  0  0  0  0  0  0  0  0  0  0  0  0     0  0  0  0  0  0  0  0  0  0  0  0
 0  0  0  0  0  0  0  0  0  0  0  0  0  …  0  0  0  0  0  0  0  0  0  0  0  0
 0  0  0  0  0  0  0  0  0  0  0  0  0     0  0  0  0  0  0  0  0  0  0  0  0
 0  0  0  0  0  0  0  0  0  0  0  0  0     0  0  0  0  0  0  0  0  0  0  0  0
 0  0  0  0  0  0  0  0  0  0  0  0  0     0  0  0  0  0  0  0  0  0  0  0  0
 0  0  0  0  0  0  0  0  0  0  0  0  0     0  0  0  0  0  0  0  0  0  0  0  0
 0  0  0  0  0  0  0  0  0  0  0  0  0  …  0  0  0  0  0  0  0  0  0  0  0  0
 0  0  0  0  0  0  0  0  0  0  0  0  0     0  0  0  0  0  0  0  0  0  0  0  0
 0  0  0  0  0  0  0  0  0  0  0  0  0   

## Mini-batches

In [5]:
N, M = size(X, 2), 100
data = [X[:,i] for i in Iterators.partition(1:N,M)]

600-element Array{BitArray{2},1}:
 [0 0 … 0 0; 0 0 … 0 0; … ; 0 0 … 0 0; 0 0 … 0 0]
 [0 0 … 0 0; 0 0 … 0 0; … ; 0 0 … 0 0; 0 0 … 0 0]
 [0 0 … 0 0; 0 0 … 0 0; … ; 0 0 … 0 0; 0 0 … 0 0]
 [0 0 … 0 0; 0 0 … 0 0; … ; 0 0 … 0 0; 0 0 … 0 0]
 [0 0 … 0 0; 0 0 … 0 0; … ; 0 0 … 0 0; 0 0 … 0 0]
 [0 0 … 0 0; 0 0 … 0 0; … ; 0 0 … 0 0; 0 0 … 0 0]
 [0 0 … 0 0; 0 0 … 0 0; … ; 0 0 … 0 0; 0 0 … 0 0]
 [0 0 … 0 0; 0 0 … 0 0; … ; 0 0 … 0 0; 0 0 … 0 0]
 [0 0 … 0 0; 0 0 … 0 0; … ; 0 0 … 0 0; 0 0 … 0 0]
 [0 0 … 0 0; 0 0 … 0 0; … ; 0 0 … 0 0; 0 0 … 0 0]
 [0 0 … 0 0; 0 0 … 0 0; … ; 0 0 … 0 0; 0 0 … 0 0]
 [0 0 … 0 0; 0 0 … 0 0; … ; 0 0 … 0 0; 0 0 … 0 0]
 [0 0 … 0 0; 0 0 … 0 0; … ; 0 0 … 0 0; 0 0 … 0 0]
 ⋮                                               
 [0 0 … 0 0; 0 0 … 0 0; … ; 0 0 … 0 0; 0 0 … 0 0]
 [0 0 … 0 0; 0 0 … 0 0; … ; 0 0 … 0 0; 0 0 … 0 0]
 [0 0 … 0 0; 0 0 … 0 0; … ; 0 0 … 0 0; 0 0 … 0 0]
 [0 0 … 0 0; 0 0 … 0 0; … ; 0 0 … 0 0; 0 0 … 0 0]
 [0 0 … 0 0; 0 0 … 0 0; … ; 0 0 … 0 0; 0 0 … 0 0]
 [0 0 … 0 0; 0 0

# Define Model

In [6]:
# Dimensions (D -> Dh -> Dz -> Dh -> D)
Dz, Dh, D = 5, 500, 28^2

# Encoder
A, μ, logσ = Dense(D, Dh, tanh), Dense(Dh, Dz), Dense(Dh, Dz)
g(X) = (h = A(X); (μ(h), logσ(h)))
z(μ, logσ) = μ + exp(logσ) * randn()

# Decoder
f = Chain(Dense(Dz, Dh, tanh), Dense(Dh, D, σ))

Chain(Dense(5, 500, tanh), Dense(500, 784, σ))

# Loss function and ELBO

In [7]:
# KL-divergence between approximation posterior and N(0, 1) prior.
kl_q_p(μ, logσ) = 0.5 * sum(exp.(2 .* logσ) + μ.^2 .- 1 .+ logσ.^2)

# logp(x|z) - conditional probability of data given latents.
logp_x_z(x, z) = sum(logpdf.(Bernoulli.(f(z)), x))

# Monte Carlo estimator of mean ELBO using M samples.
L̄(X) = ((μ̂, logσ̂) = g(X); (logp_x_z(X, z.(μ̂, logσ̂)) - kl_q_p(μ̂, logσ̂)) / M)

loss(X) = -L̄(X) + 0.01 * sum(x->sum(x.^2), params(f))

# Sample from the learned model.
modelsample() = rand.(Bernoulli.(f(z.(zeros(Dz), zeros(Dz)))))

modelsample (generic function with 1 method)

# Learning

In [8]:
evalcb = throttle(() -> @show(-L̄(X[:, rand(1:N, M)])), 30)
opt = ADAM()
@epochs 10 Flux.train!(loss, params(A, μ, logσ, f), zip(data), opt, cb=evalcb)

┌ Info: Epoch 1
└ @ Main /home/yuehhua/.julia/packages/Flux/2i5P1/src/optimise/train.jl:99


-(L̄(X[:, rand(1:N, M)])) = 541.7407199471362
-(L̄(X[:, rand(1:N, M)])) = 214.72542568485625
-(L̄(X[:, rand(1:N, M)])) = 203.42657103518593
-(L̄(X[:, rand(1:N, M)])) = 178.949635427151
-(L̄(X[:, rand(1:N, M)])) = 168.35816658830868
-(L̄(X[:, rand(1:N, M)])) = 166.11557725417492
-(L̄(X[:, rand(1:N, M)])) = 161.6899527789075
-(L̄(X[:, rand(1:N, M)])) = 162.28474149098153
-(L̄(X[:, rand(1:N, M)])) = 165.8272966943006
-(L̄(X[:, rand(1:N, M)])) = 162.09767988431912
-(L̄(X[:, rand(1:N, M)])) = 158.76872763982703
-(L̄(X[:, rand(1:N, M)])) = 164.4417796971319
-(L̄(X[:, rand(1:N, M)])) = 159.081885643022
-(L̄(X[:, rand(1:N, M)])) = 161.23722101419753
-(L̄(X[:, rand(1:N, M)])) = 161.77759667428683
-(L̄(X[:, rand(1:N, M)])) = 153.07368635785858
-(L̄(X[:, rand(1:N, M)])) = 165.53754977210403
-(L̄(X[:, rand(1:N, M)])) = 158.68665603439223
-(L̄(X[:, rand(1:N, M)])) = 159.85287863820696
-(L̄(X[:, rand(1:N, M)])) = 152.5075388574712
-(L̄(X[:, rand(1:N, M)])) = 158.66419254779237


┌ Info: Epoch 2
└ @ Main /home/yuehhua/.julia/packages/Flux/2i5P1/src/optimise/train.jl:99


-(L̄(X[:, rand(1:N, M)])) = 160.00149858271052
-(L̄(X[:, rand(1:N, M)])) = 151.71590312286938
-(L̄(X[:, rand(1:N, M)])) = 163.90929956831903
-(L̄(X[:, rand(1:N, M)])) = 153.4392633687689
-(L̄(X[:, rand(1:N, M)])) = 158.433274002692
-(L̄(X[:, rand(1:N, M)])) = 160.99774749311234
-(L̄(X[:, rand(1:N, M)])) = 152.44432808291148
-(L̄(X[:, rand(1:N, M)])) = 147.62808525842908
-(L̄(X[:, rand(1:N, M)])) = 154.0421915229954
-(L̄(X[:, rand(1:N, M)])) = 154.0824997119272
-(L̄(X[:, rand(1:N, M)])) = 152.7086095439197
-(L̄(X[:, rand(1:N, M)])) = 149.4225534575111
-(L̄(X[:, rand(1:N, M)])) = 154.1972294082436
-(L̄(X[:, rand(1:N, M)])) = 151.50959506262186
-(L̄(X[:, rand(1:N, M)])) = 147.23162035360374
-(L̄(X[:, rand(1:N, M)])) = 146.96179113028577
-(L̄(X[:, rand(1:N, M)])) = 153.2261914242795
-(L̄(X[:, rand(1:N, M)])) = 154.0460852788691
-(L̄(X[:, rand(1:N, M)])) = 146.2673600046214
-(L̄(X[:, rand(1:N, M)])) = 148.2154728050556
-(L̄(X[:, rand(1:N, M)])) = 145.804277765859


┌ Info: Epoch 3
└ @ Main /home/yuehhua/.julia/packages/Flux/2i5P1/src/optimise/train.jl:99


-(L̄(X[:, rand(1:N, M)])) = 146.31888691208675
-(L̄(X[:, rand(1:N, M)])) = 153.56963807462282
-(L̄(X[:, rand(1:N, M)])) = 144.53128907218118
-(L̄(X[:, rand(1:N, M)])) = 142.96883256775283
-(L̄(X[:, rand(1:N, M)])) = 142.9206558103404
-(L̄(X[:, rand(1:N, M)])) = 147.82756559052058
-(L̄(X[:, rand(1:N, M)])) = 136.80241449713876
-(L̄(X[:, rand(1:N, M)])) = 152.32872547586925
-(L̄(X[:, rand(1:N, M)])) = 153.0723823611273
-(L̄(X[:, rand(1:N, M)])) = 134.6572933325032
-(L̄(X[:, rand(1:N, M)])) = 137.89286961284571
-(L̄(X[:, rand(1:N, M)])) = 138.28334377038868
-(L̄(X[:, rand(1:N, M)])) = 149.65103820508196
-(L̄(X[:, rand(1:N, M)])) = 144.50743019817085
-(L̄(X[:, rand(1:N, M)])) = 144.5439524659022
-(L̄(X[:, rand(1:N, M)])) = 148.48878576926822
-(L̄(X[:, rand(1:N, M)])) = 143.99685338017315
-(L̄(X[:, rand(1:N, M)])) = 138.7548153688293
-(L̄(X[:, rand(1:N, M)])) = 144.3353897796557
-(L̄(X[:, rand(1:N, M)])) = 140.30453984206812
-(L̄(X[:, rand(1:N, M)])) = 145.83519640052916


┌ Info: Epoch 4
└ @ Main /home/yuehhua/.julia/packages/Flux/2i5P1/src/optimise/train.jl:99


-(L̄(X[:, rand(1:N, M)])) = 137.2877073039475
-(L̄(X[:, rand(1:N, M)])) = 142.99763633894116
-(L̄(X[:, rand(1:N, M)])) = 136.98130086938318
-(L̄(X[:, rand(1:N, M)])) = 141.26412767172903
-(L̄(X[:, rand(1:N, M)])) = 136.36386909596737
-(L̄(X[:, rand(1:N, M)])) = 136.73048534462498
-(L̄(X[:, rand(1:N, M)])) = 133.06010498018858
-(L̄(X[:, rand(1:N, M)])) = 144.53486410328708
-(L̄(X[:, rand(1:N, M)])) = 138.24347598779622
-(L̄(X[:, rand(1:N, M)])) = 131.99620245504636
-(L̄(X[:, rand(1:N, M)])) = 140.0788668223345
-(L̄(X[:, rand(1:N, M)])) = 130.2924765732042
-(L̄(X[:, rand(1:N, M)])) = 125.57503011433022
-(L̄(X[:, rand(1:N, M)])) = 138.4472476514121
-(L̄(X[:, rand(1:N, M)])) = 137.2665038060604
-(L̄(X[:, rand(1:N, M)])) = 140.25453512310403
-(L̄(X[:, rand(1:N, M)])) = 138.424042005107
-(L̄(X[:, rand(1:N, M)])) = 134.90956021556147
-(L̄(X[:, rand(1:N, M)])) = 144.12979759831967
-(L̄(X[:, rand(1:N, M)])) = 135.24004261731295
-(L̄(X[:, rand(1:N, M)])) = 145.0475359891398


┌ Info: Epoch 5
└ @ Main /home/yuehhua/.julia/packages/Flux/2i5P1/src/optimise/train.jl:99


-(L̄(X[:, rand(1:N, M)])) = 141.23085189531287
-(L̄(X[:, rand(1:N, M)])) = 135.64594545162075
-(L̄(X[:, rand(1:N, M)])) = 131.6279254702832
-(L̄(X[:, rand(1:N, M)])) = 135.01854771250046
-(L̄(X[:, rand(1:N, M)])) = 131.81703623994454
-(L̄(X[:, rand(1:N, M)])) = 125.24471367662046
-(L̄(X[:, rand(1:N, M)])) = 135.66779950591106
-(L̄(X[:, rand(1:N, M)])) = 135.4343034850009
-(L̄(X[:, rand(1:N, M)])) = 137.38026688799556
-(L̄(X[:, rand(1:N, M)])) = 141.02832094429385
-(L̄(X[:, rand(1:N, M)])) = 137.38136639958958
-(L̄(X[:, rand(1:N, M)])) = 128.31415298478296
-(L̄(X[:, rand(1:N, M)])) = 134.31674668308258
-(L̄(X[:, rand(1:N, M)])) = 136.8816689340441
-(L̄(X[:, rand(1:N, M)])) = 133.238344359209
-(L̄(X[:, rand(1:N, M)])) = 126.13597576343417
-(L̄(X[:, rand(1:N, M)])) = 135.42933487029214
-(L̄(X[:, rand(1:N, M)])) = 145.2054170586898
-(L̄(X[:, rand(1:N, M)])) = 129.5815836797302
-(L̄(X[:, rand(1:N, M)])) = 130.32190254418228
-(L̄(X[:, rand(1:N, M)])) = 137.75822334166395


┌ Info: Epoch 6
└ @ Main /home/yuehhua/.julia/packages/Flux/2i5P1/src/optimise/train.jl:99


-(L̄(X[:, rand(1:N, M)])) = 124.43256893404993
-(L̄(X[:, rand(1:N, M)])) = 135.3573067865918
-(L̄(X[:, rand(1:N, M)])) = 134.072460658761
-(L̄(X[:, rand(1:N, M)])) = 134.86230712844244
-(L̄(X[:, rand(1:N, M)])) = 132.22009275193986
-(L̄(X[:, rand(1:N, M)])) = 143.21634645943945
-(L̄(X[:, rand(1:N, M)])) = 136.79634253621
-(L̄(X[:, rand(1:N, M)])) = 123.95255920937639
-(L̄(X[:, rand(1:N, M)])) = 135.6306750636356
-(L̄(X[:, rand(1:N, M)])) = 136.41036356343216
-(L̄(X[:, rand(1:N, M)])) = 133.66321689231404
-(L̄(X[:, rand(1:N, M)])) = 132.689269008089
-(L̄(X[:, rand(1:N, M)])) = 132.053161661826
-(L̄(X[:, rand(1:N, M)])) = 135.22932583088695
-(L̄(X[:, rand(1:N, M)])) = 135.8441041146297
-(L̄(X[:, rand(1:N, M)])) = 131.31795596068673
-(L̄(X[:, rand(1:N, M)])) = 134.59723121150373
-(L̄(X[:, rand(1:N, M)])) = 133.05931851910805
-(L̄(X[:, rand(1:N, M)])) = 128.68801453266065
-(L̄(X[:, rand(1:N, M)])) = 128.00283682868064
-(L̄(X[:, rand(1:N, M)])) = 135.99474114196863
-(L̄(X[:, rand(1:N, M)]))

┌ Info: Epoch 7
└ @ Main /home/yuehhua/.julia/packages/Flux/2i5P1/src/optimise/train.jl:99


-(L̄(X[:, rand(1:N, M)])) = 132.3460609423721
-(L̄(X[:, rand(1:N, M)])) = 125.81402237485617
-(L̄(X[:, rand(1:N, M)])) = 134.76487686266591
-(L̄(X[:, rand(1:N, M)])) = 128.90412946179475
-(L̄(X[:, rand(1:N, M)])) = 132.3820590902147
-(L̄(X[:, rand(1:N, M)])) = 131.7298368579205
-(L̄(X[:, rand(1:N, M)])) = 127.66921957165178
-(L̄(X[:, rand(1:N, M)])) = 131.7079428250923
-(L̄(X[:, rand(1:N, M)])) = 127.63573944433406
-(L̄(X[:, rand(1:N, M)])) = 132.2774497812939
-(L̄(X[:, rand(1:N, M)])) = 125.7099616179973
-(L̄(X[:, rand(1:N, M)])) = 118.2077991860798
-(L̄(X[:, rand(1:N, M)])) = 136.5077901732908
-(L̄(X[:, rand(1:N, M)])) = 129.01942027243243
-(L̄(X[:, rand(1:N, M)])) = 136.41528663699458
-(L̄(X[:, rand(1:N, M)])) = 128.3770027791361
-(L̄(X[:, rand(1:N, M)])) = 128.93468242539961
-(L̄(X[:, rand(1:N, M)])) = 131.10025885401643
-(L̄(X[:, rand(1:N, M)])) = 131.58259326197927
-(L̄(X[:, rand(1:N, M)])) = 136.73560041451046
-(L̄(X[:, rand(1:N, M)])) = 132.78834650688322


┌ Info: Epoch 8
└ @ Main /home/yuehhua/.julia/packages/Flux/2i5P1/src/optimise/train.jl:99


-(L̄(X[:, rand(1:N, M)])) = 137.20050910200743
-(L̄(X[:, rand(1:N, M)])) = 124.97475658341934
-(L̄(X[:, rand(1:N, M)])) = 131.1826719268621
-(L̄(X[:, rand(1:N, M)])) = 124.1062471538006
-(L̄(X[:, rand(1:N, M)])) = 130.5517732147109
-(L̄(X[:, rand(1:N, M)])) = 134.4693057460509
-(L̄(X[:, rand(1:N, M)])) = 135.81367418990152
-(L̄(X[:, rand(1:N, M)])) = 128.97625616887817
-(L̄(X[:, rand(1:N, M)])) = 126.22570015033911
-(L̄(X[:, rand(1:N, M)])) = 130.111358489403
-(L̄(X[:, rand(1:N, M)])) = 129.77108926536698
-(L̄(X[:, rand(1:N, M)])) = 137.1032484912227
-(L̄(X[:, rand(1:N, M)])) = 125.87769412423397
-(L̄(X[:, rand(1:N, M)])) = 125.42826411698377
-(L̄(X[:, rand(1:N, M)])) = 135.23202535304387
-(L̄(X[:, rand(1:N, M)])) = 127.98343883175403
-(L̄(X[:, rand(1:N, M)])) = 126.81926483220508
-(L̄(X[:, rand(1:N, M)])) = 128.48574922011738
-(L̄(X[:, rand(1:N, M)])) = 130.43180665656374
-(L̄(X[:, rand(1:N, M)])) = 130.62384782153353
-(L̄(X[:, rand(1:N, M)])) = 133.0666470797613


┌ Info: Epoch 9
└ @ Main /home/yuehhua/.julia/packages/Flux/2i5P1/src/optimise/train.jl:99


-(L̄(X[:, rand(1:N, M)])) = 127.50709408314026
-(L̄(X[:, rand(1:N, M)])) = 131.6725899108759
-(L̄(X[:, rand(1:N, M)])) = 129.75330867859145
-(L̄(X[:, rand(1:N, M)])) = 134.83446792560386
-(L̄(X[:, rand(1:N, M)])) = 129.29871138123738
-(L̄(X[:, rand(1:N, M)])) = 127.07453776939288
-(L̄(X[:, rand(1:N, M)])) = 127.94491788283246
-(L̄(X[:, rand(1:N, M)])) = 128.11025828312594
-(L̄(X[:, rand(1:N, M)])) = 133.40850346262152
-(L̄(X[:, rand(1:N, M)])) = 133.32116635936697
-(L̄(X[:, rand(1:N, M)])) = 124.28105485657272
-(L̄(X[:, rand(1:N, M)])) = 130.87517488548463
-(L̄(X[:, rand(1:N, M)])) = 125.89372276239594
-(L̄(X[:, rand(1:N, M)])) = 124.293003087255
-(L̄(X[:, rand(1:N, M)])) = 136.1065192523239
-(L̄(X[:, rand(1:N, M)])) = 117.82473142352879
-(L̄(X[:, rand(1:N, M)])) = 129.1981870449076
-(L̄(X[:, rand(1:N, M)])) = 128.93357918492367
-(L̄(X[:, rand(1:N, M)])) = 132.97258756643896
-(L̄(X[:, rand(1:N, M)])) = 134.16351897969582
-(L̄(X[:, rand(1:N, M)])) = 127.27849901463647
-(L̄(X[:, rand(1:N

┌ Info: Epoch 10
└ @ Main /home/yuehhua/.julia/packages/Flux/2i5P1/src/optimise/train.jl:99


-(L̄(X[:, rand(1:N, M)])) = 129.20964146867274
-(L̄(X[:, rand(1:N, M)])) = 132.90283011545074
-(L̄(X[:, rand(1:N, M)])) = 128.9843533844208
-(L̄(X[:, rand(1:N, M)])) = 126.99632415814597
-(L̄(X[:, rand(1:N, M)])) = 126.96973657689297
-(L̄(X[:, rand(1:N, M)])) = 129.0894753582221
-(L̄(X[:, rand(1:N, M)])) = 132.41385536695032
-(L̄(X[:, rand(1:N, M)])) = 130.1043488807083
-(L̄(X[:, rand(1:N, M)])) = 129.83812437088037
-(L̄(X[:, rand(1:N, M)])) = 130.60964084425422
-(L̄(X[:, rand(1:N, M)])) = 124.15753863690945
-(L̄(X[:, rand(1:N, M)])) = 126.28980945958959
-(L̄(X[:, rand(1:N, M)])) = 127.93809608478801
-(L̄(X[:, rand(1:N, M)])) = 132.6159145454623
-(L̄(X[:, rand(1:N, M)])) = 126.59785462932248
-(L̄(X[:, rand(1:N, M)])) = 125.81240481052157
-(L̄(X[:, rand(1:N, M)])) = 128.38346555866875
-(L̄(X[:, rand(1:N, M)])) = 123.04236860123191
-(L̄(X[:, rand(1:N, M)])) = 137.06697659413277


# Sample Output

In [9]:
using Images

img(x) = Gray.(reshape(x, 28, 28))

cd(@__DIR__)
sample = hcat(img.([modelsample() for i = 1:10])...)
save("vae-sample.png", sample)

