# Convolutional neural network

In [1]:
using MLDatasets
using Flux
using Flux: @epochs, onehotbatch, onecold, crossentropy, throttle
using Base.Iterators: repeated, partition
using Statistics: mean
# using CuArrays

┌ Info: Precompiling Flux [587475ba-b771-5e3f-ad9e-33799f191a9c]
└ @ Base loading.jl:1192


## Load data

Classify CIFAR10 with a convolutional network

In [2]:
train_x, train_y = CIFAR10.traindata()
test_x,  test_y  = CIFAR10.testdata()

(FixedPointNumbers.Normed{UInt8,8}[0.62 0.596 … 0.239 0.212; 0.624 0.592 … 0.192 0.22; … ; 0.494 0.49 … 0.114 0.133; 0.455 0.467 … 0.078 0.082]

FixedPointNumbers.Normed{UInt8,8}[0.439 0.439 … 0.455 0.42; 0.435 0.431 … 0.4 0.412; … ; 0.357 0.357 … 0.322 0.329; 0.333 0.345 … 0.251 0.263]

FixedPointNumbers.Normed{UInt8,8}[0.192 0.2 … 0.659 0.627; 0.184 0.157 … 0.58 0.584; … ; 0.141 0.125 … 0.494 0.506; 0.129 0.133 … 0.42 0.431]

FixedPointNumbers.Normed{UInt8,8}[0.922 0.933 … 0.322 0.333; 0.906 0.922 … 0.18 0.243; … ; 0.914 0.925 … 0.725 0.706; 0.91 0.922 … 0.733 0.729]

FixedPointNumbers.Normed{UInt8,8}[0.922 0.933 … 0.376 0.396; 0.906 0.922 … 0.224 0.294; … ; 0.914 0.925 … 0.784 0.765; 0.91 0.922 … 0.792 0.784]

FixedPointNumbers.Normed{UInt8,8}[0.922 0.933 … 0.322 0.325; 0.906 0.922 … 0.141 0.188; … ; 0.914 0.925 … 0.769 0.749; 0.91 0.922 … 0.784 0.78]

FixedPointNumbers.Normed{UInt8,8}[0.62 0.667 … 0.09 0.11; 0.62 0.675 … 0.106 0.118; … ; 0.929 0.965 … 0.016 0.016; 0.933 0.965 … 0.0

In [3]:
train_x = Float64.(train_x / 255)
train_y = onehotbatch(train_y, 0:9)
test_x = Float64.(test_x / 255)
test_y = onehotbatch(test_y, 0:9)

10×10000 Flux.OneHotMatrix{Array{Flux.OneHotVector,1}}:
 false  false  false   true  false  …  false  false  false  false  false
 false  false  false  false  false     false  false  false   true  false
 false  false  false  false  false     false  false  false  false  false
  true  false  false  false  false     false   true  false  false  false
 false  false  false  false  false     false  false  false  false  false
 false  false  false  false  false  …  false  false   true  false  false
 false  false  false  false   true     false  false  false  false  false
 false  false  false  false  false     false  false  false  false   true
 false   true   true  false  false      true  false  false  false  false
 false  false  false  false  false     false  false  false  false  false

## Preprocessing

In [4]:
# Partition into batches of size 1,000
train = [(train_x[:, :, :, i:(i+999)], train_y[:, i:(i+999)]) for i = 1:1000:50_000];
# train = gpu.(train)

In [5]:
# Prepare test set (first 1,000 images)
test_x = test_x[:, :, :, 1:1000]# |> gpu
test_y = test_y[:, 1:1000]# |> gpu

10×1000 Flux.OneHotMatrix{Array{Flux.OneHotVector,1}}:
 false  false  false   true  false  …  false  false  false  false  false
 false  false  false  false  false     false  false   true  false  false
 false  false  false  false  false     false  false  false  false  false
  true  false  false  false  false      true  false  false   true  false
 false  false  false  false  false     false  false  false  false  false
 false  false  false  false  false  …  false  false  false  false  false
 false  false  false  false   true     false  false  false  false  false
 false  false  false  false  false     false  false  false  false  false
 false   true   true  false  false     false   true  false  false   true
 false  false  false  false  false     false  false  false  false  false

## Model

In [6]:
m = Chain(
    Conv((3, 3), 3=>32, relu),  # (32, 32, 3) -> (30, 30, 32)
    Conv((3, 3), 32=>64, relu),  # (30, 30, 32) -> (28, 28, 64)
    MaxPool((2, 2); stride=2),  # (28, 28, 64) -> (14, 14, 64)
    BatchNorm(64),
#     Dropout(0.25),
    Conv((3, 3), 64=>128, relu),  # (14, 14, 64) -> (12, 12, 128)
    MaxPool((2, 2); stride=2),  # (12, 12, 128) -> (6, 6, 128)
    Conv((2, 2), 128=>128, relu),  # (6, 6, 128) -> (5, 5, 128)
    MaxPool((2, 2); stride=2),  # (5, 5, 128) -> (2, 2, 128)
    BatchNorm(128),
#     Dropout(0.25),
    x -> reshape(x, :, size(x, 4)),  # (2, 2, 128) -> 512
    Dense(512, 1500, relu),
    BatchNorm(1500),
#     Dropout(0.5),
    Dense(1500, 10), softmax)# |> gpu

Chain(Conv((3, 3), 3=>32, NNlib.relu), Conv((3, 3), 32=>64, NNlib.relu), MaxPool((2, 2), pad = (0, 0), stride = (2, 2)), BatchNorm(64), Conv((3, 3), 64=>128, NNlib.relu), MaxPool((2, 2), pad = (0, 0), stride = (2, 2)), Conv((2, 2), 128=>128, NNlib.relu), MaxPool((2, 2), pad = (0, 0), stride = (2, 2)), BatchNorm(128), getfield(Main, Symbol("##5#6"))(), Dense(512, 1500, NNlib.relu), BatchNorm(1500), Dense(1500, 10), NNlib.softmax)

In [7]:
m(train[1][1])  # test if model works

Tracked 10×1000 Array{Float64,2}:
 0.0568679  0.0103461   0.00703472  …  0.0392362   0.130936     0.244954  
 0.0270128  0.0046239   0.022868       0.111868    0.00116973   0.028365  
 0.0213119  0.0113251   0.100495       0.00824322  0.00134271   0.174683  
 0.0612366  0.191805    0.565134       0.0136359   0.00539958   0.0704018 
 0.0151398  0.00521066  0.00566327     0.526132    0.000660848  0.0676    
 0.0280726  0.0688699   0.0108432   …  0.0250514   0.569193     0.121266  
 0.0611917  0.179117    0.0073833      0.0530104   0.00380535   0.0566555 
 0.0110749  0.171776    0.098394       0.105363    0.0076263    0.00394231
 0.699222   0.00477428  0.175787       0.110197    0.00107982   0.174497  
 0.0188699  0.352152    0.00639764     0.00726344  0.278786     0.0576363 

## Loss function

In [8]:
loss(x, y) = crossentropy(m(x), y)

loss (generic function with 1 method)

In [9]:
accuracy(x, y) = mean(onecold(m(x)) .== onecold(y))

accuracy (generic function with 1 method)

## Optimizer

In [10]:
evalcb = throttle(() -> @show(accuracy(test_x, test_y)), 10)
opt = ADAM(params(m))

#43 (generic function with 1 method)

## Training

In [11]:
@epochs 10 Flux.train!(loss, train, opt, cb=evalcb)

┌ Info: Epoch 1
└ @ Main /home/pika/.julia/packages/Flux/rcN9D/src/optimise/train.jl:93


accuracy(test_x, test_y) = 0.203
accuracy(test_x, test_y) = 0.238
accuracy(test_x, test_y) = 0.29
accuracy(test_x, test_y) = 0.324
accuracy(test_x, test_y) = 0.369
accuracy(test_x, test_y) = 0.383
accuracy(test_x, test_y) = 0.385
accuracy(test_x, test_y) = 0.393
accuracy(test_x, test_y) = 0.426
accuracy(test_x, test_y) = 0.416
accuracy(test_x, test_y) = 0.446
accuracy(test_x, test_y) = 0.461
accuracy(test_x, test_y) = 0.472
accuracy(test_x, test_y) = 0.479
accuracy(test_x, test_y) = 0.462
accuracy(test_x, test_y) = 0.481
accuracy(test_x, test_y) = 0.479


┌ Info: Epoch 2
└ @ Main /home/pika/.julia/packages/Flux/rcN9D/src/optimise/train.jl:93


accuracy(test_x, test_y) = 0.512
accuracy(test_x, test_y) = 0.494
accuracy(test_x, test_y) = 0.519
accuracy(test_x, test_y) = 0.482
accuracy(test_x, test_y) = 0.504
accuracy(test_x, test_y) = 0.523
accuracy(test_x, test_y) = 0.514
accuracy(test_x, test_y) = 0.501
accuracy(test_x, test_y) = 0.53
accuracy(test_x, test_y) = 0.534
accuracy(test_x, test_y) = 0.496
accuracy(test_x, test_y) = 0.53
accuracy(test_x, test_y) = 0.528
accuracy(test_x, test_y) = 0.526
accuracy(test_x, test_y) = 0.55
accuracy(test_x, test_y) = 0.542
accuracy(test_x, test_y) = 0.55


┌ Info: Epoch 3
└ @ Main /home/pika/.julia/packages/Flux/rcN9D/src/optimise/train.jl:93


accuracy(test_x, test_y) = 0.566
accuracy(test_x, test_y) = 0.559
accuracy(test_x, test_y) = 0.563
accuracy(test_x, test_y) = 0.544
accuracy(test_x, test_y) = 0.534
accuracy(test_x, test_y) = 0.523
accuracy(test_x, test_y) = 0.548
accuracy(test_x, test_y) = 0.544
accuracy(test_x, test_y) = 0.567
accuracy(test_x, test_y) = 0.568
accuracy(test_x, test_y) = 0.568
accuracy(test_x, test_y) = 0.595
accuracy(test_x, test_y) = 0.551
accuracy(test_x, test_y) = 0.566
accuracy(test_x, test_y) = 0.59
accuracy(test_x, test_y) = 0.573
accuracy(test_x, test_y) = 0.583


┌ Info: Epoch 4
└ @ Main /home/pika/.julia/packages/Flux/rcN9D/src/optimise/train.jl:93


accuracy(test_x, test_y) = 0.588
accuracy(test_x, test_y) = 0.588
accuracy(test_x, test_y) = 0.603
accuracy(test_x, test_y) = 0.563
accuracy(test_x, test_y) = 0.568
accuracy(test_x, test_y) = 0.569
accuracy(test_x, test_y) = 0.594
accuracy(test_x, test_y) = 0.597
accuracy(test_x, test_y) = 0.601
accuracy(test_x, test_y) = 0.586
accuracy(test_x, test_y) = 0.576
accuracy(test_x, test_y) = 0.594
accuracy(test_x, test_y) = 0.595
accuracy(test_x, test_y) = 0.605
accuracy(test_x, test_y) = 0.582
accuracy(test_x, test_y) = 0.59
accuracy(test_x, test_y) = 0.58


┌ Info: Epoch 5
└ @ Main /home/pika/.julia/packages/Flux/rcN9D/src/optimise/train.jl:93


accuracy(test_x, test_y) = 0.61
accuracy(test_x, test_y) = 0.606
accuracy(test_x, test_y) = 0.601
accuracy(test_x, test_y) = 0.596
accuracy(test_x, test_y) = 0.597
accuracy(test_x, test_y) = 0.596
accuracy(test_x, test_y) = 0.615
accuracy(test_x, test_y) = 0.609
accuracy(test_x, test_y) = 0.61
accuracy(test_x, test_y) = 0.621
accuracy(test_x, test_y) = 0.589
accuracy(test_x, test_y) = 0.614
accuracy(test_x, test_y) = 0.621
accuracy(test_x, test_y) = 0.614
accuracy(test_x, test_y) = 0.611
accuracy(test_x, test_y) = 0.601
accuracy(test_x, test_y) = 0.61


┌ Info: Epoch 6
└ @ Main /home/pika/.julia/packages/Flux/rcN9D/src/optimise/train.jl:93


accuracy(test_x, test_y) = 0.634
accuracy(test_x, test_y) = 0.628
accuracy(test_x, test_y) = 0.635
accuracy(test_x, test_y) = 0.615
accuracy(test_x, test_y) = 0.61
accuracy(test_x, test_y) = 0.587
accuracy(test_x, test_y) = 0.596
accuracy(test_x, test_y) = 0.599
accuracy(test_x, test_y) = 0.61
accuracy(test_x, test_y) = 0.62
accuracy(test_x, test_y) = 0.627
accuracy(test_x, test_y) = 0.627
accuracy(test_x, test_y) = 0.624
accuracy(test_x, test_y) = 0.626
accuracy(test_x, test_y) = 0.643
accuracy(test_x, test_y) = 0.642
accuracy(test_x, test_y) = 0.65


┌ Info: Epoch 7
└ @ Main /home/pika/.julia/packages/Flux/rcN9D/src/optimise/train.jl:93


accuracy(test_x, test_y) = 0.647
accuracy(test_x, test_y) = 0.635
accuracy(test_x, test_y) = 0.647
accuracy(test_x, test_y) = 0.648
accuracy(test_x, test_y) = 0.63
accuracy(test_x, test_y) = 0.635
accuracy(test_x, test_y) = 0.638
accuracy(test_x, test_y) = 0.635
accuracy(test_x, test_y) = 0.632
accuracy(test_x, test_y) = 0.631
accuracy(test_x, test_y) = 0.623
accuracy(test_x, test_y) = 0.652
accuracy(test_x, test_y) = 0.64
accuracy(test_x, test_y) = 0.654
accuracy(test_x, test_y) = 0.652
accuracy(test_x, test_y) = 0.646
accuracy(test_x, test_y) = 0.668


┌ Info: Epoch 8
└ @ Main /home/pika/.julia/packages/Flux/rcN9D/src/optimise/train.jl:93


accuracy(test_x, test_y) = 0.647
accuracy(test_x, test_y) = 0.658
accuracy(test_x, test_y) = 0.641
accuracy(test_x, test_y) = 0.646
accuracy(test_x, test_y) = 0.652
accuracy(test_x, test_y) = 0.647
accuracy(test_x, test_y) = 0.626
accuracy(test_x, test_y) = 0.64
accuracy(test_x, test_y) = 0.627
accuracy(test_x, test_y) = 0.612
accuracy(test_x, test_y) = 0.626
accuracy(test_x, test_y) = 0.651
accuracy(test_x, test_y) = 0.661
accuracy(test_x, test_y) = 0.668
accuracy(test_x, test_y) = 0.648
accuracy(test_x, test_y) = 0.646
accuracy(test_x, test_y) = 0.653


┌ Info: Epoch 9
└ @ Main /home/pika/.julia/packages/Flux/rcN9D/src/optimise/train.jl:93


accuracy(test_x, test_y) = 0.67
accuracy(test_x, test_y) = 0.651
accuracy(test_x, test_y) = 0.658


InterruptException: InterruptException: