In [None]:
Pkg.add("MNIST")
Pkg.add("CuArrays")

In [None]:
using Flux
using Flux: onehotbatch, argmax, crossentropy, throttle
using Base.Iterators: repeated, partition
using MNIST
using CuArrays

In [None]:
X_train, y_train = traindata();
X_test, y_test = testdata();

In [None]:
y_train = onehotbatch(y_train, 0:9)

dataset = [(reshape(cu(X_train[:, i]), (28, 28, 1, 100)), cu(y_train[:, i])) for i in partition(1:60000, 100)]

In [None]:
dataset[1][2]

In [None]:
X_train = reshape(cu(X_train), (28, 28, 1, 60000))
X_test = reshape(cu(X_test), (28, 28, 1, 10000))

function one_hot_encoding(x::Float64)
    
    tmp = zeros(10)'
    tmp[Int(x) + 1] = 1
    
    return tmp
end

y_train = cu(onehotbatch(y_train, 0:9))
y_test = cu(onehotbatch(y_test, 0:9))

In [None]:
println(typeof(X_train))
println(typeof(y_train))
println(typeof(X_test))
println(typeof(y_test))

In [None]:
model = Chain(
            #28 * 28
          Conv2D((3, 3), 1=>32, relu),
            #26 * 26 * 32   
          x -> maxpool2d(x, 2),
            #13 * 13 * 32   
          Conv2D((3, 3), 32=>64, relu),
            #11 * 11 * 64
          x -> maxpool2d(x, 2),
            #5 * 5 * 64
          x -> reshape(x, :, size(x, 4)),
          Dense(1600, 128, relu),
          Dense(128, 10), 
          softmax)

model = mapleaves(cu, model)

In [None]:
# Use cross entropy loss here
loss(x, y) = crossentropy(model(x), y)
accuracy(x, y) = mean(argmax(model(x)) .== argmax(y))

In [None]:
# Callback during the training
function eval_cb()
    
#     println("Training loss: $(loss(X_train, y_train)), acc: $(accuracy(X_train, y_train))")
    println("Test loss: $(loss(X_test, y_test)), acc: $(accuracy(X_test, y_test))")
    println()
end

optimizer = ADAM(params(model))

# Flux.train!(loss, training_data, optimizer, cb = throttle(eval_cb, 10))

In [None]:
@time Flux.train!(loss, dataset, optimizer)

In [None]:
test_set = [(reshape(cu(X_train[:, i]), (28, 28, 1, 10000)), cu(y_train[:, i])) for i in partition(1:10000, 10000)]

In [None]:
accuracy(test_set[1][1], test_set[1][2])