# Autoencoder

Package `Images` and `ImageMagick` are needed.

In [1]:
using Flux, Flux.Data.MNIST
using Flux: @epochs, onehotbatch, mse, throttle
using Base.Iterators: partition
# using CuArrays

## Load data

In [2]:
# Encode MNIST images as compressed vectors that can later be decoded back into images.
imgs = MNIST.images();

In [3]:
# Partition into batches of size 1000
data = [float(hcat(vec.(imgs)...)) for imgs in partition(imgs, 1000)];

## Model

In [4]:
N = 32 # Size of the encoding

encoder = Dense(28^2, N, relu)
decoder = Dense(N, 28^2, relu)

m = Chain(encoder, decoder)

Chain(Dense(784, 32, NNlib.relu), Dense(32, 784, NNlib.relu))

## Loss function

In [5]:
loss(x) = mse(m(x), x)

loss (generic function with 1 method)

## Optimizer

Evaluation callback function

In [6]:
evalcb() = @show(loss(data[1]))
opt = ADAM()

ADAM(0.001, (0.9, 0.999), IdDict{Any,Any}())

## Training

In [7]:
@epochs 10 Flux.train!(loss, params(m), zip(data), opt, cb=throttle(evalcb, 5))

┌ Info: Epoch 1
└ @ Main /home/pika/.julia/packages/Flux/8XpDt/src/optimise/train.jl:107


loss(data[1]) = 0.10269564859266456 (tracked)
loss(data[1]) = 0.0636520845937652 (tracked)
loss(data[1]) = 0.051288734981035804 (tracked)


┌ Info: Epoch 2
└ @ Main /home/pika/.julia/packages/Flux/8XpDt/src/optimise/train.jl:107


loss(data[1]) = 0.04618587347542145 (tracked)
loss(data[1]) = 0.0375779217002654 (tracked)
loss(data[1]) = 0.0319618017877095 (tracked)


┌ Info: Epoch 3
└ @ Main /home/pika/.julia/packages/Flux/8XpDt/src/optimise/train.jl:107


loss(data[1]) = 0.030435651066421494 (tracked)
loss(data[1]) = 0.0275808924795463 (tracked)
loss(data[1]) = 0.025250630856000815 (tracked)


┌ Info: Epoch 4
└ @ Main /home/pika/.julia/packages/Flux/8XpDt/src/optimise/train.jl:107


loss(data[1]) = 0.024122853721840528 (tracked)
loss(data[1]) = 0.02253089834765417 (tracked)
loss(data[1]) = 0.02149763743737937 (tracked)
loss(data[1]) = 0.020446787007062567 (tracked)


┌ Info: Epoch 5
└ @ Main /home/pika/.julia/packages/Flux/8XpDt/src/optimise/train.jl:107


loss(data[1]) = 0.020368844036079477 (tracked)
loss(data[1]) = 0.019593915794078143 (tracked)
loss(data[1]) = 0.01884469580167241 (tracked)


┌ Info: Epoch 6
└ @ Main /home/pika/.julia/packages/Flux/8XpDt/src/optimise/train.jl:107


loss(data[1]) = 0.01832676533268173 (tracked)
loss(data[1]) = 0.017828684007007926 (tracked)
loss(data[1]) = 0.017321188580069938 (tracked)


┌ Info: Epoch 7
└ @ Main /home/pika/.julia/packages/Flux/8XpDt/src/optimise/train.jl:107


loss(data[1]) = 0.016932931796469954 (tracked)
loss(data[1]) = 0.016527229204340157 (tracked)
loss(data[1]) = 0.016232760003952084 (tracked)


┌ Info: Epoch 8
└ @ Main /home/pika/.julia/packages/Flux/8XpDt/src/optimise/train.jl:107


loss(data[1]) = 0.015983789519647143 (tracked)
loss(data[1]) = 0.015790827113630477 (tracked)
loss(data[1]) = 0.015477510234026833 (tracked)


┌ Info: Epoch 9
└ @ Main /home/pika/.julia/packages/Flux/8XpDt/src/optimise/train.jl:107


loss(data[1]) = 0.015334620415942277 (tracked)
loss(data[1]) = 0.015171926964723428 (tracked)
loss(data[1]) = 0.014996345680425804 (tracked)


┌ Info: Epoch 10
└ @ Main /home/pika/.julia/packages/Flux/8XpDt/src/optimise/train.jl:107


loss(data[1]) = 0.014886054521911333 (tracked)
loss(data[1]) = 0.014780758925793391 (tracked)
loss(data[1]) = 0.014723165998320086 (tracked)


## Sample output

In [8]:
using Images

In [9]:
img(x::Vector) = Gray.(reshape(clamp.(x, 0, 1), 28, 28))

function sample()
  # 20 random digits
  before = [imgs[i] for i in rand(1:length(imgs), 20)]
  # Before and after images
  after = img.(map(x -> m(float(vec(x))).data, before))
  # Stack them all together
  hcat(vcat.(before, after)...)
end

sample (generic function with 1 method)

In [10]:
save("sample.jpg", sample())