# $\beta$-variational autoencoder

In [1]:
using Flux, Flux.Data.MNIST
using Flux: @epochs, throttle, params
using Distributions
import Distributions: logpdf
# using CuArrays



**For numerically stable, extend distributions slightly to have a logpdf for `p` close to 1 or 0.**

In [2]:
logpdf(b::Bernoulli, y::Bool) = y * log(b.p + eps()) + (1 - y) * log(1 - b.p + eps())

logpdf (generic function with 64 methods)

# Load data

In [3]:
X = MNIST.images();

## Binarise

In [4]:
X = float.(hcat(vec.(X)...)) .> 0.5

784×60000 BitArray{2}:
 0  0  0  0  0  0  0  0  0  0  0  0  0  …  0  0  0  0  0  0  0  0  0  0  0  0
 0  0  0  0  0  0  0  0  0  0  0  0  0     0  0  0  0  0  0  0  0  0  0  0  0
 0  0  0  0  0  0  0  0  0  0  0  0  0     0  0  0  0  0  0  0  0  0  0  0  0
 0  0  0  0  0  0  0  0  0  0  0  0  0     0  0  0  0  0  0  0  0  0  0  0  0
 0  0  0  0  0  0  0  0  0  0  0  0  0     0  0  0  0  0  0  0  0  0  0  0  0
 0  0  0  0  0  0  0  0  0  0  0  0  0  …  0  0  0  0  0  0  0  0  0  0  0  0
 0  0  0  0  0  0  0  0  0  0  0  0  0     0  0  0  0  0  0  0  0  0  0  0  0
 0  0  0  0  0  0  0  0  0  0  0  0  0     0  0  0  0  0  0  0  0  0  0  0  0
 0  0  0  0  0  0  0  0  0  0  0  0  0     0  0  0  0  0  0  0  0  0  0  0  0
 0  0  0  0  0  0  0  0  0  0  0  0  0     0  0  0  0  0  0  0  0  0  0  0  0
 0  0  0  0  0  0  0  0  0  0  0  0  0  …  0  0  0  0  0  0  0  0  0  0  0  0
 0  0  0  0  0  0  0  0  0  0  0  0  0     0  0  0  0  0  0  0  0  0  0  0  0
 0  0  0  0  0  0  0  0  0  0  0  0  0   

## Mini-batches

In [5]:
N, M = size(X, 2), 100
data = [X[:,i] for i in Iterators.partition(1:N,M)]

600-element Array{BitArray{2},1}:
 [0 0 … 0 0; 0 0 … 0 0; … ; 0 0 … 0 0; 0 0 … 0 0]
 [0 0 … 0 0; 0 0 … 0 0; … ; 0 0 … 0 0; 0 0 … 0 0]
 [0 0 … 0 0; 0 0 … 0 0; … ; 0 0 … 0 0; 0 0 … 0 0]
 [0 0 … 0 0; 0 0 … 0 0; … ; 0 0 … 0 0; 0 0 … 0 0]
 [0 0 … 0 0; 0 0 … 0 0; … ; 0 0 … 0 0; 0 0 … 0 0]
 [0 0 … 0 0; 0 0 … 0 0; … ; 0 0 … 0 0; 0 0 … 0 0]
 [0 0 … 0 0; 0 0 … 0 0; … ; 0 0 … 0 0; 0 0 … 0 0]
 [0 0 … 0 0; 0 0 … 0 0; … ; 0 0 … 0 0; 0 0 … 0 0]
 [0 0 … 0 0; 0 0 … 0 0; … ; 0 0 … 0 0; 0 0 … 0 0]
 [0 0 … 0 0; 0 0 … 0 0; … ; 0 0 … 0 0; 0 0 … 0 0]
 [0 0 … 0 0; 0 0 … 0 0; … ; 0 0 … 0 0; 0 0 … 0 0]
 [0 0 … 0 0; 0 0 … 0 0; … ; 0 0 … 0 0; 0 0 … 0 0]
 [0 0 … 0 0; 0 0 … 0 0; … ; 0 0 … 0 0; 0 0 … 0 0]
 ⋮
 [0 0 … 0 0; 0 0 … 0 0; … ; 0 0 … 0 0; 0 0 … 0 0]
 [0 0 … 0 0; 0 0 … 0 0; … ; 0 0 … 0 0; 0 0 … 0 0]
 [0 0 … 0 0; 0 0 … 0 0; … ; 0 0 … 0 0; 0 0 … 0 0]
 [0 0 … 0 0; 0 0 … 0 0; … ; 0 0 … 0 0; 0 0 … 0 0]
 [0 0 … 0 0; 0 0 … 0 0; … ; 0 0 … 0 0; 0 0 … 0 0]
 [0 0 … 0 0; 0 0 … 0 0; … ; 0 0 … 0 0; 0 0 … 0 0]
 [0 0 … 0 0; 

# Define Model

In [6]:
# Dimensions (D -> Dh -> Dz -> Dh -> D)
Dz, Dh, D = 5, 500, 28^2
β = 2

# Encoder
A, μ, logσ = Dense(D, Dh, tanh), Dense(Dh, Dz), Dense(Dh, Dz)
g(X) = (h = A(X); (μ(h), logσ(h)))
z(μ, logσ) = μ + exp(logσ) * randn()

# Decoder
f = Chain(Dense(Dz, Dh, tanh), Dense(Dh, D, σ))

Chain(Dense(5, 500, tanh), Dense(500, 784, σ))

# Loss function and ELBO

In [7]:
# KL-divergence between approximation posterior and N(0, 1) prior.
kl_q_p(μ, logσ) = 0.5 * sum(exp.(2 .* logσ) + μ.^2 .- 1 .+ logσ.^2)

# logp(x|z) - conditional probability of data given latents.
logp_x_z(x, z) = sum(logpdf.(Bernoulli.(f(z)), x))

# Monte Carlo estimator of mean ELBO using M samples.
L̄(X) = ((μ̂, logσ̂) = g(X); (logp_x_z(X, z.(μ̂, logσ̂)) - β * kl_q_p(μ̂, logσ̂)) / M)

loss(X) = -L̄(X) + 0.01 * sum(x->sum(x.^2), params(f))

# Sample from the learned model.
modelsample() = rand.(Bernoulli.(f(z.(zeros(Dz), zeros(Dz)))))

modelsample (generic function with 1 method)

# Learning

In [8]:
evalcb = throttle(() -> @show(-L̄(X[:, rand(1:N, M)])), 30)
opt = ADAM()
@epochs 10 Flux.train!(loss, params(A, μ, logσ, f), zip(data), opt, cb=evalcb)

┌ Info: Epoch 1
└ @ Main /home/yuehhua/.julia/packages/Flux/Fj3bt/src/optimise/train.jl:121


-(L̄(X[:, rand(1:N, M)])) = 546.1610727177351
-(L̄(X[:, rand(1:N, M)])) = 213.3832641424489
-(L̄(X[:, rand(1:N, M)])) = 196.60931167495124
-(L̄(X[:, rand(1:N, M)])) = 185.77055199168817
-(L̄(X[:, rand(1:N, M)])) = 173.49973063496117
-(L̄(X[:, rand(1:N, M)])) = 174.96237527696044
-(L̄(X[:, rand(1:N, M)])) = 174.5052165376622
-(L̄(X[:, rand(1:N, M)])) = 167.9713418206042
-(L̄(X[:, rand(1:N, M)])) = 174.87622953899452
-(L̄(X[:, rand(1:N, M)])) = 174.53430368502603
-(L̄(X[:, rand(1:N, M)])) = 176.01295460856446
-(L̄(X[:, rand(1:N, M)])) = 172.9049593277661
-(L̄(X[:, rand(1:N, M)])) = 163.9593267978114
-(L̄(X[:, rand(1:N, M)])) = 173.23970672429977


┌ Info: Epoch 2
└ @ Main /home/yuehhua/.julia/packages/Flux/Fj3bt/src/optimise/train.jl:121


-(L̄(X[:, rand(1:N, M)])) = 165.78650407778107
-(L̄(X[:, rand(1:N, M)])) = 163.9335055994556
-(L̄(X[:, rand(1:N, M)])) = 164.1849055124217
-(L̄(X[:, rand(1:N, M)])) = 161.72780596811205
-(L̄(X[:, rand(1:N, M)])) = 171.323078284675
-(L̄(X[:, rand(1:N, M)])) = 172.3162659488449
-(L̄(X[:, rand(1:N, M)])) = 160.9627313530683
-(L̄(X[:, rand(1:N, M)])) = 153.7854391960724
-(L̄(X[:, rand(1:N, M)])) = 154.96869187359076
-(L̄(X[:, rand(1:N, M)])) = 160.2170826972586
-(L̄(X[:, rand(1:N, M)])) = 159.0839512498177
-(L̄(X[:, rand(1:N, M)])) = 161.67495475031308
-(L̄(X[:, rand(1:N, M)])) = 163.91301064078053


┌ Info: Epoch 3
└ @ Main /home/yuehhua/.julia/packages/Flux/Fj3bt/src/optimise/train.jl:121


-(L̄(X[:, rand(1:N, M)])) = 168.9295800931025
-(L̄(X[:, rand(1:N, M)])) = 157.7221454549756
-(L̄(X[:, rand(1:N, M)])) = 164.18522081315044
-(L̄(X[:, rand(1:N, M)])) = 164.05971570446528
-(L̄(X[:, rand(1:N, M)])) = 157.840468791587
-(L̄(X[:, rand(1:N, M)])) = 164.20323363372302
-(L̄(X[:, rand(1:N, M)])) = 156.41675220477026
-(L̄(X[:, rand(1:N, M)])) = 161.5022686672569
-(L̄(X[:, rand(1:N, M)])) = 153.92443874145218
-(L̄(X[:, rand(1:N, M)])) = 157.11920603257994
-(L̄(X[:, rand(1:N, M)])) = 161.07139583195269
-(L̄(X[:, rand(1:N, M)])) = 163.12706489529074
-(L̄(X[:, rand(1:N, M)])) = 158.81918276686224
-(L̄(X[:, rand(1:N, M)])) = 161.90397481768122


┌ Info: Epoch 4
└ @ Main /home/yuehhua/.julia/packages/Flux/Fj3bt/src/optimise/train.jl:121


-(L̄(X[:, rand(1:N, M)])) = 165.98247385499005
-(L̄(X[:, rand(1:N, M)])) = 156.279943163065
-(L̄(X[:, rand(1:N, M)])) = 156.85914938593552
-(L̄(X[:, rand(1:N, M)])) = 154.7305835481507
-(L̄(X[:, rand(1:N, M)])) = 158.00143051786603
-(L̄(X[:, rand(1:N, M)])) = 154.70340335662533
-(L̄(X[:, rand(1:N, M)])) = 159.51532314278492
-(L̄(X[:, rand(1:N, M)])) = 151.11829530457112
-(L̄(X[:, rand(1:N, M)])) = 162.71839517698172
-(L̄(X[:, rand(1:N, M)])) = 153.8136360458141
-(L̄(X[:, rand(1:N, M)])) = 153.11729608216064
-(L̄(X[:, rand(1:N, M)])) = 155.23612119027905
-(L̄(X[:, rand(1:N, M)])) = 157.53606306601344


┌ Info: Epoch 5
└ @ Main /home/yuehhua/.julia/packages/Flux/Fj3bt/src/optimise/train.jl:121


-(L̄(X[:, rand(1:N, M)])) = 156.80547246847175
-(L̄(X[:, rand(1:N, M)])) = 156.91851793742086
-(L̄(X[:, rand(1:N, M)])) = 148.78056381871914
-(L̄(X[:, rand(1:N, M)])) = 151.84390222071477
-(L̄(X[:, rand(1:N, M)])) = 159.0458740649267
-(L̄(X[:, rand(1:N, M)])) = 149.71656080146306
-(L̄(X[:, rand(1:N, M)])) = 152.23794635436002
-(L̄(X[:, rand(1:N, M)])) = 152.14742580005816
-(L̄(X[:, rand(1:N, M)])) = 146.662333108082
-(L̄(X[:, rand(1:N, M)])) = 153.4659875020513
-(L̄(X[:, rand(1:N, M)])) = 154.39522564742998
-(L̄(X[:, rand(1:N, M)])) = 153.39976142253425
-(L̄(X[:, rand(1:N, M)])) = 152.6236269705001
-(L̄(X[:, rand(1:N, M)])) = 149.51247770403253


┌ Info: Epoch 6
└ @ Main /home/yuehhua/.julia/packages/Flux/Fj3bt/src/optimise/train.jl:121


-(L̄(X[:, rand(1:N, M)])) = 151.05370492017602
-(L̄(X[:, rand(1:N, M)])) = 152.48933506067732
-(L̄(X[:, rand(1:N, M)])) = 140.2174289656961
-(L̄(X[:, rand(1:N, M)])) = 156.5434934375831
-(L̄(X[:, rand(1:N, M)])) = 154.93187766013273
-(L̄(X[:, rand(1:N, M)])) = 156.85095048087737
-(L̄(X[:, rand(1:N, M)])) = 151.89831862421448
-(L̄(X[:, rand(1:N, M)])) = 152.410165256123
-(L̄(X[:, rand(1:N, M)])) = 151.94719756765272
-(L̄(X[:, rand(1:N, M)])) = 141.94465971869545
-(L̄(X[:, rand(1:N, M)])) = 150.6372129760154
-(L̄(X[:, rand(1:N, M)])) = 152.89084535310315
-(L̄(X[:, rand(1:N, M)])) = 150.24827004629267
-(L̄(X[:, rand(1:N, M)])) = 151.28491821852194


┌ Info: Epoch 7
└ @ Main /home/yuehhua/.julia/packages/Flux/Fj3bt/src/optimise/train.jl:121


-(L̄(X[:, rand(1:N, M)])) = 145.46795515786496
-(L̄(X[:, rand(1:N, M)])) = 151.56326656021602
-(L̄(X[:, rand(1:N, M)])) = 148.79184585525107
-(L̄(X[:, rand(1:N, M)])) = 151.67230999809001
-(L̄(X[:, rand(1:N, M)])) = 153.07298623950498
-(L̄(X[:, rand(1:N, M)])) = 142.38280421937415
-(L̄(X[:, rand(1:N, M)])) = 156.7254122494012
-(L̄(X[:, rand(1:N, M)])) = 148.24253885081015
-(L̄(X[:, rand(1:N, M)])) = 149.35615488116366
-(L̄(X[:, rand(1:N, M)])) = 154.54883378566413
-(L̄(X[:, rand(1:N, M)])) = 150.4988632728684
-(L̄(X[:, rand(1:N, M)])) = 142.48343785836104
-(L̄(X[:, rand(1:N, M)])) = 146.25005476087208


┌ Info: Epoch 8
└ @ Main /home/yuehhua/.julia/packages/Flux/Fj3bt/src/optimise/train.jl:121


-(L̄(X[:, rand(1:N, M)])) = 147.08972480531548
-(L̄(X[:, rand(1:N, M)])) = 153.45731042063332
-(L̄(X[:, rand(1:N, M)])) = 148.48322583179345
-(L̄(X[:, rand(1:N, M)])) = 144.7362887682248
-(L̄(X[:, rand(1:N, M)])) = 150.72069298090693
-(L̄(X[:, rand(1:N, M)])) = 144.983706579597
-(L̄(X[:, rand(1:N, M)])) = 149.20849445469904
-(L̄(X[:, rand(1:N, M)])) = 149.32374689799937
-(L̄(X[:, rand(1:N, M)])) = 145.99695607193254
-(L̄(X[:, rand(1:N, M)])) = 144.13619332000528
-(L̄(X[:, rand(1:N, M)])) = 147.40730895505408
-(L̄(X[:, rand(1:N, M)])) = 149.37388026844854
-(L̄(X[:, rand(1:N, M)])) = 158.53454391885222


┌ Info: Epoch 9
└ @ Main /home/yuehhua/.julia/packages/Flux/Fj3bt/src/optimise/train.jl:121


-(L̄(X[:, rand(1:N, M)])) = 147.40043473483627
-(L̄(X[:, rand(1:N, M)])) = 150.68261635727544
-(L̄(X[:, rand(1:N, M)])) = 146.8343245335061
-(L̄(X[:, rand(1:N, M)])) = 149.00434675414027
-(L̄(X[:, rand(1:N, M)])) = 149.7725541630081
-(L̄(X[:, rand(1:N, M)])) = 148.6570610956676
-(L̄(X[:, rand(1:N, M)])) = 155.65189411756413
-(L̄(X[:, rand(1:N, M)])) = 141.78608376865708
-(L̄(X[:, rand(1:N, M)])) = 142.9771419542146
-(L̄(X[:, rand(1:N, M)])) = 151.3683356886959
-(L̄(X[:, rand(1:N, M)])) = 149.70029616439436
-(L̄(X[:, rand(1:N, M)])) = 149.96069420331082
-(L̄(X[:, rand(1:N, M)])) = 147.57021607845846
-(L̄(X[:, rand(1:N, M)])) = 150.91445356830653


┌ Info: Epoch 10
└ @ Main /home/yuehhua/.julia/packages/Flux/Fj3bt/src/optimise/train.jl:121


-(L̄(X[:, rand(1:N, M)])) = 150.95949209566325
-(L̄(X[:, rand(1:N, M)])) = 138.0281324324919
-(L̄(X[:, rand(1:N, M)])) = 145.61387397027374
-(L̄(X[:, rand(1:N, M)])) = 147.3387607585317
-(L̄(X[:, rand(1:N, M)])) = 149.62891295824573
-(L̄(X[:, rand(1:N, M)])) = 146.82606303685188
-(L̄(X[:, rand(1:N, M)])) = 145.31980195796962
-(L̄(X[:, rand(1:N, M)])) = 155.19789914285474
-(L̄(X[:, rand(1:N, M)])) = 150.86086697675074
-(L̄(X[:, rand(1:N, M)])) = 145.3388792208747
-(L̄(X[:, rand(1:N, M)])) = 147.74668473623015
-(L̄(X[:, rand(1:N, M)])) = 143.66070167468067
-(L̄(X[:, rand(1:N, M)])) = 146.1365329136349


# Sample Output

In [9]:
using Images

img(x) = Gray.(reshape(x, 28, 28))

cd(@__DIR__)
sample = hcat(img.([modelsample() for i = 1:10])...)
save("beta-vae-sample.png", sample)

