In [None]:
using Flux
using Flux: onehotbatch, argmax, crossentropy, throttle
using Base.Iterators: repeated, partition
using MNIST
using CLArrays

In [None]:
X_train, y_train = traindata()
X_test, y_test = testdata()


X_train = X_train'
X_test = X_test'

X_train = reshape(X_train, (60000, 28, 28))
X_test = reshape(X_test, (10000, 28, 28))

X_train = [reshape(X_train[i, :], (28, 28)) for i in 1:60_000 ]
X_test = [reshape(X_test[i, :], (28, 28)) for i in 1:10_000 ]

println(size(X_train))
println(size(y_train))
println()
println(size(X_test))
println(size(y_test))

In [None]:
# X_train = reshape(X_train, (28, 28, 1, 60000))
# X_test = reshape(X_test, (28, 28, 1, 10000))

y_train = onehotbatch(y_train, 0:9)
y_test = onehotbatch(y_test, 0:9)

println(typeof(y_test))
println(size(y_test))
println(typeof(y_train))
println(size(y_train))

In [None]:
println(typeof(X_train))
println(typeof(X_test))
println(size(X_train))
println(size(X_test))

In [None]:
# X_train_f32 = CLArray(Float32.(X_train))
# y_train_f32 = CLArray(y_train)
#
X_test_f32 = CLArray(Float32.(X_test))
y_test_f32 = CLArray(Float32.(y_test))

In [None]:
# println(typeof(X_train_f32))
# println(typeof(X_test_f32))
println(typeof(y_train_f32))
println(typeof(y_test_f32))

In [None]:
# Partition into batches of size 100
training_data = [(cat(4, float.(X_train[i])...), y_train[:,i])
         for i in partition(1:60_000, 100)]

X_train = cat(4, float.(X_train)...)
X_test = cat(4, float.(X_test)...)

In [None]:
println(size(X_test))
println(size(X_train))

println(typeof(X_test))
println(typeof(X_train))

In [None]:
model = Chain(
            #28 * 28
          Conv2D((3, 3), 1=>32, relu),
            #26 * 26 * 32   
          x -> maxpool2d(x, 2),
            #13 * 13 * 32   
          Conv2D((3, 3), 32=>64, relu),
            #11 * 11 * 64
          x -> maxpool2d(x, 2),
            #5 * 5 * 64
          x -> reshape(x, :, size(x, 4)),
          Dense(1600, 128, relu),
          Dense(128, 10), 
          softmax)

In [None]:
# Use cross entropy loss here
loss(x, y) = crossentropy(model(x), y)
accuracy(x, y) = mean(argmax(model(x)) .== argmax(y))

In [None]:
model(training_data[1][1])

In [None]:
# Callback during the training
function eval_cb()
    
#     println("Training loss: $(loss(X_train, y_train)), acc: $(accuracy(X_train, y_train))")
    println("Test loss: $(loss(X_test, y_test)), acc: $(accuracy(X_test, y_test))")
    println()
end

optimizer = ADAM(params(model))

Flux.train!(loss, training_data, optimizer, cb = throttle(eval_cb, 10))

In [None]:
println("Optimization Finished!")
@printf "Model accuracy on test set: %.3f\n" accuracy(X_test, y_test)