In [23]:
import numpy as np
from collections import Counter
from itertools import repeat
import os
import subprocess
import h5py as h5

  from ._conv import register_converters as _register_converters


In [41]:
def get_windows(t_start, t_end, win_len, skip):
    '''
    Given a time interval [t_start, t_end], 
    compute lists of window start and end times.
    Windows start every 'skip' samples and are 'win_len' samples long
    -checked
    '''
    assert(skip > 0)
    win_starts = np.arange(t_start, t_end, skip)
    win_ends = win_starts + win_len -1
    return (win_starts, win_ends)

def spikes_in_interval(spikes, t_lo, t_hi, cell_group):
    ''' 
    Given an array of spikes, find all the spikes whose time is between t_lo and t_hi inclusive
    store them in the 
    '''
    if len(spikes) == 0:
        return
    
    m = int(np.floor((len(spikes)) / 2))
    
    if t_lo < spikes[m, 0]:
        spikes_in_interval(spikes[0:m], t_lo, t_hi, cell_group)
    if t_lo <= spikes[m, 0] and t_hi >= spikes[m,0]:
        cell_group.append(spikes[m, 1])  
    if t_hi > spikes[m, 0]:
        spikes_in_interval(spikes[m+1:], t_lo, t_hi, cell_group)
        
def total_firing_rates(spikes, stim_start, stim_end):
    
    stim_spikes = []
    spikes_in_interval(spikes, stim_start, stim_end, stim_spikes)
    c = Counter(stim_spikes)
    return c
    
def spike_list_to_cell_group(spike_list, clu_rates, thresh, dt, T):
    ''' 
    Given a spike list, first counts number of each spikes from each cluster
    then given a dictionary that maps clusters to total firing rates, computes whether the cluster's firing
    rate in the spike list exceeded some threshold by dividing the cluster's number of in-window spikes by the 
    window length in seconds and then checking if that value is greater than the threshold times the cluster total firing rate givne in clu_rates
    '''
    
    c = Counter(spike_list)
    cg = set()
    
    for clu in clu_rates.keys():
        if (c[clu] / dt) >= (thresh*clu_rates[clu] / T):
            cg.add(clu)
    return cg

def spikes_to_cell_groups(spikes, stim_start, stim_end, win_len, fs):
    
    total_frs = total_firing_rates(spikes, stim_start, stim_end)
    win_starts, win_ends = get_windows(stim_start, stim_end, win_len, win_len)
    cell_groups = []
    for ind, (ws, we) in enumerate(zip(win_starts, win_ends)):
        spike_list = []
        spikes_in_interval(spikes, ws, we, spike_list)
        if spike_list:
            window_time = (we + ws) / 2 - stim_start
            cell_groups.append((ind, window_time, sorted(tuple(spike_list_to_cell_group(spike_list, total_frs, dt, 1, T)))))
    return cell_groups

def cell_groups_to_bin_mat(cell_groups, ncells, nwins):
    
    bin_mat = np.zeros((ncells, nwins))
    for cg in cell_groups:
        bin_mat[cg[2], cg[0]] = 1
    return bin_mat

def build_perseus_persistent_input(cell_groups, savefile):
    """
    Formats cell group information as an input file
    for the Perseus persistent homology software, but assigns filtration
    levels for each cell group based on the time order of their appearance
    in the signal.

    Parameters
    ----------
    cell_groups : list
        cell_group information returned by spikes_to_cell_groups
    savefile : str
        File in which to put the formatted cellgroup information

    Returns
    ------
    savefile : text File
        input file for perseus
    """
    with open(savefile, "w+") as pfile:
        pfile.write("1\n")
        for cell_group in cell_groups:
            grp = list(cell_group[2])
            grp_dim = len(grp) - 1
            if grp_dim < 0:
                continue
            vert_str = str(grp)
            vert_str = vert_str.replace("[", "")
            vert_str = vert_str.replace("]", "")
            vert_str = vert_str.replace(" ", "")
            vert_str = vert_str.replace(",", " ")
            out_str = str(grp_dim) + " " + vert_str + " {}\n".format(str(cell_group[0] + 1))
            pfile.write(out_str)
    return savefile

def run_perseus(pfile):
    """
    Runs Perseus persistent homology software on the data in pfile

    Parameters
    ------
    pfile : str
        File on which to compute homology

    Returns
    ------
    betti_file : str
        File containing resultant betti numbers

    """
    pfile_split = os.path.splitext(pfile)
    of_string = pfile_split[0]
    perseus_command = "perseus"
    perseus_return_code = subprocess.call(
        [perseus_command, "nmfsimtop", pfile, of_string]
    )

    betti_file = of_string + "_betti.txt"
    # betti_file = os.path.join(os.path.split(pfile)[0], betti_file)

    return betti_file

def read_perseus_result(betti_file):
    bettis = []
    f_time = []
    try:
        with open(betti_file, "r") as bf:
            for bf_line in bf:
                if len(bf_line) < 2:
                    continue
                betti_data = bf_line.split()
                filtration_time = int(betti_data[0])
                betti_numbers = list(map(int, betti_data[1:]))
                bettis.append([filtration_time, betti_numbers])
    except:
        bettis.append([-1, [-1]])
    return bettis

def compute_bettis(spikes, stim_start, stim_end, win_len, fs):
    
    win_starts, win_ends = get_windows(stim_start, stim_end, win_len, win_len)
    cell_groups = spikes_to_cell_groups(spikes, stim_start, stim_end, win_len, fs)
    build_perseus_persistent_input(cell_groups, './test.betti')
    betti_file = run_perseus('./test.betti')
    betti_nums = read_perseus_result(betti_file)
    betti_nums = [[win_starts[x[0]-1] - stim_start, x[0], x[1]] for x in betti_nums]
    
    return betti_nums
    
from scipy.interpolate import interp1d
def betti_curve_func(betti_nums, dim, stim_start, stim_end, fs, t_in_seconds=False):
    
    betti_ts = [x[0] for x in betti_nums]
    betti_vals = [x[2][dim] for x in betti_nums]
    if t_in_seconds:
        betti_ts = list(map(lambda x: x / fs, betti_ts))
    f = interp1d(betti_ts, betti_vals, kind='zero', bounds_error = False, fill_value=(0, betti_vals[-1]))
    return f

def kwik_get_trials(kwikfile):
    with h5.File(kwikfile, 'r') as f:
        stim_names = list(f['/event_types/Stimulus/text'])
        stim_names = [x.decode('utf-8') for x in stim_names]
        stim_start_times = list(f['/event_types/Stimulus/time_samples'])
        stim_end_times = list(f['/event_types/Stimulus/stimulus_end'])
    return list(zip(stim_names, stim_start_times, stim_end_times))
        
    
def kwik_get_spikes(kwikfile):
    with h5.File(kwikfile, 'r') as f:
        spikes_clus = np.array(f['/channel_groups/0/spikes/clusters/main'])
        spikes_times = np.array(f['/channel_groups/0/spikes/time_samples'])
        
    spikes = np.vstack((spikes_times, spikes_clus)).T
    return spikes

In [6]:
n_t = 1048563
fr = 100
n_sp = int(fr * (n_t / 30000))
ts = sorted(np.random.randint(n_t, size=n_sp))
s = np.random.randint(10, size=n_sp)
spikes = np.vstack((ts, s))
spikes =spikes.T
print(n_sp/n_t * 30000)
print(n_sp)

99.9939917773181
3495


In [7]:
# get total firing rates:

stim_start = 500
stim_end = 123000
fs = 30000
T = (stim_end - stim_start) / fs
win_len = 1239
dt = win_len / fs

In [8]:
%time cell_groups = spikes_to_cell_groups(spikes, stim_start, stim_end, win_len, fs)

CPU times: user 8.43 ms, sys: 42 µs, total: 8.48 ms
Wall time: 8.38 ms


In [34]:
nwin = len(win_starts)
ncells = 10
cell_groups_to_bin_mat(cell_groups, ncells, nwin)

array([[0., 1., 0., ..., 0., 0., 1.],
       [1., 0., 0., ..., 0., 1., 0.],
       [0., 0., 0., ..., 1., 0., 0.],
       ...,
       [0., 0., 0., ..., 1., 0., 0.],
       [0., 0., 0., ..., 1., 0., 0.],
       [0., 0., 0., ..., 1., 0., 0.]])

In [38]:
build_perseus_persistent_input(cell_groups, './test.betti')

run_perseus('./test.betti')

read_perseus_result('./test_betti.txt')



'./test.betti'

In [11]:
%time betti_nums = compute_bettis(spikes, stim_start, stim_end, win_len, fs)
print(betti_nums)

CPU times: user 6.97 ms, sys: 3.9 ms, total: 10.9 ms
Wall time: 49.6 ms
[[1239, 1, [1, 0, 0, 0, 0]], [3717, 3, [1, 1, 0, 0, 0]], [4956, 4, [1, 0, 0, 0, 0]], [13629, 11, [1, 1, 0, 0, 0]], [21063, 17, [1, 1, 1, 0, 0]], [22302, 18, [1, 1, 2, 0, 0]], [27258, 22, [1, 0, 3, 0, 0]], [33453, 27, [1, 0, 4, 0, 0]], [38409, 31, [1, 0, 6, 0, 0]], [40887, 33, [1, 0, 8, 0, 0]], [50799, 41, [1, 0, 7, 0, 0]], [52038, 42, [1, 0, 8, 0, 0]], [58233, 47, [1, 0, 7, 0, 0]], [61950, 50, [1, 0, 7, 0, 0]], [68145, 55, [1, 0, 8, 0, 0]], [76818, 62, [1, 0, 7, 1, 0]], [85491, 69, [1, 0, 5, 1, 0]], [95403, 77, [1, 0, 6, 1, 0]], [100359, 81, [1, 0, 5, 1, 0]], [102837, 83, [1, 0, 4, 1, 0]], [106554, 86, [1, 0, 3, 1, 0]], [113988, 92, [1, 0, 3, 2, 0]]]


In [20]:
f = betti_curve_func(betti_nums, 2, stim_start, stim_end, fs, t_in_seconds=True)
    

In [30]:
trials = kwik_get_trials('/home/brad/experiments/B1146/sorted/block-4-AP-2300-ML-400-Z-1750/experiment1_101.kwik')

In [36]:
spikes = kwik_get_spikes('/home/brad/experiments/B1146/sorted/block-4-AP-2300-ML-400-Z-1750/experiment1_101.kwik')

In [39]:
spikes

array([[      165,        51],
       [      359,         9],
       [      570,        83],
       ...,
       [543635798,       159],
       [543636109,        32],
       [543636282,       146]])

In [42]:
trial_subset = trials[:10]
fs = 30000
win_len = np.round(0.01 * 30000)
for tr in trial_subset:
    stim_start = tr[1]
    stim_end = tr[2]
    stim_name = tr[0]
    betti_nums = compute_bettis(spikes, stim_start, stim_end, win_len, fs)
    print(stim_name, betti_nums)

2_N_G105_s_04_@2___N1211_s_04@1.wav.sine [[0.0, 1, [1, 0, 0]], [300.0, 2, [2, 0, 0]], [600.0, 3, [3, 0, 0]], [1200.0, 5, [4, 0, 0]], [1500.0, 6, [5, 0, 0]], [1800.0, 7, [6, 0, 0]], [2400.0, 9, [4, 0, 0]], [2700.0, 10, [5, 0, 0]], [3000.0, 11, [4, 0, 0]], [3300.0, 12, [3, 0, 0]], [3600.0, 13, [2, 0, 0]], [4800.0, 17, [3, 0, 0]], [5100.0, 18, [3, 1, 0]], [5700.0, 20, [3, 2, 0]], [6000.0, 21, [3, 4, 0]], [6300.0, 22, [2, 4, 0]], [6600.0, 23, [2, 5, 0]], [7500.0, 26, [2, 6, 0]], [8100.0, 28, [2, 7, 0]], [8700.0, 30, [2, 8, 0]], [9300.0, 32, [2, 9, 0]], [10200.0, 35, [2, 11, 0]], [10500.0, 36, [3, 11, 0]], [11100.0, 38, [3, 12, 0]], [14100.0, 48, [2, 13, 0]], [14400.0, 49, [2, 14, 0]], [15300.0, 52, [2, 15, 0]], [17700.0, 60, [2, 16, 0]], [18300.0, 62, [2, 15, 0]], [18600.0, 63, [2, 16, 0]], [18900.0, 64, [2, 17, 0]], [19200.0, 65, [2, 18, 0]], [19800.0, 67, [2, 19, 0]], [20100.0, 68, [2, 20, 0]], [20400.0, 69, [2, 21, 0]], [22200.0, 75, [2, 23, 0]], [22800.0, 77, [2, 25, 0]], [23100.0, 78,

3_M_G105_s_06_@1___M_G105_s_06_@1.wav.sine [[0.0, 1, [1, 0, 0]], [300.0, 2, [2, 0, 0]], [600.0, 3, [3, 0, 0]], [1500.0, 6, [4, 0, 0]], [2100.0, 8, [3, 0, 0]], [3000.0, 11, [2, 0, 0]], [3600.0, 13, [2, 1, 0]], [5100.0, 18, [2, 2, 0]], [5400.0, 19, [2, 3, 0]], [5700.0, 20, [1, 3, 0]], [6000.0, 21, [1, 4, 0]], [6300.0, 22, [1, 5, 0]], [6600.0, 23, [1, 6, 0]], [7800.0, 27, [1, 8, 0]], [9000.0, 31, [1, 9, 0]], [10200.0, 35, [1, 10, 0]], [10500.0, 36, [1, 11, 0]], [12600.0, 43, [1, 12, 0]], [12900.0, 44, [1, 11, 0]], [13500.0, 46, [2, 11, 0]], [14700.0, 50, [2, 12, 0]], [15900.0, 54, [2, 13, 0]], [16200.0, 55, [3, 13, 0]], [16800.0, 57, [3, 14, 0]], [17700.0, 60, [3, 16, 0]], [18000.0, 61, [3, 17, 0]], [18300.0, 62, [3, 18, 0]], [19800.0, 67, [3, 19, 0]], [21900.0, 74, [3, 20, 0]], [22200.0, 75, [3, 21, 0]], [24300.0, 82, [3, 22, 0]], [25200.0, 85, [2, 22, 0]], [25500.0, 86, [2, 23, 0]], [26700.0, 90, [2, 24, 0]], [27600.0, 93, [2, 25, 0]], [28800.0, 97, [2, 26, 0]], [32700.0, 110, [2, 29, 0