In [None]:
# %matplotlib inline
import numpy as np
import matplotlib.pyplot as plt
from shapely.geometry import Point, LineString
import pickle
import seaborn as sns
import os

import vdmlab as vdm

from tuning_curves_functions import get_tc, get_odd_firing_idx, linearize

import info.R063d2_info as r063d2
import info.R063d3_info as r063d3
import info.R063d4_info as r063d4
import info.R063d5_info as r063d5
import info.R063d6_info as r063d6
import info.R066d1_info as r066d1
import info.R066d2_info as r066d2
import info.R066d3_info as r066d3
import info.R066d4_info as r066d4

In [None]:
pickle_filepath = 'E:\\code\\python-vdmlab\\projects\\emily_shortcut\\cache\\pickled\\'
output_filepath = 'E:\\code\\python-vdmlab\\projects\\emily_shortcut\\plots\\sequence\\'
# pickle_filepath = 'C:\\Users\\Emily\\Code\\python-vdmlab\\projects\\emily_shortcut\\cache\\pickled\\'
# output_filepath = 'C:\\Users\\Emily\\Code\\python-vdmlab\\projects\\emily_shortcut\\plots\\sequence\\'

In [None]:
info = r063d4

In [None]:
print(info.session_id)
pos = info.get_pos(info.pxl_to_cm)
csc = info.get_csc(info.good_swr[0])
spikes = info.get_spikes()

In [None]:
print(info.session_id)
pos = info.get_pos(info.pxl_to_cm)
csc = info.get_csc(info.good_swr[0])
spikes = info.get_spikes()

tc = get_tc(info, pos, pickle_filepath, expand_by=2)

heatmap_filename = info.session_id + '_spike_heatmaps.pkl'
pickled_spike_heatmaps = os.path.join(pickle_filepath, heatmap_filename)
if os.path.isfile(pickled_spike_heatmaps):
    with open(pickled_spike_heatmaps, 'rb') as fileobj:
        spike_heatmaps = pickle.load(fileobj)
else:
    spikes = info.get_spikes()

    all_neurons = list(range(1, len(spikes['time'])))
    spike_heatmaps = vdm.get_heatmaps(all_neurons, spikes, pos)
    with open(pickled_spike_heatmaps, 'wb') as fileobj:
        pickle.dump(spike_heatmaps, fileobj)

swr_times, swr_idx, filtered_butter = vdm.detect_swr_hilbert(csc, fs=info.fs)

In [None]:
def get_speed(pos, smooth=True, t_smooth=0.5):
    """Finds the velocity of the animal from 2D position.
    
    Parameters
    ----------
    pos : dict
        With x, y, time as keys.
    smooth : bool
        Whether smoothing occurs. Default is True.
    t_smooth : float
        Range over which smoothing occurs in seconds. Default is 0.5 seconds.
    
    Returns
    -------
    speed : dict
        With time (floats), velocity (floats) as keys.
    
    """
    speed = dict()
    speed['time'] = pos['time']
    speed['velocity'] = np.sqrt((pos['x'][1:] - pos['x'][:-1]) ** 2 + (pos['y'][1:] - pos['y'][:-1]) ** 2)
    speed['velocity'] = np.hstack(([0], speed['velocity']))

    dt = np.median(np.diff(speed['time']))

    filter_length = np.ceil(t_smooth/dt)
    speed['smoothed'] = np.convolve(speed['velocity'], np.ones(int(filter_length))/filter_length, 'same')
    
    return speed

In [None]:
speed = get_speed(pos)

run_threshold = 0.45
t_run = speed['time'][speed['smoothed'] >= run_threshold]

run_idx = np.zeros(pos['time'].shape, dtype=bool)
for idx in t_run:
    run_idx |= (pos['time'] == idx)
    
run_pos = dict()
run_pos['x'] = pos['x'][run_idx]
run_pos['y'] = pos['y'][run_idx]
run_pos['time'] = pos['time'][run_idx]

In [None]:
plt.plot(run_pos['x'], run_pos['y'], 'b.')
plt.show()

In [None]:
tc = get_tc(info, run_pos, pickle_filepath)

In [None]:
colours = ['#bd0026', '#fc4e2a', '#ef3b2c', '#ec7014', '#fe9929', 
           '#78c679', '#41ab5d', '#238443', '#66c2a4', '#41b6c4', 
           '#1d91c0', '#8c6bb1', '#225ea8', '#88419d', '#ae017e', 
           '#dd3497', '#f768a1', '#fcbba1', '#fc9272', '#fb6a4a', 
           '#e31a1c', '#fb6a4a', '#993404', '#b30000', '#800026',
           '#bd0026', '#fc4e2a', '#fb6a4a', '#ef3b2c', '#ec7014', 
           '#fe9929', '#78c679', '#41ab5d', '#238443', '#66c2a4', 
           '#41b6c4', '#1d91c0', '#8c6bb1', '#225ea8', '#88419d', 
           '#ae017e', '#dd3497', '#f768a1', '#fcbba1', '#fc9272', 
           '#fb6a4a', '#e31a1c', '#fb6a4a', '#993404', '#b30000', 
           '#800026', 'k', 'k', 'k', 'k', 'k', 'k', 'k', 'k', 'k']

In [None]:
colours = ['#1d91c0', '#045a8d', '#8c96c6', 
           '#238443', '#74c476', '#7fcdbb', 
           '#d7301f', '#dd3497', '#ec7014']

In [None]:
all_u_fields = vdm.find_fields(tc['u'])
all_shortcut_fields = vdm.find_fields(tc['shortcut'])
all_novel_fields = vdm.find_fields(tc['novel'])

# u_compare = vdm.find_fields(tc['u'], hz_thres=3)
# shortcut_compare = vdm.find_fields(tc['shortcut'], hz_thres=3)
# novel_compare = vdm.find_fields(tc['novel'], hz_thres=3)

# u_fields_unique = vdm.unique_fields(all_u_fields, shortcut_compare, novel_compare)
# shortcut_fields_unique = vdm.unique_fields(all_shortcut_fields, u_compare, novel_compare)
# novel_fields_unique = vdm.unique_fields(all_novel_fields, u_compare, shortcut_compare)

u_fields_sized = vdm.sized_fields(all_u_fields, max_length=15)
shortcut_fields_sized = vdm.sized_fields(all_shortcut_fields, max_length=15)
novel_fields_sized = vdm.sized_fields(all_novel_fields, max_length=15)

u_fields_single = vdm.get_single_field(u_fields_sized)
shortcut_fields_single = vdm.get_single_field(shortcut_fields_sized)
novel_fields_single = vdm.get_single_field(novel_fields_sized)

sort_idx = vdm.get_sort_idx(tc['u'])
odd_firing_idx = get_odd_firing_idx(tc['u'])

u_tc = []
u_field_spikes = []
for idx in sort_idx:
    if idx not in odd_firing_idx:
        if idx in u_fields_single:
            u_field_spikes.append(spikes['time'][idx])
            u_tc.append(tc['u'][idx])
            
sort_idx = vdm.get_sort_idx(tc['shortcut'])
odd_firing_idx = get_odd_firing_idx(tc['shortcut'])

shortcut_tc = []
shortcut_field_spikes = []
for idx in sort_idx:
    if idx not in odd_firing_idx:
        if idx in shortcut_fields_single:
            shortcut_field_spikes.append(spikes['time'][idx])
            shortcut_tc.append(tc['shortcut'][idx])
            
sort_idx = vdm.get_sort_idx(tc['novel'])
odd_firing_idx = get_odd_firing_idx(tc['novel'])

novel_tc = []
novel_field_spikes = []
for idx in sort_idx:
    if idx not in odd_firing_idx:
        if idx in novel_fields_single:
            novel_field_spikes.append(spikes['time'][idx])
            novel_tc.append(tc['novel'][idx])

In [None]:
c = ['k', 'r', 'k', 'k', 'c', 'm', 'k', 'r', 'y', 'b', 'g', 'm', 'c', 'y', '#e7298a', '#41ab5d']
for idx, tc in enumerate(novel_tc):
    plt.plot(tc, color=c[idx])
plt.show()


In [None]:
print(len(u_field_spikes), len(shortcut_field_spikes), len(novel_field_spikes))

In [None]:
# for_u = [u_field_spikes[0], u_field_spikes[1], u_field_spikes[3]]
# for_shortcut = [shortcut_field_spikes[2], shortcut_field_spikes[4], shortcut_field_spikes[13]]
# for_novel = [novel_field_spikes[0], novel_field_spikes[1]]

for_u = [u_field_spikes[0], u_field_spikes[4], u_field_spikes[8], u_field_spikes[7]]
for_shortcut = [shortcut_field_spikes[1], shortcut_field_spikes[3], shortcut_field_spikes[11], shortcut_field_spikes[12]]
for_novel = [novel_field_spikes[0], novel_field_spikes[1], novel_field_spikes[2], novel_field_spikes[3]]

In [None]:
all_u_fields = vdm.find_fields(tc['u'], hz_thres=8)
all_shortcut_fields = vdm.find_fields(tc['shortcut'], hz_thres=8)
all_novel_fields = vdm.find_fields(tc['novel'], hz_thres=8)

# u_compare = vdm.find_fields(tc['u'], hz_thres=3)
# shortcut_compare = vdm.find_fields(tc['shortcut'], hz_thres=3)
# novel_compare = vdm.find_fields(tc['novel'], hz_thres=3)

# u_fields_unique = vdm.unique_fields(all_u_fields, shortcut_compare, novel_compare)
# shortcut_fields_unique = vdm.unique_fields(all_shortcut_fields, u_compare, novel_compare)
# novel_fields_unique = vdm.unique_fields(all_novel_fields, u_compare, shortcut_compare)

u_fields_size = vdm.sized_fields(all_u_fields, min_length=3, max_length=15)
shortcut_fields_size = vdm.sized_fields(all_shortcut_fields, min_length=3, max_length=15)
novel_fields_size = vdm.sized_fields(all_novel_fields, min_length=3, max_length=15)

u_fields = vdm.get_single_field(u_fields_size)
shortcut_fields = vdm.get_single_field(shortcut_fields_size)
novel_fields = vdm.get_single_field(novel_fields_size)

In [None]:
this_tc = tc['u']

sort_idx = vdm.get_sort_idx(this_tc)

odd_firing_idx = get_odd_firing_idx(this_tc)

these_fields = []
for key in u_fields:
    these_fields.append(key)

field_spikes = []
field_tc = []
for idx in sort_idx:
    if idx not in odd_firing_idx:
        if idx in these_fields:
            field_spikes.append(spikes['time'][idx])
            field_tc.append(this_tc[idx])

In [None]:
speed = vdm.get_speed(pos)

t_run = speed['time'][speed['smoothed'] >= info.run_threshold]

run_idx = np.zeros(pos['time'].shape, dtype=bool)
for idx in t_run:
    run_idx |= (pos['time'] == idx)
    
run_pos = dict()
run_pos['x'] = pos['x'][run_idx]
run_pos['y'] = pos['y'][run_idx]
run_pos['time'] = pos['time'][run_idx]

In [None]:
t_start = info.task_times['prerecord'][0]
t_stop = info.task_times['postrecord'][1]
linear, zone = linearize(info, run_pos, t_start, t_stop, expand_by=2)

In [None]:
this_linear = linear['u']

In [None]:
print(len(these_fields))

In [None]:
sns.set_style('white')
sns.set_style('ticks')

In [None]:
# swr = 1991
# # for swr in [2090, 2196]:

# print(swr)
ms = 10
loc = 1

# start_time = info.sequence['u']['run_start'][1]
# stop_time = info.sequence['u']['run_stop'][1]

# start_time_swr = info.sequence['u']['swr_start'][1]
# stop_time_swr = info.sequence['u']['swr_stop'][1]

start_time = 3632
stop_time = 3667

# start_time_swr = 8206.82
# stop_time_swr = 8207.38

start_time_swr = 8965.4
stop_time_swr = 8966


# for i, (start_time, stop_time, start_time_swr, stop_time_swr) in enumerate(zip(info.sequence['run_start'], 
#                                                                                info.sequence['run_stop'], 
#                                                                                info.sequence['swr_start'], 
#                                                                                info.sequence['swr_stop'])):
spike_loc = 2

rows = len(field_spikes)+1
cols = 7
fig = plt.figure()
ax1 = plt.subplot2grid((rows, cols), (0, 1), rowspan=rows, colspan=4)
ax2 = plt.subplot2grid((rows, cols), (0, 5), rowspan=rows, colspan=2)

max_position = np.zeros(len(this_linear['time']))
max_position.fill(np.max(this_linear['position']/90))
ax1.plot(this_linear['time'], max_position, color='#bdbdbd', lw=1)
ax1.plot(this_linear['time'], this_linear['position']/90, 'b.', ms=4)
ax1.plot(csc['time'], csc['data']*800+1.25, 'k', lw=1)

for sw_start, sw_stop in zip(swr_idx['start'], swr_idx['stop']):
    ax1.plot(csc['time'][sw_start:sw_stop], csc['data'][sw_start:sw_stop]*800+1.25, 'r', lw=2)
    
for idx, neuron_spikes in enumerate(field_spikes):
    ax1.plot(neuron_spikes, np.ones(len(neuron_spikes))+(idx*spike_loc+3), '|', 
             color=colours[int(np.floor((idx*spike_loc+3)/spike_loc))], ms=ms, mew=2)
ax1.set_xlim([start_time, stop_time])
ax1.set_ylim([0, len(field_spikes)*spike_loc+3])
vdm.add_scalebar(ax1, matchy=False, loc=loc)
# plt.setp(ax1, xticks=[], xticklabels=[], yticks=[])

ax2.plot(csc['time'], csc['data']*1000+1.25, 'k', lw=1)
ax2.plot(csc['time'], filtered_butter*1000+0.5, 'b', lw=1)
for sw_start, sw_stop in zip(swr_idx['start'], swr_idx['stop']):
    ax2.plot(csc['time'][sw_start:sw_stop], csc['data'][sw_start:sw_stop]*1000+1.25, 'r', lw=2)

for j, neuron_spikes in enumerate(field_spikes):
    ax2.plot(neuron_spikes, np.ones(len(neuron_spikes))+(j*spike_loc+3), '|', 
             color=colours[int(np.floor((j*spike_loc+3)/spike_loc))], ms=ms, mew=2)
ax2.set_xlim([start_time_swr, stop_time_swr])
ax2.set_ylim([0, len(field_spikes)*spike_loc+3])
vdm.add_scalebar(ax2, matchy=False, loc=loc)
# plt.setp(ax2, xticks=[], xticklabels=[], yticks=[])

x = list(range(0, len(field_tc[0])))

for ax_loc in range(0, rows-1):
    ax = plt.subplot2grid((rows, cols), (ax_loc, 0))

    idx = rows - ax_loc - 1
    ax.plot(field_tc[idx-1], color=colours[idx-1])
    ax.fill_between(x, 0, field_tc[idx-1], facecolor=colours[idx-1])
    max_loc = np.where(field_tc[idx-1] == np.max(field_tc[idx-1]))[0][0]
    ax.text(max_loc-3, 1, str(int(np.ceil(np.max(field_tc[idx-1])))))
    ax.spines['right'].set_visible(False)
    ax.spines['top'].set_visible(False)
    plt.setp(ax, xticks=[], xticklabels=[], yticks=[])

sns.despine()
plt.show()
#     filename = info.session_id + '_sequence-swr' + str(i) + '.png'
#     print(filename)
#     savepath = os.path.join(output_filepath, filename)
#     plt.savefig(savepath, dpi=300, bbox_inches='tight')
#     plt.close()

In [None]:
for trajectory in ['shortcut']:

    print(info.session_id, trajectory)
    pos = info.get_pos(info.pxl_to_cm)
    csc = info.get_csc()
    spikes = info.get_spikes()

    tc = get_tc(info, pos, pickle_filepath)

    filename = info.session_id + '_spike_heatmaps.pkl'
    pickled_spike_heatmaps = os.path.join(pickle_filepath, filename)
    if os.path.isfile(pickled_spike_heatmaps):
        with open(pickled_spike_heatmaps, 'rb') as fileobj:
            spike_heatmaps = pickle.load(fileobj)
    else:
        spikes = info.get_spikes()

        all_neurons = list(range(1, len(spikes['time'])))
        spike_heatmaps = vdm.get_heatmaps(all_neurons, spikes, pos)
        with open(pickled_spike_heatmaps, 'wb') as fileobj:
            pickle.dump(spike_heatmaps, fileobj)

    t_start = info.task_times['prerecord'][0]
    t_stop = info.task_times['postrecord'][1]
    linear, zone = linearize(info, pos)

    # swr_times, swr_idx, filtered_butter = vdm.detect_swr_hilbert(csc, fs=info.fs)

    sort_idx = vdm.get_sort_idx(tc[trajectory])

    odd_firing_idx = get_odd_firing_idx(tc[trajectory])


    all_fields = vdm.find_fields(tc[trajectory])

    # u_compare = vdm.find_fields(tc['u'], hz_thres=3)
    # shortcut_compare = vdm.find_fields(tc['shortcut'], hz_thres=3)
    # novel_compare = vdm.find_fields(tc['novel'], hz_thres=3)
    #
    # u_fields_unique = vdm.unique_fields(all_u_fields, shortcut_compare, novel_compare)
    # shortcut_fields_unique = vdm.unique_fields(all_shortcut_fields, u_compare, novel_compare)
    # novel_fields_unique = vdm.unique_fields(all_novel_fields, u_compare, shortcut_compare)

    fields_size = vdm.sized_fields(all_fields, max_length=15)
    with_fields = vdm.get_single_field(fields_size)

    sequence = info.sequence[trajectory]
    this_linear = linear[trajectory]

    these_fields = []
    for key in with_fields:
        these_fields.append(key)

    field_spikes = []
    field_tc = []
    for idx in sort_idx:
        if idx not in odd_firing_idx:
            if idx in these_fields:
                field_spikes.append(spikes['time'][idx])
                field_tc.append(tc[trajectory][idx])

In [None]:
sequence = info.sequence['shortcut']

start_time = info.sequence['shortcut']['run_start'][0]
stop_time = info.sequence['shortcut']['run_stop'][0]

start_time_swr = info.sequence['shortcut']['swr_start'][0]
stop_time_swr = info.sequence['shortcut']['swr_stop'][0]

lfp_pos_y = 2
rows = len(field_spikes) + lfp_pos_y
cols = 7

spike_loc = lfp_pos_y + 1/rows 

fig = plt.figure()

ax1 = plt.subplot2grid((rows, cols), (rows-lfp_pos_y, 1),  colspan=4)
# max_position = np.zeros(len(this_linear['time']))
# max_position.fill(np.max(this_linear['position']))
ax1.plot(this_linear['time'], np.zeros(len(this_linear['time'])), color='#bdbdbd', lw=1)
ax1.plot(this_linear['time'], -this_linear['position'], 'k', lw=1)
ax1.set_xlim([start_time, stop_time])
plt.setp(ax1, xticks=[], xticklabels=[], yticks=[])
sns.despine(ax=ax1)

for ax_loc in range(0, rows-lfp_pos_y-1):
    ax = plt.subplot2grid((rows, cols), (ax_loc, 1), colspan=4, sharex=ax1)
    spike_y = (ax_loc * spike_loc + lfp_pos_y)
    ax.plot(field_spikes[ax_loc], np.ones(len(field_spikes[ax_loc]))+spike_y, '|',
             color=sequence['colours'][ax_loc], ms=sequence['ms'], mew=1)
    ax.set_xlim([start_time, stop_time])
    if ax_loc == 0:
        vdm.add_scalebar(ax, matchy=False, bbox_transform=ax.transAxes, bbox_to_anchor=(0.9, 1.1))
    if ax_loc == rows-lfp_pos_y-1:
        sns.despine(ax=ax)
    else:
        sns.despine(ax=ax, bottom=True)
    plt.setp(ax, xticks=[], xticklabels=[], yticks=[])

ax2 = plt.subplot2grid((rows, cols), (rows-lfp_pos_y, 5), colspan=2)
ax2.plot(csc['time'], csc['data']*1000*(1*1/rows), 'k', lw=1)
ax2.set_xlim([start_time_swr, stop_time_swr])
# plt.plot(csc['time'], filtered_butter*1000+0.5, 'b', lw=1)
plt.setp(ax2, xticks=[], xticklabels=[], yticks=[])
sns.despine(ax=ax2)

for ax_loc in range(0, rows-lfp_pos_y-1):
    ax = plt.subplot2grid((rows, cols), (ax_loc, 5), colspan=2, sharex=ax2)
    spike_y = (ax_loc * spike_loc + lfp_pos_y)
    ax.plot(field_spikes[ax_loc], np.ones(len(field_spikes[ax_loc]))+spike_y, '|',
             color=sequence['colours'][ax_loc],
             ms=sequence['ms'], mew=1)
    ax.set_xlim([start_time_swr, stop_time_swr])
    if ax_loc == 0:
        vdm.add_scalebar(ax, matchy=False, bbox_transform=ax.transAxes, bbox_to_anchor=(0.9, 1.1))
    if ax_loc == rows-lfp_pos_y-1:
        sns.despine(ax=ax)
    else:
        sns.despine(ax=ax, bottom=True)
    plt.setp(ax, xticks=[], xticklabels=[], yticks=[])

x = list(range(0, np.shape(field_tc)[1]))

for ax_loc in range(0, rows-lfp_pos_y-1):
    ax = plt.subplot2grid((rows, cols), (ax_loc, 0))
    ax.plot(field_tc[ax_loc], color=sequence['colours'][ax_loc])
    ax.fill_between(x, 0, field_tc[ax_loc], facecolor=sequence['colours'][ax_loc])
    max_loc = np.where(field_tc[ax_loc] == np.max(field_tc[ax_loc]))[0][0]
    ax.text(max_loc-3, 1, str(int(np.ceil(np.max(field_tc[ax_loc])))), fontsize=8)
    plt.setp(ax, xticks=[], xticklabels=[], yticks=[])
    sns.despine(ax=ax)

# plt.tight_layout()
fig.subplots_adjust(hspace=0, wspace=0.1)
plt.show()

In [None]:
rows-lfp_pos_y-1

In [None]:
sequence['colours'][i]

In [None]:
print(len(this_linear['time']))

In [None]:
4882+30

In [None]:
max_position = np.zeros(len(linear['u']['time']))
max_position.fill(np.max(linear['u']['position']/90))
plt.plot(linear['u']['time'], max_position, color='#969696', lw=1)
plt.plot(linear['u']['time'], np.array(linear['u']['position'])/90, 'k')
plt.xlim(start_time, stop_time)
plt.show()

In [None]:
max_position = np.zeros(len(linear['u']['time']))
max_position.fill(np.max(linear['u']['position']/90))
print(max_position)

In [None]:
plt.plot(csc['time'], csc['data']*1000+1.25, 'b')
plt.plot(csc['time'], filtered_butter*1000+0.5, 'b')
plt.xlim(start_time_swr, stop_time_swr)
plt.ylim(0, 2)
plt.show()


In [None]:
print('run:', start_time, stop_time)
print('swr:', start_time_swr, stop_time_swr)

In [None]:
print(np.where(field_tc[idx-1] == np.max(field_tc[idx-1]))[0][0])

In [None]:
14609.0 - 14579.0

In [None]:
ordered_spikes = []
for idx in sort_idx:
    if idx not in odd_firing_idx:
        ordered_spikes.append(spikes['time'][idx])

In [None]:
all_neurons = []
for neuron_spikes in field_spikes:
    this_neuron = []
    for swr_start, swr_stop in zip(swr_times['start'], swr_times['stop']):
        start_idx = vdm.find_nearest_idx(neuron_spikes, swr_start)
        stop_idx = vdm.find_nearest_idx(neuron_spikes, swr_stop)
        this_neuron.append(len(neuron_spikes[start_idx:stop_idx]))
    all_neurons.append(this_neuron)

In [None]:
count_mult_neurons = []
for i, swr_start in enumerate(swr_times['start']):
    count_single_neurons = []
    for swr_spike_num in all_neurons:
        if swr_spike_num[i] > 0:
            count_single_neurons.append(i)
            if len(count_single_neurons) > 2:
                if i not in count_mult_neurons:
                    count_mult_neurons.append(i)
print('Swr events with multiple neurons:', len(count_mult_neurons))
print('Number of neurons:', len(field_spikes))

In [None]:
for get_this in count_mult_neurons:
    mult_idx = get_this
    print(mult_idx)

    plt.plot(csc['time'], csc['data']*10500)
    for swr_start, swr_stop in zip(swr_idx['start'], swr_idx['stop']):
        plt.plot(csc['time'][swr_start:swr_stop], csc['data'][swr_start:swr_stop]*10500, 'r')

    for i, neuron_spikes in enumerate(field_spikes):
        plt.plot(neuron_spikes, np.ones(len(neuron_spikes))+i*2+1, '|', color='k', ms=10, mew=1)
    plt.xlim(csc['time'][swr_idx['start'][idx]]-1, csc['time'][swr_idx['start'][idx]]+1)

    plt.ylim(-5, len(field_spikes)*2+1)
    plt.xlim(swr_times['start'][mult_idx]-0.1, swr_times['stop'][mult_idx]+0.1)
    plt.show()

In [None]:
swr_idx['start']