In [1]:
import numpy as np
import matplotlib.pyplot as plt
import scipy
%matplotlib inline
plt.viridis()
import time
from brian2.units import *
from scipy.spatial.distance import minkowski

<matplotlib.figure.Figure at 0x6af5f30>

In [2]:
def distance(s, t, grid_shape, dimensions):
        '''
        Function that computes distance in a grid of neurons taking into account periodic boundry conditions.

        First, translate source into the center of the grid.
        Second, translate target by the same amount.
        Finally, perform desired distance computation.
        '''
        s = np.asarray(s)
        t = np.asarray(t)
        _grid_size = np.asarray(grid_shape)
        trans = s - (_grid_size // 2)
        s = np.mod(s - trans, _grid_size)
        t = np.mod(t - trans, _grid_size)
        return minkowski(s, t, dimensions)

In [3]:
def generate_rates(s, grid_shape, dimensions):
    '''
    Function that generates an array the same shape as the input layer so that 
    each cell has a value corresponding to the firing rate for the neuron
    at that position.
    '''
    _rates = np.zeros(grid_shape)
    for x, y in np.ndindex(grid_shape):
        _d = distance(s, (x,y), grid_shape=grid_shape, dimensions=dimensions)
        _rates[x, y] = f_base + f_peak * np.e ** (-_d/(2 * sigma_stim**2))
    return _rates * Hz

In [4]:
duration = 100 * ms
d = 2
# Wiring
n = 16
N_layer = n ** 2
S = (n, n)

s = (n//2, n//2)
sigma_form_forward = 2.5
sigma_form_lateral = 1

# Inputs
f_mean = 20 * Hz
f_base = 5 * Hz
f_peak = 152.8  * Hz
sigma_stim = 2
t_stim = 0.02 * second

In [5]:
_chunk = 20 * ms
_dt = .1 * ms
_chunk / _dt

200.0

In [6]:
smth = [[],]

In [7]:
smth.append([3,2,1])

In [8]:
smth

[[], [3, 2, 1]]

The idea is to generate a 2D spike time array (different times for different spike source). To obtain a set a times I need to consider the rate of each spiking source, then generate a poisson series, and convert nonzero indices to times.

In [192]:
def generate_spike_times(s, dt=.1 * ms, chunk = 20 * ms, time_offset=0. * ms, grid_shape=(16,16), dimensions=2):
    rates = generate_rates(s, grid_shape, dimensions)
    spike_times = []
    for _, rate in np.ndenumerate(rates/Hz):
        spikes = np.random.poisson(rate/1000., int(chunk/dt))
        fucking_indices = np.nonzero(spikes)[0]
        fucking_indices_as_times = fucking_indices * dt
        shifted_fucking_indices = fucking_indices_as_times + time_offset
        spike_times.append((( shifted_fucking_indices)/ms).tolist())
    return spike_times

In [198]:
spike_times = generate_spike_times((13,8), time_offset=20*ms, grid_shape=S, dimensions=2)

In [199]:
def give_me_spike_times(N = 256, duration=25*ms, t_stim = 20*ms, dimensions=2):
    spike_times = [[], ] * N
    time_slot = 0
    for time_slot in range(int(duration / t_stim)):
        _sp = generate_spike_times(np.random.randint(0, np.sqrt(N), 2),
                                chunk=t_stim,
                                time_offset=time_slot * t_stim,
                                dimensions=dimensions,
                                  grid_shape=(np.sqrt(N),np.sqrt(N)))
        for index, value in np.ndenumerate(_sp):
            if hasattr(value , '__iter__'):
                for v in value:
                    spike_times[index[0]].append(v)
            else:
                spike_times[index[0]].append(value)
    print "time slot", time_slot
    if not np.isclose(duration % t_stim / ms, 0):
        print duration % t_stim
        _sp = generate_spike_times(np.random.randint(0, 16, 2),
                                chunk=duration % t_stim,
                                time_offset=(time_slot+1) * t_stim,
                                dimensions=dimensions,
                                  grid_shape=(np.sqrt(N),np.sqrt(N)))
        for index, value in np.ndenumerate(_sp):
            if hasattr(value, '__iter__'):
                for v in value:
                    spike_times[index[0]].append(v)
            else:
                spike_times[index[0]].append(value)
    return spike_times

In [200]:
spike_times = give_me_spike_times(4)



time slot 0
5. ms


In [201]:
spike_times[0]

[0.1,
 0.4,
 2.6,
 3.0,
 7.800000000000001,
 8.5,
 10.3,
 11.0,
 11.7,
 12.100000000000001,
 12.4,
 13.8,
 13.9,
 14.0,
 15.1,
 15.5,
 16.3,
 17.2,
 18.4,
 19.700000000000003,
 0.0,
 0.9,
 1.5,
 2.5,
 4.4,
 4.8999999999999995,
 5.800000000000001,
 6.4,
 6.7,
 6.800000000000001,
 8.700000000000001,
 9.2,
 9.7,
 12.3,
 12.8,
 13.4,
 13.8,
 14.9,
 15.4,
 15.600000000000001,
 16.099999999999998,
 16.7,
 16.900000000000002,
 17.400000000000002,
 18.6,
 19.0,
 19.900000000000002,
 0.1,
 0.4,
 0.8,
 1.3,
 1.9,
 2.0,
 2.8,
 3.7,
 3.9000000000000004,
 4.0,
 6.9,
 7.9,
 8.9,
 9.5,
 10.3,
 13.0,
 13.8,
 13.9,
 14.200000000000001,
 14.6,
 15.700000000000001,
 16.0,
 16.2,
 17.8,
 18.2,
 18.9,
 3.2,
 4.8999999999999995,
 5.3,
 6.0,
 6.4,
 7.3,
 7.6,
 8.0,
 9.3,
 9.7,
 10.0,
 10.5,
 10.6,
 11.2,
 11.600000000000001,
 11.9,
 12.3,
 12.4,
 14.700000000000001,
 15.0,
 16.099999999999998,
 16.400000000000002,
 16.5,
 16.6,
 17.7,
 18.4,
 18.500000000000004,
 19.200000000000003,
 19.4,
 20.6,
 22.9,
 24.

In [202]:
len(spike_times[0])

124

In [114]:
np.argsort(spike_times[0])

array([2786, 2255,  671, ..., 4822, 4828, 4731])