In [None]:
%matplotlib inline
import matplotlib.pyplot as plt
import numpy as np
from scipy import signal
import warnings
import nept

In [None]:
def get_counts(spikes, edges, gaussian_std=None, n_gaussian_std=5):
    dt = np.median(np.diff(edges))

    if gaussian_std is not None:
        n_points = n_gaussian_std * gaussian_std * 2 / dt
        n_points = max(n_points, 1.0)
        if n_points % 2 == 0:
            n_points += 1
        if n_points > len(edges):
            raise ValueError("gaussian_std is too large for these times")
        gaussian_filter = signal.gaussian(n_points, gaussian_std/dt)
        gaussian_filter /= np.sum(gaussian_filter)

    counts = np.zeros((len(spikes), len(edges)-1))
    for idx, spiketrain in enumerate(spikes):
        counts[idx] = np.histogram(spiketrain.time, bins=edges)[0]
        if gaussian_std is not None and gaussian_std > dt:
            counts[idx] = np.convolve(counts[idx], gaussian_filter, mode='same')

    return nept.AnalogSignal(counts.T, edges[:-1])

In [None]:
spikes = [nept.SpikeTrain([1., 1., 1., 5., 5., 7.])]

edges = [0, 2.5, 4, 5, 6, 10]
counts = get_counts(spikes, edges)

In [None]:
counts.data, counts.time

In [None]:
position = nept.Position(np.array([1., 6.]), np.array([0., 4.1]))

In [None]:
edges = nept.get_edges(position, binsize=0.5, lastbin=True)

In [None]:
edges

In [None]:
np.array([0., 0.5, 1., 1.5, 2., 2.5, 3., 3.5, 4., 4.1])

In [None]:
spikes = [nept.SpikeTrain([0.8, 1.1, 1.2, 1.2, 2.1, 3.1]),
          nept.SpikeTrain([0.8, 1.1, 1.2, 1.2, 2.1, 3.1])]
position = nept.Position(np.array([1., 6.]), np.array([0., 4.]))

In [None]:
window = 2.2 
dt = 0.5

n_bins = window / dt
rounded_n_bins = round(n_bins)
if abs(n_bins - rounded_n_bins) > 0.01:
    warnings.warn("window advance does not divide the window size evenly. "
                  "Using window size " + str(rounded_n_bins*dt) + " instead.")

In [None]:
def bin_spikes(spikes, position, window_size, window_advance, gaussian_std=None, n_gaussian_std=5, normalized=True):
    """Bins spikes using a sliding window.
    
    Parameters
    ----------    
    spikes: list
        Of nept.SpikeTrain
    position: nept.Position
    window: float
    window_advance: float
    gaussian_std: float
    n_gaussian_std: int
    normalized: boolean
    
    Returns
    -------
    binned_spikes: nept.AnalogSignal
    
    """
    bin_edges = np.arange(position.time[0], position.time[-1], window_advance)
    
    given_n_bins = window_size / window_advance
    n_bins = round(given_n_bins)
    if abs(n_bins - given_n_bins) > 0.01:
        warnings.warn("window advance does not divide the window size evenly. "
                      "Using window size " + str(n_bins*window_advance) + " instead.")
    if normalized:
        square_filter = np.ones(n_bins) * (1 / n_bins)
    else:
        square_filter = np.ones(n_bins)
    
    if gaussian_std is not None:
        n_points = n_gaussian_std * gaussian_std * 2 / window_advance
        n_points = max(n_points, 1.0)
        if n_points % 2 == 0:
            n_points += 1
        gaussian_filter = signal.gaussian(n_points, gaussian_std / window_advance)
        gaussian_filter /= np.sum(gaussian_filter)
        
        smoothed_spikes = []
        for spiketrain in spikes:
            smoothed_spikes.append(nept.SpikeTrain(np.convolve(spiketrain.time, 
                                                               gaussian_filter, 
                                                               mode='same'), 
                                                   spiketrain.label))
    else:
        smoothed_spikes = spikes
    
    counts = np.zeros((len(smoothed_spikes), len(bin_edges)-1))
    for idx, spiketrain in enumerate(smoothed_spikes):
        counts[idx] = np.convolve(np.histogram(spiketrain.time, bins=bin_edges)[0], 
                                  square_filter, 
                                  mode='same')
    
    return nept.AnalogSignal(counts.T, bin_edges[:-1])

In [None]:
t = bin_spikes(spikes, position, window_size=2, window_advance=0.5, gaussian_std=None)

In [None]:
t.data

In [None]:
np.array([[0.25, 0.25], [1., 1.], [1., 1.], [1.25, 1.25], [1., 1.], [0.5, 0.5], [0.5 , 0.5]])

In [None]:
t = bin_spikes(spikes, position, window_size=2, window_advance=0.5, gaussian_std=1, normalized=True)

In [None]:
t.data

In [None]:
np.array([[4.25], [5.25], [5.25], [2.], [1.], [0.], [0.]])