### Implementing Long Short-Term Memory to detect and classify Parkinsons' Freezing of Gait types in time series data

In [400]:
using Pkg

# Pkg.add("CSV")
# Pkg.add("NNlib")
# Pkg.add("DataFrames")
# Pkg.add("Distributions")
# Pkg.add("ResumableFunctions")
# Pkg.add("Flux")


In [401]:
using Flux
using Flux: @epochs, batch, throttle

using CSV
using NNlib
using DataFrames
using Distributions
using ResumableFunctions
using Random

In [402]:
parkinson = CSV.read("./filtered_data.csv", DataFrame)

Row,Id,Subject,Visit,Test,Medication,Time,AccV,AccML,AccAP,StartHesitation,Turn,Walking,event
Unnamed: 0_level_1,String15,String7,Int64,Int64,String3,Int64,Float64,Float64,Float64,Int64,Int64,Int64,String7
1,009ee11563,f62eec,4,2,on,0,-9.4173,0.767819,-1.75824,0,0,0,Normal
2,009ee11563,f62eec,4,2,on,1,-9.4251,0.768246,-1.75058,0,0,0,Normal
3,009ee11563,f62eec,4,2,on,2,-9.41995,0.779039,-1.74259,0,0,0,Normal
4,009ee11563,f62eec,4,2,on,3,-9.42127,0.772523,-1.74651,0,0,0,Normal
5,009ee11563,f62eec,4,2,on,4,-9.42811,0.777142,-1.75555,0,0,0,Normal
6,009ee11563,f62eec,4,2,on,5,-9.42602,0.774812,-1.76021,0,0,0,Normal
7,009ee11563,f62eec,4,2,on,6,-9.42569,0.768126,-1.7736,0,0,0,Normal
8,009ee11563,f62eec,4,2,on,7,-9.43509,0.765923,-1.76937,0,0,0,Normal
9,009ee11563,f62eec,4,2,on,8,-9.43209,0.770584,-1.77557,0,0,0,Normal
10,009ee11563,f62eec,4,2,on,9,-9.4295,0.775137,-1.77076,0,0,0,Normal


In [403]:
@resumable function data_loader(parkinson_dataframe, num_sequences, num_selected_sequences, batch_size; labels=["StartHesitation", "Turn", "Walking", "Normal"])
    pdf = deepcopy(parkinson_dataframe)
    
    # Sending to the model
    for i in 1:num_selected_sequences
        # if benne van a random generalt elemek tombunkben
        x = hcat(
            pdf[!, "AccV"][sequences[i] * batch_size : sequences[i] * batch_size + batch_size],
            pdf[!, "AccML"][sequences[i] * batch_size : sequences[i] * batch_size + batch_size],
            pdf[!, "AccAP"][sequences[i] * batch_size : sequences[i] * batch_size + batch_size]
        )
    
        y = Flux.onehotbatch(pdf[!, "event"][sequences[i] * batch_size : sequences[i] * batch_size + batch_size], labels)
        @yield x, y
    end
end

data_loader (generic function with 1 method)

In [404]:
function init_params(in::Integer, out::Integer; mean=0.0, std=1.0)
    [
        in, out,
        rand(Truncated(Normal(mean, std), -1, 1), out), # Wf
        rand(Truncated(Normal(mean, std), -1, 1), out), # Wi
        rand(Truncated(Normal(mean, std), -1, 1), out), # Wc
        rand(Truncated(Normal(mean, std), -1, 1), out), # Wo
        rand(Truncated(Normal(mean, std), -1, 1), out), # bf
        rand(Truncated(Normal(mean, std), -1, 1), out), # bi
        rand(Truncated(Normal(mean, std), -1, 1), out), # bc
        rand(Truncated(Normal(mean, std), -1, 1), out), # bo
        rand(Truncated(Normal(mean, std), -1, 1), out), # bo
        rand(Truncated(Normal(mean, std), -1, 1), out), # bo
        rand(Truncated(Normal(mean, std), -1, 1), out), # bo
        rand(Truncated(Normal(mean, std), -1, 1), out), # bo

        # both the Long-Term and Short-Term memories are initialized with 0 values
        zeros(out),  # c
        zeros(out),  # h
        true # update_memory
    ]
end


init_params (generic function with 1 method)

In [405]:
# implementing the forwarding method which is used in the Chaining process
function forward(x, lstm)

    for i in 1:size(x, 1)
        long_remember_percent = NNlib.sigmoid_fast((lstm.c .* lstm.wlr1) + (x[i, :] .* lstm.wlr2) + lstm.blr1)
        potential_remember_percent = NNlib.sigmoid_fast((lstm.h .* lstm.wpr1) + (x[i, :] .* lstm.wpr2) + lstm.bpr1)
        potential_memory = NNlib.tanh_fast((lstm.c .* lstm.wp1) + (x[i, :] .* lstm.wp2) + lstm.bp1)
        updated_long_memory = (lstm.c .* long_remember_percent) + (potential_memory .* potential_remember_percent)
        output_percent = NNlib.sigmoid_fast((lstm.h .* lstm.wo1) + (x[i, :] .* lstm.wo2) + lstm.bo1)
        updated_short_memory = NNlib.tanh_fast(updated_long_memory) .* output_percent

        # updating the memory
        if lstm.update_memory
            lstm.c, lstm.h = updated_long_memory, updated_short_memory
        end
    end

    lstm.c
end

forward (generic function with 1 method)

In [406]:
# custom Long Short-Term Memory layer
mutable struct LSTM
    # input and output size of the layer
    in::Integer
    out::Integer

    wlr1::Vector{Float32}
    wlr2::Vector{Float32}
    blr1::Vector{Float32}

    wpr1::Vector{Float32}
    wpr2::Vector{Float32}
    bpr1::Vector{Float32}

    wp1::Vector{Float32}
    wp2::Vector{Float32}
    bp1::Vector{Float32}

    wo1::Vector{Float32}
    wo2::Vector{Float32}
    bo1::Vector{Float32}


    # cell state (aka. long-term memory) and hidden state (aka. short-term memory)
    c::Vector{Float32}
    h::Vector{Float32}

    # prevents the model from modifying the memory state in case of testing and loss calculation
    # it needs to be set explicitely before and after calling the model(x) funtion
    update_memory::Bool
end

In [407]:
# defining the constructor
LSTM(in::Integer, out::Integer) = LSTM(init_params(in, out)...)

LSTM

In [408]:
# Overload call, so the object can be used as a function
(lstm::LSTM)(x) = forward(x, lstm)

In [409]:
# creating a functor from the struct, so that the training can optimize its parameters
Flux.@functor LSTM

In [410]:
# creating the Long Short-Term Memory layer
function LSTM((in, out)::Pair)
    LSTM(in, out) # constructor
end

LSTM

In [411]:
# explicitely defining the trainable parameters of the layer
# all the Wrights and Biases are trainable
# exceptions >> Cell State and Hidden State
# Flux.trainable(lstm::LSTM) = (lstm.Wf, lstm.Wi, lstm.Wc, lstm.Wo, lstm.bf, lstm.bi, lstm.bc, lstm.bo,)
Flux.trainable(lstm::LSTM) = (lstm.wlr1, lstm.wlr2, lstm.blr1, lstm.wpr1, lstm.wpr2, lstm.bpr1, lstm.wp1, lstm.wp2, lstm.bp1, lstm.wo1, lstm.wo2, lstm.bo1)

In [412]:
# Random generating sequences
num_sequences = 10000
batch_size = div(size(parkinson, 1), num_sequences)
@show batch_size
sequences = collect(1:num_sequences)
@show sequences
shuffle!(sequences)
@show sequences
print(batch_size * sequences[1] : batch_size * sequences[1] + batch_size)

input_size = batch_size
hidden_size = 3
num_classes = 4 

model = Chain(
    LSTM(input_size => hidden_size),
    Dense(hidden_size => num_classes),
    softmax
)

;

batch_size = 52
sequences = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127, 128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143, 144, 145, 146, 147, 148, 149, 150, 151, 152, 153, 154, 155, 156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168, 169, 170, 171, 172, 173, 174, 175, 176, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 188, 189, 190, 191, 192, 193, 194, 195, 196, 197, 198, 199, 200, 201, 202, 203, 204, 205, 206, 207, 208, 209, 210, 211, 212, 213, 214, 215, 216,

In [413]:
function loss(x, y)
    model[1].update_memory = false
    l = Flux.crossentropy.(eachrow(model(x)'), eachrow(y'))
    model[1].update_memory = true

    mean(l)
end

optimizer = ADAM(0.001)
epochs = 10


;

In [414]:
for epoch in 1:epochs
  for (input, output) in data_loader(parkinson, num_sequences, 1, batch_size) #[(1, [0,0,1,0]), (1, [0,0,1,0])]
    grads = Flux.gradient(Flux.params(model)) do
      loss(input, output)
    end
    Flux.update!(optimizer, Flux.params(model), grads)

    break
  end

end

"loss" = "loss"
"loss" = "loss"
"loss" = "loss"
"loss" = "loss"
"loss" = "loss"
"loss" = "loss"
"loss" = "loss"
"loss" = "loss"
"loss" = "loss"
"loss" = "loss"


In [415]:
model[1].update_memory = false
model([0,-50,90])

4-element Vector{Float32}:
 0.24874395
 0.24874395
 0.24874395
 0.2537681