# Convolutional neural network

In [1]:
using Flux, Flux.Data.MNIST
using Flux: @epochs, onehotbatch, onecold, throttle, logitcrossentropy
using Base.Iterators: repeated, partition
using Statistics: mean
using CuArrays

## Load data and preprocessing

Classify MNIST digits with a convolutional network

In [2]:
batch_size = 128

128

In [3]:
function minibatch(X, Y, idxs)
    X_batch = Array{Float32}(undef, size(X[1])..., 1, length(idxs))
    for i in 1:length(idxs)
        X_batch[:, :, :, i] = Float32.(X[idxs[i]])
    end
    Y_batch = onehotbatch(Y[idxs], 0:9)
    return (X_batch, Y_batch)
end

minibatch (generic function with 1 method)

In [4]:
train_labels = MNIST.labels()
train_imgs = MNIST.images()
mb_idxs = partition(1:length(train_imgs), batch_size)
train = [minibatch(train_imgs, train_labels, i) for i in mb_idxs];

In [5]:
test_imgs = MNIST.images(:test)
test_labels = MNIST.labels(:test)
test = minibatch(test_imgs, test_labels, 1:length(test_imgs));

## Model

In [6]:
model = Chain(
    Conv((3, 3), 1=>16, pad=(2,2), relu),  # (1, 1, 1) -> (3, 3, 16)
    MaxPool((2, 2)),                       # (3, 3, 16) -> (1, 1, 16)
    Conv((3, 3), 16=>32, pad=(2,2), relu), # (1, 1, 16) -> (3, 3, 32)
    MaxPool((2, 2)),                       # (3, 3, 32) -> (1, 1, 32)
    Conv((3, 3), 32=>32, pad=(2,2), relu), # (1, 1, 32) -> (3, 3, 32)
    MaxPool((2, 2)),                       # (3, 3, 32) -> (1, 1, 32)
    flatten,                               # (1, 1, 32) -> 32
    Dense(800, 10), softmax)

Chain(Conv((3, 3), 1=>16, relu), MaxPool((2, 2), pad = (0, 0, 0, 0), stride = (2, 2)), Conv((3, 3), 16=>32, relu), MaxPool((2, 2), pad = (0, 0, 0, 0), stride = (2, 2)), Conv((3, 3), 32=>32, relu), MaxPool((2, 2), pad = (0, 0, 0, 0), stride = (2, 2)), flatten, Dense(800, 10), softmax)

In [7]:
model(train[1][1])

10×128 Array{Float32,2}:
 0.117802   0.113581   0.106564   …  0.111015   0.113462   0.114249
 0.11308    0.106354   0.110325      0.113511   0.112378   0.108182
 0.0971398  0.108006   0.103689      0.101266   0.102246   0.0987219
 0.0823925  0.0793182  0.0811307     0.0868874  0.0777472  0.0882899
 0.111112   0.119631   0.111308      0.106583   0.119035   0.104185
 0.0824723  0.0779214  0.0888064  …  0.0865868  0.085736   0.0883279
 0.0906533  0.0859703  0.0882584     0.0889074  0.087638   0.0878237
 0.109116   0.114446   0.103246      0.107745   0.106808   0.105774
 0.100434   0.0994196  0.105651      0.0994514  0.0932377  0.0983425
 0.0957983  0.095353   0.101022      0.0980463  0.101712   0.106103

In [8]:
train = gpu.(train)
test = gpu.(test)
model = gpu(model)

Chain(Conv((3, 3), 1=>16, relu), MaxPool((2, 2), pad = (0, 0, 0, 0), stride = (2, 2)), Conv((3, 3), 16=>32, relu), MaxPool((2, 2), pad = (0, 0, 0, 0), stride = (2, 2)), Conv((3, 3), 32=>32, relu), MaxPool((2, 2), pad = (0, 0, 0, 0), stride = (2, 2)), flatten, Dense(800, 10), softmax)

```julia
Conv(size, input=>output, activation)
```

* size: size of filter
* input: input channel size
* output: output channel size

```julia
Conv(k, ch, σ = identity;
     init = initn,
     stride = map(_->1,k),
     pad = map(_->0,k),
     dilation = map(_->1,k))

Conv((2, 2), 1=>16, relu; init=initn, stride=(1, 1), pad=(0, 0), dilation=(1, 1))
```

```julia
MaxPool(k; pad, stride)
```

* k: size of filter

#### 拉直

```julia
flatten
```

## Loss function

In [9]:
loss(x, y) = logitcrossentropy(model(x), y)

loss (generic function with 1 method)

In [10]:
accuracy(x, y) = mean(onecold(model(x)) .== onecold(y))

accuracy (generic function with 1 method)

## Optimizer

In [11]:
evalcb() = @show(accuracy(test[1], test[2]))
opt = ADAM(0.002)

ADAM(0.002, (0.9, 0.999), IdDict{Any,Any}())

In [12]:
loss(train[1][1], train[1][2])

2.3012743f0

## Training

In [13]:
@epochs 10 Flux.train!(loss, params(model), train, opt, cb=throttle(evalcb, 10))

┌ Info: Epoch 1
└ @ Main /home/yuehhua/.julia/packages/Flux/Fj3bt/src/optimise/train.jl:121
└ @ GPUArrays /home/yuehhua/.julia/packages/GPUArrays/JqOUg/src/host/indexing.jl:43


accuracy(test[1], test[2]) = 0.2056


┌ Info: Epoch 2
└ @ Main /home/yuehhua/.julia/packages/Flux/Fj3bt/src/optimise/train.jl:121


accuracy(test[1], test[2]) = 0.878


┌ Info: Epoch 3
└ @ Main /home/yuehhua/.julia/packages/Flux/Fj3bt/src/optimise/train.jl:121


accuracy(test[1], test[2]) = 0.8814


┌ Info: Epoch 4
└ @ Main /home/yuehhua/.julia/packages/Flux/Fj3bt/src/optimise/train.jl:121


accuracy(test[1], test[2]) = 0.8875


┌ Info: Epoch 5
└ @ Main /home/yuehhua/.julia/packages/Flux/Fj3bt/src/optimise/train.jl:121


accuracy(test[1], test[2]) = 0.8905


┌ Info: Epoch 6
└ @ Main /home/yuehhua/.julia/packages/Flux/Fj3bt/src/optimise/train.jl:121


accuracy(test[1], test[2]) = 0.8909


┌ Info: Epoch 7
└ @ Main /home/yuehhua/.julia/packages/Flux/Fj3bt/src/optimise/train.jl:121


accuracy(test[1], test[2]) = 0.8911


┌ Info: Epoch 8
└ @ Main /home/yuehhua/.julia/packages/Flux/Fj3bt/src/optimise/train.jl:121


accuracy(test[1], test[2]) = 0.8934


┌ Info: Epoch 9
└ @ Main /home/yuehhua/.julia/packages/Flux/Fj3bt/src/optimise/train.jl:121


accuracy(test[1], test[2]) = 0.8914


┌ Info: Epoch 10
└ @ Main /home/yuehhua/.julia/packages/Flux/Fj3bt/src/optimise/train.jl:121


accuracy(test[1], test[2]) = 0.9852


In [14]:
accuracy(test[1], test[2])

0.9864