In [36]:
import numpy as np
from collections import Counter
from itertools import repeat
import os
import subprocess

In [42]:
def get_windows(t_start, t_end, win_len, skip):
    '''
    Given a time interval [t_start, t_end], 
    compute lists of window start and end times.
    Windows start every 'skip' samples and are 'win_len' samples long
    -checked
    '''
    assert(skip > 0)
    win_starts = np.arange(t_start, t_end, skip)
    win_ends = win_starts + win_len -1
    return (win_starts, win_ends)

def spikes_in_interval(spikes, t_lo, t_hi, cell_group):
    ''' 
    Given an array of spikes, find all the spikes whose time is between t_lo and t_hi inclusive
    store them in the 
    '''
    if len(spikes) == 0:
        return
    
    m = int(np.floor((len(spikes)) / 2))
    
    if t_lo < spikes[m, 0]:
        spikes_in_interval(spikes[0:m], t_lo, t_hi, cell_group)
    if t_lo <= spikes[m, 0] and t_hi >= spikes[m,0]:
        cell_group.append(spikes[m, 1])  
    if t_hi > spikes[m, 0]:
        spikes_in_interval(spikes[m+1:], t_lo, t_hi, cell_group)
        
def total_firing_rates(spikes, stim_start, stim_end):
    
    stim_spikes = []
    spikes_in_interval(spikes, stim_start, stim_end, stim_spikes)
    c = Counter(stim_spikes)
    return c
    
def spike_list_to_cell_group(spike_list, clu_rates, thresh, dt, T):
    ''' 
    Given a spike list, first counts number of each spikes from each cluster
    then given a dictionary that maps clusters to total firing rates, computes whether the cluster's firing
    rate in the spike list exceeded some threshold by dividing the cluster's number of in-window spikes by the 
    window length in seconds and then checking if that value is greater than the threshold times the cluster total firing rate givne in clu_rates
    '''
    
    c = Counter(spike_list)
    cg = set()
    
    for clu in clu_rates.keys():
        if (c[clu] / dt) >= (thresh*clu_rates[clu] / T):
            cg.add(clu)
    return cg

def spikes_to_cell_groups(spikes, stim_start, stim_end, win_len, fs):
    
    total_frs = total_firing_rates(spikes, stim_start, stim_end)
    win_starts, win_ends = get_windows(stim_start, stim_end, win_len, win_len)
    cell_groups = []
    for ind, (ws, we) in enumerate(zip(win_starts, win_ends)):
        spike_list = []
        spikes_in_interval(spikes, ws, we, spike_list)
        if spike_list:
            window_time = (we + ws) / 2 - stim_start
            cell_groups.append((ind, window_time, sorted(tuple(spike_list_to_cell_group(spike_list, total_frs, dt, 1, T)))))
    return cell_groups

def cell_groups_to_bin_mat(cell_groups, ncells, nwins):
    
    bin_mat = np.zeros((ncells, nwins))
    for cg in cell_groups:
        bin_mat[cg[2], cg[0]] = 1
    return bin_mat

def build_perseus_persistent_input(cell_groups, savefile):
    """
    Formats cell group information as an input file
    for the Perseus persistent homology software, but assigns filtration
    levels for each cell group based on the time order of their appearance
    in the signal.

    Parameters
    ----------
    cell_groups : list
        cell_group information returned by spikes_to_cell_groups
    savefile : str
        File in which to put the formatted cellgroup information

    Yields
    ------
    savefile : text File
        file suitable for running perseus on
    """
    with open(savefile, "w+") as pfile:
        pfile.write("1\n")
        for cell_group in cell_groups:
            grp = list(cell_group[2])
            grp_dim = len(grp) - 1
            if grp_dim < 0:
                continue
            vert_str = str(grp)
            vert_str = vert_str.replace("[", "")
            vert_str = vert_str.replace("]", "")
            vert_str = vert_str.replace(" ", "")
            vert_str = vert_str.replace(",", " ")
            out_str = str(grp_dim) + " " + vert_str + " {}\n".format(str(cell_group[0] + 1))
            pfile.write(out_str)
    return savefile

def run_perseus(pfile):
    """
    Runs Perseus persistent homology software on the data in pfile

    Parameters
    ------
    pfile : str
        File on which to compute homology

    Returns
    ------
    betti_file : str
        File containing resultant betti numbers

    """
    pfile_split = os.path.splitext(pfile)
    of_string = pfile_split[0]
    perseus_command = "perseus"
    perseus_return_code = subprocess.call(
        [perseus_command, "nmfsimtop", pfile, of_string]
    )

    betti_file = of_string + "_betti.txt"
    # betti_file = os.path.join(os.path.split(pfile)[0], betti_file)

    return betti_file

def read_perseus_result(betti_file):
    bettis = []
    f_time = []
    try:
        with open(betti_file, "r") as bf:
            for bf_line in bf:
                if len(bf_line) < 2:
                    continue
                betti_data = bf_line.split()
                filtration_time = int(betti_data[0])
                betti_numbers = list(map(int, betti_data[1:]))
                bettis.append([filtration_time, betti_numbers])
    except:
        bettis.append([-1, [-1]])
    return bettis

def compute_bettis(spikes, stim_start, stim_end, win_len, fs):
    
    win_starts, win_ends = get_windows(stim_start, stim_end, win_len, win_len)
    cell_groups = spikes_to_cell_groups(spikes, stim_start, stim_end, win_len, fs)
    build_perseus_persistent_input(cell_groups, './test.betti')
    betti_file = run_perseus('./test.betti')
    betti_nums = read_perseus_result(betti_file)
    
    betti_nums = [[win_starts[x[0]] - stim_start, x[0], x[1]] for x in betti_nums]
    return betti_nums
    
    
    

In [31]:
n_t = 1048563
fr = 100
n_sp = int(fr * (n_t / 30000))
ts = sorted(np.random.randint(n_t, size=n_sp))
s = np.random.randint(10, size=n_sp)
spikes = np.vstack((ts, s))
spikes =spikes.T
print(n_sp/n_t * 30000)
print(n_sp)

99.9939917773181
3495


In [51]:
# get total firing rates:

stim_start = 500
stim_end = 123000
fs = 30000
T = (stim_end - stim_start) / fs
win_len = 1239
dt = win_len / fs

In [5]:
%%time

total_frs = total_firing_rates(spikes, stim_start, stim_end)

win_starts, win_ends = get_windows(stim_start, stim_end, win_len, win_len)

cell_groups = []
for ws, we in zip(win_starts, win_ends):
    spike_list = []
    spikes_in_interval(spikes, ws, we, spike_list)
    if spike_list:
        cell_groups.append(spike_list_to_cell_group(spike_list, total_frs, dt, 1, T))

CPU times: user 14 ms, sys: 0 ns, total: 14 ms
Wall time: 13.9 ms


In [33]:
%time cell_groups = spikes_to_cell_groups(spikes, stim_start, stim_end, win_len, fs)

CPU times: user 11 ms, sys: 63 µs, total: 11 ms
Wall time: 10.9 ms


In [34]:
nwin = len(win_starts)
ncells = 10
cell_groups_to_bin_mat(cell_groups, ncells, nwin)

array([[0., 1., 0., ..., 0., 0., 1.],
       [1., 0., 0., ..., 0., 1., 0.],
       [0., 0., 0., ..., 1., 0., 0.],
       ...,
       [0., 0., 0., ..., 1., 0., 0.],
       [0., 0., 0., ..., 1., 0., 0.],
       [0., 0., 0., ..., 1., 0., 0.]])

In [38]:
build_perseus_persistent_input(cell_groups, './test.betti')

run_perseus('./test.betti')

read_perseus_result('./test_betti.txt')



'./test.betti'

In [52]:
%time compute_bettis(spikes, stim_start, stim_end, win_len, fs)

CPU times: user 7.6 ms, sys: 3.9 ms, total: 11.5 ms
Wall time: 54.6 ms


[[1239, 1, [1, 0, 0, 0, 0]],
 [4956, 4, [1, 1, 0, 0, 0]],
 [7434, 6, [1, 0, 0, 0, 0]],
 [11151, 9, [1, 1, 0, 0, 0]],
 [14868, 12, [1, 1, 1, 0, 0]],
 [16107, 13, [1, 0, 1, 0, 0]],
 [18585, 15, [1, 0, 2, 0, 0]],
 [24780, 20, [1, 0, 3, 0, 0]],
 [30975, 25, [1, 0, 5, 0, 0]],
 [45843, 37, [1, 0, 6, 0, 0]],
 [63189, 51, [1, 0, 7, 0, 0]],
 [66906, 54, [1, 0, 6, 0, 0]],
 [73101, 59, [1, 1, 6, 0, 0]],
 [74340, 60, [1, 1, 5, 0, 0]],
 [75579, 61, [1, 0, 4, 0, 0]],
 [78057, 63, [1, 0, 5, 0, 0]],
 [83013, 67, [1, 0, 5, 2, 0]],
 [97881, 79, [1, 0, 2, 2, 0]],
 [106554, 86, [1, 0, 1, 4, 0]],
 [115227, 93, [1, 0, 1, 5, 0]],
 [121422, 98, [1, 0, 1, 3, 0]]]

In [59]:
a,b = get_windows(stim_start, stim_end, win_len, win_len-15)
print(a, b)

[   500   1724   2948   4172   5396   6620   7844   9068  10292  11516
  12740  13964  15188  16412  17636  18860  20084  21308  22532  23756
  24980  26204  27428  28652  29876  31100  32324  33548  34772  35996
  37220  38444  39668  40892  42116  43340  44564  45788  47012  48236
  49460  50684  51908  53132  54356  55580  56804  58028  59252  60476
  61700  62924  64148  65372  66596  67820  69044  70268  71492  72716
  73940  75164  76388  77612  78836  80060  81284  82508  83732  84956
  86180  87404  88628  89852  91076  92300  93524  94748  95972  97196
  98420  99644 100868 102092 103316 104540 105764 106988 108212 109436
 110660 111884 113108 114332 115556 116780 118004 119228 120452 121676
 122900] [  1738   2962   4186   5410   6634   7858   9082  10306  11530  12754
  13978  15202  16426  17650  18874  20098  21322  22546  23770  24994
  26218  27442  28666  29890  31114  32338  33562  34786  36010  37234
  38458  39682  40906  42130  43354  44578  45802  47026  48250  494