In [2]:
using DiffEqFlux, OrdinaryDiffEq, Flux, NNlib, MLDataUtils, MLDatasets, Printf
using Flux: logitcrossentropy
using Flux.Data: DataLoader

In [5]:
function loadmnist(batchsize = bs, train_split = 0.9)
    # use MLDataUtils LabelEnc for natural onehot conversion
    onehot(labels_raw) = convertlabel(LabelEnc.OneOfK, labels_raw, LabelEnc.NativeLabels(collect(0:9)))
    # load MNIST
    imgs, labels_raw = MNIST.traindata();
    # process images into (H,W,C,BS) batches
    x_data = Float32.(reshape(imgs, size(imgs,1), size(imgs,2), 1, size(imgs,3)))
    y_data = onehot(labels_raw)
    (x_train, y_train), (x_test, y_test) = stratifiedobs((x_data, y_data), p = train_split)
    return (
        # use Flux's DataLoader to automatically minibatch and shuffle the data
        DataLoader(cpu.(collect.((x_train, y_train))); batchsize = batchsize, shuffle = true),
        # don't shuffle the test data
        DataLoader(cpu.(collect.((x_test, y_test))); batchsize = batchsize, shuffle = false)
    )
end

loadmnist (generic function with 3 methods)

In [6]:
const bs = 128
const train_split = 0.9
train_dataloader, test_dataloader = loadmnist(bs, train_split);

(DataLoader{Tuple{Array{Float32,4},Array{Int64,2}}}((Float32[0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0; … ; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0]

Float32[0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0; … ; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0]

Float32[0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0; … ; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0]

...

Float32[0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0; … ; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0]

Float32[0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0; … ; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0]

Float32[0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0; … ; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0], [0 0 … 0 0; 0 0 … 0 0; … ; 0 0 … 0 0; 0 0 … 0 0]), 128, 54000, true, 54000, [1, 2, 3, 4, 5, 6, 7, 8, 9, 10  …  53991, 53992, 53993, 53994, 53995, 53996, 53997, 53998, 53999, 54000], true), DataLoader{Tuple{Array{Float32,4},Array{Int64,2}}}((Float32[0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0; … ; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0]

Float32[0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0; … ; 0.0 0.0 … 0.0 0.0; 0.0 0.0 …

In [52]:
@show train_dataloader.data[1] |> size
@show train_dataloader.data[2] |> size

train_dataloader.data[1] |> size = (28, 28, 1, 54000)
train_dataloader.data[2] |> size = (10, 54000)


(10, 54000)

In [58]:
down = Chain(flatten, Dense(784, 20, tanh))

nn = Chain(Dense(20, 10, tanh), Dense(10, 10, tanh), Dense(10, 20, tanh))
# NeuralODE(model, tspan, args...;p = nothing, kwargs...)
nn_ode = NeuralODE(nn, (0.f0, 2.f0), Tsit5(), save_everystep = false, reltol = 1e-3, abstol = 1e-3, save_start = false) # only save the final result

fc = Chain(Dense(20, 10)) # final fully-connected layer

function DiffEqArray_to_Array(x)
    xarr = cpu(x)
    return reshape(xarr, size(xarr)[1:2])
end

model = Chain(down, nn_ode, DiffEqArray_to_Array, fc)

Chain(Chain(flatten, Dense(784, 20, tanh)), NeuralODE{Chain{Tuple{Dense{typeof(tanh),Array{Float32,2},Array{Float32,1}},Dense{typeof(tanh),Array{Float32,2},Array{Float32,1}},Dense{typeof(tanh),Array{Float32,2},Array{Float32,1}}}},Array{Float32,1},Flux.var"#34#36"{Chain{Tuple{Dense{typeof(tanh),Array{Float32,2},Array{Float32,1}},Dense{typeof(tanh),Array{Float32,2},Array{Float32,1}},Dense{typeof(tanh),Array{Float32,2},Array{Float32,1}}}}},Tuple{Float32,Float32},Tuple{Tsit5},Base.Iterators.Pairs{Symbol,Real,NTuple{4,Symbol},NamedTuple{(:save_everystep, :reltol, :abstol, :save_start),Tuple{Bool,Float64,Float64,Bool}}}}(Chain(Dense(20, 10, tanh), Dense(10, 10, tanh), Dense(10, 20, tanh)), Float32[-0.18468359, 0.4422882, -0.18116233, 0.015536827, 0.3745248, 0.37027827, 0.29794687, 0.08216352, 0.12096443, 0.29873174  …  0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], Flux.var"#34#36"{Chain{Tuple{Dense{typeof(tanh),Array{Float32,2},Array{Float32,1}},Dense{typeof(tanh),Array{Float32,2},Array

In [59]:
# we can compute the forward pass through the NN topology
x_m = model(img)

10×1 Array{Float32,2}:
 -1.1912977
  0.07104479
  0.84480995
 -1.2965094
  0.7101893
  0.12247431
  1.0377375
  0.1360069
 -0.13325228
  1.0457156

In [60]:
classify(x) = argmax.(eachcol(x))

classify (generic function with 1 method)

In [61]:
classify([1, 2, 3])

1-element Array{Int64,1}:
 3

In [62]:
function accuracy(model, data; n_batches = 100)
    total_correct = 0
    total = 0
    for (i, (x, y)) in enumerate(collect(data))
        # Only evaluate accuracy for n_batches
        i > n_batches && break
        target_class = classify(cpu(y))
        predicted_class = classify(cpu(model(x)))
        total_correct += sum(target_class .== predicted_class)
        total += length(target_class)
    end
    return total_correct / total
end

# there's no time series in loss function
# so the tspan in NeuralODE can be given arbitrary value
loss(x, y) = logitcrossentropy(model(x), y)

loss (generic function with 1 method)

In [63]:
@show accuracy(model, train_dataloader)
@show loss(img, lab)

accuracy(model, train_dataloader) = 0.071484375
loss(img, lab) = 2.574335f0


2.574335f0

In [64]:
opt = ADAM(0.05)
iter = 0

0

In [65]:
cb() = begin
    global iter += 1
    # monitor the training process
    if iter % 10 == 1
        train_accuracy = accuracy(model, train_dataloader) * 100
        test_accuracy = accuracy(model, test_dataloader;
                                 n_batches = length(test_dataloader)) * 100
        @printf("Iter: %3d || Train Accuracy: %2.3f || Test Accuracy: %2.3f\n",
                iter, train_accuracy, test_accuracy)
    end
end

cb (generic function with 1 method)

In [68]:
# Train the NN-ODE and monitor the loss and weights.
#Flux.train!(loss, params(down, nn_ode.p, fc), train_dataloader, opt, cb = cb)
Flux.train!(loss, params(model), train_dataloader, opt, cb = cb)

Iter:  41 || Train Accuracy: 81.086 || Test Accuracy: 80.117
Iter:  51 || Train Accuracy: 81.047 || Test Accuracy: 80.567
Iter:  61 || Train Accuracy: 81.242 || Test Accuracy: 81.050
Iter:  71 || Train Accuracy: 84.430 || Test Accuracy: 83.667


LoadError: InterruptException: