In [1]:
using MLDatasets, Plots, LinearAlgebra, ProgressMeter, Statistics

In [2]:
using Flux
using Flux: binarycrossentropy
using Flux: params #Used for automatic differenation (but in Project replace auto-diff with explicit gradient)
using Statistics, Random, StatsBase
using Flux: onehotbatch, crossentropy, update!
Random.seed!(0);

In [3]:
train_x, train_y = MNIST(split=:train)[:];
test_x,  test_y  = MNIST(split=:test)[:];

In [4]:
X_total = hcat(ones(60000), hcat([float.(vec(train_x[:,:,i])) for i in 1:60000]...)');
size(X_total)

(60000, 785)

In [5]:
X_per_digit = [X_total[train_y .== i, :] for i in 0:9]
size.(X_per_digit)

10-element Vector{Tuple{Int64, Int64}}:
 (5923, 785)
 (6742, 785)
 (5958, 785)
 (6131, 785)
 (5842, 785)
 (5421, 785)
 (5918, 785)
 (6265, 785)
 (5851, 785)
 (5949, 785)

# (1) Linear One vs. Rest

In [6]:
pinv_total = pinv(X_total)

"""
Make a classifier for the `digit_pos`.
"""
function make_one_vs_one_pred(digit_pos)
    Y_one_vs_rest = 2(train_y .== digit_pos) .- 1
    
    #beta is the set of weights for the specific digit...
    β_one_vs_rest = pinv_total*Y_one_vs_rest
    pred_one_vs_rest(img) = vcat(1,vec(img))'β_one_vs_rest
    return pred_one_vs_rest
end

preds_one_vs_rest = [make_one_vs_one_pred(i) for i in 0:9]

#Here is the "one vs. rest"
predict_one_vs_rest(img) = argmax([preds_one_vs_rest[i+1](img) for i in 0:9]) - 1

accuracy_one_vs_rest = mean([predict_one_vs_rest(test_x[:,:,i]) == test_y[i] for i in 1:10000]) 

@show accuracy_one_vs_rest;

function conf_matrix(pred_fun) 
    predictions = [predict_one_vs_rest(test_x[:,:,i]) for i in 1:10000]
    confusionMatrix = [sum((predictions .== i) .& (test_y .== j)) for i in 0:9, j in 0:9]
end
conf_matrix(predict_one_vs_rest)

accuracy_one_vs_rest = 0.8603


10×10 Matrix{Int64}:
 944     0   18    4    0   23   18    5   14   15
   0  1107   54   17   22   18   10   40   46   11
   1     2  813   23    6    3    9   16   11    2
   2     2   26  880    1   72    0    6   30   17
   2     3   15    5  881   24   22   26   27   80
   7     1    0   17    5  659   17    0   40    1
  14     5   42    9   10   23  875    1   15    1
   2     1   22   21    2   14    0  884   12   77
   7    14   37   22   11   39    7    0  759    4
   1     0    5   12   44   17    0   50   20  801

# (2) Linear One vs. One

In [7]:
function make_one_vs_one_pred(digit_pos, digit_neg)
    X_one_vs_one = vcat(X_per_digit[digit_pos+1],X_per_digit[digit_neg+1])
    Y_one_vs_one = vcat(ones(size(X_per_digit[digit_pos+1])[1]), -ones(size(X_per_digit[digit_neg+1])[1]))
    β_one_vs_one = pinv(X_one_vs_one)*Y_one_vs_one
    pred_one_vs_one(img) = vcat(1,vec(img))'β_one_vs_one
    return pred_one_vs_one
end

preds_one_vs_one = Dict()
@showprogress for i in 0:9
    @showprogress for j in 0:9
        i == j && continue
        preds_one_vs_one[(i,j)] = make_one_vs_one_pred(i,j)
    end
end

predict_one_vs_one_sign(img) = argmax([sum([sign(preds_one_vs_one[(i,j)](img)) for j in setdiff(0:9,i)]) for i in 0:9])-1

accuracy_one_vs_one_sign = mean([predict_one_vs_one_sign(test_x[:,:,i]) == test_y[i] for i in 1:10000]) 

@show accuracy_one_vs_one_sign;

[32mProgress: 100%|█████████████████████████████████████████| Time: 0:00:08[39m
[32mProgress: 100%|█████████████████████████████████████████| Time: 0:00:08[39m
[32mProgress: 100%|█████████████████████████████████████████| Time: 0:00:07[39m
[32mProgress: 100%|█████████████████████████████████████████| Time: 0:00:07[39m
[32mProgress: 100%|█████████████████████████████████████████| Time: 0:00:07[39m
[32mProgress: 100%|█████████████████████████████████████████| Time: 0:00:08[39m
[32mProgress: 100%|█████████████████████████████████████████| Time: 0:00:09[39m
[32mProgress: 100%|█████████████████████████████████████████| Time: 0:00:09[39m
[32mProgress: 100%|█████████████████████████████████████████| Time: 0:00:09[39m
[32mProgress: 100%|█████████████████████████████████████████| Time: 0:00:09[39m
[32mProgress: 100%|█████████████████████████████████████████| Time: 0:01:26[39m


accuracy_one_vs_one_sign = 0.9297


In [8]:
conf_matrix(predict_one_vs_one_sign)

10×10 Matrix{Int64}:
 944     0   18    4    0   23   18    5   14   15
   0  1107   54   17   22   18   10   40   46   11
   1     2  813   23    6    3    9   16   11    2
   2     2   26  880    1   72    0    6   30   17
   2     3   15    5  881   24   22   26   27   80
   7     1    0   17    5  659   17    0   40    1
  14     5   42    9   10   23  875    1   15    1
   2     1   22   21    2   14    0  884   12   77
   7    14   37   22   11   39    7    0  759    4
   1     0    5   12   44   17    0   50   20  801

# (3) Logistic regression one vs. rest

In [9]:
train_y_digit_one_vs_rest = [train_y .== k for k in 0:9];

In [23]:
sig(x) = 1/(1+float(MathConstants.e)^-x)
logistic_predict(img_vec, w) = sig.(w'*img_vec);

function train_logistic(data_x, data_y, train_number;
        num_epochs = 200, 
        mini_batch_size = 2000, 
        η = 0.02)
    println("starting $train_number")
    
    w = randn(28*28+1)
    loss(x, y) = binarycrossentropy(logistic_predict(x, w), y);
    loss_value = 0.0
    opt = ADAM(η)
    n = size(data_x)[1]
    for epoch_num in 1:num_epochs
        for batch in Iterators.partition(1:n, mini_batch_size)
            gs = gradient(()->loss( data_x'[:,batch], 
                                    data_y[batch]'), 
                                    params(w))
            for p in (w,)
                update!(opt, p, gs[p])
            end

        end
        print(".")
    end
    println()
    return w
end

train_logistic (generic function with 1 method)

In [24]:
logistic_one_rest_models = [train_logistic(X_total, train_y_digit_one_vs_rest[k+1], k) for k in 0:9];

starting 0
........................................................................................................................................................................................................
starting 1
........................................................................................................................................................................................................
starting 2
........................................................................................................................................................................................................
starting 3
........................................................................................................................................................................................................
starting 4
.............................................................................................................................................

In [25]:
function make_one_vs_one_pred_logistic(digit_pos)
    Y_one_vs_rest = 2(train_y .== digit_pos) .- 1
    w = logistic_one_rest_models[digit_pos]
    predict_f(img) = logistic_predict(vcat(1,vec(img)), w)
    return predict_f
end

logistic_pred = [make_one_vs_one_pred_logistic(k) for k in 1:10];
pred_img_log_one_rest(img) = argmax([logistic_pred[k](img) for k in 1:10])-1

accuracy_one_vs_rest_log = mean([pred_img_log_one_rest(test_x[:,:,i]) == test_y[i] for i in 1:10000]) 

@show accuracy_one_vs_rest_log;

conf_matrix(pred_img_log_one_rest)

accuracy_one_vs_rest_log = 0.9174


10×10 Matrix{Int64}:
 944     0   18    4    0   23   18    5   14   15
   0  1107   54   17   22   18   10   40   46   11
   1     2  813   23    6    3    9   16   11    2
   2     2   26  880    1   72    0    6   30   17
   2     3   15    5  881   24   22   26   27   80
   7     1    0   17    5  659   17    0   40    1
  14     5   42    9   10   23  875    1   15    1
   2     1   22   21    2   14    0  884   12   77
   7    14   37   22   11   39    7    0  759    4
   1     0    5   12   44   17    0   50   20  801

# (4) Logistic regression one vs. one

In [26]:
function make_one_vs_one_pred_x_y(digit_pos, digit_neg)
    X_one_vs_one = vcat(X_per_digit[digit_pos+1],X_per_digit[digit_neg+1])
    Y_one_vs_one = vcat(ones(size(X_per_digit[digit_pos+1])[1]), zeros(size(X_per_digit[digit_neg+1])[1]))
    return X_one_vs_one, Y_one_vs_one
end

preds_one_vs_one_log = Dict()
for i in 0:9
    for j in 0:9
        i == j && continue
        x, y = make_one_vs_one_pred_x_y(i,j)
        preds_one_vs_one_log[(i,j)] = train_logistic(x, y, (i,j))
    end
end

starting (0, 1)
........................................................................................................................................................................................................
starting (0, 2)
........................................................................................................................................................................................................
starting (0, 3)
........................................................................................................................................................................................................
starting (0, 4)
........................................................................................................................................................................................................
starting (0, 5)
....................................................................................................................

........................................................................................................................................................................................................
starting (4, 2)
........................................................................................................................................................................................................
starting (4, 3)
........................................................................................................................................................................................................
starting (4, 5)
........................................................................................................................................................................................................
starting (4, 6)
....................................................................................................................................

........................................................................................................................................................................................................
starting (8, 3)
........................................................................................................................................................................................................
starting (8, 4)
........................................................................................................................................................................................................
starting (8, 5)
........................................................................................................................................................................................................
starting (8, 6)
....................................................................................................................................

In [27]:
logistic_predict_one_vs_one(i,j,img) = logistic_predict(vcat(1,vec(img)), 
                                                preds_one_vs_one_log[(i,j)]) - 0.5

predict_one_vs_one_sign_log(img) = argmax([sum([sign(logistic_predict_one_vs_one(i,j,img)) for j in setdiff(0:9,i)]) for i in 0:9])-1

accuracy_one_vs_rest_log = mean([predict_one_vs_one_sign_log(test_x[:,:,i]) == test_y[i] for i in 1:10000])

@show accuracy_one_vs_rest_log
conf_matrix(predict_one_vs_one_sign_log)

accuracy_one_vs_rest_log = 0.9321


10×10 Matrix{Int64}:
 944     0   18    4    0   23   18    5   14   15
   0  1107   54   17   22   18   10   40   46   11
   1     2  813   23    6    3    9   16   11    2
   2     2   26  880    1   72    0    6   30   17
   2     3   15    5  881   24   22   26   27   80
   7     1    0   17    5  659   17    0   40    1
  14     5   42    9   10   23  875    1   15    1
   2     1   22   21    2   14    0  884   12   77
   7    14   37   22   11   39    7    0  759    4
   1     0    5   12   44   17    0   50   20  801

# (5) Multiclass classifier

In [15]:
n_test = length(test_y);
n_train = length(train_y);
X_test = vcat([vec(test_x[:,:,k])' for k in 1:n_test]...);

In [16]:
logistic_softmax_predict(img_vec, W) = softmax(W*img_vec)

function train_softmax_logistic(;mini_batch_size = 2000, num_epochs = 200)
    
    #Initilize parameters
    W = randn(10,28*28+1)

    opt = ADAM(0.02)
    loss(x, y) = crossentropy(logistic_softmax_predict(x, W), onehotbatch(y,0:9))

    loss_value = 0.0
    epoch_num = 0

    #Training loop
    for epoch_num in 1:num_epochs
        prev_loss_value = loss_value
        
        #Loop over mini-batches in epoch
        start_time = time_ns()
        for batch in Iterators.partition(1:n_train, mini_batch_size)
            gs = gradient(()->loss(X_total'[:,batch], train_y[batch]), params(W))
            for p in (W,)
                update!(opt, p, gs[p])
            end
        end
        end_time = time_ns()

        #record/display progress
        epoch_num += 1
        loss_value = loss(X_total', train_y)
        println("Epoch = $epoch_num ($(round((end_time-start_time)/1e9,digits=2)) sec) Loss = $loss_value")        
    end
    return W
end

train_softmax_logistic (generic function with 1 method)

In [17]:
W = train_softmax_logistic();

Epoch = 2 (8.7 sec) Loss = 1.7239047373639682
Epoch = 3 (5.29 sec) Loss = 1.0180409007626112
Epoch = 4 (5.58 sec) Loss = 0.7939230320557332
Epoch = 5 (5.6 sec) Loss = 0.6725705514246488
Epoch = 6 (5.57 sec) Loss = 0.5947298538518564
Epoch = 7 (5.59 sec) Loss = 0.5395838202235905
Epoch = 8 (5.7 sec) Loss = 0.49791725265881154
Epoch = 9 (6.05 sec) Loss = 0.4653693151230564
Epoch = 10 (6.51 sec) Loss = 0.43924988192812864
Epoch = 11 (5.84 sec) Loss = 0.41796796020110644
Epoch = 12 (5.87 sec) Loss = 0.40029340216611226
Epoch = 13 (6.39 sec) Loss = 0.38533789518321715
Epoch = 14 (5.7 sec) Loss = 0.372478676681453
Epoch = 15 (5.63 sec) Loss = 0.3612562430424544
Epoch = 16 (5.27 sec) Loss = 0.35134139457241437
Epoch = 17 (5.17 sec) Loss = 0.3425086780990951
Epoch = 18 (5.21 sec) Loss = 0.3346090028665252
Epoch = 19 (5.27 sec) Loss = 0.3275363932741277
Epoch = 20 (5.56 sec) Loss = 0.32119988894055646
Epoch = 21 (5.57 sec) Loss = 0.31551023230648045
Epoch = 22 (5.56 sec) Loss = 0.31037967460111

Epoch = 170 (6.81 sec) Loss = 0.21500919151148568
Epoch = 171 (6.83 sec) Loss = 0.21496160173870624
Epoch = 172 (6.73 sec) Loss = 0.21491444693197687
Epoch = 173 (6.72 sec) Loss = 0.2148677244730512
Epoch = 174 (8.18 sec) Loss = 0.2148214319071547
Epoch = 175 (7.51 sec) Loss = 0.21477556692394303
Epoch = 176 (7.18 sec) Loss = 0.21473012733951616
Epoch = 177 (7.31 sec) Loss = 0.21468511107944505
Epoch = 178 (7.13 sec) Loss = 0.21464051616277194
Epoch = 179 (6.63 sec) Loss = 0.2145963406869443
Epoch = 180 (7.46 sec) Loss = 0.214552582813649
Epoch = 181 (6.68 sec) Loss = 0.2145092407555109
Epoch = 182 (7.87 sec) Loss = 0.21446631276362743
Epoch = 183 (7.31 sec) Loss = 0.21442379711590798
Epoch = 184 (7.93 sec) Loss = 0.21438169210619284
Epoch = 185 (7.57 sec) Loss = 0.21433999603412526
Epoch = 186 (7.33 sec) Loss = 0.21429870719575267
Epoch = 187 (7.36 sec) Loss = 0.21425782387483464
Epoch = 188 (6.99 sec) Loss = 0.21421734433483514
Epoch = 189 (6.65 sec) Loss = 0.21417726681157873
Epoch 

In [19]:
predict_softmax(img) = argmax(logistic_softmax_predict(vcat(1,vec(img)),W)) - 1
accuracy_softmax = mean([predict_softmax(test_x[:,:,i]) == test_y[i] for i in 1:10000])

@show accuracy_softmax
conf_matrix(predict_softmax)

accuracy_softmax = 0.9221


10×10 Matrix{Int64}:
 944     0   18    4    0   23   18    5   14   15
   0  1107   54   17   22   18   10   40   46   11
   1     2  813   23    6    3    9   16   11    2
   2     2   26  880    1   72    0    6   30   17
   2     3   15    5  881   24   22   26   27   80
   7     1    0   17    5  659   17    0   40    1
  14     5   42    9   10   23  875    1   15    1
   2     1   22   21    2   14    0  884   12   77
   7    14   37   22   11   39    7    0  759    4
   1     0    5   12   44   17    0   50   20  801