In [1]:
using Flux
using Flux: Data.DataLoader
using Flux: onehotbatch, onecold, crossentropy, throttle
using Flux: @epochs
using Statistics
using MLDatasets: SVHN2
using Serialization: serialize, deserialize
using PyCall
using ImageView
using Images
using Flux
using Flux: onehotbatch, onecold, crossentropy, throttle
using Base.Iterators: repeated, partition
using Printf, BSON, LinearAlgebra
using Metalhead: trainimgs

In [2]:
x_train, y_train = SVHN2.traindata()
x_test, y_test = SVHN2.testdata();

In [19]:
epochs = 40;
batch_size = 256;

In [5]:
train_imgs = [Float32.(x_train[:, :, :, i]) for i in 1:size(x_train)[4]]
train_labels = onehotbatch(reshape(y_train, (1, size(y_train)[1])), 1:10)
train_X = [(cat(train_imgs[i]..., dims = 4), train_labels[:, i]) for i in partition(1:size(x_train)[4], batch_size)];

In [6]:
test_imgs = [Float32.(x_test[:, :, :, i]) for i in 1:5000]
test_labels = onehotbatch(y_test[1:5000], 1:10)
test_X = cat(test_imgs..., dims = 4)
test_Y = [test_labels[:, i] for i in 1:5000];

In [11]:
model = Chain(
    Conv((3, 3), 3=>8, pad=(1, 1), stride=(1, 1), relu),
    BatchNorm(8),
    Conv((3, 3), 8=>16, pad=(2, 2), stride=(1, 1), relu),
    BatchNorm(16),
    Conv((3, 3), 16=>32, pad=(2, 2), stride=(1, 1), relu),
    BatchNorm(32),
    Conv((3, 3), 32=>32, pad=(2, 2), stride=(1, 1), relu),
    BatchNorm(32),
    MeanPool((4, 4)),
    x -> reshape(x, :, size(x, 4)),
    Dense(2592, 10),
    softmax
);

In [9]:
loss(x, y) = crossentropy(model(x), y)
accuracy(x, y) = mean(onecold(model(x), 1:10) .== onecold(y, 1:10))
opt = ADAM(0.01)

ADAM(0.01, (0.9, 0.999), IdDict{Any,Any}())

In [12]:
model(train_X[1][1])

10×256 Array{Float32,2}:
 0.0988194  0.0981398  0.0888327  …  0.096061   0.0972657  0.0962161
 0.101817   0.102203   0.109146      0.103691   0.103138   0.104832
 0.0992016  0.0989206  0.0951156     0.102918   0.0998823  0.10047
 0.100163   0.100727   0.101675      0.100443   0.0999007  0.100268
 0.0967751  0.0971531  0.0921852     0.0952351  0.0948097  0.0930415
 0.106066   0.105549   0.115818   …  0.105615   0.106276   0.106259
 0.0993433  0.0977044  0.100844      0.0997558  0.0987005  0.099664
 0.0936222  0.0935635  0.0864547     0.0926201  0.0938842  0.0942821
 0.0975791  0.0996181  0.0982734     0.0957861  0.0995257  0.100255
 0.106614   0.106421   0.111656      0.107874   0.106617   0.104711

In [13]:
accuracy(test_X, test_labels)

0.0764

In [20]:
@info("Beginning training loop...")
best_acc = 0
last_improvement = 0
for epoch = 0:epochs
    global best_acc, last_improvement
    Flux.train!(loss, params(model), train_X, opt)
    acc = accuracy(test_X, test_labels)
    @info(@sprintf("[%d]: Test accuracy: %.4f", epoch, acc))
    if acc >= best_acc
        @info(" -> New best accuracy! Saving model out to SVHN_conv.bson")
        BSON.@save "SVHN_conv.bson" model epoch acc
        best_acc = acc
        last_improvement = epoch
    end
    if epoch - last_improvement >= 5 && opt.eta > 1e-6
        opt.eta /= 10.0
        @warn(" -> Haven't improved in a while, dropping learning rate to $(opt.eta)!")
        last_improvement = epoch
    end
    if epoch - last_improvement >= 10
        @warn(" -> We're calling this converged.")
        break
    end
end

┌ Info: Beginning training loop...
└ @ Main In[20]:1
┌ Info: [10]: Test accuracy: 0.8656
└ @ Main In[20]:8
┌ Info: [11]: Test accuracy: 0.8564
└ @ Main In[20]:8
┌ Info: [12]: Test accuracy: 0.8532
└ @ Main In[20]:8
┌ Info: [13]: Test accuracy: 0.8598
└ @ Main In[20]:8
┌ Info: [14]: Test accuracy: 0.8634
└ @ Main In[20]:8
└ @ Main In[20]:17
┌ Info: [15]: Test accuracy: 0.8828
└ @ Main In[20]:8
┌ Info:  -> New best accuracy! Saving model out to SVHN_conv.bson
└ @ Main In[20]:10
┌ Info: [16]: Test accuracy: 0.8834
└ @ Main In[20]:8
┌ Info:  -> New best accuracy! Saving model out to SVHN_conv.bson
└ @ Main In[20]:10
┌ Info: [17]: Test accuracy: 0.8838
└ @ Main In[20]:8
┌ Info:  -> New best accuracy! Saving model out to SVHN_conv.bson
└ @ Main In[20]:10
┌ Info: [18]: Test accuracy: 0.8836
└ @ Main In[20]:8
┌ Info: [19]: Test accuracy: 0.8818
└ @ Main In[20]:8
┌ Info: [20]: Test accuracy: 0.8810
└ @ Main In[20]:8
┌ Info: [21]: Test accuracy: 0.8804
└ @ Main In[20]:8
┌ Info: [22]: Test accura