In [1]:
using Knet, Printf
#import Pkg 
#Pkg.add("Printf")

In [2]:
include(Knet.dir("data", "mnist.jl"))
dtrn,dtst = mnistdata(xsize=(784,:)); # dtrn and dtst = [ (x1,y1), (x2,y2), ... ] where xi,yi are minibatches of 100

┌ Info: Loading MNIST...
└ @ Main /kuacc/users/sozcelik19/.julia/packages/Knet/2xiR8/data/mnist.jl:33


In [3]:
summary(dtrn)

"600-element Knet.Data{Tuple{KnetArray{Float32,2},Array{UInt8,1}}}"

In [4]:
mutable struct Hidden
    w # weight
    b # bias
    fun # non-linear activation function
end

mutable struct Linear # softmax layer
    w # weight
    b # bias
end

In [5]:
function Hidden(xsize::Int, ysize::Int, fun=relu)
    w = Param(convert(KnetArray{Float32},randn(ysize, xsize)))
    b = Param(convert(KnetArray{Float32},randn(ysize)))
    return Hidden(w,b,fun)
end

function Linear(xsize::Int, ysize::Int) # softmax layer
    w = Param(convert(KnetArray{Float32},randn(ysize, xsize)))
    b = Param(convert(KnetArray{Float32},randn(ysize)))
    return Linear(w,b)
end

Linear

In [6]:
function (l::Hidden)(x)
    l.fun.(l.w * x .+ l.b)
end

function (l::Linear)(x) # softmax layer
    l.w * x .+ l.b
end

In [7]:
mutable struct Network
    hidden_layer::Hidden
    softmax_layer::Linear
end

In [8]:
function Network(hiddenx, hiddeny, fun, softx, softy) 
    return Network(Hidden(hiddenx, hiddeny, fun), Linear(softx, softy))
end

Network

In [9]:
function (n::Network)(x)
    x = n.hidden_layer(x)
    x = n.softmax_layer(x)
    return x
end

In [10]:
function (n::Network)(x, y)
    return nll(n(x),y)
end

In [11]:
function trainNetwork!(n::Network, x, y)
    J = @diff n(x, y)
    for par in params(n)
        g = grad(J, par)
        update!(value(par), g)
    end
    return value(J)
end

trainNetwork! (generic function with 1 method)

In [12]:
function accuracy(net, data)
    correct = 0
    num = 0
    for (x,y_hat) in data
        y_pred = net(x)
        for i in 1:length(y_hat)
            if argmax(softmax(y_pred[:,i])) == y_hat[i]
                correct += 1
            end
            num += 1
        end
    end
    return correct / num
end

accuracy (generic function with 1 method)

In [13]:
net = Network(784, 500, relu, 500, 10)
for i=1:10
    loss = 0
    for (x,y) in dtrn
        loss += trainNetwork!(net, x, y)
    end
    trnacc = accuracy(net, dtrn)
    tstacc = accuracy(net, dtst)
    @printf("epoch: %d loss: %g trn accuracy: %g tst accuracy: %g\n", i, loss/600, trnacc, tstacc)
end

epoch: 1 loss: 9.96154 trn accuracy: 0.92345 tst accuracy: 0.9187
epoch: 2 loss: 2.63946 trn accuracy: 0.944083 tst accuracy: 0.9309
epoch: 3 loss: 1.64889 trn accuracy: 0.954833 tst accuracy: 0.9364
epoch: 4 loss: 1.12979 trn accuracy: 0.957783 tst accuracy: 0.9394
epoch: 5 loss: 0.819794 trn accuracy: 0.964217 tst accuracy: 0.9422
epoch: 6 loss: 0.612769 trn accuracy: 0.966683 tst accuracy: 0.9422
epoch: 7 loss: 0.462772 trn accuracy: 0.969517 tst accuracy: 0.9423
epoch: 8 loss: 0.362421 trn accuracy: 0.971683 tst accuracy: 0.9429
epoch: 9 loss: 0.276665 trn accuracy: 0.976017 tst accuracy: 0.9446
epoch: 10 loss: 0.219735 trn accuracy: 0.973667 tst accuracy: 0.9426
