In [1]:
using Flux, Flux.Data.MNIST
using Flux: onehotbatch, argmax, crossentropy, throttle
using Base.Iterators: repeated, partition

In [2]:
X_train = MNIST.images()
y_train = onehotbatch(MNIST.labels(), 0:9)
X_test = MNIST.images(:test)
y_test = onehotbatch(MNIST.labels(:test), 0:9)

10×10000 Flux.OneHotMatrix{Array{Flux.OneHotVector,1}}:
 false  false  false   true  false  …  false  false  false  false  false
 false  false   true  false  false     false  false  false  false  false
 false   true  false  false  false      true  false  false  false  false
 false  false  false  false  false     false   true  false  false  false
 false  false  false  false   true     false  false   true  false  false
 false  false  false  false  false  …  false  false  false   true  false
 false  false  false  false  false     false  false  false  false   true
  true  false  false  false  false     false  false  false  false  false
 false  false  false  false  false     false  false  false  false  false
 false  false  false  false  false     false  false  false  false  false

In [None]:
# Partition into batches of size 1,000
training_data = [(cat(4, float.(X_train[i])...), y_train[:,i])
         for i in partition(1:60_000, 100)]

# X_train = cat(4, float.(X_train)...)
X_test = cat(4, float.(X_test)...)

In [None]:
model = Chain(
          Conv2D((2,2), 1=>16, relu),
          x -> maxpool2d(x, 2),
          Conv2D((2,2), 16=>8, relu),
          x -> maxpool2d(x, 2),
          x -> reshape(x, :, size(x, 4)),
          Dense(288, 10), 
          softmax)

In [None]:
# Use cross entropy loss here
loss(x, y) = crossentropy(model(x), y)
accuracy(x, y) = mean(argmax(model(x)) .== argmax(y))

In [None]:
# Callback during the training
function eval_cb()
    
    println("Training loss: $(loss(X_train, y_train)), acc: $(accuracy(X_train, y_train))")
    println("Test loss: $(loss(X_test, y_test)), acc: $(accuracy(X_test, y_test))")
    println()
end

optimizer = ADAM(params(model))

Flux.train!(loss, training_data, optimizer, cb = throttle(eval_cb, 10))

In [None]:
println("Optimization Finished!")
@printf "Model accuracy on training set: %.3f\n" accuracy(X_train, y_train)
@printf "Model accuracy on test set: %.3f\n" accuracy(X_test, y_test)