# Autoencoder

In [1]:
using Flux, Flux.Data.MNIST
using Flux: @epochs, onehotbatch, argmax, mse, throttle
using Base.Iterators: partition
using Juno: @progress
# using CuArrays

## Load data

In [2]:
# Encode MNIST images as compressed vectors that can later be decoded back into images.
imgs = MNIST.images();

In [3]:
# Partition into batches of size 1000
data = [float(hcat(vec.(imgs)...)) for imgs in partition(imgs, 1000)]

60-element Array{Array{Float64,2},1}:
 [0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0; … ; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0]
 [0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0; … ; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0]
 [0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0; … ; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0]
 [0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0; … ; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0]
 [0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0; … ; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0]
 [0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0; … ; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0]
 [0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0; … ; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0]
 [0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0; … ; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0]
 [0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0; … ; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0]
 [0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0; … ; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0]
 [0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0; … ; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0]
 [0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0; … ; 0.0 0.0 … 0.0 0.

## Model

In [4]:
N = 32 # Size of the encoding

encoder = Dense(28^2, N, relu)
decoder = Dense(N, 28^2, relu)

m = Chain(encoder, decoder)

Chain(Dense(784, 32, NNlib.relu), Dense(32, 784, NNlib.relu))

## Loss function

In [5]:
loss(x) = mse(m(x), x)

loss (generic function with 1 method)

## Optimizer

Evaluation callback function

In [6]:
evalcb = throttle(() -> @show(loss(data[1])), 5)
opt = ADAM(params(m))

(::#80) (generic function with 1 method)

## Training

In [7]:
@epochs 10 Flux.train!(loss, zip(data), opt, cb=evalcb)

[1m[36mINFO: [39m[22m[36mEpoch 1
[39m

loss(data[1]) = 0.10301851420531254 (tracked)
loss(data[1]) = 0.058225334903295095 (tracked)
loss(data[1]) = 0.04941644979092086 (tracked)


[1m[36mINFO: [39m[22m[36mEpoch 2
[39m

loss(data[1]) = 0.04215141141027766 (tracked)
loss(data[1]) = 0.03749660426425393 (tracked)
loss(data[1]) = 0.03358730545090354 (tracked)


[1m[36mINFO: [39m[22m[36mEpoch 3
[39m

loss(data[1]) = 0.03024133924042138 (tracked)
loss(data[1]) = 0.028000318548681694 (tracked)
loss(data[1]) = 0.02595731576951134 (tracked)


[1m[36mINFO: [39m[22m[36mEpoch 4
[39m

loss(data[1]) = 0.024130078067049088 (tracked)
loss(data[1]) = 0.022682590687048528 (tracked)
loss(data[1]) = 0.021590732565155316 (tracked)
loss(data[1]) = 0.02049329715207205 (tracked)


[1m[36mINFO: [39m[22m[36mEpoch 5
[39m

loss(data[1]) = 0.01954677059541594 (tracked)
loss(data[1]) = 0.018843326639288994 (tracked)
loss(data[1]) = 0.018253449000807796 (tracked)


[1m[36mINFO: [39m[22m[36mEpoch 6
[39m

loss(data[1]) = 0.017508530736130103 (tracked)
loss(data[1]) = 0.017036168681417156 (tracked)
loss(data[1]) = 0.01654713851152025 (tracked)


[1m[36mINFO: [39m[22m[36mEpoch 7
[39m

loss(data[1]) = 0.01609399568118102 (tracked)
loss(data[1]) = 0.01574832057018676 (tracked)
loss(data[1]) = 0.015502287687207775 (tracked)


[1m[36mINFO: [39m[22m[36mEpoch 8
[39m

loss(data[1]) = 0.015201150100779735 (tracked)
loss(data[1]) = 0.014914934395975233 (tracked)
loss(data[1]) = 0.014758935207491247 (tracked)


[1m[36mINFO: [39m[22m[36mEpoch 9
[39m

loss(data[1]) = 0.014447534557420174 (tracked)
loss(data[1]) = 0.014355098964354046 (tracked)
loss(data[1]) = 0.014216107273898921 (tracked)


[1m[36mINFO: [39m[22m[36mEpoch 10
[39m

loss(data[1]) = 0.01410818750114153 (tracked)
loss(data[1]) = 0.014000397195330786 (tracked)
loss(data[1]) = 0.013916149107299958 (tracked)


## Sample output

In [8]:
using Images

In [9]:
img(x::Vector) = Gray.(reshape(clamp.(x, 0, 1), 28, 28))

function sample()
  # 20 random digits
  before = [imgs[i] for i in rand(1:length(imgs), 20)]
  # Before and after images
  after = img.(map(x -> cpu(m)(float(vec(x))).data, before))
  # Stack them all together
  hcat(vcat.(before, after)...)
end

sample (generic function with 1 method)

In [10]:
save("sample.png", sample())