# Convolutional neural network

In [1]:
using Flux, Flux.Data.MNIST
using Flux: @epochs, onehotbatch, onecold, crossentropy, throttle
using Base.Iterators: repeated, partition
using Statistics: mean
# using CuArrays

## Load data

Classify MNIST digits with a convolutional network

In [2]:
imgs = Flux.Data.MNIST.images()
labels = onehotbatch(Flux.Data.MNIST.labels(), 0:9)

10×60000 Flux.OneHotMatrix{Array{Flux.OneHotVector,1}}:
 false   true  false  false  false  …  false  false  false  false  false
 false  false  false   true  false     false  false  false  false  false
 false  false  false  false  false     false  false  false  false  false
 false  false  false  false  false     false   true  false  false  false
 false  false   true  false  false     false  false  false  false  false
  true  false  false  false  false  …  false  false   true  false  false
 false  false  false  false  false     false  false  false   true  false
 false  false  false  false  false     false  false  false  false  false
 false  false  false  false  false      true  false  false  false   true
 false  false  false  false   true     false  false  false  false  false

## Preprocessing

In [3]:
# Partition into batches of size 1,000
train = [(cat(float.(imgs[i])...; dims=4), labels[:,i]) for i in partition(1:60_000, 1000)];
# train = gpu.(train)

In [4]:
train[1][1]

28×28×1×1000 Array{Float64,4}:
[:, :, 1, 1] =
 0.0  0.0  0.0  0.0  0.0       0.0       …  0.0       0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0       0.0          0.0       0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0       0.0          0.0       0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0       0.0          0.0       0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0       0.0          0.0       0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0       0.0       …  0.498039  0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0       0.0          0.25098   0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0       0.0          0.0       0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0       0.0          0.0       0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0       0.0          0.0       0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0       0.0       …  0.0       0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0       0.0          0.0       0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0       0.0          0.0       0.0  0.0  0.0  0.0
 ⋮   

In [5]:
# Prepare test set (first 1,000 images)
test_x = cat(float.(Flux.Data.MNIST.images(:test)[1:1000])...; dims=4)# |> gpu
test_y = onehotbatch(Flux.Data.MNIST.labels(:test)[1:1000], 0:9)# |> gpu

10×1000 Flux.OneHotMatrix{Array{Flux.OneHotVector,1}}:
 false  false  false   true  false  …  false  false   true  false  false
 false  false   true  false  false     false  false  false  false  false
 false   true  false  false  false      true   true  false  false  false
 false  false  false  false  false     false  false  false  false  false
 false  false  false  false   true     false  false  false  false  false
 false  false  false  false  false  …  false  false  false  false  false
 false  false  false  false  false     false  false  false  false  false
  true  false  false  false  false     false  false  false  false  false
 false  false  false  false  false     false  false  false   true  false
 false  false  false  false  false     false  false  false  false   true

## Model

In [6]:
m = Chain(
    Conv((2, 2), 1=>16, relu),  # (28, 28, 1) -> (27, 27, 16)
    MaxPool((2, 2)),  # (27, 27, 16) -> (13, 13, 16)
    Conv((2, 2), 16=>8, relu),  # (13, 13, 16) -> (12, 12, 8)
    MaxPool((2, 2)),  # (12, 12, 8) -> (6, 6, 8)
    x -> reshape(x, :, size(x, 4)),  # (6, 6, 8) -> 288
    Dense(288, 10), softmax)# |> gpu

Chain(Conv((2, 2), 1=>16, NNlib.relu), MaxPool((2, 2), pad = (0, 0), stride = (2, 2)), Conv((2, 2), 16=>8, NNlib.relu), MaxPool((2, 2), pad = (0, 0), stride = (2, 2)), getfield(Main, Symbol("##5#6"))(), Dense(288, 10), NNlib.softmax)

```julia
Conv(size, input=>output, activation)
```

* size: size of filter
* input: input channel size
* output: output channel size

```julia
Conv(k, ch, σ = identity;
     init = initn,
     stride = map(_->1,k),
     pad = map(_->0,k),
     dilation = map(_->1,k))

Conv((2, 2), 1=>16, relu; init=initn, stride=(1, 1), pad=(0, 0), dilation=(1, 1))
```

```julia
MaxPool(k; pad, stride)
```

* k: size of filter

```julia
x -> reshape(x, :, size(x, 4))
```

拉直

## Loss function

In [7]:
loss(x, y) = crossentropy(m(x), y)

loss (generic function with 1 method)

In [8]:
accuracy(x, y) = mean(onecold(m(x)) .== onecold(y))

accuracy (generic function with 1 method)

## Optimizer

In [9]:
evalcb = throttle(() -> @show(accuracy(test_x, test_y)), 10)
opt = ADAM(params(m))

#43 (generic function with 1 method)

## Training

In [10]:
@epochs 5 Flux.train!(loss, train, opt, cb=evalcb)

┌ Info: Epoch 1
└ @ Main /home/pika/.julia/packages/Flux/rcN9D/src/optimise/train.jl:93


accuracy(test_x, test_y) = 0.194
accuracy(test_x, test_y) = 0.127
accuracy(test_x, test_y) = 0.173
accuracy(test_x, test_y) = 0.222
accuracy(test_x, test_y) = 0.345
accuracy(test_x, test_y) = 0.442
accuracy(test_x, test_y) = 0.533
accuracy(test_x, test_y) = 0.607
accuracy(test_x, test_y) = 0.654


┌ Info: Epoch 2
└ @ Main /home/pika/.julia/packages/Flux/rcN9D/src/optimise/train.jl:93


accuracy(test_x, test_y) = 0.668
accuracy(test_x, test_y) = 0.682
accuracy(test_x, test_y) = 0.706
accuracy(test_x, test_y) = 0.749
accuracy(test_x, test_y) = 0.785
accuracy(test_x, test_y) = 0.784
accuracy(test_x, test_y) = 0.808
accuracy(test_x, test_y) = 0.822
accuracy(test_x, test_y) = 0.824


┌ Info: Epoch 3
└ @ Main /home/pika/.julia/packages/Flux/rcN9D/src/optimise/train.jl:93


accuracy(test_x, test_y) = 0.833
accuracy(test_x, test_y) = 0.84
accuracy(test_x, test_y) = 0.844
accuracy(test_x, test_y) = 0.858
accuracy(test_x, test_y) = 0.871
accuracy(test_x, test_y) = 0.876
accuracy(test_x, test_y) = 0.883


┌ Info: Epoch 4
└ @ Main /home/pika/.julia/packages/Flux/rcN9D/src/optimise/train.jl:93


accuracy(test_x, test_y) = 0.885
accuracy(test_x, test_y) = 0.891
accuracy(test_x, test_y) = 0.894
accuracy(test_x, test_y) = 0.902
accuracy(test_x, test_y) = 0.905
accuracy(test_x, test_y) = 0.908
accuracy(test_x, test_y) = 0.909


┌ Info: Epoch 5
└ @ Main /home/pika/.julia/packages/Flux/rcN9D/src/optimise/train.jl:93


accuracy(test_x, test_y) = 0.911
accuracy(test_x, test_y) = 0.915
accuracy(test_x, test_y) = 0.913
accuracy(test_x, test_y) = 0.918
accuracy(test_x, test_y) = 0.921
accuracy(test_x, test_y) = 0.927
accuracy(test_x, test_y) = 0.931
