# Autoencoder

Package `Images` and `ImageMagick` are needed.

In [1]:
using Flux, Flux.Data.MNIST
using Flux: @epochs, onehotbatch, mse, throttle
using Base.Iterators: partition

## Load data

In [2]:
# Encode MNIST images as compressed vectors that can later be decoded back into images.
imgs = MNIST.images();

In [3]:
# Partition into batches of size 1000
data = [float(hcat(vec.(imgs)...)) for imgs in partition(imgs, 1000)];

## Model

In [4]:
N = 32 # Size of the encoding

encoder = Dense(28^2, N, relu)
decoder = Dense(N, 28^2, relu)

m = Chain(encoder, decoder)

Chain(Dense(784, 32, relu), Dense(32, 784, relu))

## Loss function

In [5]:
loss(x) = mse(m(x), x)

loss (generic function with 1 method)

## Optimizer

Evaluation callback function

In [6]:
evalcb = throttle(() -> @show(loss(data[1])), 10)
opt = ADAM()

ADAM(0.001, (0.9, 0.999), IdDict{Any,Any}())

## Training

In [7]:
@epochs 10 Flux.train!(loss, params(m), zip(data), opt, cb=evalcb)

┌ Info: Epoch 1
└ @ Main /home/yuehhua/.julia/packages/Flux/2i5P1/src/optimise/train.jl:99


loss(data[1]) = 0.10198192405514445
loss(data[1]) = 0.09329526919395476
loss(data[1]) = 0.08480317839952525
loss(data[1]) = 0.07687969855579678
loss(data[1]) = 0.0704605695707707
loss(data[1]) = 0.06624697607165259
loss(data[1]) = 0.06339369960105612
loss(data[1]) = 0.061584227817640606
loss(data[1]) = 0.05991603237391989
loss(data[1]) = 0.05797558678584206
loss(data[1]) = 0.05589186258902266
loss(data[1]) = 0.05376184919410346
loss(data[1]) = 0.05181396076403669
loss(data[1]) = 0.049990004702938344
loss(data[1]) = 0.04822352548346777


┌ Info: Epoch 2
└ @ Main /home/yuehhua/.julia/packages/Flux/2i5P1/src/optimise/train.jl:99


loss(data[1]) = 0.046479092977078446
loss(data[1]) = 0.044722812414763474
loss(data[1]) = 0.04311381999873915
loss(data[1]) = 0.04168233012975945
loss(data[1]) = 0.040362363803873744
loss(data[1]) = 0.03915851543536165
loss(data[1]) = 0.03800651892741241
loss(data[1]) = 0.036910688408067674
loss(data[1]) = 0.0357922192472344
loss(data[1]) = 0.03473011137537648
loss(data[1]) = 0.03385541809221263
loss(data[1]) = 0.03297559844193548
loss(data[1]) = 0.03220470289981109
loss(data[1]) = 0.03156101867347058
loss(data[1]) = 0.03084506864345676


┌ Info: Epoch 3
└ @ Main /home/yuehhua/.julia/packages/Flux/2i5P1/src/optimise/train.jl:99


loss(data[1]) = 0.030061033262188088
loss(data[1]) = 0.029366746009795247
loss(data[1]) = 0.028845320008677327
loss(data[1]) = 0.028368300633456755
loss(data[1]) = 0.027779128019652995
loss(data[1]) = 0.027224191096114386
loss(data[1]) = 0.026875181001271987
loss(data[1]) = 0.026463312140465656
loss(data[1]) = 0.02599412082088445
loss(data[1]) = 0.025580061543330156
loss(data[1]) = 0.02527431734596715
loss(data[1]) = 0.024861847602127198
loss(data[1]) = 0.024504343886634517
loss(data[1]) = 0.02417441940055002
loss(data[1]) = 0.02384312598818848


┌ Info: Epoch 4
└ @ Main /home/yuehhua/.julia/packages/Flux/2i5P1/src/optimise/train.jl:99


loss(data[1]) = 0.023455921017546476
loss(data[1]) = 0.023139236129541562
loss(data[1]) = 0.022874373065838643
loss(data[1]) = 0.022597287273806596
loss(data[1]) = 0.02230199491772898
loss(data[1]) = 0.022042057524876542
loss(data[1]) = 0.021840955605840547
loss(data[1]) = 0.021577619810620802
loss(data[1]) = 0.021370233075295633
loss(data[1]) = 0.02113111286512837
loss(data[1]) = 0.02096089615727962
loss(data[1]) = 0.02075819816071669
loss(data[1]) = 0.020543427525220006
loss(data[1]) = 0.02037074114825104
loss(data[1]) = 0.020191940940464913


┌ Info: Epoch 5
└ @ Main /home/yuehhua/.julia/packages/Flux/2i5P1/src/optimise/train.jl:99


loss(data[1]) = 0.019941063314698932
loss(data[1]) = 0.019791061754766057
loss(data[1]) = 0.01960693027115865
loss(data[1]) = 0.019456742205425228
loss(data[1]) = 0.01928390982837343
loss(data[1]) = 0.019143734025688842
loss(data[1]) = 0.01903434071604811
loss(data[1]) = 0.01885409038331693
loss(data[1]) = 0.018746390636598737
loss(data[1]) = 0.018558651606137934
loss(data[1]) = 0.01843731679959365
loss(data[1]) = 0.018281980985416432
loss(data[1]) = 0.01813634684258975
loss(data[1]) = 0.01797165512807349
loss(data[1]) = 0.017857163244587845


┌ Info: Epoch 6
└ @ Main /home/yuehhua/.julia/packages/Flux/2i5P1/src/optimise/train.jl:99


loss(data[1]) = 0.017628806078106855
loss(data[1]) = 0.017483612051681228
loss(data[1]) = 0.017376730739975085
loss(data[1]) = 0.017253727668885523
loss(data[1]) = 0.01711098399057115
loss(data[1]) = 0.017004178488847416
loss(data[1]) = 0.01692958183644551
loss(data[1]) = 0.01678798489593995
loss(data[1]) = 0.01673292765351836
loss(data[1]) = 0.016609407384690394
loss(data[1]) = 0.016557257341416082
loss(data[1]) = 0.016457967808757083
loss(data[1]) = 0.016351236663412694
loss(data[1]) = 0.01626824584461809
loss(data[1]) = 0.01620034647043706


┌ Info: Epoch 7
└ @ Main /home/yuehhua/.julia/packages/Flux/2i5P1/src/optimise/train.jl:99


loss(data[1]) = 0.016064412377830444
loss(data[1]) = 0.015965978689774233
loss(data[1]) = 0.015929268235421963
loss(data[1]) = 0.015856010209101188
loss(data[1]) = 0.01576412885070873
loss(data[1]) = 0.01569620540871551
loss(data[1]) = 0.015674920541536547
loss(data[1]) = 0.01557576015949725
loss(data[1]) = 0.015551715051129386
loss(data[1]) = 0.015476205424077223
loss(data[1]) = 0.01545749269807157
loss(data[1]) = 0.01538973326481256
loss(data[1]) = 0.015312492255264817
loss(data[1]) = 0.015260923952296336
loss(data[1]) = 0.015214457427852355


┌ Info: Epoch 8
└ @ Main /home/yuehhua/.julia/packages/Flux/2i5P1/src/optimise/train.jl:99


loss(data[1]) = 0.015119306483456437
loss(data[1]) = 0.015039645283523587
loss(data[1]) = 0.015057447078583868
loss(data[1]) = 0.014984819298099679
loss(data[1]) = 0.014931646734518094
loss(data[1]) = 0.014881696174369312
loss(data[1]) = 0.014882715861900875
loss(data[1]) = 0.014811002204524951
loss(data[1]) = 0.014802739108632446
loss(data[1]) = 0.01475281949535716
loss(data[1]) = 0.014757662400246408
loss(data[1]) = 0.014702735850471793
loss(data[1]) = 0.014650846857443488
loss(data[1]) = 0.014607535358138225
loss(data[1]) = 0.014584196573505731


┌ Info: Epoch 9
└ @ Main /home/yuehhua/.julia/packages/Flux/2i5P1/src/optimise/train.jl:99


loss(data[1]) = 0.014512652148318003
loss(data[1]) = 0.014440865979096093
loss(data[1]) = 0.014488150961347392
loss(data[1]) = 0.014415174603554832
loss(data[1]) = 0.01439031698662705
loss(data[1]) = 0.01434740995870472
loss(data[1]) = 0.01435921406212487
loss(data[1]) = 0.014308646800560978
loss(data[1]) = 0.01430581533229613
loss(data[1]) = 0.01427598980245602
loss(data[1]) = 0.014293835496236522
loss(data[1]) = 0.014248941350340796
loss(data[1]) = 0.014211355461671604
loss(data[1]) = 0.01417088456822883
loss(data[1]) = 0.01416675423567998


┌ Info: Epoch 10
└ @ Main /home/yuehhua/.julia/packages/Flux/2i5P1/src/optimise/train.jl:99


loss(data[1]) = 0.014103417322901926
loss(data[1]) = 0.01403955391237272
loss(data[1]) = 0.014100175478855595
loss(data[1]) = 0.01402885703128089
loss(data[1]) = 0.014021365052729213
loss(data[1]) = 0.013982655182365432
loss(data[1]) = 0.014002185318177509
loss(data[1]) = 0.013962357598580416
loss(data[1]) = 0.013961210374713973
loss(data[1]) = 0.013944749260351871
loss(data[1]) = 0.013968432677243863
loss(data[1]) = 0.013931154294073936
loss(data[1]) = 0.01390171266384128
loss(data[1]) = 0.013864744783436881
loss(data[1]) = 0.013871326029007303


## Sample output

In [8]:
using Images



In [11]:
img(x::Vector) = Gray.(reshape(clamp.(x, 0, 1), 28, 28))

function sample()
  # 20 random digits
  before = [imgs[i] for i in rand(1:length(imgs), 20)]
  # Before and after images
  after = img.(map(x -> m(float(vec(x))), before))
  # Stack them all together
  hcat(vcat.(before, after)...)
end

sample (generic function with 1 method)

In [13]:
save("sample.jpg", sample())

┌ Info: Precompiling ImageMagick [6218d12a-5da1-5696-b52f-db25d2ecc6d1]
└ @ Base loading.jl:1273
