# Convolutional neural network

In [1]:
using Flux, Flux.Data.MNIST
using Flux: @epochs, onehotbatch, onecold, crossentropy, throttle
using Base.Iterators: repeated, partition
using Statistics: mean
# using CuArrays

┌ Info: Recompiling stale cache file /home/pika/.julia/compiled/v1.0/Flux/QdkVy.ji for Flux [587475ba-b771-5e3f-ad9e-33799f191a9c]
└ @ Base loading.jl:1190


## Load data

Classify MNIST digits with a convolutional network

In [2]:
imgs = Flux.Data.MNIST.images()
labels = onehotbatch(Flux.Data.MNIST.labels(), 0:9)

10×60000 Flux.OneHotMatrix{Array{Flux.OneHotVector,1}}:
 false   true  false  false  false  …  false  false  false  false  false
 false  false  false   true  false     false  false  false  false  false
 false  false  false  false  false     false  false  false  false  false
 false  false  false  false  false     false   true  false  false  false
 false  false   true  false  false     false  false  false  false  false
  true  false  false  false  false  …  false  false   true  false  false
 false  false  false  false  false     false  false  false   true  false
 false  false  false  false  false     false  false  false  false  false
 false  false  false  false  false      true  false  false  false   true
 false  false  false  false   true     false  false  false  false  false

## Preprocessing

In [3]:
# Partition into batches of size 1,000
train = [(cat(float.(imgs[i])...; dims=4), labels[:,i]) for i in partition(1:60_000, 1000)];
# train = gpu.(train)

In [4]:
train[1][1]

28×28×1×1000 Array{Float64,4}:
[:, :, 1, 1] =
 0.0  0.0  0.0  0.0  0.0       0.0       …  0.0       0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0       0.0          0.0       0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0       0.0          0.0       0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0       0.0          0.0       0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0       0.0          0.0       0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0       0.0       …  0.498039  0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0       0.0          0.25098   0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0       0.0          0.0       0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0       0.0          0.0       0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0       0.0          0.0       0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0       0.0       …  0.0       0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0       0.0          0.0       0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0       0.0          0.0       0.0  0.0  0.0  0.0
 ⋮   

In [5]:
# Prepare test set (first 1,000 images)
test_x = cat(float.(Flux.Data.MNIST.images(:test)[1:1000])...; dims=4)# |> gpu
test_y = onehotbatch(Flux.Data.MNIST.labels(:test)[1:1000], 0:9)# |> gpu

10×1000 Flux.OneHotMatrix{Array{Flux.OneHotVector,1}}:
 false  false  false   true  false  …  false  false   true  false  false
 false  false   true  false  false     false  false  false  false  false
 false   true  false  false  false      true   true  false  false  false
 false  false  false  false  false     false  false  false  false  false
 false  false  false  false   true     false  false  false  false  false
 false  false  false  false  false  …  false  false  false  false  false
 false  false  false  false  false     false  false  false  false  false
  true  false  false  false  false     false  false  false  false  false
 false  false  false  false  false     false  false  false   true  false
 false  false  false  false  false     false  false  false  false   true

## Model

In [6]:
m = Chain(
    Conv((2, 2), 1=>16, relu),  # (28, 28, 1) -> (27, 27, 16)
    MaxPool((2, 2)),  # (27, 27, 16) -> (13, 13, 16)
    Conv((2, 2), 16=>8, relu),  # (13, 13, 16) -> (12, 12, 8)
    MaxPool((2, 2)),  # (12, 12, 8) -> (6, 6, 8)
    x -> reshape(x, :, size(x, 4)),  # (6, 6, 8) -> 288
    Dense(288, 10), softmax)# |> gpu

Chain(Conv((2, 2), 1=>16, NNlib.relu), MaxPool((2, 2), pad = (0, 0), stride = (2, 2)), Conv((2, 2), 16=>8, NNlib.relu), MaxPool((2, 2), pad = (0, 0), stride = (2, 2)), getfield(Main, Symbol("##5#6"))(), Dense(288, 10), NNlib.softmax)

```julia
Conv(size, input=>output, activation)
```

* size: size of filter
* input: input channel size
* output: output channel size

```julia
Conv(k, ch, σ = identity;
     init = initn,
     stride = map(_->1,k),
     pad = map(_->0,k),
     dilation = map(_->1,k))

Conv((2, 2), 1=>16, relu; init=initn, stride=(1, 1), pad=(0, 0), dilation=(1, 1))
```

```julia
MaxPool(k; pad, stride)
```

* k: size of filter

```julia
x -> reshape(x, :, size(x, 4))
```

拉直

## Loss function

In [7]:
loss(x, y) = crossentropy(m(x), y)

loss (generic function with 1 method)

In [8]:
accuracy(x, y) = mean(onecold(m(x)) .== onecold(y))

accuracy (generic function with 1 method)

## Optimizer

In [9]:
evalcb = throttle(() -> @show(accuracy(test_x, test_y)), 10)
opt = ADAM(params(m))

#43 (generic function with 1 method)

## Training

In [10]:
@epochs 10 Flux.train!(loss, train, opt, cb=evalcb)

┌ Info: Epoch 1
└ @ Main /home/pika/.julia/packages/Flux/rcN9D/src/optimise/train.jl:93


accuracy(test_x, test_y) = 0.143
accuracy(test_x, test_y) = 0.142
accuracy(test_x, test_y) = 0.106
accuracy(test_x, test_y) = 0.151
accuracy(test_x, test_y) = 0.14
accuracy(test_x, test_y) = 0.129
accuracy(test_x, test_y) = 0.141
accuracy(test_x, test_y) = 0.222
accuracy(test_x, test_y) = 0.36
accuracy(test_x, test_y) = 0.471
accuracy(test_x, test_y) = 0.531
accuracy(test_x, test_y) = 0.581
accuracy(test_x, test_y) = 0.636


┌ Info: Epoch 2
└ @ Main /home/pika/.julia/packages/Flux/rcN9D/src/optimise/train.jl:93


accuracy(test_x, test_y) = 0.677
accuracy(test_x, test_y) = 0.689
accuracy(test_x, test_y) = 0.691
accuracy(test_x, test_y) = 0.709
accuracy(test_x, test_y) = 0.746
accuracy(test_x, test_y) = 0.764
accuracy(test_x, test_y) = 0.776
accuracy(test_x, test_y) = 0.759
accuracy(test_x, test_y) = 0.788
accuracy(test_x, test_y) = 0.797
accuracy(test_x, test_y) = 0.808
accuracy(test_x, test_y) = 0.813
accuracy(test_x, test_y) = 0.81


┌ Info: Epoch 3
└ @ Main /home/pika/.julia/packages/Flux/rcN9D/src/optimise/train.jl:93


accuracy(test_x, test_y) = 0.814
accuracy(test_x, test_y) = 0.818
accuracy(test_x, test_y) = 0.822
accuracy(test_x, test_y) = 0.813
accuracy(test_x, test_y) = 0.827
accuracy(test_x, test_y) = 0.83
accuracy(test_x, test_y) = 0.833
accuracy(test_x, test_y) = 0.834
accuracy(test_x, test_y) = 0.835
accuracy(test_x, test_y) = 0.836
accuracy(test_x, test_y) = 0.837
accuracy(test_x, test_y) = 0.842
accuracy(test_x, test_y) = 0.843


┌ Info: Epoch 4
└ @ Main /home/pika/.julia/packages/Flux/rcN9D/src/optimise/train.jl:93


accuracy(test_x, test_y) = 0.844
accuracy(test_x, test_y) = 0.842
accuracy(test_x, test_y) = 0.845
accuracy(test_x, test_y) = 0.846
accuracy(test_x, test_y) = 0.851
accuracy(test_x, test_y) = 0.851
accuracy(test_x, test_y) = 0.854
accuracy(test_x, test_y) = 0.849
accuracy(test_x, test_y) = 0.856
accuracy(test_x, test_y) = 0.858
accuracy(test_x, test_y) = 0.857
accuracy(test_x, test_y) = 0.858
accuracy(test_x, test_y) = 0.86


┌ Info: Epoch 5
└ @ Main /home/pika/.julia/packages/Flux/rcN9D/src/optimise/train.jl:93


accuracy(test_x, test_y) = 0.86
accuracy(test_x, test_y) = 0.857
accuracy(test_x, test_y) = 0.859
accuracy(test_x, test_y) = 0.857
accuracy(test_x, test_y) = 0.864
accuracy(test_x, test_y) = 0.861
accuracy(test_x, test_y) = 0.869
accuracy(test_x, test_y) = 0.863
accuracy(test_x, test_y) = 0.871
accuracy(test_x, test_y) = 0.87
accuracy(test_x, test_y) = 0.875
accuracy(test_x, test_y) = 0.873


┌ Info: Epoch 6
└ @ Main /home/pika/.julia/packages/Flux/rcN9D/src/optimise/train.jl:93


accuracy(test_x, test_y) = 0.866
accuracy(test_x, test_y) = 0.872
accuracy(test_x, test_y) = 0.87
accuracy(test_x, test_y) = 0.871
accuracy(test_x, test_y) = 0.87
accuracy(test_x, test_y) = 0.872
accuracy(test_x, test_y) = 0.872
accuracy(test_x, test_y) = 0.87
accuracy(test_x, test_y) = 0.879
accuracy(test_x, test_y) = 0.876
accuracy(test_x, test_y) = 0.882
accuracy(test_x, test_y) = 0.881


┌ Info: Epoch 7
└ @ Main /home/pika/.julia/packages/Flux/rcN9D/src/optimise/train.jl:93


accuracy(test_x, test_y) = 0.88
accuracy(test_x, test_y) = 0.882
accuracy(test_x, test_y) = 0.882
accuracy(test_x, test_y) = 0.877
accuracy(test_x, test_y) = 0.875
accuracy(test_x, test_y) = 0.889
accuracy(test_x, test_y) = 0.884
accuracy(test_x, test_y) = 0.88
accuracy(test_x, test_y) = 0.885
accuracy(test_x, test_y) = 0.886
accuracy(test_x, test_y) = 0.888
accuracy(test_x, test_y) = 0.886


┌ Info: Epoch 8
└ @ Main /home/pika/.julia/packages/Flux/rcN9D/src/optimise/train.jl:93


accuracy(test_x, test_y) = 0.885
accuracy(test_x, test_y) = 0.886
accuracy(test_x, test_y) = 0.89
accuracy(test_x, test_y) = 0.89
accuracy(test_x, test_y) = 0.886
accuracy(test_x, test_y) = 0.891
accuracy(test_x, test_y) = 0.897
accuracy(test_x, test_y) = 0.891
accuracy(test_x, test_y) = 0.895
accuracy(test_x, test_y) = 0.893
accuracy(test_x, test_y) = 0.9
accuracy(test_x, test_y) = 0.893


┌ Info: Epoch 9
└ @ Main /home/pika/.julia/packages/Flux/rcN9D/src/optimise/train.jl:93


accuracy(test_x, test_y) = 0.895
accuracy(test_x, test_y) = 0.893
accuracy(test_x, test_y) = 0.898
accuracy(test_x, test_y) = 0.897
accuracy(test_x, test_y) = 0.894
accuracy(test_x, test_y) = 0.896
accuracy(test_x, test_y) = 0.901
accuracy(test_x, test_y) = 0.902
accuracy(test_x, test_y) = 0.902
accuracy(test_x, test_y) = 0.902
accuracy(test_x, test_y) = 0.904
accuracy(test_x, test_y) = 0.906
accuracy(test_x, test_y) = 0.902


┌ Info: Epoch 10
└ @ Main /home/pika/.julia/packages/Flux/rcN9D/src/optimise/train.jl:93


accuracy(test_x, test_y) = 0.903
accuracy(test_x, test_y) = 0.908
accuracy(test_x, test_y) = 0.908
accuracy(test_x, test_y) = 0.903
accuracy(test_x, test_y) = 0.899
accuracy(test_x, test_y) = 0.908
accuracy(test_x, test_y) = 0.909
accuracy(test_x, test_y) = 0.903
accuracy(test_x, test_y) = 0.905
accuracy(test_x, test_y) = 0.911
accuracy(test_x, test_y) = 0.909
accuracy(test_x, test_y) = 0.908
