In [None]:
%matplotlib inline
import matplotlib.pyplot as plt
import numpy as np
import os
import nept
import seaborn as sns

import scalebar

from loading_data import get_data
from plot_sequence_raster import plot_sequence
from run import spike_sorted_infos
from analyze_tuning_curves import get_tuning_curves
from utils_maze import speed_threshold
from analyze_decode_bytrial import get_trials

In [None]:
from matplotlib import animation
from IPython.display import HTML

In [None]:
thisdir = os.getcwd()
pickle_filepath = os.path.join(thisdir, "cache", "pickled")
output_filepath = os.path.join(thisdir, "plots", "exploring_swrs")

In [None]:
import info.r068d5 as r068d5
infos = [r068d5]
info = r068d5

In [None]:
# infos = spike_sorted_infos

In [None]:
remove_interneurons = True
resting_only = False

In [None]:
print(info.session_id)
events, position, spikes, lfp, _ = get_data(info)

if remove_interneurons:
    max_mean_firing = 5
    interneurons = np.zeros(len(spikes), dtype=bool)
    for i, spike in enumerate(spikes):
        if len(spike.time) / info.session_length >= max_mean_firing:
            interneurons[i] = True
    spikes = spikes[~interneurons]

# task_times = ["prerecord", "phase1", "pauseA", "phase2", "pauseB", "phase3", "postrecord"]
task_times = ["postrecord"]

In [None]:
n_swrs = np.zeros(len(task_times))
phase_duration = np.zeros(len(task_times))

for i, task_time in enumerate(task_times):
    if remove_interneurons:
        condition = "_no-interneurons"
    else:
        condition = ""

    epochs_of_interest = info.task_times[task_time]

    phase_duration[i] = epochs_of_interest.durations[0] / 60.

    if resting_only:
        sliced_position = position.time_slice(epochs_of_interest.start, epochs_of_interest.stop)
        epochs_of_interest = speed_threshold(sliced_position, speed_limit=4., rest=True)
        condition = condition + "_rest"

    sliced_lfp = lfp.time_slice(epochs_of_interest.starts, epochs_of_interest.stops)
    sliced_spikes = [spiketrain.time_slice(epochs_of_interest.starts, epochs_of_interest.stops) for spiketrain in
                     spikes]

    z_thresh = 2.0
    power_thresh = 3.0
    merge_thresh = 0.02
    min_length = 0.05
    swrs = nept.detect_swr_hilbert(sliced_lfp, fs=info.fs, thresh=(140.0, 250.0), z_thresh=z_thresh,
                                   power_thresh=power_thresh, merge_thresh=merge_thresh, min_length=min_length)

    multi_swrs = nept.find_multi_in_epochs(sliced_spikes, swrs, min_involved=4)

In [None]:
swrs = multi_swrs

spike_counts = []

for i in range(swrs.n_epochs):
    start = swrs.starts[i]
    stop = swrs.stops[i]
    n_spikes_swr = np.sum([len(spiketrain.time_slice(start, stop).time) for spiketrain in spikes])

    len_swr = stop - start

    start_pre = swrs.starts[i] - len_swr
    stop_pre = swrs.starts[i]
    n_spikes_swr_pre = np.sum([len(spiketrain.time_slice(start_pre, stop_pre).time) for spiketrain in spikes])

    start_post = swrs.stops[i]
    stop_post = swrs.stops[i] + len_swr
    n_spikes_swr_post = np.sum([len(spiketrain.time_slice(start_post, stop_post).time) for spiketrain in spikes])

    spike_counts.append([n_spikes_swr_pre, n_spikes_swr, n_spikes_swr_post])

fig, ax = plt.subplots()
cmap = plt.cm.get_cmap('Greys')

pp = ax.pcolormesh(spike_counts, vmax=100., cmap=cmap)

ax.set_xticklabels('')
ax.set_xticks(np.arange(3)+.5)
ax.set_xticklabels(['pre','SWR','post'])

title = info.session_id + ' SWR spike count ' + task_time
plt.title(title)

fig.colorbar(pp)
plt.show()

In [None]:
def plotthis(idx):
    print(spike_counts[idx])
    starts = multi_swrs.starts[idx]
    stops = multi_swrs.stops[idx]
    
    print(starts, stops)

    buffer=0.1

    start_time = starts - buffer
    stop_time = stops + buffer

    rows = len(spikes)
    add_rows = int(rows / 8)

    ms = 800 / rows
    mew = 0.7
    spike_loc = 1

    fig = plt.figure(figsize=(8, 8))
    ax1 = plt.subplot2grid((rows + add_rows, 2), (0, 0), rowspan=rows)

    # Plotting the spike raster
    for idx, neuron_spikes in enumerate(spikes):
        ax1.plot(neuron_spikes.time, np.ones(len(neuron_spikes.time)) + (idx * spike_loc), '|',
                 color='k', ms=ms, mew=mew)

    ax1.set_xticks([])
    ax1.set_xlim([start_time, stop_time])
    ax1.set_ylim([0.5, len(spikes) * spike_loc + 0.5])

    # Plotting the LFP
    ax2 = plt.subplot2grid((rows + add_rows, 2), (rows, 0), rowspan=add_rows, sharex=ax1)

    start_idx = nept.find_nearest_idx(lfp.time, start_time)
    stop_idx = nept.find_nearest_idx(lfp.time, stop_time)
    ax2.plot(lfp.time[start_idx:stop_idx], lfp.data[start_idx:stop_idx], '#3288bd', lw=0.3)

    start_idx = nept.find_nearest_idx(lfp.time, start)
    stop_idx = nept.find_nearest_idx(lfp.time, stop)
    print(start_time, stop_time)
    print(lfp.time[start_idx:stop_idx][0], lfp.time[start_idx:stop_idx][-1])
    ax2.plot(lfp.time[start_idx:stop_idx], lfp.data[start_idx:stop_idx], 'r', lw=0.4)

    ax2.set_xticks([])
    ax2.set_xlim([start_time, stop_time])
    ax2.set_yticks([])

    scalebar.add_scalebar(ax2, matchy=False, bbox_transform=fig.transFigure,
                          bbox_to_anchor=(0.5, 0.05), units='ms')

    # Plotting the position
    ax3 = plt.subplot2grid((rows + add_rows, 2), (0, 1), rowspan=int(0.5 * rows))
    ax3.plot(position.x, position.y, '.', color="#bdbdbd", ms=2)

    start_idx = nept.find_nearest_idx(position.time, start_time)
    stop_idx = nept.find_nearest_idx(position.time, stop_time)

    cmap = plt.get_cmap('Oranges')
    colours = cmap(np.linspace(0.25, 0.75, stop_idx - start_idx))

    for j, (x, y) in enumerate(zip(position.x[start_idx:stop_idx], position.y[start_idx:stop_idx])):
        ax3.plot(x, y, ".", color=colours[j])

    ax3.set_xlim(np.min(position.x), np.max(position.x))
    ax3.set_ylim(np.min(position.y), np.max(position.y))
    ax3.axis('off')

    # Cleaning up the plot
    sns.despine(bottom=True)
    plt.tight_layout(h_pad=0.003)

    plt.show()

In [None]:
spike_counts

In [None]:
ratio = [(swr[0]+swr[2])/swr[1] for swr in spike_counts]
print(np.where(np.array(ratio) <0.2)[0])

In [None]:
idx = 19
plotthis(idx)