In [1]:
import matplotlib.pyplot as plt
import numpy as np

iris_data_location = '../Data/Iris/iris.data'

from snn_dpe import Encoder, Neuron, Synapse
from snn_dpe.tools.data import normalize_iris_data, read_iris_data
from snn_dpe.tools.network import (create_encoders, create_network,
                                   find_steady_state, reset_network,
                                   run_network, run_network_early_exit)
from snn_dpe.tools.plotting import *
from snn_dpe.tools.test import predict
from snn_dpe.tools.train import forward_pass, mse, train_all, update_weights

In [4]:
iris_data, labels, classes, attributes = read_iris_data(iris_data_location)

normalized_iris_data = normalize_iris_data(iris_data, attributes)

test_idx = 0
normalized_iris_data_sample = normalized_iris_data[test_idx]
label = labels[test_idx]

# create a test network and encoders
n_neurons = 16
n_synapses = int(n_neurons * np.random.uniform(low=2, high=3)) # random number from n_neurons * 2 to n_neurons * 3

neurons = create_network(n_neurons, n_synapses)

encoders = create_encoders(len(attributes))

dpe_weights = np.random.rand(n_neurons, len(classes))

sim_time = 200

# show we can reduce error for a singe test sample
for i in range(10):
    # feed a test sample into the test network
    fire_matrix = run_network_early_exit(neurons, encoders, normalized_iris_data_sample, sim_time)

    reset_network(neurons, encoders)

    x, y = forward_pass(fire_matrix, dpe_weights)

    y_hat = np.zeros(len(classes))
    y_hat[label] = 1

    print(f'Mean Squared Error {i} = {mse(y, y_hat)}')

    update_weights(fire_matrix, dpe_weights, x, y, y_hat)

Mean Squared Error 0 = 12.949742371894132
Mean Squared Error 1 = 12.521111087477806
Mean Squared Error 2 = 12.106667326851849
Mean Squared Error 3 = 11.705941488662791
Mean Squared Error 4 = 11.31847951517382
Mean Squared Error 5 = 10.943842377777298
Mean Squared Error 6 = 10.581605579536628
Mean Squared Error 7 = 10.23135867419373
Mean Squared Error 8 = 9.892704801097235
Mean Squared Error 9 = 9.565260235524335
