### Load packages

In [13]:
import math
import numpy as np
from scipy import signal
from scipy.stats import norm
import matplotlib.pyplot as plt

# Function generators

In [14]:
def noisy_sine_wave(a, f, phi, sigma, time):
    mySin  = np.vectorize(math.sin)
    return a*mySin(2*math.pi*f*time + phi*np.ones(len(time))) + np.random.rand(len(time))*sigma


def sum_of_sine_waves(a, f, phi, sigma, time):
    signal = np.zeros(len(time))
    for j in range(len(a)):
        signal = signal + noisy_sine_wave(a[j], f[j], phi[j], sigma, time)
    return signal


def noisy_gaussian_wave(a, m, s, sigma, time):
    return a*norm.pdf(time, m, s) + sigma*np.random.rand(len(time))


def sum_of_gaussian_wave(a, m, s, sigma, time):
    signal = np.zeros(len(time))
    for j in range(len(a)):
        signal = signal + noisy_gaussian_wave(a[j], m[j], s[j], 0, time) + np.random.rand(len(time))*sigma
    return signal

# Encoders

In [27]:
def temporal_contrast_encoder(data, factor):
    # Based on algorithm provided in:
    #   Sengupta et al. (2017)
    #   Petro et al. (2020)
    diff = np.zeros(len(data)-1)
    spikes = np.zeros(len(data))
    for i in range(len(data)-1):
        diff[i] = data[i+1] - data[i]
    threshold = np.mean(diff) + factor * np.std(diff)
    diff = np.insert(diff, 0, diff[1])
    for i in range(len(data)):
        if diff[i] > threshold:
            spikes[i] = 1
        elif diff[i] < -threshold:
            spikes[i] = -1
    return spikes, threshold


def step_forward_encoder(data, threshold):
    # Based on algorithm provided in:
    #   Petro et al. (2020)
    startpoint = data[0]
    spikes = np.zeros(len(data))
    base = startpoint
    for i in range(1,len(data)):
        if data[i] > base + threshold:
            spikes[i] = 1
            base = base + threshold
        elif data[i] < base - threshold:
            spikes[i] = -1
            base = base - threshold
    return spikes, startpoint


def moving_window_encoder(data, threshold, window):
    # Based on algorithm provided in:
    #   Petro et al. (2020)
    startpoint = data[0]
    spikes = np.zeros(len(data))
    base = np.mean(data[0:window+1])
    for i in range(window+1):
        if data[i] > base + threshold:
            spikes[i] = 1
        elif data[i] < base - threshold:
            spikes[i] = -1
    for i in range(window+2, len(data)):
        base = np.mean(data[(i-window-1):(i-1)])
        if data[i] > base + threshold:
            spikes[i] = 1
        elif data[i] < base - threshold:
            spikes[i] = -1
    return spikes, startpoint


def hough_spike_encoder(data, fir):
    # Based on algorithm provided in:
    #   Schrauwen et al. (2003)
    spikes = np.zeros(len(data))
    shift = min(data)
    data = data - shift*np.ones(len(data))
    for i in range(len(data)):
        count = 0
        for j in range(len(fir)):
            if i+j < len(data):
                if data[i+j] >= fir[j]:
                    count = count + 1
        if count == len(fir):
            spikes[i] = 1
            for j in range(len(fir)):
                if i+j < len(data):
                    data[i+j] = data[i+j] - fir[j]
    return spikes, shift


def modified_hough_spike_encoder(data, fir, threshold):
    # Based on algorithm provided in:
    #   Schrauwen et al. (2003)
    spikes = np.zeros(len(data))
    shift = min(data)
    data = data - shift*np.ones(len(data))
    for i in range(len(data)):
        error = 0
        for j in range(len(fir)):
            if i+j < len(data):
                if data[i+j] < fir[j]:
                    error = error + fir[j] - data[i+j]
        if error <= threshold:
            spikes[i] = 1
            for j in range(len(fir)):
                if i+j < len(data):
                    data[i+j] = data[i+j] - fir[j]
    return spikes, shift


def ben_spike_encoder(data, fir, threshold):
    # Based on algorithm provided in:
    #   Petro et al. (2020)
    #   Sengupta et al. (2017)
    #   Schrauwen et al. (2003)
    spikes = np.zeros(len(data))
    shift = min(data)
    data = data - shift*np.ones(len(data))
    for i in range(len(data)-len(fir)+1):
        err1 = 0
        err2 = 0
        for j in range(len(fir)):
            err1 = err1 + abs(data[i+j] - fir[j])
            err2 = err2 + abs(data[i+j-1])
        if err1 <= err2*threshold:
            spikes[i] = 1
            for j in range(len(fir)):
                if i+j+1 < len(data):
                    data[i+j+1] = data[i+j+1] - fir[j]
    return spikes, shift


def grf_spike_encoder(data, m, min_input, max_input):
    # Adapted from algorithm provided in:
    #   Bohté et al. (2002)
    # Modifications: definition of sigma, removal of beta constant,
    #                and modified WTA process

    if np.isscalar(data):
        data = [data]

    spikes = np.zeros((len(data),m))
    neuron_outputs = np.zeros(m)

    for j in range(len(data)):
        for i in range(m):
            mu = min_input + (2*(i + 1)-3)/2*(max_input - min_input)/(m-2)
            sigma = (max_input - min_input)/(m-2)
            neuron_outputs[i] = norm.pdf(data[j], mu, sigma)

        spikes[j,np.argmax(neuron_outputs)] = 1
    return spikes

def one_hot_place_spike_encoder(data, m, min_input, max_input):
    # Simple population coding algorithm adapted from Stagsted et al. (2020) that represents inputs by a location. 
    # An input is assigned to the neuron that is closest to its value. 
    # Only one neuron fires at every timestep

    if np.isscalar(data):
        data = [data]

    spikes = np.zeros((len(data),m))

    for j in range(len(data)):
        size_change = 1/2*(max_input - min_input)/(m-2) # to make sure it has the same lower/upper bounds as the Bohte paper
        idx = int(np.round(((data[j] - (min_input - size_change)) / ((max_input + size_change) - (min_input - size_change))) * (m - 1)))
        spikes[j, idx] = 1
    
    return spikes


def grf_spike_with_internal_timesteps_encoder(data, min_input, max_input, neurons=10, timesteps=10, beta=1.5):
    """Create a series of spikes based on Gaussian Receptive Fields
    Adapted from algorithm provided in:
        Bohté et al. (2002)
    
    Keyword arguments:
    data -- 
    neurons -- numbers of neurons (default 10)
    timesteps -- number of timesteps (default 10)
    min_input -- minimal value
    max_input -- maximum value
    beta -- tuning parameter that determines the width of the receptive fields
    """

    if np.isscalar(data):
        data = [data]
        
    spikes = np.zeros((len(data), timesteps, neurons))
    responses = np.zeros(neurons)

    # Calculation of mu and sigma of the Gaussian receptive fields
    mu = min_input + (2*(np.arange(neurons)+1)-3)/2*(max_input - min_input)/(neurons-2)
    sigma = 1/beta*(max_input - min_input)/(neurons-2)
    max_prob = norm.pdf(mu[0], mu[0], sigma)

    for j in range(len(data)):
        for i in range(neurons):
            responses[i] = norm.pdf(data[j], mu[i], sigma)
            size_change = max_prob / (2 * timesteps)
            new = int(np.round(((responses[i] + size_change) / (max_prob + 2 * size_change) * (timesteps + 1)) + 0.0001)) # 0.0001 for roundoff errors...
            spiking_time = timesteps - new
            if spiking_time < timesteps - 1:
                spikes[j, spiking_time, i] = 1
    spikes = spikes.reshape([len(data) * timesteps, neurons])
    return spikes

# Decoder

In [28]:
def temporal_contrast_decoder(spikes, threshold):
    # Based on algorithm provided in:
    #   Sengupta et al. (2017)
    #   Petro et al. (2020)
    signal = np.zeros(len(spikes))
    for i in range(1, len(spikes)):
        if spikes[i] > 0:
            signal[i] = signal[i-1] + threshold
        elif spikes[i] < 0:
            signal[i] = signal[i-1] - threshold
        else:
            signal[i] = signal[i-1]
    return signal

def step_forward_decoder(spikes, threshold, startpoint):
    # Based on algorithm provided in:
    #   Petro et al. (2020)
    signal = np.zeros(len(spikes))
    signal[0] = startpoint
    for i in range(1,len(spikes)):
        if spikes[i] > 0:
            signal[i] = signal[i-1] + threshold
        elif spikes[i] < 0:
            signal[i] = signal[i-1] -threshold
        else:
            signal[i] = signal[i-1]
    return signal

def moving_window_decoder(spikes, threshold, startpoint):
    # Based on algorithm provided in:
    #   Petro et al. (2020)
    signal = np.zeros(len(spikes))
    signal[0] = startpoint
    for i in range(1,len(spikes)):
        if spikes[i] > 0:
            signal[i] = signal[i-1] + threshold
        elif spikes[i] < 0:
            signal[i] = signal[i-1] - threshold
        else:
            signal[i] = signal[i-1]
    return signal

def ben_spike_decoder(spikes, fir, shift):
    # Based on algorithm provided in:
    #   Petro et al. (2020)
    #   Sengupta et al. (2017)
    #   Schrauwen et al. (2003)
    signal = np.convolve(spikes, fir)
    signal = signal + shift*np.ones(len(signal))
    signal = signal[0:(len(signal)-len(fir)+1)]
    return signal

def grf_spike_decoder(spikes, min_input, max_input):
    shape = spikes.shape
    signal = np.zeros(shape[0])
    for i in range(shape[0]):
        signal[i] = min_input + (2*(np.argmax(spikes[i,:]) + 1)-3)/2*(max_input - min_input)/(shape[1]-2)
    return signal

def one_hot_place_spike_decoder(spikes, min_input, max_input):
    shape = spikes.shape
    signal = np.zeros(shape[0])
    for i in range(shape[0]):
        signal[i] = min_input + (2*(np.argmax(spikes[i,:]) + 1)-3)/2*(max_input - min_input)/(shape[1]-2)
    return signal

def grf_spike_with_internal_timesteps_decoder(spikes, n_timesteps, min_input, max_input):
    shape = spikes.shape
    spikes = spikes.reshape((int(shape[0]/n_timesteps), n_timesteps, shape[1]))
    signal = np.zeros(len(spikes))
    mu = np.zeros(shape[1])

    for i in range(shape[1]):
        mu[i] = min_input + (2*(i + 1)-3)/2*(max_input - min_input)/(shape[1]-2)

    for i in range(len(spikes)):
        spike_times = np.zeros(shape[1])
        for j in range(n_timesteps):
            for spike_idx in spikes[i, j, :].nonzero():
                spike_times[spike_idx] = n_timesteps - j

        weight_center  = np.sum(mu*spike_times)/np.sum(spike_times)
        signal[i] = weight_center

    return signal

In [45]:
dt = 0.01
T_max = 4

tbr_factors   = [1.005, 1.005, 1.005]
sf_thresholds = [0.35, 0.05, 0.35]
mw_thresholds = [0.325, 0.015, 0.225]
mw_window     = [3, 3, 3]

time                  = np.arange(0, T_max, dt)
signal                = sum_of_sine_waves([2, -0.5, 0.75], [1.0, 3.0, 5.0], [0.0, 0.0, 0.0], 0.0, time)
signal_guassian       = sum_of_gaussian_wave([1, 0.5], [0.2, 0.75], [0.1, 0.1], 0.1, time)
signal_guassian_noisy = noisy_gaussian_wave([1, 0.5], [0.2, 0.75], [0.1, 0.1], 0.1, time)
signal_noisy          = noisy_sine_wave([1, 0.5], [0.2, 0.75], [0.1, 0.1], 0.1, time)

(spikes, threshold) = temporal_contrast_encoder(signal, tbr_factors[0])
signal_TBR = temporal_contrast_decoder(spikes, threshold)

fig, axs = plt.subplots(2, 1,figsize=(15, 7))
axs[0].plot(time, signal)
axs[0].plot(time, signal_TBR)
axs[0].legend(['Original','Reconstructed'])
axs[0].set_ylabel('voltage(V)')
axs[0].set_title("Temporal Contrast Algorithm TBR")
axs[1].stem(time,spikes,use_line_collection=True)
axs[1].set_ylabel('Spikes')
plt.tight_layout()
plt.show()

ValueError: operands could not be broadcast together with shapes (400,) (2,) 

In [17]:
dt = 0.01
T_max = 4
time = np.arange(0, T_max, dt)

S = list()
S.append(sum_of_sine_waves([2, -0.5, 0.75], [1.0, 3.0, 5.0], [0.0, 0.0, 0.0], 0.0, time))
S.append(sum_of_sine_waves([-0.25], [1.0], [0.0], 0.05, time))
S.append(sum_of_gaussian_wave([1, 0.5], [0.2, 0.75], [0.1, 0.1], 0.1, time))

tbr_factors = [1.005, 1.005, 1.005]
    
sf_thresholds = [0.35, 0.05, 0.35]
    
mw_thresholds = [0.325, 0.015, 0.225]
mw_window = [3, 3, 3]

for i in range(len(S)):

    (spikes_TBR, threshold) = temporal_contrast(S[i], tbr_factors[i])
    signal_TBR = 2*temporal_contrast(spikes_TBR, threshold)

    spikes_SF, startpoint = step_forward(S[i], sf_thresholds[i])
    signal_SF = DS.step_forward(spikes_SF, sf_thresholds[i], startpoint)

    spikes_MW, startpoint = moving_window(S[i], mw_thresholds[i], mw_window[i])
    signal_MW = DS.moving_window(spikes_MW, mw_thresholds[i], startpoint)

    plt.subplot(3*len(S),3,(1+i*3*len(S),4+i*3*len(S)))
    plt.plot(time, S[i])
    plt.plot(time, signal_TBR)
    plt.gca().axes.get_xaxis().set_visible(False)
    plt.gca().axes.get_yaxis().set_visible(False)
    if i == 0:
        plt.title("Temporal Contrast Algorithm TBR")

        plt.subplot(3*len(S),3,7+i*3*len(S))
        plt.stem(time, spikes_TBR)
        plt.gca().axes.get_xaxis().set_visible(False)
        plt.gca().axes.get_yaxis().set_visible(False)

        plt.subplot(3*len(S),3,(2+i*3*len(S),5+i*3*len(S)))
        plt.plot(time, S[i])
        plt.plot(time, signal_SF)
        plt.gca().axes.get_xaxis().set_visible(False)
        plt.gca().axes.get_yaxis().set_visible(False)
        if i == 0:
            plt.title("Step Forward Algorithm SF")

        plt.subplot(3*len(S),3,8+i*3*len(S))
        plt.stem(time, spikes_SF)
        plt.gca().axes.get_xaxis().set_visible(False)
        plt.gca().axes.get_yaxis().set_visible(False)

        plt.subplot(3*len(S),3,(3+i*3*len(S),6+i*3*len(S)))
        plt.plot(time, S[i])
        plt.plot(time, signal_MW)
        plt.gca().axes.get_xaxis().set_visible(False)
        plt.gca().axes.get_yaxis().set_visible(False)
        if i == 0:
            plt.title("Moving Window Algorithm MW")

        plt.subplot(3*len(S),3,9+i*3*len(S))
        plt.stem(time, spikes_MW)
        plt.gca().axes.get_xaxis().set_visible(False)
        plt.gca().axes.get_yaxis().set_visible(False)
        
    plt.show()

    hsa_window = [12, 15, 12]
    hsa_fir = list()
    hsa_fir.append(signal.triang(hsa_window[0]))
    hsa_fir.append(norm.pdf(np.linspace(1, hsa_window[1], hsa_window[1]), 0, 5))
    hsa_fir.append(signal.triang(hsa_window[2]))

    hsa_m_thresholds = [0.85, 0.05, 0.5]

    bsa_window = [9, 10, 8]
    bsa_fir = list()
    bsa_fir.append(signal.triang(bsa_window[0]))
    bsa_fir.append(norm.pdf(np.linspace(1, bsa_window[1], bsa_window[1]), 1.5, 3.5))
    bsa_fir.append(signal.triang(bsa_window[2]))

    bsa_thresholds = [1.175, 1.05, 1.2]

    for i in range(len(S)):

        spikes_HSA, shift = ES.hough_spike(S[i], hsa_fir[i])
        signal_HSA = DS.ben_spike(spikes_HSA, hsa_fir[i], shift)

        spikes_HSAm, shift = ES.modified_hough_spike(S[i], hsa_fir[i], hsa_m_thresholds[i])
        signal_HSAm = DS.ben_spike(spikes_HSAm, hsa_fir[i], shift)

        spikes_BSA, shift = ES.ben_spike(S[i], bsa_fir[i], bsa_thresholds[i])
        signal_BSA = DS.ben_spike(spikes_BSA, bsa_fir[i], shift)

        plt.subplot(3*len(S),3,(1+i*3*len(S),4+i*3*len(S)))
        plt.plot(time, S[i])
        plt.plot(time, signal_HSA)
        plt.gca().axes.get_xaxis().set_visible(False)
        plt.gca().axes.get_yaxis().set_visible(False)
        if i == 0:
            plt.title("Hough Spike Algorithm HSA")

        plt.subplot(3*len(S),3,7+i*3*len(S))
        plt.stem(time, spikes_HSA)
        plt.gca().axes.get_xaxis().set_visible(False)
        plt.gca().axes.get_yaxis().set_visible(False)

        plt.subplot(3*len(S),3,(2+i*3*len(S),5+i*3*len(S)))
        plt.plot(time, S[i])
        plt.plot(time, signal_HSAm)
        plt.gca().axes.get_xaxis().set_visible(False)
        plt.gca().axes.get_yaxis().set_visible(False)
        if i == 0:
            plt.title("Threshold Hough Spike Algorithm T-HSA")

        plt.subplot(3*len(S),3,8+i*3*len(S))
        plt.stem(time, spikes_HSAm)
        plt.gca().axes.get_xaxis().set_visible(False)
        plt.gca().axes.get_yaxis().set_visible(False)

        plt.subplot(3*len(S),3,(3+i*3*len(S),6+i*3*len(S)))
        plt.plot(time, S[i])
        plt.plot(time, signal_BSA)
        plt.gca().axes.get_xaxis().set_visible(False)
        plt.gca().axes.get_yaxis().set_visible(False)
        if i == 0:
            plt.title("Ben Spike Algorithm BSA")

        plt.subplot(3*len(S),3,9+i*3*len(S))
        plt.stem(time, spikes_BSA)
        plt.gca().axes.get_xaxis().set_visible(False)
        plt.gca().axes.get_yaxis().set_visible(False)
        
    plt.show()

ValueError: too many values to unpack (expected 2)