## Imports

In [1]:
using MLDatasets
using Random
import Pkg
using Flux: onehotbatch
using PoissonRandom
using Distributions
using Statistics
using Printf

In [2]:
n_full_test  = 10000 # number of examples to use for full tests  (every epoch)
n_quick_test = 100   # number of examples to use for quick tests (every 1000 examples)

100

## Global Variables

In [3]:
# global variables
l_f_phase      = 2  # length of forward phase (time steps)
l_t_phase      = 2  # length of target phase (time steps)
l_f_phase_test = 2  # length of forward phase for tests (time steps)
dt          = 1.0    # time step (ms)
lambda_max  = 0.2*dt # maximum spike rate (spikes per time step)
# integration_time = 1 # time steps of integration of neuronal variables used for plasticity
# we don't use integration time, this makes index start from 0 where julia indexes start from 1 - directly use 1 inst 0

0.2

In [4]:
# kernel parameters
tau_L = 10.0 # leak time constant

# conductance parameters
g_B = 0.6                                   # basal conductance
g_L = 1.0/tau_L                             # leak conductance
g_D = g_B                                   # dendritic conductance in output layer

E_E = 8                                     # excitation reversal potential
E_I = -8                                    # inhibition reversal potential

# steady state constants
k_B = g_B/(g_L + g_B)
k_D = g_D/(g_L + g_D)
k_I = 1.0/(g_L + g_D)

# weight update constants
P_hidden = 20.0/lambda_max      # hidden layer error signal scaling factor
P_final  = 20.0/(lambda_max^2)  # final layer error signal scaling factor

499.9999999999999

## Functions

In [5]:
# sigmoid function
function sigma(x) 
    return 1 / (1 + exp(-x))
end

sigma (generic function with 1 method)

In [6]:
# derivative sigmoid function
function deriv_sigma(x)
    return exp(-x) / (1 + exp(-x))^2
end

deriv_sigma (generic function with 1 method)

## Objects

In [7]:
mutable struct HiddenLayer
    m # The layer number, eg. m = 1 for the first hidden layer.
    B
    A
    C
    lambda_C
    alpha_f
    alpha_t
    E
    delta_W
    delta_b
end

In [8]:
mutable struct FinalLayer
    m # The layer number, ie. m = M where M is the total number of layers.
    lambda_C # something about soma
    B
    C
    C_f
    k_D2
    k_E
    k_I
    E
    delta_W
    delta_b
end

In [9]:
mutable struct Network
    n # Tuple - Number of units in each layer of the network, eg. (500, 10) here.
    hiddenLayer::HiddenLayer
    finalLayer::FinalLayer
    W
    b
    Y
    f_etas
    loss
end

## Initialization Sub-Functions

In [10]:
function initializeWbY(ln, n, n_in)
    W = Array{Any}(undef, ln)
    b = Array{Any}(undef, ln)
    Y = Array{Any}(undef, ln-1)
    
    # weight optimization parameters
    V_avg = 3   # desired average of dendritic potential
    V_sd  = 3   # desired standard deviation of dendritic potential
    b_avg = 0.8  # desired average of bias
    b_sd  = 0.001 # desired standard deviation of bias
    nu    = lambda_max*0.25  # slope of linear region of activation function
    V_sm  = V_sd^2 + V_avg^2 # second moment of dendritic potential
    
    for m in 2:-1:1
        # get number of units in the layer below
        if m != 1
            N = n[m-1] #500 W2
        else
            N = n_in  #784 W1
        end
            
        # generate feedforward weights & biases
        # calculate weight variables needed to get desired average & standard deviations of somatic potentials
        W_avg = (V_avg - b_avg)/(nu*N*V_avg) 
        W_sm  = (V_sm + (nu^2)*(N - N^2)*(W_avg^2)*(V_avg^2) - 2*N*nu*b_avg*V_avg*W_avg - (b_avg^2))/(N*(nu^2)*V_sm)
        W_sd  = sqrt(W_sm - W_avg^2)
 
        W[m] = W_avg .+ 3.465*W_sd*rand(Uniform(-1,1), n[m], N) # (500,784) , (10,500)
        b[m] = b_avg .+ 3.465*b_sd*rand(Uniform(-1,1), n[m], 1) # (500,1) , (10,1)
        
        if m != 1
            Y[m-1] = W_avg .+ 3.465*W_sd*rand(Uniform(-1,1), N, n[end]) # (500,10)
        end
    end
        
    return W, b, Y
end

initializeWbY (generic function with 1 method)

## Initializations

In [11]:
# f_input_size - The size of feedforward input, eg. 784 for MNIST input. (784, 1) here.
# b_input_size - The size of feedback input. This is the same as the number of units in the next layer. (10,1) here.
function HiddenLayer(m, n, f_input_size, b_input_size, size_W) # m=1
    B = zeros(n[m], 1)                          # (500,1) m=1 in hidden here
    A = zeros(n[m], 1)
    C = zeros(n[m], 1)
    lambda_C = zeros(n[m], 1)
    alpha_f = zeros(n[m], 1)
    alpha_t = zeros(n[m], 1)
    E = zeros(n[m], 1)
    delta_W = zeros(size_W)
    delta_b = zeros(n[m], 1)
    return HiddenLayer(m, B, A, C, lambda_C, alpha_f, alpha_t, E, delta_W, delta_b)
end

HiddenLayer

In [12]:
# f_input_size - The size of feedforward input. This is the same as the number of units in the previous layer. (500,1)
function FinalLayer(m, n, f_input_size, size_W) # m=2
    lambda_C = zeros(n[m], 1)                   # (10,1) m=2 in final here
    B = zeros(n[m], 1)                                
    C = zeros(n[m], 1)
    C_f = zeros(n[m], 1)
    k_D2 = zeros(n[m], 1)
    k_E = zeros(n[m], 1)
    k_I = zeros(n[m], 1)
    E = zeros(n[m], 1)
    delta_W = zeros(size_W)
    delta_b = zeros(n[m], 1)
    return FinalLayer(m, lambda_C, B, C, C_f, k_D2, k_E, k_I, E, delta_W, delta_b)
end

FinalLayer

In [13]:
function Network(n, n_in)
    W, b, Y = initializeWbY(length(n), n, n_in)
    hiddenLayer = HiddenLayer(1, n, n_in, n[end], size(W[1]))
    finalLayer = FinalLayer(2, n, n[end-1], size(W[2]))
    f_etas = (0.21, 0.21)       #Learning rates for each layer's feedforward weights
    loss = 0
    return Network(n, hiddenLayer, finalLayer, W, b, Y, f_etas, loss)
end

Network

## Callable Objects

In [14]:
# net f_phase
# Perform a forward phase.
# x : Input array of size (X, 1) where X is the size of the input, eg. (784, 1).
function (net::Network)(x)
    net.hiddenLayer.B = net.W[net.hiddenLayer.m]*x .+ net.b[net.hiddenLayer.m] # (500,784) * (784,1) .+ (500,1) m=1 in hidden here
    net.hiddenLayer.C = k_B*net.hiddenLayer.B
    net.hiddenLayer.lambda_C = lambda_max*sigma.(net.hiddenLayer.C)
    net.finalLayer.B = net.W[net.finalLayer.m]*net.hiddenLayer.lambda_C .+ net.b[net.finalLayer.m] # (10,500) * (500, 1) .+ (10,1) m=2 in final here
    net.finalLayer.C = k_D.*net.finalLayer.B
    net.finalLayer.lambda_C = lambda_max*sigma.(net.finalLayer.C)
    net.hiddenLayer.A = net.Y[net.hiddenLayer.m]*net.finalLayer.lambda_C # (500,10) * (10,1) m=1 in hidden here

    # calculate plateau potentials for hidden layer neurons
    net.hiddenLayer.alpha_f = sigma.(net.hiddenLayer.A)
    
    # record C(soma) forward potential
    net.finalLayer.C_f = net.finalLayer.C
end

In [15]:
# net t_phase
# Perform a target phase.
# x : Input array of size (X, 1) where X is the size of the input, eg. (784, 1).
# t : Target array of size (T, 1) where T is the size of the target, eg. (10, 1).

function (net::Network)(x, t)
    net.hiddenLayer.B = net.W[net.hiddenLayer.m]*x .+ net.b[net.hiddenLayer.m] # (500,784) * (784,1) .+ (500,1) m=1 in hidden here
    net.hiddenLayer.C = k_B*net.hiddenLayer.B
    net.hiddenLayer.lambda_C = lambda_max*sigma.(net.hiddenLayer.C)
    net.finalLayer.B = net.W[net.finalLayer.m]*net.hiddenLayer.lambda_C .+ net.b[net.finalLayer.m] # (10,500) * (500, 1) .+ (10,1) m=2 in final here
    g_E = t
    g_I = -g_E .+ 1
    net.finalLayer.k_D2 = g_D./((g_L + g_D) .+ g_E + g_I)
    net.finalLayer.k_E  = g_E./((g_L + g_D) .+ g_E + g_I)
    net.finalLayer.k_I  = g_I./((g_L + g_D) .+ g_E + g_I)
    net.finalLayer.C = net.finalLayer.k_D2.*net.finalLayer.B + net.finalLayer.k_E*E_E + net.finalLayer.k_I*E_I
    net.finalLayer.lambda_C = lambda_max*sigma.(net.finalLayer.C)
    net.hiddenLayer.A = net.Y[net.hiddenLayer.m]*net.finalLayer.lambda_C # (500,10) * (10,1) m=1 in hidden here
    
    # calculate plateau potentials for hidden layer neurons
    net.hiddenLayer.alpha_t = sigma.(net.hiddenLayer.A)
    
    # update weights
    net.finalLayer.E = (net.finalLayer.lambda_C - lambda_max*sigma.(net.finalLayer.C_f)).*((-k_D)*lambda_max*deriv_sigma.(net.finalLayer.C_f))
    net.finalLayer.delta_W = net.finalLayer.E * net.hiddenLayer.lambda_C'   #(10,1)*(1,500) = (10,500)
    net.W[net.finalLayer.m] += -net.f_etas[net.finalLayer.m]*P_final*net.finalLayer.delta_W
    net.finalLayer.delta_b = net.finalLayer.E
    net.b[net.finalLayer.m] += -net.f_etas[net.finalLayer.m]*P_final*net.finalLayer.delta_b
    
    net.hiddenLayer.E = (net.hiddenLayer.alpha_t - net.hiddenLayer.alpha_f).*((-k_B)*lambda_max*deriv_sigma.(net.hiddenLayer.C))
    net.hiddenLayer.delta_W = net.hiddenLayer.E * x'   #((500,1)*(1,784) = (500,784)
    net.W[net.hiddenLayer.m] += -net.f_etas[net.hiddenLayer.m]*P_hidden*net.hiddenLayer.delta_W
    net.hiddenLayer.delta_b = net.hiddenLayer.E
    net.b[net.hiddenLayer.m] += -net.f_etas[net.hiddenLayer.m]*P_hidden*net.hiddenLayer.delta_b
    
    net.loss = mean((net.finalLayer.lambda_C - lambda_max*sigma.(net.finalLayer.C_f)).^2)

    # reset averages
    net.finalLayer.C_f .*= 0
end

## Training Sub-Functions

In [16]:
function clear_vars(net::Network)
    net.hiddenLayer.A .*= 0
    net.hiddenLayer.B .*= 0
    net.hiddenLayer.C .*= 0
    net.hiddenLayer.lambda_C .*= 0

    net.hiddenLayer.E .*= 0
    net.hiddenLayer.delta_W .*= 0
    net.hiddenLayer.delta_b .*= 0

    net.hiddenLayer.alpha_f .*= 0
    net.hiddenLayer.alpha_t .*= 0
    
    net.finalLayer.B .*= 0
    net.finalLayer.C .*= 0
    net.finalLayer.lambda_C .*= 0

    net.finalLayer.E .*= 0
    net.finalLayer.delta_W .*= 0
    net.finalLayer.delta_b .*= 0

    net.finalLayer.C_f .*= 0
end

clear_vars (generic function with 1 method)

In [17]:
# net test_weights
# Test the network's current weights on the test set. The network's layers are copied
#      and restored to their previous state after testing.
# n_test : The number of test examples to use.

function test_weights(net::Network, n_test, x_test, t_test, l_f_phase, l_f_phase_test)
    # save old length of forward phase
    old_l_f_phase = l_f_phase

    # set new length of forward phase
    l_f_phase = l_f_phase_test

    # initialize count of correct classifications
    num_correct = 0

    # shuffle testing data
    order = randperm(size(x_test)[2]) #randperm(10000)
    x_test = x_test[:,order] 
    t_test = t_test[:,order]
    
    for i in 1:n_test
        # clear all layer variables
        clear_vars(net)

        # get testing example data
        x = lambda_max*x_test[:, i]
        t = t_test[:, i]

        # do a forward phase & get the unit with maximum average somatic potential
        net(x) # net f_phase
        sel_num = argmax(net.finalLayer.C_f)[1]

        # get the target number from testing example data
        target_num = argmax(t)

        # increment correct classification counter if they match
        if sel_num == target_num
            num_correct += 1
        end
    end
    
    # calculate percent error
    err_rate = (1.0 - num_correct/n_test)*100.0

    l_f_phase = old_l_f_phase

    clear_vars(net)
    
    return err_rate
end

test_weights (generic function with 1 method)

## Training

In [18]:
function load_MNIST()
    x_train = MNIST.convert2features(MNIST.traintensor(Float64))
    t_train = MNIST.trainlabels()
    t_train = onehotbatch(t_train, [0, 1, 2, 3, 4, 5, 6, 7, 8, 9])    
    x_test = MNIST.convert2features(MNIST.testtensor(Float64))
    t_test = MNIST.testlabels()
    t_test = onehotbatch(t_test, [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]) 
    return x_train, t_train, x_test, t_test
end

load_MNIST (generic function with 1 method)

In [19]:
# net train
function train(net::Network, n_epochs = 10, n_training_examples = 60000)
    
    x_train, t_train, x_test, t_test = load_MNIST()
    
    current_epoch = 1
    
    # don't do an initial weight test
    println("Start of epoch $(current_epoch)")
    
    # start time used for timing how long each 1000 examples take
    start_time = nothing
    
    #record_training_error
    num_correct = 0
    
    for k in 1:n_epochs
        # shuffle the training data
        order = randperm(size(x_train)[2]) #randperm(60000)
        x_train = x_train[:,order] 
        t_train = t_train[:,order]
                  
        for i in 1:n_training_examples   # n become i here
            # set start time
            if start_time == nothing
                start_time = time()
            end
            
            # get training example data
            x = lambda_max*x_train[:,i]
            t = t_train[:,i]
            
            # do forward phase
            net(x) # net f_phase
            
            sel_num = argmax(net.finalLayer.C_f)[1] #axis=-1

            # get the target number from testing example data
            target_num = argmax(t)

            # increment correct classification counter if they match
            if sel_num == target_num
                num_correct += 1
            end
            
            # do target phase
            net(x, t) # net t_phase
            
            if i % 1000 == 0
                if i != n_training_examples
                    # we're partway through an epoch; do a quick weight test
                    test_err = test_weights(net, n_quick_test, x_test, t_test, l_f_phase, l_f_phase_test)
                    @printf("Epoch %d, example %d/%d. QE: %g ", current_epoch, i, n_training_examples, test_err)
                else
                    # we've reached the end of an epoch; do a full weight test
                    test_err = test_weights(net, n_full_test, x_test, t_test, l_f_phase, l_f_phase_test)
                    @printf("FE: %g ", test_err)
                    
                    # calculate percent training error for this epoch
                    err_rate = (1.0 - num_correct/n_training_examples)*100.0
                    @printf("TE: %g ", err_rate)

                    num_correct = 0
                end
                # get end time & reset start time
                end_time = time()
                time_elapsed = end_time - start_time
                @printf("T: %g\n", time_elapsed)
                start_time = nothing
            end
        end
        # update latest epoch counter
        current_epoch += 1
    end
end

train (generic function with 3 methods)

In [20]:
n = (500, 10) #for now building everything for having one hidden layer
n_in  = 784
net = Network(n, n_in)
train(net)

Start of epoch 1
Epoch 1, example 1000/60000. QE: 64 T: 6.71466
Epoch 1, example 2000/60000. QE: 22 T: 1.9816
Epoch 1, example 3000/60000. QE: 19 T: 1.93428
Epoch 1, example 4000/60000. QE: 23 T: 1.89175
Epoch 1, example 5000/60000. QE: 18 T: 1.96241
Epoch 1, example 6000/60000. QE: 16 T: 1.89409
Epoch 1, example 7000/60000. QE: 14 T: 1.90264
Epoch 1, example 8000/60000. QE: 8 T: 1.93887
Epoch 1, example 9000/60000. QE: 13 T: 1.93291
Epoch 1, example 10000/60000. QE: 18 T: 1.93048
Epoch 1, example 11000/60000. QE: 12 T: 1.91588
Epoch 1, example 12000/60000. QE: 11 T: 1.87871
Epoch 1, example 13000/60000. QE: 8 T: 1.93153
Epoch 1, example 14000/60000. QE: 12 T: 1.91759
Epoch 1, example 15000/60000. QE: 11 T: 1.92545
Epoch 1, example 16000/60000. QE: 9 T: 1.90236
Epoch 1, example 17000/60000. QE: 9 T: 1.91745
Epoch 1, example 18000/60000. QE: 8 T: 1.9211
Epoch 1, example 19000/60000. QE: 6 T: 1.89625
Epoch 1, example 20000/60000. QE: 9 T: 1.92993
Epoch 1, example 21000/60000. QE: 12 T: 1

Epoch 3, example 56000/60000. QE: 1 T: 1.87293
Epoch 3, example 57000/60000. QE: 4 T: 1.90584
Epoch 3, example 58000/60000. QE: 4 T: 1.90877
Epoch 3, example 59000/60000. QE: 8 T: 1.88537
FE: 4.24 TE: 4.21167 T: 5.91066
Epoch 4, example 1000/60000. QE: 3 T: 1.854
Epoch 4, example 2000/60000. QE: 4 T: 1.77096
Epoch 4, example 3000/60000. QE: 2 T: 1.79936
Epoch 4, example 4000/60000. QE: 3 T: 1.85726
Epoch 4, example 5000/60000. QE: 2 T: 1.75168
Epoch 4, example 6000/60000. QE: 5 T: 1.82168
Epoch 4, example 7000/60000. QE: 8 T: 1.78412
Epoch 4, example 8000/60000. QE: 3 T: 1.75553
Epoch 4, example 9000/60000. QE: 2 T: 1.78125
Epoch 4, example 10000/60000. QE: 5 T: 1.78104
Epoch 4, example 11000/60000. QE: 6 T: 1.76579
Epoch 4, example 12000/60000. QE: 3 T: 1.77471
Epoch 4, example 13000/60000. QE: 4 T: 1.7799
Epoch 4, example 14000/60000. QE: 5 T: 1.77081
Epoch 4, example 15000/60000. QE: 3 T: 1.75345
Epoch 4, example 16000/60000. QE: 3 T: 1.77839
Epoch 4, example 17000/60000. QE: 1 T: 1

Epoch 6, example 53000/60000. QE: 3 T: 1.75486
Epoch 6, example 54000/60000. QE: 3 T: 1.78316
Epoch 6, example 55000/60000. QE: 0 T: 1.78519
Epoch 6, example 56000/60000. QE: 4 T: 1.76981
Epoch 6, example 57000/60000. QE: 5 T: 1.77404
Epoch 6, example 58000/60000. QE: 1 T: 1.78638
Epoch 6, example 59000/60000. QE: 6 T: 1.78312
FE: 3.18 TE: 2.595 T: 5.68971
Epoch 7, example 1000/60000. QE: 4 T: 2.50159
Epoch 7, example 2000/60000. QE: 1 T: 1.87346
Epoch 7, example 3000/60000. QE: 2 T: 1.88517
Epoch 7, example 4000/60000. QE: 5 T: 1.8785
Epoch 7, example 5000/60000. QE: 4 T: 1.89667
Epoch 7, example 6000/60000. QE: 2 T: 1.92301
Epoch 7, example 7000/60000. QE: 6 T: 1.94833
Epoch 7, example 8000/60000. QE: 7 T: 1.86915
Epoch 7, example 9000/60000. QE: 6 T: 1.87154
Epoch 7, example 10000/60000. QE: 4 T: 1.94041
Epoch 7, example 11000/60000. QE: 0 T: 1.86703
Epoch 7, example 12000/60000. QE: 5 T: 1.874
Epoch 7, example 13000/60000. QE: 9 T: 1.95623
Epoch 7, example 14000/60000. QE: 3 T: 1.8

Epoch 9, example 50000/60000. QE: 4 T: 1.88744
Epoch 9, example 51000/60000. QE: 3 T: 1.86931
Epoch 9, example 52000/60000. QE: 6 T: 1.90608
Epoch 9, example 53000/60000. QE: 3 T: 1.87525
Epoch 9, example 54000/60000. QE: 2 T: 1.86478
Epoch 9, example 55000/60000. QE: 4 T: 1.88399
Epoch 9, example 56000/60000. QE: 5 T: 1.87354
Epoch 9, example 57000/60000. QE: 2 T: 1.86467
Epoch 9, example 58000/60000. QE: 3 T: 1.92734
Epoch 9, example 59000/60000. QE: 3 T: 1.86378
FE: 2.56 TE: 1.79833 T: 5.79065
Epoch 10, example 1000/60000. QE: 2 T: 1.83968
Epoch 10, example 2000/60000. QE: 2 T: 1.75312
Epoch 10, example 3000/60000. QE: 5 T: 1.76371
Epoch 10, example 4000/60000. QE: 2 T: 1.86543
Epoch 10, example 5000/60000. QE: 5 T: 1.76292
Epoch 10, example 6000/60000. QE: 2 T: 1.75272
Epoch 10, example 7000/60000. QE: 2 T: 1.8146
Epoch 10, example 8000/60000. QE: 3 T: 1.74878
Epoch 10, example 9000/60000. QE: 3 T: 1.74216
Epoch 10, example 10000/60000. QE: 2 T: 1.80185
Epoch 10, example 11000/6000