# $\beta$-variational autoencoder

In [1]:
using Flux, Flux.Data.MNIST
using Flux: @epochs, throttle, params
using Distributions
import Distributions: logpdf



**For numerically stable, extend distributions slightly to have a logpdf for `p` close to 1 or 0.**

In [2]:
logpdf(b::Bernoulli, y::Bool) = y * log(b.p + eps()) + (1 - y) * log(1 - b.p + eps())

logpdf (generic function with 62 methods)

# Load data

In [3]:
X = MNIST.images();

## Binarise

In [4]:
X = float.(hcat(vec.(X)...)) .> 0.5

784×60000 BitArray{2}:
 false  false  false  false  false  …  false  false  false  false  false
 false  false  false  false  false     false  false  false  false  false
 false  false  false  false  false     false  false  false  false  false
 false  false  false  false  false     false  false  false  false  false
 false  false  false  false  false     false  false  false  false  false
 false  false  false  false  false  …  false  false  false  false  false
 false  false  false  false  false     false  false  false  false  false
 false  false  false  false  false     false  false  false  false  false
 false  false  false  false  false     false  false  false  false  false
 false  false  false  false  false     false  false  false  false  false
 false  false  false  false  false  …  false  false  false  false  false
 false  false  false  false  false     false  false  false  false  false
 false  false  false  false  false     false  false  false  false  false
     ⋮                      

## Mini-batches

In [5]:
N, M = size(X, 2), 100
data = [X[:,i] for i in Iterators.partition(1:N,M)]

600-element Array{BitArray{2},1}:
 [false false … false false; false false … false false; … ; false false … false false; false false … false false]
 [false false … false false; false false … false false; … ; false false … false false; false false … false false]
 [false false … false false; false false … false false; … ; false false … false false; false false … false false]
 [false false … false false; false false … false false; … ; false false … false false; false false … false false]
 [false false … false false; false false … false false; … ; false false … false false; false false … false false]
 [false false … false false; false false … false false; … ; false false … false false; false false … false false]
 [false false … false false; false false … false false; … ; false false … false false; false false … false false]
 [false false … false false; false false … false false; … ; false false … false false; false false … false false]
 [false false … false false; false false … false false

# Define Model

In [6]:
# Dimensions (D -> Dh -> Dz -> Dh -> D)
Dz, Dh, D = 5, 500, 28^2
β = 1.5

# Encoder
A, μ, logσ = Dense(D, Dh, tanh), Dense(Dh, Dz), Dense(Dh, Dz)
g(X) = (h = A(X); (μ(h), logσ(h)))
z(μ, logσ) = μ + exp(logσ) * randn()

# Decoder
f = Chain(Dense(Dz, Dh, tanh), Dense(Dh, D, σ))

Chain(Dense(5, 500, tanh), Dense(500, 784, NNlib.σ))

# Loss function and ELBO

In [7]:
# KL-divergence between approximation posterior and N(0, 1) prior.
kl_q_p(μ, logσ) = 0.5 * sum(exp.(2 .* logσ) + μ.^2 .- 1 .+ logσ.^2)

# logp(x|z) - conditional probability of data given latents.
logp_x_z(x, z) = sum(logpdf.(Bernoulli.(f(z)), x))

# Monte Carlo estimator of mean ELBO using M samples.
L̄(X) = ((μ̂, logσ̂) = g(X); (logp_x_z(X, z.(μ̂, logσ̂)) - β * kl_q_p(μ̂, logσ̂)) / M)

loss(X) = -L̄(X) + 0.01 * sum(x->sum(x.^2), params(f))

# Sample from the learned model.
modelsample() = rand.(Bernoulli.(f(z.(zeros(Dz), zeros(Dz)))))

modelsample (generic function with 1 method)

# Learning

In [8]:
evalcb = throttle(() -> @show(-L̄(X[:, rand(1:N, M)])), 30)
opt = ADAM(params(A, μ, logσ, f))
@epochs 10 Flux.train!(loss, zip(data), opt, cb=evalcb)

┌ Info: Epoch 1
└ @ Main /home/pika/.julia/packages/Flux/rcN9D/src/optimise/train.jl:93


-(L̄(X[:, rand(1:N, M)])) = 543.4741079735873 (tracked)
-(L̄(X[:, rand(1:N, M)])) = 240.77436031284057 (tracked)
-(L̄(X[:, rand(1:N, M)])) = 214.55491107573835 (tracked)
-(L̄(X[:, rand(1:N, M)])) = 202.6194659295882 (tracked)
-(L̄(X[:, rand(1:N, M)])) = 196.061716175536 (tracked)
-(L̄(X[:, rand(1:N, M)])) = 229.71551460594156 (tracked)
-(L̄(X[:, rand(1:N, M)])) = 190.74841643774397 (tracked)
-(L̄(X[:, rand(1:N, M)])) = 182.52543165371193 (tracked)
-(L̄(X[:, rand(1:N, M)])) = 178.50423430269012 (tracked)
-(L̄(X[:, rand(1:N, M)])) = 182.71750069258906 (tracked)
-(L̄(X[:, rand(1:N, M)])) = 186.8326452119993 (tracked)
-(L̄(X[:, rand(1:N, M)])) = 176.69945552584286 (tracked)
-(L̄(X[:, rand(1:N, M)])) = 176.52656385132178 (tracked)
-(L̄(X[:, rand(1:N, M)])) = 178.5466491467675 (tracked)
-(L̄(X[:, rand(1:N, M)])) = 174.33262594283798 (tracked)
-(L̄(X[:, rand(1:N, M)])) = 182.798034980013 (tracked)
-(L̄(X[:, rand(1:N, M)])) = 172.04000194800375 (tracked)
-(L̄(X[:, rand(1:N, M)])) = 174.8403922

┌ Info: Epoch 2
└ @ Main /home/pika/.julia/packages/Flux/rcN9D/src/optimise/train.jl:93


-(L̄(X[:, rand(1:N, M)])) = 176.7694080113036 (tracked)
-(L̄(X[:, rand(1:N, M)])) = 163.58063721559668 (tracked)
-(L̄(X[:, rand(1:N, M)])) = 178.51966218664376 (tracked)
-(L̄(X[:, rand(1:N, M)])) = 171.40585714926885 (tracked)
-(L̄(X[:, rand(1:N, M)])) = 170.13388389106186 (tracked)
-(L̄(X[:, rand(1:N, M)])) = 163.19634013803807 (tracked)
-(L̄(X[:, rand(1:N, M)])) = 172.48401042392106 (tracked)
-(L̄(X[:, rand(1:N, M)])) = 170.60879724307378 (tracked)
-(L̄(X[:, rand(1:N, M)])) = 169.51257890430148 (tracked)
-(L̄(X[:, rand(1:N, M)])) = 168.75875447803827 (tracked)
-(L̄(X[:, rand(1:N, M)])) = 167.13314345343005 (tracked)
-(L̄(X[:, rand(1:N, M)])) = 167.88887663319517 (tracked)
-(L̄(X[:, rand(1:N, M)])) = 167.29259625839924 (tracked)
-(L̄(X[:, rand(1:N, M)])) = 159.64977281491463 (tracked)
-(L̄(X[:, rand(1:N, M)])) = 165.939097913098 (tracked)
-(L̄(X[:, rand(1:N, M)])) = 163.31506943301912 (tracked)
-(L̄(X[:, rand(1:N, M)])) = 156.61848813372083 (tracked)
-(L̄(X[:, rand(1:N, M)])) = 153.81

┌ Info: Epoch 3
└ @ Main /home/pika/.julia/packages/Flux/rcN9D/src/optimise/train.jl:93


-(L̄(X[:, rand(1:N, M)])) = 161.09267978767036 (tracked)
-(L̄(X[:, rand(1:N, M)])) = 152.08125992721168 (tracked)
-(L̄(X[:, rand(1:N, M)])) = 159.77174894245704 (tracked)
-(L̄(X[:, rand(1:N, M)])) = 153.62701131948194 (tracked)
-(L̄(X[:, rand(1:N, M)])) = 160.52247998460555 (tracked)
-(L̄(X[:, rand(1:N, M)])) = 157.66476805463665 (tracked)
-(L̄(X[:, rand(1:N, M)])) = 163.35711716697972 (tracked)
-(L̄(X[:, rand(1:N, M)])) = 162.9161286389841 (tracked)
-(L̄(X[:, rand(1:N, M)])) = 160.8087336119899 (tracked)
-(L̄(X[:, rand(1:N, M)])) = 162.7831659362523 (tracked)
-(L̄(X[:, rand(1:N, M)])) = 158.56166254924574 (tracked)
-(L̄(X[:, rand(1:N, M)])) = 154.35628903260533 (tracked)
-(L̄(X[:, rand(1:N, M)])) = 164.78654371710377 (tracked)
-(L̄(X[:, rand(1:N, M)])) = 151.18449950755928 (tracked)
-(L̄(X[:, rand(1:N, M)])) = 161.45906718603118 (tracked)
-(L̄(X[:, rand(1:N, M)])) = 157.00063626283873 (tracked)
-(L̄(X[:, rand(1:N, M)])) = 155.6758089627048 (tracked)
-(L̄(X[:, rand(1:N, M)])) = 146.426

┌ Info: Epoch 4
└ @ Main /home/pika/.julia/packages/Flux/rcN9D/src/optimise/train.jl:93


-(L̄(X[:, rand(1:N, M)])) = 149.95255968988707 (tracked)
-(L̄(X[:, rand(1:N, M)])) = 151.4364823374859 (tracked)
-(L̄(X[:, rand(1:N, M)])) = 157.60619066195696 (tracked)
-(L̄(X[:, rand(1:N, M)])) = 153.8723283826835 (tracked)
-(L̄(X[:, rand(1:N, M)])) = 155.04766869792473 (tracked)
-(L̄(X[:, rand(1:N, M)])) = 157.9817558567911 (tracked)
-(L̄(X[:, rand(1:N, M)])) = 155.83468796389738 (tracked)
-(L̄(X[:, rand(1:N, M)])) = 150.25107150850647 (tracked)
-(L̄(X[:, rand(1:N, M)])) = 151.20531800119207 (tracked)
-(L̄(X[:, rand(1:N, M)])) = 151.30437006258734 (tracked)
-(L̄(X[:, rand(1:N, M)])) = 146.03711623705402 (tracked)
-(L̄(X[:, rand(1:N, M)])) = 160.52275485397394 (tracked)
-(L̄(X[:, rand(1:N, M)])) = 158.16557467679297 (tracked)
-(L̄(X[:, rand(1:N, M)])) = 155.2483544970169 (tracked)
-(L̄(X[:, rand(1:N, M)])) = 153.878413088408 (tracked)
-(L̄(X[:, rand(1:N, M)])) = 153.14098868843067 (tracked)
-(L̄(X[:, rand(1:N, M)])) = 143.96702156600406 (tracked)
-(L̄(X[:, rand(1:N, M)])) = 158.20997

┌ Info: Epoch 5
└ @ Main /home/pika/.julia/packages/Flux/rcN9D/src/optimise/train.jl:93


-(L̄(X[:, rand(1:N, M)])) = 148.37397200927887 (tracked)
-(L̄(X[:, rand(1:N, M)])) = 153.94518502696067 (tracked)
-(L̄(X[:, rand(1:N, M)])) = 149.77077032238122 (tracked)
-(L̄(X[:, rand(1:N, M)])) = 151.06459035170104 (tracked)
-(L̄(X[:, rand(1:N, M)])) = 152.01470811981812 (tracked)
-(L̄(X[:, rand(1:N, M)])) = 143.10409909813978 (tracked)
-(L̄(X[:, rand(1:N, M)])) = 151.14902791476433 (tracked)
-(L̄(X[:, rand(1:N, M)])) = 145.92168819153693 (tracked)
-(L̄(X[:, rand(1:N, M)])) = 151.1257209156018 (tracked)
-(L̄(X[:, rand(1:N, M)])) = 155.18631440696055 (tracked)
-(L̄(X[:, rand(1:N, M)])) = 148.9823239459329 (tracked)
-(L̄(X[:, rand(1:N, M)])) = 154.90394443267053 (tracked)
-(L̄(X[:, rand(1:N, M)])) = 147.19398818731966 (tracked)
-(L̄(X[:, rand(1:N, M)])) = 150.13252773829743 (tracked)
-(L̄(X[:, rand(1:N, M)])) = 155.07548314383388 (tracked)
-(L̄(X[:, rand(1:N, M)])) = 137.3824091669641 (tracked)
-(L̄(X[:, rand(1:N, M)])) = 150.9803126706165 (tracked)
-(L̄(X[:, rand(1:N, M)])) = 153.554

┌ Info: Epoch 6
└ @ Main /home/pika/.julia/packages/Flux/rcN9D/src/optimise/train.jl:93


-(L̄(X[:, rand(1:N, M)])) = 149.14822855222482 (tracked)
-(L̄(X[:, rand(1:N, M)])) = 150.2908843766514 (tracked)
-(L̄(X[:, rand(1:N, M)])) = 142.34307568393243 (tracked)
-(L̄(X[:, rand(1:N, M)])) = 155.90217932080563 (tracked)
-(L̄(X[:, rand(1:N, M)])) = 153.32342775878413 (tracked)
-(L̄(X[:, rand(1:N, M)])) = 145.67369976270925 (tracked)
-(L̄(X[:, rand(1:N, M)])) = 154.37190423537285 (tracked)
-(L̄(X[:, rand(1:N, M)])) = 145.08667806127238 (tracked)
-(L̄(X[:, rand(1:N, M)])) = 144.82554396232786 (tracked)
-(L̄(X[:, rand(1:N, M)])) = 150.1980765930353 (tracked)
-(L̄(X[:, rand(1:N, M)])) = 149.30187899428412 (tracked)
-(L̄(X[:, rand(1:N, M)])) = 145.14639419716897 (tracked)
-(L̄(X[:, rand(1:N, M)])) = 153.97729106679836 (tracked)
-(L̄(X[:, rand(1:N, M)])) = 152.63574428892716 (tracked)
-(L̄(X[:, rand(1:N, M)])) = 141.55000648203554 (tracked)
-(L̄(X[:, rand(1:N, M)])) = 154.76577390255719 (tracked)
-(L̄(X[:, rand(1:N, M)])) = 155.2906961504228 (tracked)
-(L̄(X[:, rand(1:N, M)])) = 147.95

┌ Info: Epoch 7
└ @ Main /home/pika/.julia/packages/Flux/rcN9D/src/optimise/train.jl:93


-(L̄(X[:, rand(1:N, M)])) = 148.01814564221402 (tracked)
-(L̄(X[:, rand(1:N, M)])) = 144.79265698143644 (tracked)
-(L̄(X[:, rand(1:N, M)])) = 163.0880628639331 (tracked)
-(L̄(X[:, rand(1:N, M)])) = 142.94253328528745 (tracked)
-(L̄(X[:, rand(1:N, M)])) = 152.75367991155096 (tracked)
-(L̄(X[:, rand(1:N, M)])) = 145.61398402269947 (tracked)
-(L̄(X[:, rand(1:N, M)])) = 145.95199194974074 (tracked)
-(L̄(X[:, rand(1:N, M)])) = 140.79064909178007 (tracked)
-(L̄(X[:, rand(1:N, M)])) = 155.47965718125667 (tracked)
-(L̄(X[:, rand(1:N, M)])) = 148.87593292571233 (tracked)
-(L̄(X[:, rand(1:N, M)])) = 145.4982064315091 (tracked)
-(L̄(X[:, rand(1:N, M)])) = 144.68238414638628 (tracked)
-(L̄(X[:, rand(1:N, M)])) = 156.3014173827816 (tracked)
-(L̄(X[:, rand(1:N, M)])) = 153.01518464845148 (tracked)
-(L̄(X[:, rand(1:N, M)])) = 153.49789434140573 (tracked)
-(L̄(X[:, rand(1:N, M)])) = 155.3980577320025 (tracked)
-(L̄(X[:, rand(1:N, M)])) = 147.17813916096048 (tracked)
-(L̄(X[:, rand(1:N, M)])) = 152.796

┌ Info: Epoch 8
└ @ Main /home/pika/.julia/packages/Flux/rcN9D/src/optimise/train.jl:93


-(L̄(X[:, rand(1:N, M)])) = 145.69178738348367 (tracked)
-(L̄(X[:, rand(1:N, M)])) = 145.03059457378646 (tracked)
-(L̄(X[:, rand(1:N, M)])) = 153.68878188256065 (tracked)
-(L̄(X[:, rand(1:N, M)])) = 149.70149028870503 (tracked)
-(L̄(X[:, rand(1:N, M)])) = 151.63543402318442 (tracked)
-(L̄(X[:, rand(1:N, M)])) = 154.08904601822772 (tracked)
-(L̄(X[:, rand(1:N, M)])) = 152.38501227682318 (tracked)
-(L̄(X[:, rand(1:N, M)])) = 144.64228973274848 (tracked)
-(L̄(X[:, rand(1:N, M)])) = 143.75462604496522 (tracked)
-(L̄(X[:, rand(1:N, M)])) = 150.3079787636872 (tracked)
-(L̄(X[:, rand(1:N, M)])) = 136.8346077583917 (tracked)
-(L̄(X[:, rand(1:N, M)])) = 138.6232755949978 (tracked)
-(L̄(X[:, rand(1:N, M)])) = 150.9411427728309 (tracked)
-(L̄(X[:, rand(1:N, M)])) = 157.14455847874189 (tracked)
-(L̄(X[:, rand(1:N, M)])) = 139.79865725888712 (tracked)
-(L̄(X[:, rand(1:N, M)])) = 144.63669022121633 (tracked)
-(L̄(X[:, rand(1:N, M)])) = 152.58803841932263 (tracked)
-(L̄(X[:, rand(1:N, M)])) = 143.999

┌ Info: Epoch 9
└ @ Main /home/pika/.julia/packages/Flux/rcN9D/src/optimise/train.jl:93


-(L̄(X[:, rand(1:N, M)])) = 154.67025227613965 (tracked)
-(L̄(X[:, rand(1:N, M)])) = 140.05794553858078 (tracked)
-(L̄(X[:, rand(1:N, M)])) = 151.84244307685066 (tracked)
-(L̄(X[:, rand(1:N, M)])) = 146.5560710661267 (tracked)
-(L̄(X[:, rand(1:N, M)])) = 144.33194537413357 (tracked)
-(L̄(X[:, rand(1:N, M)])) = 146.18949128316967 (tracked)
-(L̄(X[:, rand(1:N, M)])) = 145.44932236031562 (tracked)
-(L̄(X[:, rand(1:N, M)])) = 142.44041401622647 (tracked)
-(L̄(X[:, rand(1:N, M)])) = 145.2582109542259 (tracked)
-(L̄(X[:, rand(1:N, M)])) = 147.52041034419204 (tracked)
-(L̄(X[:, rand(1:N, M)])) = 143.495032425886 (tracked)
-(L̄(X[:, rand(1:N, M)])) = 149.47049631054148 (tracked)
-(L̄(X[:, rand(1:N, M)])) = 148.48016334982924 (tracked)
-(L̄(X[:, rand(1:N, M)])) = 146.43666520401402 (tracked)
-(L̄(X[:, rand(1:N, M)])) = 149.18736640463985 (tracked)
-(L̄(X[:, rand(1:N, M)])) = 151.96899008231364 (tracked)
-(L̄(X[:, rand(1:N, M)])) = 147.09192625900977 (tracked)
-(L̄(X[:, rand(1:N, M)])) = 142.320

┌ Info: Epoch 10
└ @ Main /home/pika/.julia/packages/Flux/rcN9D/src/optimise/train.jl:93


-(L̄(X[:, rand(1:N, M)])) = 146.01713837483206 (tracked)
-(L̄(X[:, rand(1:N, M)])) = 149.4676124294938 (tracked)
-(L̄(X[:, rand(1:N, M)])) = 139.2853205014667 (tracked)
-(L̄(X[:, rand(1:N, M)])) = 138.43811662813982 (tracked)
-(L̄(X[:, rand(1:N, M)])) = 143.5935785430124 (tracked)
-(L̄(X[:, rand(1:N, M)])) = 152.00983110291057 (tracked)
-(L̄(X[:, rand(1:N, M)])) = 150.20575321995221 (tracked)
-(L̄(X[:, rand(1:N, M)])) = 152.40143788468592 (tracked)
-(L̄(X[:, rand(1:N, M)])) = 146.60731551609481 (tracked)
-(L̄(X[:, rand(1:N, M)])) = 148.28766532721644 (tracked)
-(L̄(X[:, rand(1:N, M)])) = 155.15046255220884 (tracked)
-(L̄(X[:, rand(1:N, M)])) = 147.79844296798484 (tracked)
-(L̄(X[:, rand(1:N, M)])) = 149.24441069570747 (tracked)
-(L̄(X[:, rand(1:N, M)])) = 145.35524750719068 (tracked)
-(L̄(X[:, rand(1:N, M)])) = 153.46057624228789 (tracked)


# Sample Output

In [9]:
using Images

img(x) = Gray.(reshape(x, 28, 28))

cd(@__DIR__)
sample = hcat(img.([modelsample() for i = 1:10])...)
save("beta-vae-sample.png", sample)