In [None]:
%matplotlib inline
import matplotlib.pyplot as plt
import numpy as np
import os
import nept
import seaborn as sns

import scalebar

from loading_data import get_data
from plot_sequence_raster import plot_sequence
from run import spike_sorted_infos
from analyze_tuning_curves import get_tuning_curves
from utils_maze import get_trials
from exploring_swrs import plot_swr_stats

In [None]:
from exploring_swrs import plot_spike_counts, plot_swr, plot_swrs_stats

In [None]:
from matplotlib import animation
from IPython.display import HTML

In [None]:
thisdir = os.getcwd()
pickle_filepath = os.path.join(thisdir, "cache", "pickled")
output_filepath = os.path.join(thisdir, "plots", "exploring_swrs")

In [None]:
import info.r063d5 as r063d5
import info.r063d6 as r063d6
infos = [r063d5, r063d6]

In [None]:
task_times = ["prerecord", "phase1", "pauseA", "phase2", "pauseB", "phase3", "postrecord"]
all_swrs = np.zeros(len(task_times))
all_durations = np.zeros(len(task_times))

for info in infos:
    n_swrs, durations = plot_swr_stats(info, resting_only=True, 
                                       plot_example_swr_rasters=False, 
                                       plot_swr_spike_counts=False)
    all_swrs += n_swrs
    all_durations += durations

In [None]:
n_swrs, durations

In [None]:
all_swrs += n_swrs

In [None]:
all_swrs

In [None]:
plot_swrs_stats(all_swrs, all_durations, task_times, "title", "ylabel", savepath=None)

In [None]:
events, position, spikes, lfp, _ = get_data(info)

plt.plot(position.x, position.y, "k.", ms=3)
plt.show()

In [None]:
def find_multi_in_epochs(spikes, epochs, min_involved):
    """Finds epochs with minimum number of participating neurons.

    Parameters
    ----------
    spikes: np.array
        Of nept.SpikeTrain objects
    epochs: nept.Epoch
    min_involved: int

    Returns
    -------
    multi_epochs: nept.Epoch

    """
    multi_starts = []
    multi_stops = []

    for start, stop in zip(epochs.starts, epochs.stops):
        sliced_spikes = [spiketrain.time_slice(start, stop) for spiketrain in spikes]
        n_spikes = np.asarray([len(spiketrain.time) for spiketrain in sliced_spikes])

        n_active = len(n_spikes[n_spikes >= 1])

        if n_active >= min_involved:
            multi_starts.append(start)
            multi_stops.append(stop)
        print(n_active, start, stop)

    multi_epochs = nept.Epoch(np.hstack([np.array(multi_starts)[..., np.newaxis],
                                         np.array(multi_stops)[..., np.newaxis]]))

    return multi_epochs

In [None]:
# Remove interneurons
max_mean_firing = 5
interneurons = np.zeros(len(spikes), dtype=bool)
for i, spike in enumerate(spikes):
    if len(spike.time) / info.session_length >= max_mean_firing:
        interneurons[i] = True
spikes = spikes[~interneurons]

z_thresh = 2.0
power_thresh = 3.0
merge_thresh = 0.02
min_length = 0.05
swrs = nept.detect_swr_hilbert(lfp, fs=info.fs, thresh=(140.0, 250.0), z_thresh=z_thresh,
                               power_thresh=power_thresh, merge_thresh=merge_thresh, min_length=min_length)
print("Total swrs for this session:", str(swrs.n_epochs))

# Restrict SWRs to those with 4 or more participating neurons
swrs = find_multi_in_epochs(spikes, swrs, min_involved=5)
print("N swrs for this session with at least 5 active neurons:", str(swrs.n_epochs))

# Find rest epochs for entire session
rest_epochs = nept.rest_threshold(position, thresh=0.167, t_smooth=0.5)

# task_times = ["prerecord", "phase1", "pauseA", "phase2", "pauseB", "phase3", "postrecord"]
task_times = ["prerecord"]

n_swrs = np.zeros(len(task_times))
duration = np.zeros(len(task_times))

In [None]:
resting_only = True
plot_swr_spike_counts = True
plot_example_swr_rasters = True

In [None]:
for i, task_time in enumerate(task_times):
    condition = ""

    # Restrict SWRs to those during epochs of interest
    epochs_of_interest = info.task_times[task_time]

    if resting_only:
        epochs_of_interest = epochs_of_interest.intersect(rest_epochs)

    if epochs_of_interest.n_epochs == 0:
        print("No epochs of interest identified.")
        duration[i] = 0.
    else:
        duration[i] = np.sum(epochs_of_interest.durations) / 60.

        phase_swrs = swrs.overlaps(epochs_of_interest)
        phase_swrs = phase_swrs[phase_swrs.durations >= 0.05]
        
        n_swrs[i] = phase_swrs.n_epochs
        
        if phase_swrs.n_epochs > 0:
            if plot_swr_spike_counts:
                filename = info.session_id + "_" + str(i) + task_time + "_swr-spike-count"
                savepath = os.path.join(output_filepath, "summary", filename)
                plot_spike_counts(info, phase_swrs, spikes, task_time, savepath=savepath)


            if plot_example_swr_rasters:
                sliced_spikes = [spiketrain.time_slice(epochs_of_interest.starts, epochs_of_interest.stops)
                                 for spiketrain in spikes]

                filename = info.session_id + "_" + str(i) + task_time + "_swr-raster"
                savepath = os.path.join(output_filepath, filename)
                plot_swr(swrs, lfp, position, sliced_spikes, savepath=savepath)
    print("N swrs for", task_time, ":", str(phase_swrs.n_epochs))

In [None]:
phase_swrs.starts

In [None]:
start = 3436.879396  
stop = 3436.941396
s = [spiketrain.time_slice(start, stop) for spiketrain in spikes]
n = [len(spiketrain.time) for spiketrain in s]

rows = len(spikes)
add_rows = int(rows / 8)

ms = 800 / rows
mew = 0.7
spike_loc = 1

fig = plt.figure(figsize=(8, 8))
ax1 = plt.subplot2grid((rows + add_rows, 2), (0, 0), rowspan=rows)

# Plotting the spike raster
for idx, neuron_spikes in enumerate(spikes):
    ax1.plot(neuron_spikes.time, np.ones(len(neuron_spikes.time)) + (idx * spike_loc), '|',
             color='k', ms=ms, mew=mew)

ax1.set_xticks([])
ax1.set_xlim([start, stop])
ax1.set_ylim([0.5, len(spikes) * spike_loc + 0.5])

In [None]:
i = 1
buffer = 0.0
start = phase_swrs.starts[i] - buffer
stop = phase_swrs.stops[i] + buffer

rows = len(spikes)
add_rows = int(rows / 8)

ms = 800 / rows
mew = 0.7
spike_loc = 1

fig = plt.figure(figsize=(8, 8))
ax1 = plt.subplot2grid((rows + add_rows, 2), (0, 0), rowspan=rows)

# Plotting the spike raster
for idx, neuron_spikes in enumerate(spikes):
    ax1.plot(neuron_spikes.time, np.ones(len(neuron_spikes.time)) + (idx * spike_loc), '|',
             color='k', ms=ms, mew=mew)

ax1.set_xticks([])
ax1.set_xlim([start, stop])
ax1.set_ylim([0.5, len(spikes) * spike_loc + 0.5])

In [None]:
epochs = nept.Epoch(np.array([[1.0, 4.0, 6.0], [2.0, 5.0, 7.0]]))

spikes = [nept.SpikeTrain(np.array([6.7])),
          nept.SpikeTrain(np.array([2.0, 6.5])),
          nept.SpikeTrain(np.array([2.0, 4.1])),
          nept.SpikeTrain(np.array([2.0, 4.3]))]

min_involved = 3
multi_epochs = nept.find_multi_in_epochs(spikes, epochs, min_involved)

assert np.allclose(multi_epochs.starts, np.array([1.]))
assert np.allclose(multi_epochs.stops, np.array([2.]))

In [None]:
multi_epochs.starts, multi_epochs.stops

In [None]:
def plot_swrs_stats(n_swrs, durations, task_times, title, ylabel, savepath=None):
    fig, ax = plt.subplots()
    ind = np.arange(len(task_times))

    rate = n_swrs / durations

    plt.bar(ind, rate)

    labels = ["n = {:.0f}".format(i) for i in n_swrs]
    patches = ax.patches
    for patch, text in zip(patches, labels):
        txt_height = patch.get_height() + (patch.get_height() / 50)
        txt_location = patch.get_x() + (patch.get_width() / 2)
        ax.text(txt_location, txt_height, text, ha='center', va='bottom', size=10)
        
    labels = ["of {:.1f} m".format(i) for i in durations]
    patches = ax.patches
    for patch, text in zip(patches, labels):
        txt_height = patch.get_height() + (patch.get_height() / 50)
        txt_location = patch.get_x() + (patch.get_width() / 2)
        ax.text(txt_location, 0.1, text, ha='center', va='bottom', size=10)

    ax.set_xticks(ind)
    ax.set_xticklabels(task_times, rotation=75, fontsize=14)
    
    plt.ylabel(ylabel)
    plt.title(title)
    
    ax.spines['right'].set_visible(False)
    ax.spines['top'].set_visible(False)
    ax.yaxis.set_ticks_position('left')
    ax.xaxis.set_ticks_position('bottom')
    plt.tight_layout()

    if savepath is not None:
        plt.savefig(savepath + ".png")
        plt.close("all")
    else:
        plt.show()

In [None]:
duration

In [None]:
plot_swrs_stats(n_swrs, duration, task_times, "title", "ylabel")

In [None]:
5/30.

In [None]:
u.n_epochs

In [None]:
pos = position[rest_epochs]

In [None]:
plt.plot(pos.time, pos.y, "k.", ms=3)
plt.show()

In [None]:
plot_spike_counts(info, swrs, spikes, "pauseA")

In [None]:
plot_swr(swrs, lfp, position, spikes, buffer=0.15, n_plots=5)

In [None]:
trial_epochs = get_trials(events, info.task_times["phase3"])
start = trial_epochs[0].start
stop = trial_epochs[0].stop

full_trial = position.time_slice(start, stop)
trial = pos.time_slice(start, stop)
plt.plot(full_trial.time, full_trial.x, "y.")
plt.plot(trial.time, trial.x, "k.")
plt.show()
plt.plot(full_trial.time, full_trial.y, "y.")
plt.plot(trial.time, trial.y, "k.")
plt.show()