In [None]:
import matplotlib.pyplot as plt
import numpy as np
import os
import scipy
import nept

from loading_data import get_data

thisdir = os.getcwd()
pickle_filepath = os.path.join(thisdir, "cache", "pickled")
output_filepath = os.path.join(thisdir, "plots", "zscore")
if not os.path.exists(output_filepath):
    os.makedirs(output_filepath)

In [None]:
import info.r063d2 as r063d2
import info.r068d5 as info
infos = [r063d2, info]

In [None]:
events, position, spikes, lfp, lfp_theta = get_data(info)

In [None]:
start = info.task_times["prerecord"].start
stop = info.task_times["prerecord"].stop

# parameters
z_thresh = 1
merge_thresh = 0.01
min_length = 0.03
fs = info.fs
thresh = (140.0, 250.0)
min_involved = 4

rest_labels = ["prerecord", "pauseA", "pauseB", "postrecord"]
rest_starts = [info.task_times[task_label].start for task_label in rest_labels]
rest_stops = [info.task_times[task_label].stop for task_label in rest_labels]
rest_lfp = lfp.time_slice(rest_starts, rest_stops)

swrs = nept.detect_swr_hilbert(rest_lfp, fs, thresh, z_thresh, merge_thresh=merge_thresh, min_length=min_length)
swrs = nept.find_multi_in_epochs(spikes, swrs, min_involved=min_involved)

swrs = swrs.time_slice(start, stop)

print(swrs.n_epochs)
sliced_lfp = lfp.time_slice(start, stop)

all_spikes = np.sort(np.concatenate([spiketrain.time for spiketrain in spikes]))
sliced_all_spikes = all_spikes[(start <= all_spikes) & (all_spikes <= stop)]

dt = 0.02
std = 0.01
firing_thresh = 20
bin_edges = nept.get_edges(sliced_all_spikes, dt)

convolved_spikes = np.histogram(sliced_all_spikes, bins=bin_edges)[0].astype(float)
convolved_spikes = nept.gaussian_filter(convolved_spikes, std=std, dt=dt)

In [None]:
zspikes = scipy.stats.zscore(convolved_spikes)

In [None]:
plt.plot(zspikes)
plt.show()

In [None]:
z_spikes_thresh = 1
# Finding locations where the power changes
detect = zspikes > z_spikes_thresh
detect = np.hstack([0, detect, 0])  # pad to detect first or last element change
signal_change = np.diff(detect.astype(int))

In [None]:
start_swr_idx = np.where(signal_change == 1)[0]
stop_swr_idx = np.where(signal_change == -1)[0] - 1

# Getting times associated with these power changes
start_time = lfp.time[start_swr_idx]
stop_time = lfp.time[stop_swr_idx]

# Removing doubles
start_times = start_time[(stop_time - start_time) != 0]
stop_times = stop_time[(stop_time - start_time) != 0]

mua = nept.Epoch(np.array([start_times, stop_times]))

In [None]:
print("n_swrs before mua thresh:", swrs.n_epochs)

In [None]:
swrs_mua = swrs.overlaps(mua)

In [None]:
print("n_swrs after mua thresh:", swrs_mua.n_epochs)

In [None]:


# swr params
swr_params = dict()
swr_params["z_thresh"] = 3
swr_params["power_thresh"] = 3
swr_params["merge_thresh"] = 0.02
swr_params["min_length"] = 0.05
swr_params["swr_thresh"] = (140.0, 250.0)
swr_params["min_involved"] = 4

swrs = nept.detect_swr_hilbert(lfp,
                               fs=info.fs,
                               thresh=swr_params["swr_thresh"],
                               z_thresh=swr_params["z_thresh"],
                               power_thresh=swr_params["power_thresh"],
                               merge_thresh=swr_params["merge_thresh"],
                               min_length=swr_params["min_length"])
swrs = nept.find_multi_in_epochs(spikes, swrs, min_involved=swr_params["min_involved"])

In [None]:
swrs.n_epochs

In [None]:
def butter_bandpass(signal, thresh, fs, order=4):
    """Filters signal using butterworth filter.

    Parameters
    ----------
    signal : nept.LFP
    fs : int
        Eg. 2000. Should get this from experiment-specifics.
    thresh : tuple
        With format (lowcut, highcut).
        Typically (140.0, 250.0) for sharp-wave ripple detection.
    order : int
        Default set to 4.

    Returns
    -------
    filtered_butter : np.array

    """
    signal = np.squeeze(signal)
    nyquist = 0.5 * fs

    b, a = scipy.signal.butter(order, [thresh[0]/nyquist, thresh[1]/nyquist], btype='band')
    filtered_butter = scipy.signal.filtfilt(b, a, signal)

    return filtered_butter

def next_regular(target):
    """Finds the next regular number greater than or equal to target.

    Regular numbers are composites of the prime factors 2, 3, and 5.
    Also known as 5-smooth numbers or Hamming numbers, these are the optimal
    size for inputs to fast-fourier transforms (FFTPACK).

    Parameters
    ----------
    target : positive int

    Returns
    -------
    match : int

    Notes
    -----
    This function was taken from the scipy.signal.signaltools module.
    See http://scipy.org/scipylib/
    """
    if target <= 6:
        print(target)
        return target

    # Quickly check if it's already a power of 2
    if not (target & (target-1)):
        return target

    match = float('inf')  # Anything found will be smaller
    p5 = 1
    while p5 < target:
        p35 = p5
        while p35 < target:
            # Ceiling integer division, avoiding conversion to float
            # (quotient = ceil(target / p35))
            quotient = -(-target // p35)

            # Quickly find next power of 2 >= quotient
            p2 = 2**((quotient - 1).bit_length())

            N = p2 * p35
            if N == target:
                return N
            elif N < match:
                match = N
            p35 *= 3
            if p35 == target:
                return p35
        if p35 < match:
            match = p35
        p5 *= 5
        if p5 == target:
            return p5
    if p5 < match:
        match = p5
    return match

In [None]:
def detect_swr_hilbert_limited_zscore(lfp, 
                            fs, 
                            thresh, 
                            times_for_zscore,
                            z_thresh=3,
                            merge_thresh=0.02, 
                            min_length=0.01):
    """Finds sharp-wave ripple (SWR) times and indices.

    Parameters
    ----------
    lfp : nept.LocalFieldPotential
    fs : int
        Experiment-specific, something in the range of 2000 typical.
    thresh : tuple
        With format (lowcut, highcut).
        Typically (140.0, 250.0) for sharp-wave ripple detection.
    times_for_zscore : nept.Epoch
        Containing the epoch for which the zscore will be computed and the z_thresh applied.
    z_thresh : int or float
        The default is set to 3
    merge_thres : int or float
        The default is set to 0.02
    min_length : float
        Any sequence less than this amount is not considered a sharp-wave ripple.
        The default is set to 0.01.

    Returns
    -------
    swrs : nept.Epoch
        Containing nept.LocalFieldPotential for each SWR event

    """
    # Filtering signal with butterworth fitler
    filtered_butter = nept.butter_bandpass(lfp.data, thresh, fs)

    # Get LFP power (using Hilbert)
    # Zero padding to nearest regular number to speed up fast fourier transforms (FFT) computed in the hilbert function.
    # Regular numbers are composites of the prime factors 2, 3, and 5.
    hilbert_n = nept.next_regular(lfp.n_samples)
    power = np.abs(scipy.signal.hilbert(filtered_butter, N=hilbert_n))
    
    # removing the zero padding now that the power is computed
    power_lfp = nept.AnalogSignal(power[:lfp.n_samples], lfp.time)

    # Apply zscore thresh to restricted data to find power thresh
    sliced_power_lfp = power_lfp.time_slice(times_for_zscore.start, times_for_zscore.stop)
    zpower = scipy.stats.zscore(np.squeeze(sliced_power_lfp.data))
    
    zthresh_idx = (np.abs(zpower-z_thresh)).argmin()
    power_thresh = sliced_power_lfp.data[zthresh_idx][0]
    print(power_thresh)

    # Finding locations where the power changes
    detect = np.squeeze(power_lfp.data) > power_thresh
    detect = np.hstack([0, detect, 0])  # pad to detect first or last element change
    signal_change = np.diff(detect.astype(int))

    start_swr_idx = np.where(signal_change == 1)[0]
    stop_swr_idx = np.where(signal_change == -1)[0] - 1

    # Getting times associated with these power changes
    start_time = lfp.time[start_swr_idx]
    stop_time = lfp.time[stop_swr_idx]

    # Merging ranges that are closer - in time - than the merge_threshold.
    no_double = start_time[1:] - stop_time[:-1]
    merge_idx = np.where(no_double < merge_thresh)[0]
    start_merged = np.delete(start_time, merge_idx + 1)
    stop_merged = np.delete(stop_time, merge_idx)
    start_merged_idx = np.delete(start_swr_idx, merge_idx + 1)
    stop_merged_idx = np.delete(stop_swr_idx, merge_idx)

    # Removing ranges that are shorter - in time - than the min_length value.
    swr_len = stop_merged - start_merged
    short_idx = np.where(swr_len < min_length)[0]
    start_merged = np.delete(start_merged, short_idx)
    stop_merged = np.delete(stop_merged, short_idx)
    start_merged_idx = np.delete(start_merged_idx, short_idx)
    stop_merged_idx = np.delete(stop_merged_idx, short_idx)

    swrs = nept.Epoch(np.array([start_merged, stop_merged]))

    return swrs

In [None]:
temp_swrs = detect_swr_hilbert_limited_zscore(lfp,
                                              fs=info.fs,
                                              thresh=swr_params["swr_thresh"],
                                              times_for_zscore=nept.Epoch([info.task_times["pauseB"].start, info.task_times["pauseB"].stop]),
                                              z_thresh=swr_params["z_thresh"],
                                              merge_thresh=swr_params["merge_thresh"],
                                              min_length=swr_params["min_length"])
temp_swrs = nept.find_multi_in_epochs(spikes, temp_swrs, min_involved=swr_params["min_involved"])

In [None]:
temp_swrs.n_epochs

In [None]:
buffer = 0.1
swr_highlight = "#fc4e2a"

for i in range(10):
    start = temp_swrs.starts[i]
    stop = temp_swrs.stops[i]
 
    start_idx = nept.find_nearest_idx(lfp.time, start - buffer)
    stop_idx = nept.find_nearest_idx(lfp.time, stop + buffer)
    plt.plot(lfp.time[start_idx:stop_idx], lfp.data[start_idx:stop_idx], color="k", lw=0.3, alpha=0.9)

    start_idx = nept.find_nearest_idx(lfp.time, start)
    stop_idx = nept.find_nearest_idx(lfp.time, stop)
    plt.plot(lfp.time[start_idx:stop_idx], lfp.data[start_idx:stop_idx], color=swr_highlight, lw=0.6)
    plt.axis("off")
    plt.show()

In [None]:
buffer = 0.1
swr_highlight = "#fc4e2a"

for i in range(10):
    start = swrs.starts[i]
    stop = swrs.stops[i]
 
    start_idx = nept.find_nearest_idx(lfp.time, start - buffer)
    stop_idx = nept.find_nearest_idx(lfp.time, stop + buffer)
    plt.plot(lfp.time[start_idx:stop_idx], lfp.data[start_idx:stop_idx], color="k", lw=0.3, alpha=0.9)

    start_idx = nept.find_nearest_idx(lfp.time, start)
    stop_idx = nept.find_nearest_idx(lfp.time, stop)
    plt.plot(lfp.time[start_idx:stop_idx], lfp.data[start_idx:stop_idx], color=swr_highlight, lw=0.6)
    plt.axis("off")
    plt.show()

In [None]:
plt.plot(lfp.time, lfp.data)
plt.show()

In [None]:
sliced_lfp = lfp.time_slice(info.task_times["pauseB"].start, info.task_times["pauseB"].stop)
plt.plot(sliced_lfp.time, sliced_lfp.data)
plt.show()

In [None]:
sliced_lfp = lfp.time_slice(info.task_times["phase1"].start, info.task_times["phase1"].stop)
plt.plot(sliced_lfp.time, sliced_lfp.data)
plt.show()

In [None]:
fs=info.fs
thresh=swr_params["swr_thresh"]
    
# Filtering signal with butterworth fitler
filtered_butter = butter_bandpass(lfp.data, thresh, fs)

plt.plot(filtered_butter[:1000])
plt.show()

In [None]:
times_for_zscore = nept.Epoch([info.task_times["pauseB"].start, info.task_times["pauseB"].stop])
z_thresh = 3

# Get LFP power (using Hilbert)
# Zero padding to nearest regular number to speed up fast fourier transforms (FFT) computed in the hilbert function.
# Regular numbers are composites of the prime factors 2, 3, and 5.
hilbert_n = next_regular(lfp.n_samples)
power = np.abs(scipy.signal.hilbert(filtered_butter, N=hilbert_n))

# removing the zero padding now that the power is computed
power_lfp = nept.AnalogSignal(power[:lfp.n_samples], lfp.time)

# Apply zscore thresh to restricted data to find power thresh
sliced_power_lfp = power_lfp.time_slice(times_for_zscore.start, times_for_zscore.stop)
zpower = scipy.stats.zscore(np.squeeze(sliced_power_lfp.data))

zthresh_idx = (np.abs(zpower-z_thresh)).argmin()
power_thresh = sliced_power_lfp.data[zthresh_idx][0]

In [None]:
plt.plot(power_lfp.time, power_lfp.data)
plt.axhline(power_thresh)
plt.show()

In [None]:
power_thresh