In [None]:
%matplotlib inline
import matplotlib.pyplot as plt
import numpy as np
import os
import nept
import seaborn as sns

import scalebar

from loading_data import get_data
from plot_sequence_raster import plot_sequence
from run import spike_sorted_infos
from analyze_tuning_curves import get_tuning_curves
from utils_maze import speed_threshold
from analyze_decode_bytrial import get_trials

In [None]:
from matplotlib import animation
from IPython.display import HTML

In [None]:
thisdir = os.getcwd()
pickle_filepath = os.path.join(thisdir, "cache", "pickled")
output_filepath = os.path.join(thisdir, "plots", "exploring_swrs")

In [None]:
import info.r068d5 as r068d5
infos = [r068d5]

In [None]:
# infos = spike_sorted_infos

In [None]:
from exploring_swrs import plot_swr_stats, plot_swr

In [None]:
print(info.session_id)
events, position, spikes, lfp, _ = get_data(info)

if remove_interneurons:
    max_mean_firing = 5
    interneurons = np.zeros(len(spikes), dtype=bool)
    for i, spike in enumerate(spikes):
        if len(spike.time) / info.session_length >= max_mean_firing:
            interneurons[i] = True
    spikes = spikes[~interneurons]

task_times = ["prerecord", "phase1", "pauseA", "phase2", "pauseB", "phase3", "postrecord"]

n_swrs = np.zeros(len(task_times))
phase_duration = np.zeros(len(task_times))

In [None]:
task_time = "pauseA"

epochs_of_interest = info.task_times[task_time]

phase_duration[i] = epochs_of_interest.durations[0] / 60.



In [None]:
sliced_position = position.time_slice(epochs_of_interest.start, epochs_of_interest.stop)
epochs_of_interest = speed_threshold(sliced_position, speed_limit=4., rest=True)
condition = condition + "_rest"


In [None]:
sliced_lfp = lfp.time_slice(epochs_of_interest.starts, epochs_of_interest.stops)
sliced_spikes = [spiketrain.time_slice(epochs_of_interest.starts, epochs_of_interest.stops) for spiketrain in
                 spikes]
sliced_position = position.time_slice(epochs_of_interest.starts, epochs_of_interest.stops)

In [None]:
plt.plot(sliced_position.x, sliced_position.y, "g.")
plt.show()

In [None]:
z_thresh = 2.0
power_thresh = 3.0
merge_thresh = 0.02
min_length = 0.05
swrs = nept.detect_swr_hilbert(sliced_lfp, fs=info.fs, thresh=(140.0, 250.0), z_thresh=z_thresh,
                               power_thresh=power_thresh, merge_thresh=merge_thresh, min_length=min_length)

multi_swrs = nept.find_multi_in_epochs(sliced_spikes, swrs, min_involved=4)

n_swrs[i] = multi_swrs.n_epochs

In [None]:
print(n_swrs)

In [None]:
multi_swrs

In [None]:

def plot_spike_counts(swrs, savepath=None):
    spike_counts = []

    for i in range(swrs.n_epochs):
        start = swrs.starts[i]
        stop = swrs.stops[i]
        n_spikes_swr = np.sum([len(spiketrain.time_slice(start, stop).time) for spiketrain in spikes])

        len_swr = stop - start

        start_pre = swrs.starts[i] - len_swr
        stop_pre = swrs.starts[i]
        n_spikes_swr_pre = np.sum([len(spiketrain.time_slice(start_pre, stop_pre).time) for spiketrain in spikes])

        start_post = swrs.stops[i]
        stop_post = swrs.stops[i] + len_swr
        n_spikes_swr_post = np.sum([len(spiketrain.time_slice(start_post, stop_post).time) for spiketrain in spikes])

        spike_counts.append([n_spikes_swr_pre, n_spikes_swr, n_spikes_swr_post])

    fig, ax = plt.subplots()
    cmap = plt.cm.get_cmap('Greys')

    pp = ax.pcolormesh(spike_counts, vmax=100., cmap=cmap)

    ax.set_xticklabels('')
    ax.set_xticks(np.arange(3)+.5)
    ax.set_xticklabels(['pre','SWR','post'])

    title = info.session_id + ' SWR spike count ' + task_time
    plt.title(title)
    print(title)

    fig.colorbar(pp)

    if savepath is not None:
        plt.savefig(savepath)
    else:
        plt.show()

In [None]:
savepath = os.path.join(output_filepath, "summary", info.session_id + "_" + task_time + "_swr-spike-count")
plot_spike_counts(multi_swrs)

In [None]:
np.arange(len(data))+1.astype(str)

In [None]:
for info in infos:
    events, position, spikes, lfp, _ = get_data(info)
    
    trial_epochs = get_trials(events, info.task_times["phase3"])
    
    trial_idx = 10   
    trial_start = trial_epochs.starts[trial_idx]
    trial_stop = trial_epochs.stops[trial_idx]
    trial_times = nept.Epoch([trial_start, trial_stop])
    
    phase = info.task_times["phase3"]
    sliced_position = position.time_slice(trial_times.start, trial_times.stop)
    resting_epochs = speed_threshold(sliced_position, speed_limit=4., rest=True)
    rest_position = sliced_position[resting_epochs]

    sliced_lfp = lfp.time_slice(resting_epochs.starts, resting_epochs.stops)
    sliced_spikes = [spiketrain.time_slice(resting_epochs.starts, resting_epochs.stops) for spiketrain in spikes]

In [None]:
plt.plot(sliced_position.time, sliced_position.y, "k.")
plt.plot(rest_position.time, rest_position.y, "b.")
plt.show()

In [None]:
xedges, yedges = nept.get_xyedges(position)

fig, ax = plt.subplots()

xx, yy = np.meshgrid(xedges, yedges)

pad_amount = 5
ax.set_xlim((np.floor(np.min(rest_position.x))-pad_amount, np.ceil(np.max(rest_position.x))+pad_amount))
ax.set_ylim((np.floor(np.min(rest_position.y))-pad_amount, np.ceil(np.max(rest_position.y))+pad_amount))

plt.plot(sliced_position.x, sliced_position.y, '.', color="#bdbdbd")
rat_position, = ax.plot([], [], "<", color="b")

fig.tight_layout()


def init():
    rat_position.set_data([], [])
    return rat_position


def animate(i):
    rat_position.set_data(rest_position.x[i], rest_position.y[i])
    return rat_position

anim = animation.FuncAnimation(fig, animate, frames=rest_position.n_samples, interval=80, 
                               blit=False, repeat=False)

In [None]:
HTML(anim.to_html5_video())

In [None]:
info = r068d5
remove_interneurons = False
resting_only = True

print(info.session_id)
events, position, spikes, lfp, _ = get_data(info)

condition = ""

if remove_interneurons:
    max_mean_firing = 5
    interneurons = np.zeros(len(spikes), dtype=bool)
    for i, spike in enumerate(spikes):
        if len(spike.time) / info.session_length >= max_mean_firing:
            interneurons[i] = True
    spikes = spikes[~interneurons]
    condition = condition + "_no-interneurons"

# task_times = ["prerecord", "phase1", "pauseA", "phase2", "pauseB", "phase3", "postrecord"]
task_times = ["pauseB"]

In [None]:
n_swrs = np.zeros(len(task_times))
phase_duration = np.zeros(len(task_times))

for i, task_time in enumerate(task_times):
    if remove_interneurons:
        condition = "_no-interneurons"
    else:
        condition = ""
        
    epochs_of_interest = info.task_times[task_time]

    phase_duration[i] = epochs_of_interest.durations[0] / 60.

    if resting_only:
        sliced_position = position.time_slice(epochs_of_interest.start, epochs_of_interest.stop)
        plt.plot(sliced_position.x, sliced_position.y, "k.")
        plt.show()
        epochs_of_interest = speed_threshold(sliced_position, speed_limit=4., rest=True)
        condition = condition + "_rest"
        print(condition)

    sliced_lfp = lfp.time_slice(epochs_of_interest.starts, epochs_of_interest.stops)
    sliced_spikes = [spiketrain.time_slice(epochs_of_interest.starts, epochs_of_interest.stops) for spiketrain in
                     spikes]

    z_thresh = 2.0
    power_thresh = 3.0
    merge_thresh = 0.02
    min_length = 0.05
    swrs = nept.detect_swr_hilbert(sliced_lfp, fs=info.fs, thresh=(140.0, 250.0), z_thresh=z_thresh,
                                   power_thresh=power_thresh, merge_thresh=merge_thresh, min_length=min_length)

    multi_swrs = nept.find_multi_in_epochs(sliced_spikes, swrs, min_involved=4)

    n_swrs[i] = multi_swrs.n_epochs

In [None]:
plot_swr(multi_swrs, sliced_lfp, position, sliced_spikes)

In [None]:
print("n_swrs:", n_swrs)
print("swr_rate:", n_swrs / phase_duration)