# Autoencoder

In [1]:
using Flux, Flux.Data.MNIST
using Flux: @epochs, onehotbatch, mse, throttle
using Base.Iterators: partition
using Juno: @progress
# using CuArrays

## Load data

In [2]:
# Encode MNIST images as compressed vectors that can later be decoded back into images.
imgs = MNIST.images();

In [3]:
# Partition into batches of size 1000
data = [float(hcat(vec.(imgs)...)) for imgs in partition(imgs, 1000)];

## Model

In [4]:
N = 32 # Size of the encoding

encoder = Dense(28^2, N, relu)
decoder = Dense(N, 28^2, relu)

m = Chain(encoder, decoder)

Chain(Dense(784, 32, NNlib.relu), Dense(32, 784, NNlib.relu))

## Loss function

In [5]:
loss(x) = mse(m(x), x)

loss (generic function with 1 method)

## Optimizer

Evaluation callback function

In [6]:
evalcb = throttle(() -> @show(loss(data[1])), 5)
opt = ADAM(params(m))

#43 (generic function with 1 method)

## Training

In [7]:
@epochs 10 Flux.train!(loss, zip(data), opt, cb=evalcb)

┌ Info: Epoch 1
└ @ Main /home/pika/.julia/packages/Flux/rcN9D/src/optimise/train.jl:93


loss(data[1]) = 0.1029851723607906 (tracked)
loss(data[1]) = 0.07671148831222126 (tracked)
loss(data[1]) = 0.06864255985283794 (tracked)
loss(data[1]) = 0.06451378182325296 (tracked)


┌ Info: Epoch 2
└ @ Main /home/pika/.julia/packages/Flux/rcN9D/src/optimise/train.jl:93


loss(data[1]) = 0.05923904636871739 (tracked)
loss(data[1]) = 0.05442764513338547 (tracked)
loss(data[1]) = 0.048660546412125805 (tracked)


┌ Info: Epoch 3
└ @ Main /home/pika/.julia/packages/Flux/rcN9D/src/optimise/train.jl:93


loss(data[1]) = 0.04451510359687858 (tracked)
loss(data[1]) = 0.041220726858272626 (tracked)
loss(data[1]) = 0.03874185213907743 (tracked)


┌ Info: Epoch 4
└ @ Main /home/pika/.julia/packages/Flux/rcN9D/src/optimise/train.jl:93


loss(data[1]) = 0.036448726315656176 (tracked)
loss(data[1]) = 0.034762413928990314 (tracked)
loss(data[1]) = 0.03348931665422283 (tracked)


┌ Info: Epoch 5
└ @ Main /home/pika/.julia/packages/Flux/rcN9D/src/optimise/train.jl:93


loss(data[1]) = 0.03241610773942246 (tracked)
loss(data[1]) = 0.031310457390174976 (tracked)
loss(data[1]) = 0.030366144671311643 (tracked)


┌ Info: Epoch 6
└ @ Main /home/pika/.julia/packages/Flux/rcN9D/src/optimise/train.jl:93


loss(data[1]) = 0.02951633174461925 (tracked)
loss(data[1]) = 0.02886390754699453 (tracked)
loss(data[1]) = 0.028262986230630293 (tracked)


┌ Info: Epoch 7
└ @ Main /home/pika/.julia/packages/Flux/rcN9D/src/optimise/train.jl:93


loss(data[1]) = 0.027580934880922923 (tracked)
loss(data[1]) = 0.027064552071918487 (tracked)
loss(data[1]) = 0.026557285757991097 (tracked)
loss(data[1]) = 0.026139458351797498 (tracked)


┌ Info: Epoch 8
└ @ Main /home/pika/.julia/packages/Flux/rcN9D/src/optimise/train.jl:93


loss(data[1]) = 0.025481566420839867 (tracked)
loss(data[1]) = 0.02506963154784914 (tracked)
loss(data[1]) = 0.02467425861158518 (tracked)


┌ Info: Epoch 9
└ @ Main /home/pika/.julia/packages/Flux/rcN9D/src/optimise/train.jl:93


loss(data[1]) = 0.024213306137517605 (tracked)
loss(data[1]) = 0.023804294917928518 (tracked)
loss(data[1]) = 0.023521127123377502 (tracked)


┌ Info: Epoch 10
└ @ Main /home/pika/.julia/packages/Flux/rcN9D/src/optimise/train.jl:93


loss(data[1]) = 0.023190394737163115 (tracked)
loss(data[1]) = 0.022996516634977832 (tracked)
loss(data[1]) = 0.022771939417715152 (tracked)


## Sample output

In [8]:
using Images

In [9]:
img(x::Vector) = Gray.(reshape(clamp.(x, 0, 1), 28, 28))

function sample()
  # 20 random digits
  before = [imgs[i] for i in rand(1:length(imgs), 20)]
  # Before and after images
  after = img.(map(x -> m(float(vec(x))).data, before))
  # Stack them all together
  hcat(vcat.(before, after)...)
end

sample (generic function with 1 method)

In [10]:
save("sample.jpg", sample())