# Convolutional neural network

In [1]:
using Flux, Flux.Data.MNIST
using Flux: @epochs, onehotbatch, onecold, crossentropy, throttle
using Base.Iterators: repeated, partition
using Statistics: mean
using CuArrays

┌ Info: Precompiling Flux [587475ba-b771-5e3f-ad9e-33799f191a9c]
└ @ Base loading.jl:1273


## Load data and preprocessing

Classify MNIST digits with a convolutional network

In [2]:
batch_size = 128

128

In [3]:
function minibatch(X, Y, idxs)
    X_batch = Array{Float32}(undef, size(X[1])..., 1, length(idxs))
    for i in 1:length(idxs)
        X_batch[:, :, :, i] = Float32.(X[idxs[i]])
    end
    Y_batch = onehotbatch(Y[idxs], 0:9)
    return (X_batch, Y_batch)
end

minibatch (generic function with 1 method)

In [4]:
train_labels = MNIST.labels()
train_imgs = MNIST.images()
mb_idxs = partition(1:length(train_imgs), batch_size)
train = [minibatch(train_imgs, train_labels, i) for i in mb_idxs];

In [5]:
test_imgs = MNIST.images(:test)
test_labels = MNIST.labels(:test)
test = minibatch(test_imgs, test_labels, 1:length(test_imgs));

## Model

In [6]:
model = Chain(
    Conv((3, 3), 1=>16, pad=(2,2), relu),  # (1, 1, 1) -> (3, 3, 16)
    MaxPool((2, 2)),                       # (3, 3, 16) -> (1, 1, 16)
    Conv((3, 3), 16=>32, pad=(2,2), relu), # (1, 1, 16) -> (3, 3, 32)
    MaxPool((2, 2)),                       # (3, 3, 32) -> (1, 1, 32)
    Conv((3, 3), 32=>32, pad=(2,2), relu), # (1, 1, 32) -> (3, 3, 32)
    MaxPool((2, 2)),                       # (3, 3, 32) -> (1, 1, 32)
    x -> reshape(x, :, size(x, 4)),        # (1, 1, 32) -> 32
    Dense(800, 10), softmax)

Chain(Conv((3, 3), 1=>16, relu), MaxPool((2, 2), pad = (0, 0, 0, 0), stride = (2, 2)), Conv((3, 3), 16=>32, relu), MaxPool((2, 2), pad = (0, 0, 0, 0), stride = (2, 2)), Conv((3, 3), 32=>32, relu), MaxPool((2, 2), pad = (0, 0, 0, 0), stride = (2, 2)), #5, Dense(800, 10), softmax)

In [7]:
model(train[1][1])

10×128 Array{Float32,2}:
 0.112337   0.113622   0.112011   …  0.111775   0.114617   0.107833 
 0.0956268  0.0948047  0.0968034     0.0972567  0.0935568  0.0957295
 0.0981474  0.0991135  0.0975963     0.0988294  0.0962602  0.100368 
 0.107728   0.113859   0.119287      0.109618   0.115767   0.109394 
 0.0933075  0.0863181  0.0889688     0.0919112  0.0839984  0.0907601
 0.107482   0.113453   0.110401   …  0.105376   0.112065   0.112879 
 0.0966003  0.0948287  0.0940934     0.097852   0.0972406  0.0996174
 0.093528   0.0960599  0.0925315     0.0947581  0.0927276  0.0976891
 0.101795   0.0911661  0.0928497     0.0938984  0.0947846  0.0940598
 0.0934481  0.0967738  0.0954577     0.0987258  0.098983   0.0916699

In [8]:
train = gpu.(train)
test = gpu.(test)
model = gpu(model)

Chain(Conv((3, 3), 1=>16, relu), MaxPool((2, 2), pad = (0, 0, 0, 0), stride = (2, 2)), Conv((3, 3), 16=>32, relu), MaxPool((2, 2), pad = (0, 0, 0, 0), stride = (2, 2)), Conv((3, 3), 32=>32, relu), MaxPool((2, 2), pad = (0, 0, 0, 0), stride = (2, 2)), #5, Dense(800, 10), softmax)

```julia
Conv(size, input=>output, activation)
```

* size: size of filter
* input: input channel size
* output: output channel size

```julia
Conv(k, ch, σ = identity;
     init = initn,
     stride = map(_->1,k),
     pad = map(_->0,k),
     dilation = map(_->1,k))

Conv((2, 2), 1=>16, relu; init=initn, stride=(1, 1), pad=(0, 0), dilation=(1, 1))
```

```julia
MaxPool(k; pad, stride)
```

* k: size of filter

```julia
x -> reshape(x, :, size(x, 4))
```

拉直

## Loss function

In [9]:
loss(x, y) = crossentropy(model(x), y)

loss (generic function with 1 method)

In [10]:
accuracy(x, y) = mean(onecold(model(x)) .== onecold(y))

accuracy (generic function with 1 method)

## Optimizer

In [11]:
evalcb() = @show(accuracy(test[1], test[2]))
opt = ADAM(0.002)

ADAM(0.002, (0.9, 0.999), IdDict{Any,Any}())

In [12]:
loss(train[1][1], train[1][2])

2.3097384f0

## Training

In [13]:
@epochs 10 Flux.train!(loss, params(model), train, opt, cb=throttle(evalcb, 10))

┌ Info: Epoch 1
└ @ Main /home/yuehhua/.julia/packages/Flux/2i5P1/src/optimise/train.jl:99
└ @ GPUArrays /home/yuehhua/.julia/packages/GPUArrays/1wgPO/src/indexing.jl:16


accuracy(test[1], test[2]) = 0.1017


┌ Info: Epoch 2
└ @ Main /home/yuehhua/.julia/packages/Flux/2i5P1/src/optimise/train.jl:99


accuracy(test[1], test[2]) = 0.9745


┌ Info: Epoch 3
└ @ Main /home/yuehhua/.julia/packages/Flux/2i5P1/src/optimise/train.jl:99


accuracy(test[1], test[2]) = 0.9842


┌ Info: Epoch 4
└ @ Main /home/yuehhua/.julia/packages/Flux/2i5P1/src/optimise/train.jl:99


accuracy(test[1], test[2]) = 0.9862


┌ Info: Epoch 5
└ @ Main /home/yuehhua/.julia/packages/Flux/2i5P1/src/optimise/train.jl:99


accuracy(test[1], test[2]) = 0.9883


┌ Info: Epoch 6
└ @ Main /home/yuehhua/.julia/packages/Flux/2i5P1/src/optimise/train.jl:99


accuracy(test[1], test[2]) = 0.9864


┌ Info: Epoch 7
└ @ Main /home/yuehhua/.julia/packages/Flux/2i5P1/src/optimise/train.jl:99


accuracy(test[1], test[2]) = 0.9874


┌ Info: Epoch 8
└ @ Main /home/yuehhua/.julia/packages/Flux/2i5P1/src/optimise/train.jl:99


accuracy(test[1], test[2]) = 0.9886


┌ Info: Epoch 9
└ @ Main /home/yuehhua/.julia/packages/Flux/2i5P1/src/optimise/train.jl:99


accuracy(test[1], test[2]) = 0.9866


┌ Info: Epoch 10
└ @ Main /home/yuehhua/.julia/packages/Flux/2i5P1/src/optimise/train.jl:99


accuracy(test[1], test[2]) = 0.9884
