In [1]:
using Flux
using Flux: Data.DataLoader
using Flux: onehotbatch, onecold, crossentropy, throttle
using Flux: @epochs
using Statistics
using MLDatasets: SVHN2
using Serialization: serialize, deserialize
using PyCall
using ImageView
using Images
using Flux
using Flux: onehotbatch, onecold, crossentropy, throttle
using Base.Iterators: repeated, partition
using Printf, BSON, LinearAlgebra
using Metalhead: trainimgs

In [2]:
x_train, y_train = SVHN2.traindata()
x_test, y_test = SVHN2.testdata();

In [47]:
epochs = 20;
batch_size = 64;

In [48]:
train_imgs = [Float32.(x_train[:, :, :, i]) for i in 1:size(x_train)[4]]
train_labels = onehotbatch(reshape(y_train, (1, size(y_train)[1])), 1:10)
train_X = [(cat(train_imgs[i]..., dims = 4), train_labels[:, i]) for i in partition(1:size(x_train)[4], batch_size)];

In [49]:
test_imgs = [Float32.(x_test[:, :, :, i]) for i in 1:5000]
test_labels = onehotbatch(y_test[1:5000], 1:10)
test_X = cat(test_imgs..., dims = 4)
test_Y = [test_labels[:, i] for i in 1:5000];

In [51]:
model = Chain(
    Conv((3, 3), 3=>16, pad=(1,1), relu),
    MaxPool((2,2)),
    Conv((3, 3), 16=>32, pad=(1,1), relu),
    MaxPool((2,2)),
    Conv((3, 3), 32=>32, pad=(1,1), relu),
    MaxPool((2,2)),
    x -> reshape(x, :, size(x, 4)),
    Dense(512, 10),
    softmax,
);

In [52]:
loss(x, y) = Flux.mse(model(x),y)

accuracy(x, y) = mean(onecold(model(x), 1:10) .== onecold(y, 1:10))

opt = OADAM()

OADAM(0.001, (0.5, 0.9), IdDict{Any,Any}())

In [53]:
model(train_X[1][1])

10×64 Array{Float32,2}:
 0.101072   0.103532   0.107408   …  0.106786   0.113866   0.0994656
 0.110122   0.106725   0.123303      0.112537   0.11084    0.109236
 0.0916596  0.0923067  0.0832604     0.0809823  0.0838644  0.0924811
 0.103014   0.102289   0.100855      0.104458   0.102226   0.103217
 0.106402   0.109186   0.126762      0.115599   0.118602   0.110489
 0.0994335  0.0999715  0.0951587  …  0.0979461  0.0990117  0.100463
 0.102674   0.102719   0.105738      0.10679    0.103628   0.10062
 0.10147    0.102919   0.103368      0.110249   0.101639   0.102591
 0.0911628  0.0901203  0.0745476     0.0824927  0.0855616  0.0899217
 0.0929902  0.0902318  0.0795987     0.0821595  0.0807616  0.0915141

In [54]:
accuracy(test_X, test_labels)

0.1134

In [55]:
@info("Beginning training loop...")
best_acc = 0.0
last_improvement = 0
for epoch = 1:epochs
    global best_acc, last_improvement
    Flux.train!(loss, params(model), train_X, opt)
    acc = accuracy(test_X, test_labels)
    @info(@sprintf("[%d]: Test accuracy: %.4f", epoch, acc))
    if acc > best_acc
        @info(" -> New best accuracy! Saving model out to SVHN_conv.bson")
        BSON.@save "SVHN_conv.bson" model epoch acc
        best_acc = acc
        last_improvement = epoch
    end
    if epoch - last_improvement >= 5 && opt.eta > 1e-6
        opt.eta /= 10.0
        @warn(" -> Haven't improved in a while, dropping learning rate to $(opt.eta)!")
        last_improvement = epoch
    end
    if epoch - last_improvement >= 10
        @warn(" -> We're calling this converged.")
        break
    end
end

┌ Info: Beginning training loop...
└ @ Main In[55]:1
┌ Info: [1]: Test accuracy: 0.7798
└ @ Main In[55]:8
┌ Info:  -> New best accuracy! Saving model out to SVHN_conv.bson
└ @ Main In[55]:10
┌ Info: [2]: Test accuracy: 0.8188
└ @ Main In[55]:8
┌ Info:  -> New best accuracy! Saving model out to SVHN_conv.bson
└ @ Main In[55]:10
┌ Info: [3]: Test accuracy: 0.8332
└ @ Main In[55]:8
┌ Info:  -> New best accuracy! Saving model out to SVHN_conv.bson
└ @ Main In[55]:10
┌ Info: [4]: Test accuracy: 0.8560
└ @ Main In[55]:8
┌ Info:  -> New best accuracy! Saving model out to SVHN_conv.bson
└ @ Main In[55]:10
┌ Info: [5]: Test accuracy: 0.8604
└ @ Main In[55]:8
┌ Info:  -> New best accuracy! Saving model out to SVHN_conv.bson
└ @ Main In[55]:10
┌ Info: [6]: Test accuracy: 0.8686
└ @ Main In[55]:8
┌ Info:  -> New best accuracy! Saving model out to SVHN_conv.bson
└ @ Main In[55]:10
┌ Info: [7]: Test accuracy: 0.8682
└ @ Main In[55]:8
┌ Info: [8]: Test accuracy: 0.8690
└ @ Main In[55]:8
┌ Info:  -> N

In [56]:
BSON.@load "SVHN_conv.bson" model

In [57]:
accuracy(test_X, test_labels)

0.8866