# Convolutional neural network

In [1]:
using Flux, Flux.Data.MNIST
using Flux: @epochs, onehotbatch, argmax, crossentropy, throttle
using Base.Iterators: repeated, partition
# using CuArrays

## Load data

Classify MNIST digits with a convolutional network

In [2]:
imgs = Flux.Data.MNIST.images()
labels = onehotbatch(Flux.Data.MNIST.labels(), 0:9)

10×60000 Flux.OneHotMatrix{Array{Flux.OneHotVector,1}}:
 false   true  false  false  false  …  false  false  false  false  false
 false  false  false   true  false     false  false  false  false  false
 false  false  false  false  false     false  false  false  false  false
 false  false  false  false  false     false   true  false  false  false
 false  false   true  false  false     false  false  false  false  false
  true  false  false  false  false  …  false  false   true  false  false
 false  false  false  false  false     false  false  false   true  false
 false  false  false  false  false     false  false  false  false  false
 false  false  false  false  false      true  false  false  false   true
 false  false  false  false   true     false  false  false  false  false

## Preprocessing

In [3]:
# Partition into batches of size 1,000
train = [(cat(4, float.(imgs[i])...), labels[:,i]) for i in partition(1:60_000, 1000)];
# train = gpu.(train)

In [4]:
train[1][1]

28×28×1×1000 Array{Float64,4}:
[:, :, 1, 1] =
 0.0  0.0  0.0  0.0  0.0       0.0       …  0.0       0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0       0.0          0.0       0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0       0.0          0.0       0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0       0.0          0.0       0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0       0.0          0.0       0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0       0.0       …  0.498039  0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0       0.0          0.25098   0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0       0.0          0.0       0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0       0.0          0.0       0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0       0.0          0.0       0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0       0.0       …  0.0       0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0       0.0          0.0       0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0       0.0          0.0       0.0  0.0  0.0  0.0
 ⋮   

In [5]:
# Prepare test set (first 1,000 images)
test_x = cat(4, float.(Flux.Data.MNIST.images(:test)[1:1000])...)# |> gpu
test_y = onehotbatch(Flux.Data.MNIST.labels(:test)[1:1000], 0:9)# |> gpu

10×1000 Flux.OneHotMatrix{Array{Flux.OneHotVector,1}}:
 false  false  false   true  false  …  false  false   true  false  false
 false  false   true  false  false     false  false  false  false  false
 false   true  false  false  false      true   true  false  false  false
 false  false  false  false  false     false  false  false  false  false
 false  false  false  false   true     false  false  false  false  false
 false  false  false  false  false  …  false  false  false  false  false
 false  false  false  false  false     false  false  false  false  false
  true  false  false  false  false     false  false  false  false  false
 false  false  false  false  false     false  false  false   true  false
 false  false  false  false  false     false  false  false  false   true

## Model

In [6]:
m = Chain(
    Conv((2, 2), 1=>16, relu),  # (28, 28, 1) -> (27, 27, 16)
    x -> maxpool(x, (2, 2)),  # (27, 27, 16) -> (26, 26, 16)
    Conv((2, 2), 16=>8, relu),  # (26, 26, 16) -> (25, 25, 8)
    x -> maxpool(x, (2, 2)),  # (25, 25, 8) -> (24, 24, 8)
    x -> reshape(x, :, size(x, 4)),  # (24, 24, 8) -> 288
    Dense(288, 10), softmax)# |> gpu

Chain(Conv((2, 2), 1=>16, NNlib.relu), #3, Conv((2, 2), 16=>8, NNlib.relu), #4, #5, Dense(288, 10), NNlib.softmax)

```julia
Conv(size, input=>output, activation)
```

* size: size of filter
* input: input channel size
* output: output channel size

```julia
Conv(k, ch, σ = identity;
     init = initn,
     stride = map(_->1,k),
     pad = map(_->0,k),
     dilation = map(_->1,k))

Conv((2, 2), 1=>16, relu; init=initn, stride=(1, 1), pad=(0, 0), dilation=(1, 1))
```

```julia
maxpool(x, k; pad, stride)
```

* x: data
* k: size of filter

```julia
x -> reshape(x, :, size(x, 4))
```

拉直

## Loss function

In [7]:
loss(x, y) = crossentropy(m(x), y)

loss (generic function with 1 method)

In [8]:
accuracy(x, y) = mean(argmax(m(x)) .== argmax(y))

accuracy (generic function with 1 method)

## Optimizer

In [9]:
evalcb = throttle(() -> @show(accuracy(test_x, test_y)), 10)
opt = ADAM(params(m))

(::#80) (generic function with 1 method)

## Training

In [10]:
@epochs 5 Flux.train!(loss, train, opt, cb=evalcb)

[1m[36mINFO: [39m[22m[36mEpoch 1
[39m

accuracy(test_x, test_y) = 0.2
accuracy(test_x, test_y) = 0.133
accuracy(test_x, test_y) = 0.242
accuracy(test_x, test_y) = 0.317
accuracy(test_x, test_y) = 0.366
accuracy(test_x, test_y) = 0.498
accuracy(test_x, test_y) = 0.613
accuracy(test_x, test_y) = 0.664


[1m[36mINFO: [39m[22m[36mEpoch 2
[39m

accuracy(test_x, test_y) = 0.686
accuracy(test_x, test_y) = 0.708
accuracy(test_x, test_y) = 0.743
accuracy(test_x, test_y) = 0.781
accuracy(test_x, test_y) = 0.778
accuracy(test_x, test_y) = 0.799
accuracy(test_x, test_y) = 0.809
accuracy(test_x, test_y) = 0.823


[1m[36mINFO: [39m[22m[36mEpoch 3
[39m

accuracy(test_x, test_y) = 0.818
accuracy(test_x, test_y) = 0.828
accuracy(test_x, test_y) = 0.839
accuracy(test_x, test_y) = 0.843
accuracy(test_x, test_y) = 0.849
accuracy(test_x, test_y) = 0.85
accuracy(test_x, test_y) = 0.863


[1m[36mINFO: [39m[22m[36mEpoch 4
[39m

accuracy(test_x, test_y) = 0.859
accuracy(test_x, test_y) = 0.866
accuracy(test_x, test_y) = 0.865
accuracy(test_x, test_y) = 0.867
accuracy(test_x, test_y) = 0.868
accuracy(test_x, test_y) = 0.882
accuracy(test_x, test_y) = 0.877
accuracy(test_x, test_y) = 0.886


[1m[36mINFO: [39m[22m[36mEpoch 5
[39m

accuracy(test_x, test_y) = 0.887
accuracy(test_x, test_y) = 0.888
accuracy(test_x, test_y) = 0.893
accuracy(test_x, test_y) = 0.892
accuracy(test_x, test_y) = 0.9
accuracy(test_x, test_y) = 0.9
accuracy(test_x, test_y) = 0.903
