# Multi-layer perceptron (classification)

In [1]:
using MLDatasets
using Flux
using Flux: @epochs, onehotbatch, onecold, crossentropy, throttle
using Base.Iterators: repeated
using Statistics: mean

## Load data

In [2]:
train_x, train_y = MNIST.traindata()
test_x,  test_y  = MNIST.testdata()

(FixedPointNumbers.Normed{UInt8,8}[0.0N0f8 0.0N0f8 … 0.0N0f8 0.0N0f8; 0.0N0f8 0.0N0f8 … 0.0N0f8 0.0N0f8; … ; 0.0N0f8 0.0N0f8 … 0.0N0f8 0.0N0f8; 0.0N0f8 0.0N0f8 … 0.0N0f8 0.0N0f8]

FixedPointNumbers.Normed{UInt8,8}[0.0N0f8 0.0N0f8 … 0.0N0f8 0.0N0f8; 0.0N0f8 0.0N0f8 … 0.0N0f8 0.0N0f8; … ; 0.0N0f8 0.0N0f8 … 0.0N0f8 0.0N0f8; 0.0N0f8 0.0N0f8 … 0.0N0f8 0.0N0f8]

FixedPointNumbers.Normed{UInt8,8}[0.0N0f8 0.0N0f8 … 0.0N0f8 0.0N0f8; 0.0N0f8 0.0N0f8 … 0.0N0f8 0.0N0f8; … ; 0.0N0f8 0.0N0f8 … 0.0N0f8 0.0N0f8; 0.0N0f8 0.0N0f8 … 0.0N0f8 0.0N0f8]

...

FixedPointNumbers.Normed{UInt8,8}[0.0N0f8 0.0N0f8 … 0.0N0f8 0.0N0f8; 0.0N0f8 0.0N0f8 … 0.0N0f8 0.0N0f8; … ; 0.0N0f8 0.0N0f8 … 0.0N0f8 0.0N0f8; 0.0N0f8 0.0N0f8 … 0.0N0f8 0.0N0f8]

FixedPointNumbers.Normed{UInt8,8}[0.0N0f8 0.0N0f8 … 0.0N0f8 0.0N0f8; 0.0N0f8 0.0N0f8 … 0.0N0f8 0.0N0f8; … ; 0.0N0f8 0.0N0f8 … 0.0N0f8 0.0N0f8; 0.0N0f8 0.0N0f8 … 0.0N0f8 0.0N0f8]

FixedPointNumbers.Normed{UInt8,8}[0.0N0f8 0.0N0f8 … 0.0N0f8 0.0N0f8; 0.0N0f8 0.0N0f8 … 0.0N0f8 0.0N

## Preprocessing

In [3]:
train_x = train_x / 255
train_x = hcat([vec(train_x[:, :, k]) for k in 1:size(train_x, 3)]...)

784×60000 Array{Float32,2}:
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  …  0.0  0.0  0.0  0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0     0.0  0.0  0.0  0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0     0.0  0.0  0.0  0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0     0.0  0.0  0.0  0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0     0.0  0.0  0.0  0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  …  0.0  0.0  0.0  0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0     0.0  0.0  0.0  0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0     0.0  0.0  0.0  0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0     0.0  0.0  0.0  0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0     0.0  0.0  0.0  0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  …  0.0  0.0  0.0  0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0     0.0  0.0  0.0  0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  

In [4]:
train_y = onehotbatch(train_y, 0:9)

10×60000 Flux.OneHotMatrix{Array{Flux.OneHotVector,1}}:
 0  1  0  0  0  0  0  0  0  0  0  0  0  …  0  0  0  0  0  0  0  0  0  0  0  0
 0  0  0  1  0  0  1  0  1  0  0  0  0     0  0  0  0  0  0  1  0  0  0  0  0
 0  0  0  0  0  1  0  0  0  0  0  0  0     0  0  0  1  0  0  0  0  0  0  0  0
 0  0  0  0  0  0  0  1  0  0  1  0  1     0  0  0  0  0  0  0  0  1  0  0  0
 0  0  1  0  0  0  0  0  0  1  0  0  0     0  0  0  0  0  0  0  0  0  0  0  0
 1  0  0  0  0  0  0  0  0  0  0  1  0  …  0  0  0  0  0  1  0  0  0  1  0  0
 0  0  0  0  0  0  0  0  0  0  0  0  0     0  0  0  0  0  0  0  0  0  0  1  0
 0  0  0  0  0  0  0  0  0  0  0  0  0     1  0  0  0  0  0  0  0  0  0  0  0
 0  0  0  0  0  0  0  0  0  0  0  0  0     0  1  0  0  0  0  0  1  0  0  0  1
 0  0  0  0  1  0  0  0  0  0  0  0  0     0  0  1  0  1  0  0  0  0  0  0  0

In [5]:
dataset = repeated((train_x, train_y), 200)

Base.Iterators.Take{Base.Iterators.Repeated{Tuple{Array{Float32,2},Flux.OneHotMatrix{Array{Flux.OneHotVector,1}}}}}(Base.Iterators.Repeated{Tuple{Array{Float32,2},Flux.OneHotMatrix{Array{Flux.OneHotVector,1}}}}((Float32[0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0; … ; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0], Bool[0 1 … 0 0; 0 0 … 0 0; … ; 0 0 … 0 1; 0 0 … 0 0])), 200)

## Model

In [6]:
m = Chain(
    Dense(28^2, 32, relu),
    Dense(32, 10),
    softmax)

Chain(Dense(784, 32, relu), Dense(32, 10), softmax)

In [7]:
loss(x, y) = crossentropy(m(x), y)

loss (generic function with 1 method)

In [8]:
accuracy(x, y) = mean(onecold(m(x)) .== onecold(y))

accuracy (generic function with 1 method)

## Optimizer

In [9]:
evalcb() = @show(loss(train_x, train_y))
opt = ADAM(0.005)

ADAM(0.005, (0.9, 0.999), IdDict{Any,Any}())

In [10]:
@time @epochs 5 Flux.train!(loss, params(m), dataset, opt, cb=throttle(evalcb, 10))

┌ Info: Epoch 1
└ @ Main /home/yuehhua/.julia/packages/Flux/2i5P1/src/optimise/train.jl:99


loss(train_x, train_y) = 2.301369f0
loss(train_x, train_y) = 2.1804109f0
loss(train_x, train_y) = 1.8360584f0
loss(train_x, train_y) = 1.4171159f0
loss(train_x, train_y) = 1.0976263f0


┌ Info: Epoch 2
└ @ Main /home/yuehhua/.julia/packages/Flux/2i5P1/src/optimise/train.jl:99


loss(train_x, train_y) = 0.9808849f0
loss(train_x, train_y) = 0.7923923f0
loss(train_x, train_y) = 0.67392635f0
loss(train_x, train_y) = 0.5831098f0
loss(train_x, train_y) = 0.5222042f0


┌ Info: Epoch 3
└ @ Main /home/yuehhua/.julia/packages/Flux/2i5P1/src/optimise/train.jl:99


loss(train_x, train_y) = 0.51142335f0
loss(train_x, train_y) = 0.46765563f0
loss(train_x, train_y) = 0.43529418f0
loss(train_x, train_y) = 0.40860063f0
loss(train_x, train_y) = 0.38671792f0


┌ Info: Epoch 4
└ @ Main /home/yuehhua/.julia/packages/Flux/2i5P1/src/optimise/train.jl:99


loss(train_x, train_y) = 0.38316342f0
loss(train_x, train_y) = 0.36552578f0
loss(train_x, train_y) = 0.3516401f0
loss(train_x, train_y) = 0.34025624f0
loss(train_x, train_y) = 0.32911277f0


┌ Info: Epoch 5
└ @ Main /home/yuehhua/.julia/packages/Flux/2i5P1/src/optimise/train.jl:99


loss(train_x, train_y) = 0.32770792f0
loss(train_x, train_y) = 0.31789044f0
loss(train_x, train_y) = 0.30964553f0
loss(train_x, train_y) = 0.30193678f0
227.143020 seconds (1.82 G allocations: 278.935 GiB, 6.43% gc time)


In [11]:
accuracy(train_x, train_y)

0.9158166666666666

In [12]:
# Test set accuracy
test_x = test_x / 255
test_x = hcat([vec(test_x[:, :, k]) for k in 1:size(test_x, 3)]...)
test_y = onehotbatch(test_y, 0:9)

10×10000 Flux.OneHotMatrix{Array{Flux.OneHotVector,1}}:
 0  0  0  1  0  0  0  0  0  0  1  0  0  …  0  0  0  0  0  1  0  0  0  0  0  0
 0  0  1  0  0  1  0  0  0  0  0  0  0     0  0  0  0  0  0  1  0  0  0  0  0
 0  1  0  0  0  0  0  0  0  0  0  0  0     0  0  0  0  0  0  0  1  0  0  0  0
 0  0  0  0  0  0  0  0  0  0  0  0  0     0  0  0  0  0  0  0  0  1  0  0  0
 0  0  0  0  1  0  1  0  0  0  0  0  0     0  0  0  0  0  0  0  0  0  1  0  0
 0  0  0  0  0  0  0  0  1  0  0  0  0  …  1  0  0  0  0  0  0  0  0  0  1  0
 0  0  0  0  0  0  0  0  0  0  0  1  0     0  1  0  0  0  0  0  0  0  0  0  1
 1  0  0  0  0  0  0  0  0  0  0  0  0     0  0  1  0  0  0  0  0  0  0  0  0
 0  0  0  0  0  0  0  0  0  0  0  0  0     0  0  0  1  0  0  0  0  0  0  0  0
 0  0  0  0  0  0  0  1  0  1  0  0  1     0  0  0  0  1  0  0  0  0  0  0  0

In [13]:
accuracy(test_x, test_y)

0.9178