In [5]:
from json import encoder
import matplotlib.pyplot as plt
import numpy as np

iris_data_location = '../../Datasets/Iris/iris.data'

from snn_dpe import Encoder, Neuron, Synapse
from snn_dpe.tools.data import normalize_iris_data, read_iris_data
from snn_dpe.tools.network import (create_encoders, create_network,
                                   find_steady_state, reset_network,
                                   run_network, run_network_early_exit)
from snn_dpe.tools.plotting import plot_spike_raster
from snn_dpe.tools.train.classification import predict, train_all
from snn_dpe.tools.train.utils import forward_pass, mse, update_weights

iris_data, labels, classes, attributes = read_iris_data(iris_data_location)

normalized_iris_data = normalize_iris_data(iris_data, attributes)

test_idx = 0
normalized_iris_data_sample = normalized_iris_data[test_idx]
label = labels[test_idx]

# create a test network and encoders
n_neurons = 16
n_synapses = int(n_neurons * np.random.uniform(low=2, high=3)) # random number from n_neurons * 2 to n_neurons * 3

neurons = create_network(n_neurons, n_synapses)

encoders = create_encoders(len(attributes))

dpe_weights = np.random.rand(n_neurons, len(classes))

sim_time = 200

# show we can reduce error for a singe test sample
for i in range(10):
    # feed a test sample into the test network
    spike_raster, encoder_fires = run_network(neurons, encoders, normalized_iris_data_sample, sim_time)
    # plot_spike_raster(fire_matrix)
    reset_network(neurons, encoders)

    x, y = forward_pass(spike_raster, dpe_weights)

    y_hat = np.zeros(len(classes))
    y_hat[label] = 1

    print(f'Mean Squared Error after {i} weight update{'s' if i != 1 else ''} = {mse(y, y_hat)}')

    update_weights(dpe_weights, x, y, y_hat)

Mean Squared Error after 0 weight updates = 11.935134663768153
Mean Squared Error after 1 weight update = 10.858986900870725
Mean Squared Error after 2 weight updates = 9.517418368718788
Mean Squared Error after 3 weight updates = 8.34159330259093
Mean Squared Error after 4 weight updates = 7.311035002362387
Mean Squared Error after 5 weight updates = 6.407796552388359
Mean Squared Error after 6 weight updates = 5.616148280446289
Mean Squared Error after 7 weight updates = 4.922303829419114
Mean Squared Error after 8 weight updates = 4.314180071326163
Mean Squared Error after 9 weight updates = 3.781186682664456
