# Convolutional neural network

In [1]:
using Flux, Flux.Data.MNIST
using Flux: @epochs, onehotbatch, onecold, crossentropy, throttle
using Base.Iterators: repeated, partition
using Statistics: mean
# using CuArrays

## Load data and preprocessing

Classify MNIST digits with a convolutional network

In [2]:
batch_size = 128

128

In [3]:
function minibatch(X, Y, idxs)
    X_batch = Array{Float32}(undef, size(X[1])..., 1, length(idxs))
    for i in 1:length(idxs)
        X_batch[:, :, :, i] = Float32.(X[idxs[i]])
    end
    Y_batch = onehotbatch(Y[idxs], 0:9)
    return (X_batch, Y_batch)
end

minibatch (generic function with 1 method)

In [4]:
train_labels = MNIST.labels()
train_imgs = MNIST.images()
mb_idxs = partition(1:length(train_imgs), batch_size)
train = [minibatch(train_imgs, train_labels, i) for i in mb_idxs];

In [5]:
test_imgs = MNIST.images(:test)
test_labels = MNIST.labels(:test)
test = minibatch(test_imgs, test_labels, 1:length(test_imgs));

## Model

In [6]:
model = Chain(
    Conv((3, 3), 1=>16, pad=(2,2), relu),  # (1, 1, 1) -> (3, 3, 16)
    MaxPool((2, 2)),                       # (3, 3, 16) -> (1, 1, 16)
    Conv((3, 3), 16=>32, pad=(2,2), relu), # (1, 1, 16) -> (3, 3, 32)
    MaxPool((2, 2)),                       # (3, 3, 32) -> (1, 1, 32)
    Conv((3, 3), 32=>32, pad=(2,2), relu), # (1, 1, 32) -> (3, 3, 32)
    MaxPool((2, 2)),                       # (3, 3, 32) -> (1, 1, 32)
    x -> reshape(x, :, size(x, 4)),        # (1, 1, 32) -> 32
    Dense(800, 10), softmax)

Chain(Conv((3, 3), 1=>16, NNlib.relu), MaxPool((2, 2), pad = (0, 0, 0, 0), stride = (2, 2)), Conv((3, 3), 16=>32, NNlib.relu), MaxPool((2, 2), pad = (0, 0, 0, 0), stride = (2, 2)), Conv((3, 3), 32=>32, NNlib.relu), MaxPool((2, 2), pad = (0, 0, 0, 0), stride = (2, 2)), getfield(Main, Symbol("##5#6"))(), Dense(800, 10), NNlib.softmax)

In [7]:
model(train[1][1])

Tracked 10×128 Array{Float32,2}:
 0.00488587  0.0059864    0.0172354    …  0.0160831    0.0181191  
 0.00328304  0.000408046  0.00628934      0.0131039    0.00512937 
 0.475399    0.189241     0.681357        0.523271     0.306449   
 0.011853    0.0178683    0.000955412     0.0959452    0.034394   
 0.159037    0.149042     0.0538785       0.0298879    0.395101   
 2.43308e-6  1.94792e-5   1.6662e-5    …  2.38639e-5   4.22073e-5 
 7.1593e-5   0.000156501  0.00059031      0.000301392  0.000219752
 0.235879    0.0950901    0.138555        0.254942     0.0665807  
 0.0162545   0.00619433   0.00344156      0.0128089    0.00580126 
 0.0933343   0.535994     0.0976801       0.053633     0.168164   

In [8]:
# train = gpu.(train)
# test = gpu.(test)
# model = gpu(model)

```julia
Conv(size, input=>output, activation)
```

* size: size of filter
* input: input channel size
* output: output channel size

```julia
Conv(k, ch, σ = identity;
     init = initn,
     stride = map(_->1,k),
     pad = map(_->0,k),
     dilation = map(_->1,k))

Conv((2, 2), 1=>16, relu; init=initn, stride=(1, 1), pad=(0, 0), dilation=(1, 1))
```

```julia
MaxPool(k; pad, stride)
```

* k: size of filter

```julia
x -> reshape(x, :, size(x, 4))
```

拉直

## Loss function

In [9]:
loss(x, y) = crossentropy(model(x), y)

loss (generic function with 1 method)

In [10]:
accuracy(x, y) = mean(onecold(model(x)) .== onecold(y))

accuracy (generic function with 1 method)

## Optimizer

In [11]:
evalcb() = @show(accuracy(test[1], test[2]))
opt = ADAM()

ADAM(0.001, (0.9, 0.999), IdDict{Any,Any}())

In [12]:
loss(train[1][1], train[1][2])

4.2682443f0 (tracked)

## Training

In [13]:
@epochs 10 Flux.train!(loss, params(model), train, opt, cb=throttle(evalcb, 10))

┌ Info: Epoch 1
└ @ Main /home/pika/.julia/packages/Flux/qXNjB/src/optimise/train.jl:105


accuracy(test[1], test[2]) = 0.1154
accuracy(test[1], test[2]) = 0.702
accuracy(test[1], test[2]) = 0.8391
accuracy(test[1], test[2]) = 0.8596
accuracy(test[1], test[2]) = 0.8807
accuracy(test[1], test[2]) = 0.903
accuracy(test[1], test[2]) = 0.924
accuracy(test[1], test[2]) = 0.9311
accuracy(test[1], test[2]) = 0.9303
accuracy(test[1], test[2]) = 0.921
accuracy(test[1], test[2]) = 0.9232
accuracy(test[1], test[2]) = 0.9401
accuracy(test[1], test[2]) = 0.9424
accuracy(test[1], test[2]) = 0.9408
accuracy(test[1], test[2]) = 0.9475
accuracy(test[1], test[2]) = 0.9495
accuracy(test[1], test[2]) = 0.9561
accuracy(test[1], test[2]) = 0.9514
accuracy(test[1], test[2]) = 0.951
accuracy(test[1], test[2]) = 0.9509
accuracy(test[1], test[2]) = 0.9461
accuracy(test[1], test[2]) = 0.9595
accuracy(test[1], test[2]) = 0.9623
accuracy(test[1], test[2]) = 0.9636
accuracy(test[1], test[2]) = 0.9633
accuracy(test[1], test[2]) = 0.961
accuracy(test[1], test[2]) = 0.961
accuracy(test[1], test[2]) = 0.9692

┌ Info: Epoch 2
└ @ Main /home/pika/.julia/packages/Flux/qXNjB/src/optimise/train.jl:105


accuracy(test[1], test[2]) = 0.9662
accuracy(test[1], test[2]) = 0.9583
accuracy(test[1], test[2]) = 0.9691
accuracy(test[1], test[2]) = 0.9737
accuracy(test[1], test[2]) = 0.9692
accuracy(test[1], test[2]) = 0.9712
accuracy(test[1], test[2]) = 0.9534
accuracy(test[1], test[2]) = 0.9744
accuracy(test[1], test[2]) = 0.9707
accuracy(test[1], test[2]) = 0.973
accuracy(test[1], test[2]) = 0.9681
accuracy(test[1], test[2]) = 0.973
accuracy(test[1], test[2]) = 0.9751
accuracy(test[1], test[2]) = 0.9747
accuracy(test[1], test[2]) = 0.9764
accuracy(test[1], test[2]) = 0.9736
accuracy(test[1], test[2]) = 0.9725
accuracy(test[1], test[2]) = 0.9779
accuracy(test[1], test[2]) = 0.969
accuracy(test[1], test[2]) = 0.9755
accuracy(test[1], test[2]) = 0.9766
accuracy(test[1], test[2]) = 0.9785
accuracy(test[1], test[2]) = 0.9751
accuracy(test[1], test[2]) = 0.977
accuracy(test[1], test[2]) = 0.9782
accuracy(test[1], test[2]) = 0.9803
accuracy(test[1], test[2]) = 0.9739
accuracy(test[1], test[2]) = 0.9

┌ Info: Epoch 3
└ @ Main /home/pika/.julia/packages/Flux/qXNjB/src/optimise/train.jl:105


accuracy(test[1], test[2]) = 0.9773
accuracy(test[1], test[2]) = 0.9752
accuracy(test[1], test[2]) = 0.9796
accuracy(test[1], test[2]) = 0.9809
accuracy(test[1], test[2]) = 0.9826
accuracy(test[1], test[2]) = 0.9829
accuracy(test[1], test[2]) = 0.9787
accuracy(test[1], test[2]) = 0.9694
accuracy(test[1], test[2]) = 0.9824
accuracy(test[1], test[2]) = 0.9826
accuracy(test[1], test[2]) = 0.9832
accuracy(test[1], test[2]) = 0.9826
accuracy(test[1], test[2]) = 0.9829
accuracy(test[1], test[2]) = 0.9816
accuracy(test[1], test[2]) = 0.9799
accuracy(test[1], test[2]) = 0.9811
accuracy(test[1], test[2]) = 0.9802
accuracy(test[1], test[2]) = 0.98
accuracy(test[1], test[2]) = 0.9813
accuracy(test[1], test[2]) = 0.9814
accuracy(test[1], test[2]) = 0.9764
accuracy(test[1], test[2]) = 0.9808
accuracy(test[1], test[2]) = 0.9822
accuracy(test[1], test[2]) = 0.9828
accuracy(test[1], test[2]) = 0.9842
accuracy(test[1], test[2]) = 0.9833
accuracy(test[1], test[2]) = 0.9834
accuracy(test[1], test[2]) = 0

┌ Info: Epoch 4
└ @ Main /home/pika/.julia/packages/Flux/qXNjB/src/optimise/train.jl:105


accuracy(test[1], test[2]) = 0.9826
accuracy(test[1], test[2]) = 0.9837
accuracy(test[1], test[2]) = 0.9834
accuracy(test[1], test[2]) = 0.9818
accuracy(test[1], test[2]) = 0.9824
accuracy(test[1], test[2]) = 0.9833
accuracy(test[1], test[2]) = 0.979
accuracy(test[1], test[2]) = 0.9787
accuracy(test[1], test[2]) = 0.9842
accuracy(test[1], test[2]) = 0.986
accuracy(test[1], test[2]) = 0.984
accuracy(test[1], test[2]) = 0.9854
accuracy(test[1], test[2]) = 0.984
accuracy(test[1], test[2]) = 0.9869
accuracy(test[1], test[2]) = 0.984
accuracy(test[1], test[2]) = 0.9868
accuracy(test[1], test[2]) = 0.9839
accuracy(test[1], test[2]) = 0.9827
accuracy(test[1], test[2]) = 0.984
accuracy(test[1], test[2]) = 0.9842
accuracy(test[1], test[2]) = 0.9833
accuracy(test[1], test[2]) = 0.9852
accuracy(test[1], test[2]) = 0.9853
accuracy(test[1], test[2]) = 0.9845
accuracy(test[1], test[2]) = 0.9866
accuracy(test[1], test[2]) = 0.9873
accuracy(test[1], test[2]) = 0.9852
accuracy(test[1], test[2]) = 0.987

┌ Info: Epoch 5
└ @ Main /home/pika/.julia/packages/Flux/qXNjB/src/optimise/train.jl:105


accuracy(test[1], test[2]) = 0.9859
accuracy(test[1], test[2]) = 0.9861
accuracy(test[1], test[2]) = 0.987
accuracy(test[1], test[2]) = 0.9842
accuracy(test[1], test[2]) = 0.9836
accuracy(test[1], test[2]) = 0.9788
accuracy(test[1], test[2]) = 0.9814
accuracy(test[1], test[2]) = 0.9854
accuracy(test[1], test[2]) = 0.9844
accuracy(test[1], test[2]) = 0.9867
accuracy(test[1], test[2]) = 0.9869
accuracy(test[1], test[2]) = 0.988
accuracy(test[1], test[2]) = 0.9853
accuracy(test[1], test[2]) = 0.9853
accuracy(test[1], test[2]) = 0.9846
accuracy(test[1], test[2]) = 0.9875
accuracy(test[1], test[2]) = 0.9844
accuracy(test[1], test[2]) = 0.9855
accuracy(test[1], test[2]) = 0.986
accuracy(test[1], test[2]) = 0.9874
accuracy(test[1], test[2]) = 0.9856
accuracy(test[1], test[2]) = 0.9855
accuracy(test[1], test[2]) = 0.9891
accuracy(test[1], test[2]) = 0.9847
accuracy(test[1], test[2]) = 0.987
accuracy(test[1], test[2]) = 0.9878
accuracy(test[1], test[2]) = 0.9844
accuracy(test[1], test[2]) = 0.9

┌ Info: Epoch 6
└ @ Main /home/pika/.julia/packages/Flux/qXNjB/src/optimise/train.jl:105


accuracy(test[1], test[2]) = 0.9871
accuracy(test[1], test[2]) = 0.9877
accuracy(test[1], test[2]) = 0.9879
accuracy(test[1], test[2]) = 0.9823
accuracy(test[1], test[2]) = 0.9835
accuracy(test[1], test[2]) = 0.9809
accuracy(test[1], test[2]) = 0.9833
accuracy(test[1], test[2]) = 0.986
accuracy(test[1], test[2]) = 0.9854
accuracy(test[1], test[2]) = 0.9874
accuracy(test[1], test[2]) = 0.9889
accuracy(test[1], test[2]) = 0.9882
accuracy(test[1], test[2]) = 0.9871
accuracy(test[1], test[2]) = 0.9883
accuracy(test[1], test[2]) = 0.987
accuracy(test[1], test[2]) = 0.9866
accuracy(test[1], test[2]) = 0.9876
accuracy(test[1], test[2]) = 0.9841
accuracy(test[1], test[2]) = 0.9874
accuracy(test[1], test[2]) = 0.988
accuracy(test[1], test[2]) = 0.9879
accuracy(test[1], test[2]) = 0.9896
accuracy(test[1], test[2]) = 0.9864
accuracy(test[1], test[2]) = 0.9862
accuracy(test[1], test[2]) = 0.9865
accuracy(test[1], test[2]) = 0.988
accuracy(test[1], test[2]) = 0.9877
accuracy(test[1], test[2]) = 0.9

┌ Info: Epoch 7
└ @ Main /home/pika/.julia/packages/Flux/qXNjB/src/optimise/train.jl:105


accuracy(test[1], test[2]) = 0.9881
accuracy(test[1], test[2]) = 0.9875
accuracy(test[1], test[2]) = 0.9878
accuracy(test[1], test[2]) = 0.9818
accuracy(test[1], test[2]) = 0.9846
accuracy(test[1], test[2]) = 0.9806
accuracy(test[1], test[2]) = 0.987
accuracy(test[1], test[2]) = 0.9856
accuracy(test[1], test[2]) = 0.9864
accuracy(test[1], test[2]) = 0.9877
accuracy(test[1], test[2]) = 0.9888
accuracy(test[1], test[2]) = 0.9882
accuracy(test[1], test[2]) = 0.987
accuracy(test[1], test[2]) = 0.9882
accuracy(test[1], test[2]) = 0.9862
accuracy(test[1], test[2]) = 0.9882
accuracy(test[1], test[2]) = 0.9871
accuracy(test[1], test[2]) = 0.9846
accuracy(test[1], test[2]) = 0.9854
accuracy(test[1], test[2]) = 0.9894
accuracy(test[1], test[2]) = 0.9865
accuracy(test[1], test[2]) = 0.9864
accuracy(test[1], test[2]) = 0.99
accuracy(test[1], test[2]) = 0.9881
accuracy(test[1], test[2]) = 0.9887
accuracy(test[1], test[2]) = 0.9883
accuracy(test[1], test[2]) = 0.9877
accuracy(test[1], test[2]) = 0.9

┌ Info: Epoch 8
└ @ Main /home/pika/.julia/packages/Flux/qXNjB/src/optimise/train.jl:105


accuracy(test[1], test[2]) = 0.9884
accuracy(test[1], test[2]) = 0.9879
accuracy(test[1], test[2]) = 0.988
accuracy(test[1], test[2]) = 0.9807
accuracy(test[1], test[2]) = 0.9841
accuracy(test[1], test[2]) = 0.983
accuracy(test[1], test[2]) = 0.9869
accuracy(test[1], test[2]) = 0.9846
accuracy(test[1], test[2]) = 0.9869
accuracy(test[1], test[2]) = 0.988
accuracy(test[1], test[2]) = 0.9881
accuracy(test[1], test[2]) = 0.9881
accuracy(test[1], test[2]) = 0.9878
accuracy(test[1], test[2]) = 0.9884
accuracy(test[1], test[2]) = 0.9864
accuracy(test[1], test[2]) = 0.9881
accuracy(test[1], test[2]) = 0.9877
accuracy(test[1], test[2]) = 0.9855
accuracy(test[1], test[2]) = 0.9861
accuracy(test[1], test[2]) = 0.9901
accuracy(test[1], test[2]) = 0.9875
accuracy(test[1], test[2]) = 0.9907
accuracy(test[1], test[2]) = 0.9867
accuracy(test[1], test[2]) = 0.9869
accuracy(test[1], test[2]) = 0.9865
accuracy(test[1], test[2]) = 0.9884
accuracy(test[1], test[2]) = 0.9886
accuracy(test[1], test[2]) = 0.

┌ Info: Epoch 9
└ @ Main /home/pika/.julia/packages/Flux/qXNjB/src/optimise/train.jl:105


accuracy(test[1], test[2]) = 0.9882
accuracy(test[1], test[2]) = 0.9869
accuracy(test[1], test[2]) = 0.9876
accuracy(test[1], test[2]) = 0.9826
accuracy(test[1], test[2]) = 0.986
accuracy(test[1], test[2]) = 0.9812
accuracy(test[1], test[2]) = 0.9887
accuracy(test[1], test[2]) = 0.9858
accuracy(test[1], test[2]) = 0.9861
accuracy(test[1], test[2]) = 0.9876
accuracy(test[1], test[2]) = 0.9886
accuracy(test[1], test[2]) = 0.9883
accuracy(test[1], test[2]) = 0.9893
accuracy(test[1], test[2]) = 0.9886
accuracy(test[1], test[2]) = 0.9869
accuracy(test[1], test[2]) = 0.9879
accuracy(test[1], test[2]) = 0.9864
accuracy(test[1], test[2]) = 0.9865
accuracy(test[1], test[2]) = 0.9868
accuracy(test[1], test[2]) = 0.9896
accuracy(test[1], test[2]) = 0.9862
accuracy(test[1], test[2]) = 0.9907
accuracy(test[1], test[2]) = 0.9868
accuracy(test[1], test[2]) = 0.9867
accuracy(test[1], test[2]) = 0.9862
accuracy(test[1], test[2]) = 0.9881
accuracy(test[1], test[2]) = 0.9864
accuracy(test[1], test[2]) = 

┌ Info: Epoch 10
└ @ Main /home/pika/.julia/packages/Flux/qXNjB/src/optimise/train.jl:105


accuracy(test[1], test[2]) = 0.989
accuracy(test[1], test[2]) = 0.9869
accuracy(test[1], test[2]) = 0.987
accuracy(test[1], test[2]) = 0.9881
accuracy(test[1], test[2]) = 0.9836
accuracy(test[1], test[2]) = 0.9855
accuracy(test[1], test[2]) = 0.9877
accuracy(test[1], test[2]) = 0.9872
accuracy(test[1], test[2]) = 0.9795
accuracy(test[1], test[2]) = 0.9861
accuracy(test[1], test[2]) = 0.9886
accuracy(test[1], test[2]) = 0.9867
accuracy(test[1], test[2]) = 0.988
accuracy(test[1], test[2]) = 0.9869
accuracy(test[1], test[2]) = 0.9887
accuracy(test[1], test[2]) = 0.9885
accuracy(test[1], test[2]) = 0.9877
accuracy(test[1], test[2]) = 0.9858
accuracy(test[1], test[2]) = 0.9889
accuracy(test[1], test[2]) = 0.9884
accuracy(test[1], test[2]) = 0.9861
accuracy(test[1], test[2]) = 0.9871
accuracy(test[1], test[2]) = 0.9875
accuracy(test[1], test[2]) = 0.9862
accuracy(test[1], test[2]) = 0.9859
accuracy(test[1], test[2]) = 0.9878
accuracy(test[1], test[2]) = 0.9877
accuracy(test[1], test[2]) = 0.