# Autoencoder

In [1]:
using Flux, Flux.Data.MNIST
using Flux: @epochs, onehotbatch, mse, throttle
using Base.Iterators: partition
using Juno: @progress
# using CuArrays

## Load data

In [2]:
# Encode MNIST images as compressed vectors that can later be decoded back into images.
imgs = MNIST.images();

In [3]:
# Partition into batches of size 1000
data = [float(hcat(vec.(imgs)...)) for imgs in partition(imgs, 1000)];

## Model

In [4]:
N = 32 # Size of the encoding

encoder = Dense(28^2, N, relu)
decoder = Dense(N, 28^2, relu)

m = Chain(encoder, decoder)

Chain(Dense(784, 32, NNlib.relu), Dense(32, 784, NNlib.relu))

## Loss function

In [5]:
loss(x) = mse(m(x), x)

loss (generic function with 1 method)

## Optimizer

Evaluation callback function

In [6]:
evalcb() = @show(loss(data[1]))
opt = ADAM()

ADAM(0.001, (0.9, 0.999), IdDict{Any,Any}())

## Training

In [7]:
@epochs 10 Flux.train!(loss, params(m), zip(data), opt, cb=throttle(evalcb, 5))

┌ Info: Epoch 1
└ @ Main /home/pika/.julia/packages/Flux/T3PhK/src/optimise/train.jl:103


loss(data[1]) = 0.10143132390843942 (tracked)
loss(data[1]) = 0.06487017723375788 (tracked)
loss(data[1]) = 0.05109868361424779 (tracked)


┌ Info: Epoch 2
└ @ Main /home/pika/.julia/packages/Flux/T3PhK/src/optimise/train.jl:103


loss(data[1]) = 0.045234982304229045 (tracked)
loss(data[1]) = 0.03674250857078072 (tracked)
loss(data[1]) = 0.031240617984398342 (tracked)


┌ Info: Epoch 3
└ @ Main /home/pika/.julia/packages/Flux/T3PhK/src/optimise/train.jl:103


loss(data[1]) = 0.029858290504550914 (tracked)
loss(data[1]) = 0.026602000670371932 (tracked)
loss(data[1]) = 0.024534749546358564 (tracked)


┌ Info: Epoch 4
└ @ Main /home/pika/.julia/packages/Flux/T3PhK/src/optimise/train.jl:103


loss(data[1]) = 0.02374443094912763 (tracked)
loss(data[1]) = 0.02225817183362645 (tracked)
loss(data[1]) = 0.020935166616510723 (tracked)


┌ Info: Epoch 5
└ @ Main /home/pika/.julia/packages/Flux/T3PhK/src/optimise/train.jl:103


loss(data[1]) = 0.020404263813002324 (tracked)
loss(data[1]) = 0.019693673784473335 (tracked)
loss(data[1]) = 0.019289624047367563 (tracked)
loss(data[1]) = 0.018807390849681346 (tracked)
loss(data[1]) = 0.018284567136729282 (tracked)


┌ Info: Epoch 6
└ @ Main /home/pika/.julia/packages/Flux/T3PhK/src/optimise/train.jl:103


loss(data[1]) = 0.018224655930689726 (tracked)
loss(data[1]) = 0.017784779507541322 (tracked)
loss(data[1]) = 0.017436541043695578 (tracked)
loss(data[1]) = 0.017045513268953608 (tracked)


┌ Info: Epoch 7
└ @ Main /home/pika/.julia/packages/Flux/T3PhK/src/optimise/train.jl:103


loss(data[1]) = 0.016755984918667647 (tracked)
loss(data[1]) = 0.016484850166958092 (tracked)
loss(data[1]) = 0.01630692474895048 (tracked)
loss(data[1]) = 0.016035606510307592 (tracked)
loss(data[1]) = 0.015866883494122026 (tracked)


┌ Info: Epoch 8
└ @ Main /home/pika/.julia/packages/Flux/T3PhK/src/optimise/train.jl:103


loss(data[1]) = 0.01578553885580495 (tracked)
loss(data[1]) = 0.015678063034902594 (tracked)
loss(data[1]) = 0.015551235833276241 (tracked)
loss(data[1]) = 0.015484794597874971 (tracked)
loss(data[1]) = 0.015381737458824592 (tracked)
loss(data[1]) = 0.01525126992440941 (tracked)


┌ Info: Epoch 9
└ @ Main /home/pika/.julia/packages/Flux/T3PhK/src/optimise/train.jl:103


loss(data[1]) = 0.015129474657067172 (tracked)
loss(data[1]) = 0.01500149688668255 (tracked)
loss(data[1]) = 0.014948678820193568 (tracked)
loss(data[1]) = 0.014930155856202363 (tracked)
loss(data[1]) = 0.014773730296627013 (tracked)


┌ Info: Epoch 10
└ @ Main /home/pika/.julia/packages/Flux/T3PhK/src/optimise/train.jl:103


loss(data[1]) = 0.014677877305216901 (tracked)
loss(data[1]) = 0.014593874336272829 (tracked)
loss(data[1]) = 0.014595970168856508 (tracked)
loss(data[1]) = 0.014503345691405847 (tracked)
loss(data[1]) = 0.014442656438763184 (tracked)


## Sample output

In [8]:
using Images

In [9]:
img(x::Vector) = Gray.(reshape(clamp.(x, 0, 1), 28, 28))

function sample()
  # 20 random digits
  before = [imgs[i] for i in rand(1:length(imgs), 20)]
  # Before and after images
  after = img.(map(x -> m(float(vec(x))).data, before))
  # Stack them all together
  hcat(vcat.(before, after)...)
end

sample (generic function with 1 method)

In [10]:
save("sample.jpg", sample())