# Convolutional neural network

In [1]:
using Flux, Flux.Data.MNIST
using Flux: @epochs, onehotbatch, onecold, crossentropy, throttle
using Base.Iterators: repeated, partition
using Statistics: mean
# using CuArrays

## Load data

Classify MNIST digits with a convolutional network

In [2]:
imgs = Flux.Data.MNIST.images()
labels = onehotbatch(Flux.Data.MNIST.labels(), 0:9)

10×60000 Flux.OneHotMatrix{Array{Flux.OneHotVector,1}}:
 false   true  false  false  false  …  false  false  false  false  false
 false  false  false   true  false     false  false  false  false  false
 false  false  false  false  false     false  false  false  false  false
 false  false  false  false  false     false   true  false  false  false
 false  false   true  false  false     false  false  false  false  false
  true  false  false  false  false  …  false  false   true  false  false
 false  false  false  false  false     false  false  false   true  false
 false  false  false  false  false     false  false  false  false  false
 false  false  false  false  false      true  false  false  false   true
 false  false  false  false   true     false  false  false  false  false

## Preprocessing

In [3]:
# Partition into batches of size 1,000
train = [(cat(float.(imgs[i])...; dims=4), labels[:,i]) for i in partition(1:60_000, 1000)];
# train = train |> gpu

In [4]:
train[1][1]

28×28×1×1000 Array{Float64,4}:
[:, :, 1, 1] =
 0.0  0.0  0.0  0.0  0.0       0.0       …  0.0       0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0       0.0          0.0       0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0       0.0          0.0       0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0       0.0          0.0       0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0       0.0          0.0       0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0       0.0       …  0.498039  0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0       0.0          0.25098   0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0       0.0          0.0       0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0       0.0          0.0       0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0       0.0          0.0       0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0       0.0       …  0.0       0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0       0.0          0.0       0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0       0.0          0.0       0.0  0.0  0.0  0.0
 ⋮   

In [5]:
# Prepare test set (first 1,000 images)
test_x = cat(float.(Flux.Data.MNIST.images(:test)[1:1000])...; dims=4)# |> gpu
test_y = onehotbatch(Flux.Data.MNIST.labels(:test)[1:1000], 0:9)# |> gpu

10×1000 Flux.OneHotMatrix{Array{Flux.OneHotVector,1}}:
 false  false  false   true  false  …  false  false   true  false  false
 false  false   true  false  false     false  false  false  false  false
 false   true  false  false  false      true   true  false  false  false
 false  false  false  false  false     false  false  false  false  false
 false  false  false  false   true     false  false  false  false  false
 false  false  false  false  false  …  false  false  false  false  false
 false  false  false  false  false     false  false  false  false  false
  true  false  false  false  false     false  false  false  false  false
 false  false  false  false  false     false  false  false   true  false
 false  false  false  false  false     false  false  false  false   true

## Model

In [6]:
m = Chain(
    Conv((2, 2), 1=>16, relu),  # (28, 28, 1) -> (27, 27, 16)
    MaxPool((2, 2)),  # (27, 27, 16) -> (13, 13, 16)
    Conv((2, 2), 16=>8, relu),  # (13, 13, 16) -> (12, 12, 8)
    MaxPool((2, 2)),  # (12, 12, 8) -> (6, 6, 8)
    x -> reshape(x, :, size(x, 4)),  # (6, 6, 8) -> 288
    Dense(288, 10), softmax)# |> gpu

Chain(Conv((2, 2), 1=>16, NNlib.relu), MaxPool((2, 2), pad = (0, 0), stride = (2, 2)), Conv((2, 2), 16=>8, NNlib.relu), MaxPool((2, 2), pad = (0, 0), stride = (2, 2)), getfield(Main, Symbol("##5#6"))(), Dense(288, 10), NNlib.softmax)

```julia
Conv(size, input=>output, activation)
```

* size: size of filter
* input: input channel size
* output: output channel size

```julia
Conv(k, ch, σ = identity;
     init = initn,
     stride = map(_->1,k),
     pad = map(_->0,k),
     dilation = map(_->1,k))

Conv((2, 2), 1=>16, relu; init=initn, stride=(1, 1), pad=(0, 0), dilation=(1, 1))
```

```julia
MaxPool(k; pad, stride)
```

* k: size of filter

```julia
x -> reshape(x, :, size(x, 4))
```

拉直

## Loss function

In [7]:
loss(x, y) = crossentropy(m(x), y)

loss (generic function with 1 method)

In [8]:
accuracy(x, y) = mean(onecold(m(x)) .== onecold(y))

accuracy (generic function with 1 method)

## Optimizer

In [9]:
evalcb() = @show(accuracy(test_x, test_y))
opt = ADAM()

ADAM(0.001, (0.9, 0.999), IdDict{Any,Any}())

## Training

In [10]:
@epochs 10 Flux.train!(loss, params(m), train, opt, cb=throttle(evalcb, 10))

┌ Info: Epoch 1
└ @ Main /home/pika/.julia/packages/Flux/T3PhK/src/optimise/train.jl:103


accuracy(test_x, test_y) = 0.146
accuracy(test_x, test_y) = 0.235
accuracy(test_x, test_y) = 0.301
accuracy(test_x, test_y) = 0.398
accuracy(test_x, test_y) = 0.524
accuracy(test_x, test_y) = 0.625
accuracy(test_x, test_y) = 0.692


┌ Info: Epoch 2
└ @ Main /home/pika/.julia/packages/Flux/T3PhK/src/optimise/train.jl:103


accuracy(test_x, test_y) = 0.721
accuracy(test_x, test_y) = 0.757
accuracy(test_x, test_y) = 0.8
accuracy(test_x, test_y) = 0.821
accuracy(test_x, test_y) = 0.836


┌ Info: Epoch 3
└ @ Main /home/pika/.julia/packages/Flux/T3PhK/src/optimise/train.jl:103


accuracy(test_x, test_y) = 0.84
accuracy(test_x, test_y) = 0.843
accuracy(test_x, test_y) = 0.857
accuracy(test_x, test_y) = 0.868
accuracy(test_x, test_y) = 0.875


┌ Info: Epoch 4
└ @ Main /home/pika/.julia/packages/Flux/T3PhK/src/optimise/train.jl:103


accuracy(test_x, test_y) = 0.877
accuracy(test_x, test_y) = 0.883
accuracy(test_x, test_y) = 0.896
accuracy(test_x, test_y) = 0.903
accuracy(test_x, test_y) = 0.904


┌ Info: Epoch 5
└ @ Main /home/pika/.julia/packages/Flux/T3PhK/src/optimise/train.jl:103


accuracy(test_x, test_y) = 0.908
accuracy(test_x, test_y) = 0.91
accuracy(test_x, test_y) = 0.917
accuracy(test_x, test_y) = 0.921
accuracy(test_x, test_y) = 0.925


┌ Info: Epoch 6
└ @ Main /home/pika/.julia/packages/Flux/T3PhK/src/optimise/train.jl:103


accuracy(test_x, test_y) = 0.923
accuracy(test_x, test_y) = 0.927
accuracy(test_x, test_y) = 0.931
accuracy(test_x, test_y) = 0.934
accuracy(test_x, test_y) = 0.937


┌ Info: Epoch 7
└ @ Main /home/pika/.julia/packages/Flux/T3PhK/src/optimise/train.jl:103


accuracy(test_x, test_y) = 0.937
accuracy(test_x, test_y) = 0.935
accuracy(test_x, test_y) = 0.938
accuracy(test_x, test_y) = 0.944
accuracy(test_x, test_y) = 0.947


┌ Info: Epoch 8
└ @ Main /home/pika/.julia/packages/Flux/T3PhK/src/optimise/train.jl:103


accuracy(test_x, test_y) = 0.948
accuracy(test_x, test_y) = 0.943
accuracy(test_x, test_y) = 0.948
accuracy(test_x, test_y) = 0.953
accuracy(test_x, test_y) = 0.952


┌ Info: Epoch 9
└ @ Main /home/pika/.julia/packages/Flux/T3PhK/src/optimise/train.jl:103


accuracy(test_x, test_y) = 0.955
accuracy(test_x, test_y) = 0.959
accuracy(test_x, test_y) = 0.952
accuracy(test_x, test_y) = 0.959
accuracy(test_x, test_y) = 0.954


┌ Info: Epoch 10
└ @ Main /home/pika/.julia/packages/Flux/T3PhK/src/optimise/train.jl:103


accuracy(test_x, test_y) = 0.955
accuracy(test_x, test_y) = 0.961
accuracy(test_x, test_y) = 0.956
accuracy(test_x, test_y) = 0.962
accuracy(test_x, test_y) = 0.96
