In [1]:
# fashion mnist mlp
#
# reference: https://github.com/FluxML/model-zoo/blob/master/mnist/mlp.jl

In [2]:
using Flux
using Flux: onehotbatch, argmax, crossentropy, throttle, @epochs
using BSON: @save, @load
using Base.Iterators: repeated

using MLDatasets # FashionMNIST
using ColorTypes: N0f8, Gray

const Img = Matrix{Gray{N0f8}}

function prepare_train()
    # load full training set
    train_x, train_y = FashionMNIST.traindata() # 60_000

    trainrange = 1:6_000 # 1:60_000
    imgs = Img.([train_x[:,:,i] for i in trainrange])
    # Stack images into one large batch
    X = hcat(float.(reshape.(imgs, :))...) |> gpu
    # One-hot-encode the labels
    Y = onehotbatch(train_y[trainrange], 0:9) |> gpu
    X, Y
end

function prepare_test()
    # load full test set
    test_x,  test_y  = FashionMNIST.testdata() # 10_000

    testrange = 1:1_000 # 1:10_000
    test_imgs = Img.([test_x[:,:,i] for i in testrange])
    tX = hcat(float.(reshape.(test_imgs, :))...) |> gpu
    tY = onehotbatch(test_y[testrange], 0:9) |> gpu
    tX, tY
end

prepare_test (generic function with 1 method)

In [3]:
X, Y = prepare_train()
tX, tY = prepare_test()

m = Chain(
  Dense(28^2, 32, relu),
  Dense(32, 10),
  softmax) |> gpu

loss(x, y) = crossentropy(m(x), y)

accuracy(x, y) = mean(argmax(m(x)) .== argmax(y))

dataset = repeated((X, Y), 200)
evalcb = () -> @show(loss(X, Y))
opt = ADAM(params(m))

(::#71) (generic function with 1 method)

In [4]:
@epochs 5 Flux.train!(loss, dataset, opt, cb = throttle(evalcb, 2))


[1m[36mINFO: [39m[22m[36mEpoch 1
[39m

loss(X, Y) = 2.3015550915093077 (tracked)
loss(X, Y) = 0.8855032571942232 (tracked)
loss(X, Y) = 0.6416434557583437 (tracked)
loss(X, Y) = 0.541326148273287 (tracked)
loss(X, Y) = 0.48179262874649564 (tracked)
loss(X, Y) = 0.43759184322745537 (tracked)
loss(X, Y) = 0.4038686027324355 (tracked)
loss(X, Y) = 0.3752677778454351 (tracked)


[1m[36mINFO: [39m[22m[36mEpoch 2
[39m

loss(X, Y) = 0.37339573477617 (tracked)
loss(X, Y) = 0.34900580044461926 (tracked)
loss(X, Y) = 0.32657241548408367 (tracked)
loss(X, Y) = 0.3076585024362067 (tracked)
loss(X, Y) = 0.28964055129595195 (tracked)
loss(X, Y) = 0.2735883810775052 (tracked)
loss(X, Y) = 0.2585356075135701 (tracked)
loss(X, Y) = 0.2459053835983541 (tracked)


[1m[36mINFO: [39m[22m[36mEpoch 3
[39m

loss(X, Y) = 0.24080511146521177 (tracked)
loss(X, Y) = 0.2289355641560989 (tracked)
loss(X, Y) = 0.21716029263167755 (tracked)
loss(X, Y) = 0.206020928398754 (tracked)
loss(X, Y) = 0.19543263216624565 (tracked)
loss(X, Y) = 0.18462336414055677 (tracked)
loss(X, Y) = 0.17438605239311952 (tracked)
loss(X, Y) = 0.16441522799281238 (tracked)


[1m[36mINFO: [39m[22m[36mEpoch 4
[39m

loss(X, Y) = 0.15790689188090581 (tracked)
loss(X, Y) = 0.1489303852719797 (tracked)
loss(X, Y) = 0.1407690692366152 (tracked)
loss(X, Y) = 0.13319247974534595 (tracked)
loss(X, Y) = 0.12551529003305983 (tracked)
loss(X, Y) = 0.11878615529760551 (tracked)
loss(X, Y) = 0.11232176680217024 (tracked)
loss(X, Y) = 0.10652097499550225 (tracked)


[1m[36mINFO: [39m[22m[36mEpoch 5
[39m

loss(X, Y) = 0.10473497732555366 (tracked)
loss(X, Y) = 0.09914006649404462 (tracked)
loss(X, Y) = 0.09392987895425263 (tracked)
loss(X, Y) = 0.08874919289669157 (tracked)
loss(X, Y) = 0.08439056644404323 (tracked)
loss(X, Y) = 0.08015316590749122 (tracked)
loss(X, Y) = 0.07605457560769505 (tracked)
loss(X, Y) = 0.07206892567811944 (tracked)


In [5]:
accuracy(X, Y)

0.9886666666666667

In [6]:
accuracy(tX, tY)

0.833