In [1]:
import matplotlib.pyplot as plt
import numpy as np

iris_data_location = '../../Datasets/Iris/iris.data'

from snn_dpe import Encoder, Neuron, Synapse
from snn_dpe.tools.data import normalize_iris_data, read_iris_data
from snn_dpe.tools.network import (create_encoders, create_network,
                                   find_steady_state, reset_network,
                                   run_network, run_network_early_exit)
from snn_dpe.tools.plotting import plot_spike_raster
from snn_dpe.tools.train.classification import predict, train_all
from snn_dpe.tools.train.utils import forward_pass, mse, update_weights

In [2]:
iris_data, labels, classes, attributes = read_iris_data(iris_data_location)

normalized_iris_data = normalize_iris_data(iris_data, attributes)

test_idx = 0
normalized_iris_data_sample = normalized_iris_data[test_idx]
label = labels[test_idx]

# create a test network and encoders
n_neurons = 16
n_synapses = int(n_neurons * np.random.uniform(low=2, high=3)) # random number from n_neurons * 2 to n_neurons * 3

neurons = create_network(n_neurons, n_synapses)

encoders = create_encoders(len(attributes))

dpe_weights = np.random.rand(n_neurons, len(classes))

sim_time = 200

# show we can reduce error for a singe test sample
for i in range(10):
    # feed a test sample into the test network
    spike_raster = run_network_early_exit(neurons, encoders, normalized_iris_data_sample, sim_time)
    # plot_spike_raster(fire_matrix)
    reset_network(neurons, encoders)

    x, y = forward_pass(spike_raster, dpe_weights)

    y_hat = np.zeros(len(classes))
    y_hat[label] = 1

    print(f'Mean Squared Error {i} weight updates = {mse(y, y_hat)}')

    update_weights(dpe_weights, x, y, y_hat)

Mean Squared Error 0 weight updates = 31.825071420893405
Mean Squared Error 1 weight updates = 25.390240568034418
Mean Squared Error 2 weight updates = 20.256492360279
Mean Squared Error 3 weight updates = 16.16075600554292
Mean Squared Error 4 weight updates = 12.893151984339616
Mean Squared Error 5 weight updates = 10.286237106374527
Mean Squared Error 6 weight updates = 8.20642414958514
Mean Squared Error 7 weight updates = 6.54713639462572
Mean Squared Error 8 weight updates = 5.223346269763501
Mean Squared Error 9 weight updates = 4.167218247698045
