In [2]:
import string

import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
from matplotlib import cm as cm_mlib
from matplotlib.ticker import LinearLocator, FormatStrFormatter
import scipy
from matplotlib import animation, rc, colors
import brian2.units as bunits
import matplotlib as mlib
from scipy import stats
from pprint import pprint as pp
from mpl_toolkits.axes_grid1 import make_axes_locatable, ImageGrid, AxesGrid
import traceback
import os
import copy
from datetime import datetime
from brian2.units import *
import warnings
warnings.filterwarnings("ignore", category=UserWarning)

# ensure we use viridis as the default cmap
plt.viridis()

In [3]:
# ensure we use the same rc parameters for all matplotlib outputs
mlib.rcParams.update({'font.size': 24})
mlib.rcParams.update({'errorbar.capsize': 5})
mlib.rcParams.update({'figure.autolayout': True})

In [4]:
def get_ordered_layer_names(dictionary):
    layer_names = list(dictionary)
    layer_names.sort()
    layer_names[0] = layer_names.pop(-1)
    return layer_names

def get_ordered_layer_names_from_data(data):
    dictionary = data['spikes_dict'][()].keys()
    return get_ordered_layer_names(dictionary)
    
def get_shape_from_name(name):
    shape_string = name.split('_')
    if len(shape_string) <2:
        return
    else:
        shape_string = shape_string[1]
    shape_list = shape_string.split('x')
    return tuple([int(item) for item in shape_list])   

In [5]:
# data = np.load("mnist_results__235237_08072019.npz")

# data = np.load("mnist_results_ed_mnist_network_test_100_full_recordings.npz")
#data = np.load("results/output_t_stim_80_testing_examples_50.npz")
data= np.load("results/output.npz")
#data= np.load("results/output_t_stim_800_testing_examples_20.npz")
#"..//pynn_object_serialisation/pynn_object_serialisation/"
#"experiments/mnist_testing/results/mnist_results_ed_mnist_network_test_1.npz")
# data = np.load("mnist_results_ed_mnist_network_test.npz")

In [6]:
input_layer_name = 'InputLayer'
output_layer_name = get_ordered_layer_names_from_data(data)[-1]
y_test = data['y_test']
spikes_dict = data['spikes_dict'].ravel()[0]
t_stim=data['t_stim']
runtime=int(data['runtime'])
N_layer=int(data['N_layer'])
neo_object = data['neo_spikes_dict']

data.close()

labels = np.load('/mnt/snntoolbox/snn_toolbox_private/examples/models/05-mobilenet_dwarf_v1/label_names.npz')
label_names = labels['arr_0']
labels.close()

In [7]:
def get_bounds(bin_number, t_stim =t_stim):
    lower_end_bin_time = bin_number * t_stim
    higher_end_bin_time = (bin_number+1) * t_stim
    return lower_end_bin_time, higher_end_bin_time

def get_bin_spikes(spikes, bin_number):
    lower_end_bin_time, higher_end_bin_time = get_bounds(bin_number)
    output = spikes[np.where((spikes[:,1] >= lower_end_bin_time) & (spikes[:,1] < higher_end_bin_time)),:]
    output = np.asarray(output).astype(int)
    return output

def get_counts(spikes, bin_number, t_stim=t_stim, minlength = 3*32**2):
    spikes = get_bin_spikes(spikes, bin_number)
    just_spikes = spikes.reshape((-1,2))[:,0]
    counts = np.bincount(just_spikes, minlength=minlength)
    return counts

def get_rates(spikes, bin_number, t_stim=t_stim):
    return get_counts(spikes, bin_number)/t_stim

def plot_rates(rates, shape=(32,32,3)):
    rates /= rates.max()
    plt.imshow(rates.reshape(shape))
    
def plot_bin(spikes, bin_number):
    plot_rates(get_rates(spikes, bin_number))
    

def get_prediction(spikes, bin_number, t_stim=t_stim, output_size=10):
    counts = get_counts(spikes, bin_number, t_stim, 10)
    if counts.max() > 0:
        return np.argmax(counts)
    else:
        return -1

In [8]:
input_layer_name = get_ordered_layer_names(spikes_dict)[0]
output_layer_name = get_ordered_layer_names(spikes_dict)[-1]
input_layer_shape = (32, 32, 3)
input_spikes = spikes_dict[input_layer_name]
output_spikes = spikes_dict[output_layer_name]

In [9]:
get_counts(input_spikes, 5).shape

(3072,)

In [10]:
number_of_examples = runtime // t_stim
actual_test_labels = y_test[:number_of_examples]
y_pred = np.ones(number_of_examples) * (-1)
output_spikes = spikes_dict[output_layer_name]
for bin_number in range(number_of_examples):
    y_pred[bin_number]=get_prediction(output_spikes, bin_number)

In [11]:
print("Accuracy", np.count_nonzero(y_pred==actual_test_labels)/float(number_of_examples))

Accuracy 0.0


In [12]:
for i in range(number_of_examples):
    prediction = get_prediction(output_spikes, i)
    predicted_label = label_names[prediction].decode('UTF-8')
    actual_label = label_names[actual_test_labels[i]][0].decode('UTF-8')
    print("Ppredicted: {}, Actual: {}".format(predicted_label, actual_label))
    #plot_bin(input_spikes, i)
    #plt.show()

Ppredicted: truck, Actual: cat
Ppredicted: truck, Actual: ship
Ppredicted: truck, Actual: ship
Ppredicted: truck, Actual: airplane
Ppredicted: truck, Actual: frog
Ppredicted: horse, Actual: frog
Ppredicted: horse, Actual: automobile
Ppredicted: horse, Actual: frog
Ppredicted: horse, Actual: cat
Ppredicted: horse, Actual: automobile


In [13]:
bin_number = 19

prediction = get_prediction(output_spikes, bin_number)
label_names[prediction]

b'truck'

In [14]:
nid_bincount = get_counts(output_spikes, 12, t_stim, 10)
print(nid_bincount)
print("Count: Min", np.min(nid_bincount), "Max", np.max(nid_bincount), "Mean", np.mean(nid_bincount))
nid_bincount = nid_bincount/(t_stim * ms)
print("Hz   : Min", np.min(nid_bincount), "Max", np.max(nid_bincount), "Mean", np.mean(nid_bincount))

[0 0 0 0 0 0 0 0 0 0]
Count: Min 0 Max 0 Mean 0.0
Hz   : Min 0. Hz Max 0. Hz Mean 0. Hz
