In [None]:
from ephys import core, events
import numpy as np
import neuraltda.topology3 as tp3

In [None]:
def shuffle_trial_spikes(spikes, stim_start, stim_end):
    # trial_spikes is a list of (sample, id) spikes
    # for each spike, generate a new sample in range stim_start, stim_end
    trial_spikes = []
    tp3.spikes_in_interval(spikes, stim_start, stim_end, trial_spikes)
    new_spikes_times = []
    new_spikes_ids = []
    for spike in trial_spikes:
        new_sample = np.random.randint(stim_start, high=stim_end+1)
        new_spikes_times.append(new_sample)
        new_spikes_ids.append(spike[1])
    out_spikes = np.vstack((new_spikes_times, new_spikes_ids)).T
    out_spikes_ind = np.argsort(out_spikes[:, 0])
    
    return out_spikes[out_spikes_ind, :]

def get_betti_curve(spikes, stim_name, stim_start, stim_end, win_len, fs, thresh, t, dim, betti_curve_dict, times):
    betti_nums = tp3.compute_bettis(spikes, stim_start, stim_end, win_len, fs, thresh=4.0)

    betti_func = tp3.betti_curve_func(betti_nums, dim, stim_start, stim_end, fs, t_in_seconds=True)
    betti_curve = betti_func(t)
  
    if stim_name not in betti_curve_dict.keys():
        betti_curve_dict[stim_name] = [] 
        times[stim_name] = []
    betti_curve_dict[stim_name].append(betti_curve)
    times[stim_name].append(t)

In [None]:
bp = '/home/brad/krista/B1083/P03S03/'
kwikfile = '/home/brad/krista/B1083/P03S03/B1083_cat_P03_S03_1.kwik'
trials = events.load_trials(bp)
spikes = core.load_spikes(bp)
clus = core.load_clusters(bp)
fs = core.load_fs(bp)

gc = list(clus[clus['quality']=='Good']['cluster'])
spikes = spikes[spikes['cluster'].isin(gc)]
spike_t = np.array(spikes['time_samples']).astype(int)
spike_clu = np.array(spikes['cluster']).astype(int)
spikes = np.vstack((spike_t, spike_clu)).T
spikes.shape

stim_name = list(trials['stimulus'])
stim_start = list(trials['time_samples'])
stim_end = list(trials['stimulus_end'])
trials = list(zip(stim_name, stim_start, stim_end))

In [None]:
import tqdm
trial_subset = trials
win_len = np.round(0.01 * fs)
t = np.linspace(0, 6, 1000)
dim = 2
thresh = 4.0

betti_curves = {}
shuffled_betti_curves = {}
times = {}
shuffled_times = {}
for tr in tqdm.tqdm(trial_subset):
    
    stim_start = tr[1]
    stim_end = tr[2]
    stim_name = tr[0]

    t = np.linspace(0, stim_end-stim_start, 2000) / fs
    get_betti_curve(spikes, stim_name, stim_start, stim_end, win_len, fs, thresh, t, dim, betti_curves, times)

    shuffled_trial_spikes = shuffle_trial_spikes(spikes, stim_start, stim_end)
    get_betti_curve(shuffled_trial_spikes, stim_name, stim_start, stim_end, win_len, fs, thresh, t, dim,shuffled_betti_curves, shuffled_times)


In [None]:
import matplotlib.pyplot as plt
%matplotlib inline

for stim in shuffled_betti_curves.keys():
    c = np.vstack(betti_curves[stim])
    c = np.mean(c, axis=0)
    cs = np.vstack(shuffled_betti_curves[stim])
    cs = np.mean(cs, axis=0)
    plt.figure()
    plt.plot(times[stim][0], c, 'k')
    plt.plot(times[stim][0], cs, 'k--')
    plt.title(stim)