In [135]:
using MLDatasets, Printf
using SimpleChains, Flux

In [100]:
num_image_classes = 10
learning_rate = 3e-4
num_epochs = 10

10

In [101]:
function get_data(split)
    x, y = MLDatasets.MNIST(split)[:]
    (reshape(x, 28 * 28 * 1, :), UInt32.(y .+ 1))
end

function display_loss(accuracy, loss)
    @printf("    training accuracy %.2f, loss %.4f\n", 100 * accuracy, loss)
end

display_loss (generic function with 1 method)

In [102]:
xtrain, ytrain = get_data(:train);
xtest, ytest = get_data(:test);

In [139]:
σ(x) = max.(x,0)

σ (generic function with 1 method)

In [146]:
function sc_model() 
    return SimpleChain(
                static(28 * 28),
                TurboDense(σ, 32),
                TurboDense(identity, 10))
end

# function sc_model() 
#     return SimpleChain(
#                 static(28 * 28),
#                 TurboDense{true}(σ, 32),
#                 TurboDense{true}(identity, 10))
# end

sc_model (generic function with 1 method)

In [147]:
model = sc_model()
sc_loss = SimpleChains.add_loss(model, LogitCrossEntropyLoss(ytrain))

G = SimpleChains.alloc_threaded_grad(sc_loss)
p = SimpleChains.init_params(model)

opt = SimpleChains.ADAM(learning_rate)
SimpleChains.train_batched!(G, p, sc_loss, xtrain, opt, num_epochs)

train_acc, train_loss = SimpleChains.accuracy_and_loss(sc_loss, xtrain, ytrain, p)
display_loss(train_acc, train_loss)

test_acc, test_loss = SimpleChains.accuracy_and_loss(sc_loss, xtest, ytest, p)
display_loss(test_acc, test_loss)

    training accuracy 94.07, loss 0.2096
    training accuracy 93.85, loss 0.2112


In [148]:
# 784*32 + 32 + 32*10 + 10  # this is with bias
# 785 * 32 + 33*10

In [149]:
size(p)

(25450,)

In [150]:
x = xtrain[:,1];

In [151]:
model(x,p)

10-element StaticArraysCore.SVector{10, Float32} with indices SOneTo(10):
  -1.3173958
  -5.99278
  -2.911115
   4.409305
 -10.200331
   5.847276
  -3.1492841
   1.5898215
  -4.785516
  -3.0534124

In [152]:
W1, W2 = SimpleChains.weights(model, p)
b1, b2 = SimpleChains.biases(model, p)

W2*σ(W1*x + b1) + b2

10-element Vector{Float32}:
  -1.3173966
  -5.99278
  -2.9111135
   4.4093046
 -10.200329
   5.8472757
  -3.149284
   1.5898207
  -4.785516
  -3.0534134

In [155]:
W2*Flux.relu(W1*x + b1) + b2

10-element Vector{Float32}:
  -1.3173966
  -5.99278
  -2.9111135
   4.4093046
 -10.200329
   5.8472757
  -3.149284
   1.5898207
  -4.785516
  -3.0534134

In [156]:
# d = Float32.(p)
# W1 = reshape(d[1:32*784], 32,784)
# b1 = d[32*784+1 : 32*784+32]
# W2 = reshape(d[32*784+32+1:32*784+32+32*10], 10, 32)
# b2 = d[end-9:end]

# W2*σ(W1*x + b1) + b2

10-element Vector{Float32}:
  -1.3173966
  -5.99278
  -2.9111135
   4.4093046
 -10.200329
   5.8472757
  -3.149284
   1.5898207
  -4.785516
  -3.0534134

In [None]:
# begin
#     for i in 1:2
#         println("SimpleChains run #$i")
#         @time "  gradient buffer allocation" G = SimpleChains.alloc_threaded_grad(sc_loss)
#         @time "  parameter initialization" p = SimpleChains.init_params(model)
        
#         # @time "  forward pass" model(xtrain, p)
#         # g = similar(p);
#         # @time "  valgrad!" valgrad!(g, sc_loss, xtrain, p)
    
#         opt = SimpleChains.ADAM(learning_rate)
#         @time "  train $(num_epochs) epochs" SimpleChains.train_batched!(G, p, sc_loss, xtrain, opt, num_epochs)
    
#         @time "  compute training accuracy and loss" train_acc, train_loss = SimpleChains.accuracy_and_loss(sc_loss, xtrain, ytrain, p)
#         display_loss(train_acc, train_loss)
    
#         @time "  compute test accuracy and loss" test_acc, test_loss = SimpleChains.accuracy_and_loss(sc_loss, xtest, ytest, p)
#         display_loss(test_acc, test_loss)
    
#         println("------------------------------------")
#     end
# end

In [161]:
# model.layers[1]

TurboDense static(32) with bias.
Activation layer applying: σ