In [39]:
using PaddedViews
using LinearAlgebra
using TimerOutputs
using BenchmarkTools
using Profile
to = TimerOutput();


In [40]:
abstract type Node end
abstract type Operator <: Node end

mutable struct Variable{N} <: Node
    name::String
    output::Array{Float64, N}
    gradient::Array{Float64, N}
    v₁::Array{Float64, N}
    v₂::Array{Float64, N}
    v̂₁::Array{Float64, N}
    v̂₂::Array{Float64, N}
    Variable(N, output; name = "?") = new{N}(name, output, zeros(size(output)), zeros(size(output)), zeros(size(output)), zeros(size(output)), zeros(size(output)))
end

mutable struct NodeOperator{F, N} <: Operator
    name::String
    inputs::Vector{Node}
    output::Array{Float64, N}
    gradient::Array{Float64, N}
    NodeOperator(fun, inputs...; name = "?", shape=(1,1,1)) = new{typeof(fun), length(shape)}(name, [inputs...], zeros(shape), zeros(shape))
end

mutable struct RNNOperator{F, N} <: Operator
    name::String
    h::Array{Float64,2}
    inputs::Vector{Node}
    output::Array{Float64, N}
    gradient::Array{Float64, N}
    RNNOperator(fun, h, inputs...; name = "?", shape=(1,1,1)) = new{typeof(fun), length(shape)}(name,h, [inputs...], zeros(shape), zeros(shape))
end

In [41]:

function visit(node::Node, visited::Set, order::Vector)
    if node ∉ visited
        push!(visited, node)
        push!(order, node)
    end
end

function visit(node::Operator, visited::Set, order::Vector)
    if node ∉ visited
        for input in node.inputs
            visit(input, visited, order)
        end
        push!(visited, node)
        push!(order, node)
    end
end


function create_graph(root::Node)
    visited = Set{Node}()
    order = Vector{Node}()
    visit(root, visited, order)
    return order
end

create_graph (generic function with 1 method)

In [42]:
import Base: show, summary
show(io::IO, x::NodeOperator{F}) where {F} = print(io, "op ", "(", F, ")");
show(io::IO, x::RNNOperator{F}) where {F} = print(io, "op ", "(", F, ")");
show(io::IO, x::Variable) = begin
    print(io, "var ", x.name);
    print(io, "\n ┣━ ^ "); summary(io, x.output)
    print(io, "\n ┗━ ∇ ");  summary(io, x.gradient)
end

show (generic function with 611 methods)

In [43]:
zero_gradient!(node::Node) = fill!(node.gradient, 0)

compute!(node::Variable) = nothing
compute!(node::Operator) = node.output .= forward(node, [input.output for input in node.inputs]...)

function forward!(order::Vector{Node})::Float64
    for node in order
        compute!(node)
        zero_gradient!(node)
    end
    
    return last(order).output[1]
end   

forward! (generic function with 1 method)

In [44]:
update!(node::Node, gradient) = if isempty(node.gradient)
    node.gradient = gradient else node.gradient .+= gradient
end

function backward!(order::Vector; seed = 1.0)
    result = last(order)
    result.gradient .= [seed]
    
    for node in reverse(order)
        backward!(node)
    end
end

backward!(node::Variable) = nothing

function backward!(node::Operator)
    backward(node, [input.output for input in node.inputs]..., node.gradient)
end

backward! (generic function with 3 methods)

In [45]:
struct RNNParams
    W::Variable{2}
    U::Variable{2}
    b::Variable{1}
end

struct DenseParams
    weights::Variable{2}
    bias::Variable{1}
end
struct NetworkParams
    rnn::RNNParams
    dense::DenseParams
end

In [46]:
recurent_layer(x::Node, h, W::Node,U::Node, b::Node, shape) = RNNOperator(recurent_layer,h,name="rnn",shape=shape,x,W,U,b)
@timeit to "recurent forward" forward(node::RNNOperator{typeof(recurent_layer)}, x, W, U, b) = let
   fill!(node.h,0)
    @views x_cur =  x[1:recsize]
    node.h[:,1]  = tanh.(W*node.h[:,1]+U*x_cur+b)   
    for t in range(2,layerNumber)
       @views x_cur =  x[(t-1)*(recsize)+1:(t)*(recsize)]
       node.h[:,t] = tanh.(W*node.h[:,t-1]+U*x_cur+b)   
    end
    return node.h[:,layerNumber]
end

@timeit to "recurent backward" backward(node::RNNOperator{typeof(recurent_layer)}, x, W, U, b, g) = let
   gradienth = zeros(size(node.gradient)[1],layerNumber)
   gradienth[:,layerNumber] = node.gradient
  gradientW =  node.inputs[2].gradient
  gradientU = node.inputs[3].gradient
  gradientB = node.inputs[4].gradient
 t = layerNumber-1
  while t >=1
    gradienth[:,t] = ((node.inputs[2].output'*gradienth[:,t+1])'*diagm(1 .- node.h[:,t+1].^2) )'
    diag = diagm(1 .- node.h[:,t].^2);
    t-=1
 end
 t = layerNumber
 while t>=1
   diag = diagm(1 .- node.h[:,t].^2);
    gradientU += diag*gradienth[:,t]*node.inputs[1].output[(t-1)*(recsize)+1:(t)*(recsize)]'
    gradientB += diag*gradienth[:,t]
    t-=1
 end
 t = layerNumber
 while t>1
   diag = diagm(1 .- node.h[:,t].^2);
   gradientW += diag*gradienth[:,t]*node.h[:,t-1]'
   t-=1
 end
      node.inputs[2].gradient = gradientW 
      node.inputs[3].gradient = gradientU
      node.inputs[4].gradient = gradientB
    return tuple(zeros(size(x)), gradientW, gradientU, gradientB)
end



backward (generic function with 4 methods)

In [47]:
dense_layer(x::Node, w::Node, b::Node, shape) = NodeOperator(dense_layer, name="dense", shape=shape, x, w, b)

@timeit to "dense forward" forward(::NodeOperator{typeof(dense_layer)}, x, w, b) = let
    return w * x + b
end

@timeit to "dense backward" backward(node::NodeOperator{typeof(dense_layer)}, x, w, b, g) = let
    node.inputs[1].gradient = w' * g
    node.inputs[2].gradient = g * x'
    node.inputs[3].gradient = g
end

backward (generic function with 4 methods)

In [48]:

softmax(x::Node, shape) = NodeOperator(softmax, name="softmax", shape=shape, x)

@timeit to "softmax forward" forward(::NodeOperator{typeof(softmax)}, x) = let
    return exp.(x) ./ sum(exp.(x))
end

@timeit to "softmax backward" backward(node::NodeOperator{typeof(softmax)}, x, g) = let
    y = node.output
    J = diagm(y) .- y * y'
    node.inputs[1].gradient = (J' * g)
end

backward (generic function with 4 methods)

In [49]:
cross_entropy_loss(ŷ::Node, y::Node, shape) = NodeOperator(cross_entropy_loss, name="cross_entropy_loss", shape=shape,  ŷ, y)

@timeit to "cross_entropy_loss forward" forward(::NodeOperator{typeof(cross_entropy_loss)}, ŷ, y) = let
    return sum((ŷ-y) .^ 2 ./ 10)
end

@timeit to "cross_entropy_loss backward" backward(node::NodeOperator{typeof(cross_entropy_loss)}, ŷ, y, g) = let
    x = zeros(10)
    node.inputs[1].gradient = (ŷ-y)/5
    node.inputs[2].gradient =[0.0]
end
     

backward (generic function with 4 methods)

In [50]:
function create_network(x::Variable{1}, y::Variable{1}, params::NetworkParams)
    x₁ = recurent_layer(x,zeros(64,layerNumber),params.rnn.W,params.rnn.U,params.rnn.b,(64))
    x₂ = dense_layer(x₁,params.dense.weights,params.dense.bias,(10))
    ŷ = softmax(x₂, (10))
    loss = cross_entropy_loss(ŷ, y, (1))
    return create_graph(loss)
end

create_network (generic function with 1 method)

In [51]:
function he_weights_init(prev, shape...)
    std = sqrt(2.0/prev)
    weights = rand(Float64, shape) .*2 .-1
    return weights .* std
end

he_weights_init (generic function with 1 method)

In [52]:

mutable struct Adam
    α::Float64
    ε::Float64
    m₁::Float64
    m₂::Float64
    k::Int64
    Adam(α=0.001, m₁=0.9, m₂=0.999, ε=1e-8) = new(α, ε, m₁, m₂, 1)
end

In [53]:
function update_weights!(graph, M::Adam)
    for node in graph
        if (typeof(node) == Variable{1} || typeof(node) == Variable{2}) && (node.name !="x"&&node.name!="y")
            update_weights_N!(node, M)
        end
    end
    M.k += 1
end



function update_weights_N!(node::Variable{T}, M::Adam) where T
    g = node.gradient
    v₁ = node.v₁
    v₂ = node.v₂
    v̂₁ = node.v̂₁
    v̂₂ = node.v̂₂
    m₁, m₂, k, α, ε = M.m₁, M.m₂, M.k, M.α, M.ε
    v₁ .= @. m₁ * v₁ + (1.0 - m₁) * g
    v₂ .= @. m₂ * v₂ + (1.0 - m₂) * (g .* g)

    v̂₁ .= v₁ ./ (1.0 - m₁^k)
    v̂₂ .= v₂ ./ (1.0 - m₂^k)

    node.output .-= @. α*v̂₁ / (sqrt(v̂₂) + ε)
    
    nothing
end

update_weights_N! (generic function with 1 method)

In [54]:
function validate(x, y, graph, test_data)::Float64
    correct = 0
    correct_class = zeros(10)
    (x_data,y_data) = loader(test_data)
    length = size(y_data)[2]
    for i in range(1,length)
        x.output = x_data[:,i]
        y.output = y_data[:,i]
        forward!(graph)
        pred = argmax(graph[9].output)
        if 1 == y_data[:,i][pred]
            correct += 1
        end
    end
    
    acc_val = correct/size(y_data)[2]
    
    return acc_val
end

validate (generic function with 1 method)

In [55]:
recsize = 14*14

recurent = RNNParams(
    Variable(2, he_weights_init(64, 64, 64), name="W1"),
    Variable(2, he_weights_init(64, 64, recsize), name="U1"),
    Variable(1, zeros(64), name="b1")
)
dense = DenseParams(
    Variable(2, he_weights_init(64, 10, 64), name="w2"),
    Variable(1, zeros(10), name="b2")
)
networkparams = NetworkParams(recurent,dense)

NetworkParams(RNNParams(var W1
 ┣━ ^ 64×64 Matrix{Float64}
 ┗━ ∇ 64×64 Matrix{Float64}, var U1
 ┣━ ^ 64×196 Matrix{Float64}
 ┗━ ∇ 64×196 Matrix{Float64}, var b1
 ┣━ ^ 64-element Vector{Float64}
 ┗━ ∇ 64-element Vector{Float64}), DenseParams(var w2
 ┣━ ^ 10×64 Matrix{Float64}
 ┗━ ∇ 10×64 Matrix{Float64}, var b2
 ┣━ ^ 10-element Vector{Float64}
 ┗━ ∇ 10-element Vector{Float64}))

In [56]:
# Opracowane na podstawie https://minpy.readthedocs.io/en/latest/tutorial/rnn_mnist.html
using MLDatasets, Flux
train_data = MLDatasets.MNIST(split=:train)
test_data  = MLDatasets.MNIST(split=:test)

function loader(data)
    x1dim = reshape(data.features, 28 * 28, :) # reshape 28×28 pixels into a vector of pixels
    yhot  = Flux.onehotbatch(data.targets, 0:9) # make a 10×60000 OneHotMatrix
    (x1dim, yhot)
end
(x_data,y_data) = loader(train_data)
x::Variable{1} = Variable(1, x_data[:,1], name="x")
y::Variable{1} = Variable(1, y_data[:,1], name="y")
layerNumber = floor(Int,size(x_data[:,1])[1]/recsize)
net = create_network(x,y,networkparams)

11-element Vector{Node}:
 var x
 ┣━ ^ 784-element Vector{Float64}
 ┗━ ∇ 784-element Vector{Float64}
 var W1
 ┣━ ^ 64×64 Matrix{Float64}
 ┗━ ∇ 64×64 Matrix{Float64}
 var U1
 ┣━ ^ 64×196 Matrix{Float64}
 ┗━ ∇ 64×196 Matrix{Float64}
 var b1
 ┣━ ^ 64-element Vector{Float64}
 ┗━ ∇ 64-element Vector{Float64}
 op (typeof(recurent_layer))
 var w2
 ┣━ ^ 10×64 Matrix{Float64}
 ┗━ ∇ 10×64 Matrix{Float64}
 var b2
 ┣━ ^ 10-element Vector{Float64}
 ┗━ ∇ 10-element Vector{Float64}
 op (typeof(dense_layer))
 op (typeof(softmax))
 var y
 ┣━ ^ 10-element Vector{Float64}
 ┗━ ∇ 10-element Vector{Float64}
 op (typeof(cross_entropy_loss))

In [57]:
using ProgressMeter
loss::Float64 = 0.0
pred::UInt8 = 0
epochs = 5
losses = zeros(epochs)
acc = zeros(epochs)
test_acc = zeros(epochs)
correct = 0 
adam = Adam()
for epoch in 1:epochs
    loss = 0
    correct = 0
    length = size(y_data)[2]
    @time for i in range(1,length)
         x.output .= @view x_data[:,i]
         y.output .= @view y_data[:,i]
        @timeit to "forward" loss += forward!(net)
        prob::NodeOperator{typeof(softmax), 1} = net[9]
        pred = argmax(prob.output)
        if 1 == y_data[:,i][pred]
            correct += 1
        end
        @timeit to "backward" backward!(net)
        
        @timeit to "update weights" update_weights!(net, adam)
    end
    losses[epoch] = loss/length
    acc[epoch] = correct/length
    test_acc[epoch] = validate(x,y,net,test_data)
    println("Epoch: ", epoch, "\tAverage loss: ", round(losses[epoch], digits=3), "\tAverage acc: ", round(acc[epoch],digits=3),"\tAverage test acc: ",round(test_acc[epoch],digits=3))
end
show(to)
reset_timer!(to);

 46.780334 seconds (37.79 M allocations: 84.525 GiB, 9.71% gc time, 2.98% compilation time)
Epoch: 1	Average loss: 0.014	Average acc: 0.904	Average test acc: 0.906
 42.950336 seconds (37.07 M allocations: 84.478 GiB, 8.79% gc time)
Epoch: 2	Average loss: 0.012	Average acc: 0.922	Average test acc: 0.912
 43.120719 seconds (37.07 M allocations: 84.478 GiB, 9.07% gc time)
Epoch: 3	Average loss: 0.012	Average acc: 0.92	Average test acc: 0.916
 42.568344 seconds (37.07 M allocations: 84.478 GiB, 8.91% gc time)
Epoch: 4	Average loss: 0.012	Average acc: 0.923	Average test acc: 0.915
 41.379916 seconds (37.07 M allocations: 84.478 GiB, 8.83% gc time)
Epoch: 5	Average loss: 0.012	Average acc: 0.923	Average test acc: 0.921
[0m[1m ────────────────────────────────────────────────────────────────────────────────[22m
[0m[1m                               [22m         Time                    Allocations      
                               ───────────────────────   ────────────────────────
     