In [None]:
%matplotlib inline
import matplotlib.pyplot as plt
from matplotlib import gridspec
import numpy as np
import itertools
import scipy
import os
import nept

from matplotlib import animation, rc
from IPython.display import HTML

from loading_data import get_data
from exploring_swrs import plot_swr

In [None]:
import info.r068d6 as info

In [None]:
print(info.session_id)
events, position, spikes, lfp, _ = get_data(info)

# Remove interneurons
max_mean_firing = 5
interneurons = np.zeros(len(spikes), dtype=bool)
for i, spike in enumerate(spikes):
    if len(spike.time) / info.session_length >= max_mean_firing:
        interneurons[i] = True
spikes = spikes[~interneurons]

# Find SWRs for the whole session
z_thresh = 2.0
power_thresh = 3.0
merge_thresh = 0.02
min_length = 0.05
swrs = nept.detect_swr_hilbert(lfp, fs=info.fs, thresh=(140.0, 250.0), z_thresh=z_thresh,
                               power_thresh=power_thresh, merge_thresh=merge_thresh, min_length=min_length)
print("Total swrs for this session:", str(swrs.n_epochs))

# Restrict SWRs to those with 4 or more participating neurons
swrs = nept.find_multi_in_epochs(spikes, swrs, min_involved=4)
print("N swrs for this session with at least 4 active neurons:", str(swrs.n_epochs))

In [None]:
# Find rest epochs for entire session
rest_epochs = nept.rest_threshold(position, thresh=12., t_smooth=0.8)

task_times = ["prerecord", "phase1", "pauseA", "phase2", "pauseB", "phase3", "postrecord"]

n_swrs = np.zeros(len(task_times))
duration = np.zeros(len(task_times))

In [None]:
resting_only = True
plot_swr_spike_counts=False

In [None]:
for i, task_time in enumerate(task_times):
    # Restrict SWRs to those during epochs of interest
    epochs_of_interest = info.task_times[task_time]

    if resting_only:
        epochs_of_interest = epochs_of_interest.intersect(rest_epochs)

    if epochs_of_interest.n_epochs == 0:
        print("No epochs of interest identified.")
    else:
        duration[i] = np.sum(epochs_of_interest.durations) / 60.

        phase_swrs = epochs_of_interest.overlaps(swrs)
        phase_swrs = phase_swrs[phase_swrs.durations >= 0.05]

        n_swrs[i] = phase_swrs.n_epochs

In [None]:
np.where(epochs_of_interest.stops < 9313.804756)

In [None]:
phase_swrs.starts[0], phase_swrs.stops[0]

In [None]:
epochs_of_interest[5].start, epochs_of_interest[5].stop

In [None]:
epochs_of_interest.starts, epochs_of_interest.stops

In [None]:
plot_swr(phase_swrs, lfp, position, spikes, n_plots=1)