In [None]:
import matplotlib.pyplot as plt
from matplotlib.lines import Line2D
from collections import OrderedDict
import os
import numpy as np
import scipy
import nept

from loading_data import get_data

import info.r068d7 as r068d7
import info.r068d8 as r068d8

from run import analysis_infos

infos = analysis_infos
# infos = [r068d7, r068d8]

thisdir = os.getcwd()
pickle_filepath = os.path.join(thisdir, "cache", "pickled")
output_filepath = os.path.join(thisdir, "plots", "n_swrs")
if not os.path.exists(output_filepath):
    os.makedirs(output_filepath)
    
# thisdir = os.path.dirname(os.path.realpath(__file__))
# output_filepath = os.path.join(thisdir, "plots", "n_swrs")
# if not os.path.exists(output_filepath):
#     os.makedirs(output_filepath)

n_swr_before_mua = OrderedDict()
n_swr_after_mua = OrderedDict()

for info in infos:
    events, position, spikes, lfp, lfp_theta = get_data(info)
    
#     n_swr_before_mua = OrderedDict()
#     n_swr_after_mua = OrderedDict()

    task_times = ["prerecord", "pauseA", "pauseB", "postrecord"]
#     task_times = ["pauseB", "postrecord"]
    for task_time in task_times:
        # parameters
        z_thresh = 1.
        merge_thresh = 0.01
        min_length = 0.03
        fs = info.fs
        thresh = (140.0, 250.0)
        min_involved = 4

        rest_labels = ["prerecord", "pauseA", "pauseB", "postrecord"]
        rest_starts = [info.task_times[task_label].start for task_label in rest_labels]
        rest_stops = [info.task_times[task_label].stop for task_label in rest_labels]
        rest_lfp = lfp.time_slice(rest_starts, rest_stops)

        swrs = nept.detect_swr_hilbert(rest_lfp, fs, thresh, z_thresh, merge_thresh=merge_thresh, min_length=min_length)
        swrs = nept.find_multi_in_epochs(spikes, swrs, min_involved=min_involved)

        swrs = swrs.time_slice(info.task_times[task_time].start, info.task_times[task_time].stop)

        print(task_time, "n_swrs before mua thresh:", swrs.n_epochs)
        n_swr_before_mua[task_time] = swrs.n_epochs
        sliced_lfp = lfp.time_slice(info.task_times[task_time].start, info.task_times[task_time].stop)

        all_spikes = np.sort(np.concatenate([spiketrain.time for spiketrain in spikes]))

        dt = 0.02
        std = 0.01
        bin_edges = nept.get_edges(sliced_lfp.time, dt)

        convolved_spikes = np.histogram(all_spikes, bins=bin_edges)[0].astype(float)
        convolved_spikes = nept.gaussian_filter(convolved_spikes, std=std, dt=dt)

        z_spikes_thresh = 2
        multi_unit = nept.get_epoch_from_zscored_thresh(convolved_spikes, bin_edges, thresh=z_spikes_thresh)

        #for plotting
        sliced_all_spikes = all_spikes[(info.task_times[task_time].start <= all_spikes) &
                                       (all_spikes <= info.task_times[task_time].stop)]
        zscored = scipy.stats.zscore(convolved_spikes)
        zthresh_idx = (np.abs(zscored - z_spikes_thresh)).argmin()
        raw_thresh = convolved_spikes[zthresh_idx]

        swrs_with_mua = multi_unit.overlaps(swrs)
        # swrs_with_mua = swrs.overlaps(multi_unit)
        print(task_time, "n_swrs after mua thresh:", swrs_with_mua.n_epochs)
        n_swr_after_mua[task_time] = swrs_with_mua.n_epochs

        if 0:
            # Finding epochs close to the threshold
            zscored_lfp = nept.AnalogSignal(scipy.stats.zscore(rest_lfp.data), rest_lfp.time)
            sliced_zscored_lfp = zscored_lfp.time_slice(info.task_times[task_time].start, info.task_times[task_time].stop)

            dist_from_thresh = np.zeros(swrs_with_mua.n_epochs)
            for i, (start, stop) in enumerate(zip(swrs_with_mua.starts, swrs_with_mua.stops)):
                this_swr_lfp = sliced_zscored_lfp.time_slice(start, stop)
                dist_from_thresh[i] = np.abs(z_thresh - np.max(this_swr_lfp.data))

            n_near_thresh = 10
            n_near_thresh = min(n_near_thresh, swrs_with_mua.n_epochs-1)
            idx_near_thresh = np.argpartition(dist_from_thresh, n_near_thresh)[:n_near_thresh]
            print("Mean distance of", str(n_near_thresh), "near thresh:",
                  str(np.round(np.mean(dist_from_thresh[idx_near_thresh]), 3)))

            fig, ax = plt.subplots()
            plt.plot(bin_edges[:-1], convolved_spikes/(50*2000)+0.00025, "g")
            plt.axhline(raw_thresh/(50*2000)+0.00025, color="k")
            plt.plot(sliced_lfp.time, sliced_lfp.data)
            plt.plot(sliced_all_spikes, np.ones(len(sliced_all_spikes))*0.0002, ".")
            for start, stop in zip(swrs.starts, swrs.stops):
                this_swr_lfp = lfp.time_slice(start, stop)
                plt.plot(this_swr_lfp.time, this_swr_lfp.data, color="c")
            for start, stop in zip(multi_unit.starts, multi_unit.stops):
                this_swr_lfp = lfp.time_slice(start, stop)
                plt.plot(this_swr_lfp.time, this_swr_lfp.data, "y")
            for i, (start, stop) in enumerate(zip(swrs_with_mua.starts, swrs_with_mua.stops)):
                plt.fill_between([start, stop], np.max(lfp.data), np.min(lfp.data), color="#cccccc")
                if i in idx_near_thresh:
                    color = "r"
                else:
                    color = "b"
                this_swr_lfp = lfp.time_slice(start, stop)
                plt.plot(this_swr_lfp.time, this_swr_lfp.data, color=color)
            plt.text(0.01, 0.01, "n_swrs: " + str(swrs_with_mua.n_epochs), transform=ax.transAxes)

            custom_lines = [Line2D([0], [0], color="c", lw=2),
                            Line2D([0], [0], color="y", lw=2),
                            Line2D([0], [0], color="r", lw=2),
                            Line2D([0], [0], color="#cccccc", lw=2)]
            ax.legend(custom_lines, ['SWRs', 'MUA', 'SWRs-MUA-near_thresh', 'SWRs-MUA'], fontsize=12)
            plt.title(info.session_id+" " + task_time + " lfp_thresh:"+str(z_thresh)+", mua_thresh:"+str(z_spikes_thresh),
                      fontsize=12)
            # plt.xlim(info.task_times[task_time].start, info.task_times[task_time].stop)
            plt.show()


        # plt.savefig(os.path.join(output_filepath, info.session_id+"_check-swr-prerecord_zthresh"+str(z_thresh)+".png"))
        # plt.close()

#     plt.bar(n_swr_before_mua.keys(), n_swr_before_mua.values(), width=1, color='g')
#     plt.title(info.session_id + " n_swrs before MUA")
#     plt.show()

title = "n_swrs_by_phase"
plt.bar(n_swr_after_mua.keys(), n_swr_after_mua.values(), width=1, color='g')
plt.title(title)
# plt.savefig(os.path.join(output_filepath, title+".png"))
# plt.close()
plt.show()


In [None]:
fig, ax = plt.subplots()
title = "n_swrs_by_phase"
ax.bar(n_swr_after_mua.keys(), n_swr_after_mua.values(), width=1, color='g')
for i, v in enumerate(n_swr_after_mua.values()):
    ax.text(i, v + 3, str(v))
plt.title(title)
plt.ylim(0, max(n_swr_after_mua.values())+50)
# plt.savefig(os.path.join(output_filepath, title+".png"))
# plt.close()
plt.show()

In [None]:
max(n_swr_after_mua.values()) +10

In [None]:
1/0

In [None]:
%matplotlib inline
import matplotlib.pyplot as plt
import os
import numpy as np
import scipy
import nept

from loading_data import get_data

import info.r063d5 as info
events, position, spikes, lfp, lfp_theta = get_data(info)

task_time = "prerecord"

# parameters
z_thresh = 1.5
merge_thresh = 0.01
min_length = 0.03
fs = info.fs
thresh = (140.0, 250.0)
min_involved = 4

rest_labels = ["prerecord", "pauseA", "pauseB", "postrecord"]
rest_starts = [info.task_times[task_label].start for task_label in rest_labels]
rest_stops = [info.task_times[task_label].stop for task_label in rest_labels]
rest_lfp = lfp.time_slice(rest_starts, rest_stops)

swrs = nept.detect_swr_hilbert(rest_lfp, fs, thresh, z_thresh, merge_thresh=merge_thresh, min_length=min_length)
swrs = nept.find_multi_in_epochs(spikes, swrs, min_involved=min_involved)

swrs = swrs.time_slice(info.task_times[task_time].start, info.task_times[task_time].stop)

print("n_swrs before mua thresh:", swrs.n_epochs)
sliced_lfp = lfp.time_slice(info.task_times[task_time].start, info.task_times[task_time].stop)

all_spikes = np.sort(np.concatenate([spiketrain.time for spiketrain in spikes]))

dt = 0.02
std = 0.01
bin_edges = nept.get_edges(sliced_lfp.time, dt)

convolved_spikes = np.histogram(all_spikes, bins=bin_edges)[0].astype(float)
convolved_spikes = nept.gaussian_filter(convolved_spikes, std=std, dt=dt)

z_spikes_thresh = 2
multi_unit = nept.get_epoch_from_zscored_thresh(convolved_spikes, bin_edges, thresh=z_spikes_thresh)

#for plotting
sliced_all_spikes = all_spikes[(info.task_times[task_time].start <= all_spikes) & (all_spikes <= info.task_times[task_time].stop)]
zscored = scipy.stats.zscore(convolved_spikes)
zthresh_idx = (np.abs(zscored - z_spikes_thresh)).argmin()
raw_thresh = convolved_spikes[zthresh_idx]

these_swrs = multi_unit.overlaps(swrs)

In [None]:
these_swrs.n_epochs

In [None]:
zscored_lfp = nept.AnalogSignal(scipy.stats.zscore(rest_lfp.data), rest_lfp.time)
sliced_zscored_lfp = zscored_lfp.time_slice(info.task_times[task_time].start, info.task_times[task_time].stop)

In [None]:
plt.plot(zscored_lfp.time, zscored_lfp.data)
plt.show()

In [None]:
zscored_lfp = nept.AnalogSignal(scipy.stats.zscore(rest_lfp.data), rest_lfp.time)
sliced_zscored_lfp = zscored_lfp.time_slice(info.task_times[task_time].start, info.task_times[task_time].stop)

dist_from_thresh = np.zeros(these_swrs.n_epochs)
for i, (start, stop) in enumerate(zip(these_swrs.starts, these_swrs.stops)):
    this_swr_lfp = sliced_zscored_lfp.time_slice(start, stop)
    dist_from_thresh[i] = np.abs(z_thresh - np.max(this_swr_lfp.data))
dist_from_thresh

In [None]:
n_near_thresh = 53
n_near_thresh = min(n_near_thresh, these_swrs.n_epochs-1)
idx_near_thresh = np.argpartition(dist_from_thresh, n_near_thresh)[:n_near_thresh]
okletsdothis[idx_near_thresh]

In [None]:
1/0

In [None]:
%matplotlib inline
import matplotlib
import matplotlib.pyplot as plt
import os
import numpy as np
import scipy
# import mpld3
import nept

from loading_data import get_data

# mpld3.enable_notebook()

In [None]:
import info.r063d5 as info

In [None]:
events, position, spikes, lfp, lfp_theta = get_data(info)

In [None]:
start = info.task_times["prerecord"].start
stop = info.task_times["prerecord"].stop

# parameters
z_thresh = 1.5
merge_thresh = 0.02
min_length = 0.05
fs = info.fs
thresh = (140.0, 250.0)
min_involved = 4

In [None]:
swrs = nept.detect_swr_hilbert(lfp, fs, thresh, z_thresh, merge_thresh=merge_thresh, min_length=min_length)

start_keeps = (swrs.starts >= start) & (swrs.starts <= stop)
stop_keeps = (swrs.stops >= start) & (swrs.stops <= stop)
swrs = nept.Epoch([swrs.starts[start_keeps], swrs.stops[stop_keeps]])

swrs.n_epochs

In [None]:
buffer=0.1

for start, stop in zip(swrs.starts, swrs.stops):
    this_swr_lfp = lfp.time_slice(start-buffer, stop+buffer)
    plt.plot(this_swr_lfp.time, this_swr_lfp.data, "k")
    this_swr_lfp = lfp.time_slice(start, stop)
    plt.plot(this_swr_lfp.time, this_swr_lfp.data, "r")
    plt.show()

In [None]:
start = info.task_times["prerecord"].start
stop = info.task_times["prerecord"].stop

sliced_lfp = lfp.time_slice(start, stop)
plt.plot(sliced_lfp.time, sliced_lfp.data, "k")

for start, stop in zip(swrs.starts, swrs.stops):
    this_swr_lfp = lfp.time_slice(start, stop)
    plt.plot(this_swr_lfp.time, this_swr_lfp.data, "r")
plt.show()

In [None]:
all_spikes = np.sort(np.concatenate([spiketrain.time for spiketrain in spikes]))
sliced_all_spikes = all_spikes[(start <= all_spikes) & (all_spikes <= stop)]

dt = 0.025
bin_edges = nept.get_edges(sliced_all_spikes, dt)

std = 0.01
filtered_spikes = np.histogram(sliced_all_spikes, bins=bin_edges)[0].astype(float)
filtered_spikes = nept.gaussian_filter(filtered_spikes, std=std, dt=dt)

In [None]:
fig, ax = plt.subplots()
plt.plot(bin_edges[:-1], filtered_spikes)
ax.fill_between((2200, 2210), 100, color="k", alpha=0.1)
plt.show()

In [None]:
bin_edges = nept.get_edges(sliced_all_spikes, 0.025)

n_bins = 3
square_filter = np.ones(n_bins)
shouldthisbesquare = np.convolve(np.histogram(sliced_all_spikes, bins=bin_edges)[0].astype(float), 
                                 square_filter, mode="same")
times = np.linspace(sliced_lfp.time[0], sliced_lfp.time[-1], shouldthisbesquare.shape[0])

In [None]:
sliced_lfp = lfp.time_slice(start, stop)

In [None]:
plt.plot(sliced_lfp.time, sliced_lfp.data)
plt.show()

In [None]:
1/0

In [None]:
events, position, spikes, lfp, lfp_theta = get_data(info)

start = info.task_times["prerecord"].start
stop = info.task_times["prerecord"].stop

# parameters
z_thresh = 1.
merge_thresh = 0.01
min_length = 0.03
fs = info.fs
thresh = (140.0, 250.0)
min_involved = 4

swrs = nept.detect_swr_hilbert(lfp, fs, thresh, z_thresh, merge_thresh=merge_thresh, min_length=min_length)
swrs = nept.find_multi_in_epochs(spikes, swrs, min_involved=min_involved)

swrs = swrs.time_slice(start, stop)

print(swrs.n_epochs)

sliced_lfp = lfp.time_slice(start, stop)

all_spikes = np.sort(np.concatenate([spiketrain.time for spiketrain in spikes]))
sliced_all_spikes = all_spikes[(start <= all_spikes) & (all_spikes <= stop)]

dt = 0.02
std = 0.01
firing_thresh = 20
bin_edges = nept.get_edges(sliced_all_spikes, dt)

convolved_spikes = np.histogram(sliced_all_spikes, bins=bin_edges)[0].astype(float)
convolved_spikes = nept.gaussian_filter(convolved_spikes, std=std, dt=dt)

plt.plot(bin_edges[:-1], convolved_spikes)
plt.show()

# # Finding locations where the firing rate is above thresh
# detect = convolved_spikes > firing_thresh
# detect = np.hstack([0, detect, 0])  # pad to detect first or last element change
# signal_change = np.diff(detect.astype(int))

# start_idx = np.where(signal_change == 1)[0]
# stop_idx = np.where(signal_change == -1)[0] - 1

# high_firing_rates = nept.Epoch([lfp.time[start_idx], lfp.time[stop_idx]])

# fig, ax = plt.subplots()
# plt.plot(sliced_lfp.time, sliced_lfp.data)
# plt.plot(sliced_all_spikes, np.ones(len(sliced_all_spikes))*0.0002, ".")
# plt.plot(bin_edges[:-1], convolved_spikes/(50*2000)+0.00025)
# for start, stop in zip(swrs.starts, swrs.stops):
#     this_swr_lfp = lfp.time_slice(start, stop)
#     plt.plot(this_swr_lfp.time, this_swr_lfp.data, "r")
# for start, stop in zip(high_firing_rates.starts, high_firing_rates.stops):
#     ax.fill_between((start, stop), 50, color="k", alpha=0.1)
# plt.show()

In [None]:
plt.plot(bin_edges[:-1], convolved_spikes)
plt.show()

In [None]:
firing_rate_above_thresh = convolved_spikes >= firing_thresh
for start, stop in zip(swrs.starts, swrs.stops):
    start_idx = nept.find_nearest_idx(lfp.time, start)
    stop_idx = nept.find_nearest_idx(lfp.time, stop)
    if np.any(firing_rate_above_thresh[start_idx:stop_idx]):
        print("huh")

In [None]:
1/0

In [None]:
# Filtering signal with butterworth fitler
filtered_butter = nept.butter_bandpass(lfp.data, thresh, fs)

# Get LFP power (using Hilbert) and z-score the power
# Zero padding to nearest regular number to speed up fast fourier transforms (FFT) computed in the hilbert function.
# Regular numbers are composites of the prime factors 2, 3, and 5.
hilbert_n = nept.next_regular(lfp.n_samples)
power = np.abs(scipy.signal.hilbert(filtered_butter, N=hilbert_n))
power = power[:lfp.n_samples]  # removing the zero padding now that the power is computed
zpower = scipy.stats.zscore(power)

# Finding locations where the power changes
detect = zpower > z_thresh
detect = np.hstack([0, detect, 0])  # pad to detect first or last element change
signal_change = np.diff(detect.astype(int))

start_swr_idx = np.where(signal_change == 1)[0]
stop_swr_idx = np.where(signal_change == -1)[0] - 1

# Getting times associated with these power changes
start_time = lfp.time[start_swr_idx]
stop_time = lfp.time[stop_swr_idx]

# Removing doubles
start_times = start_time[(stop_time - start_time) != 0]
stop_times = stop_time[(stop_time - start_time) != 0]

swrs = nept.Epoch(np.array([start_times, stop_times]))

# Merging epochs that are closer - in time - than the merge_threshold.
swrs = swrs.merge(gap=merge_thresh)

# Removing epochs that are shorter - in time - than the min_length value.
keep_indices = swrs.durations >= min_length
swrs = nept.Epoch([swrs.starts[keep_indices], swrs.stops[keep_indices]])

In [None]:
swrs.n_epochs

In [None]:
swr_lfps = []
for start, stop in zip(swrs.starts, swrs.stops):
    swr_lfps.append(lfp.time_slice(start, stop))

In [None]:
plt.plot(lfp.time, lfp.data)
for swr_lfp in swr_lfps[:5]:
    plt.plot(swr_lfp.time, swr_lfp.data)
plt.show()

In [None]:
times_for_zscore = nept.Epoch([info.task_times["pauseB"].start, info.task_times["pauseB"].stop])

In [None]:
sliced_power_lfp = power_lfp.time_slice(times_for_zscore.start, times_for_zscore.stop)
zpower = scipy.stats.zscore(np.squeeze(sliced_power_lfp.data))

In [None]:
plt.plot(zpower)
plt.show()

In [None]:
zthresh_idx = (np.abs(zpower - z_thresh)).argmin()

In [None]:
zthresh_idx = (np.abs(zpower - z_thresh)).argmin()
power_thresh = sliced_power_lfp.data[zthresh_idx][0]

In [None]:
sliced_power_lfp.data[zthresh_idx]

In [None]:
plt.plot(power_lfp.time, power_lfp.data)
plt.axhline(power_thresh, color="m")
plt.show()

In [None]:
detect = np.squeeze(power_lfp.data) > power_thresh
detect = np.hstack([0, detect, 0])  # pad to detect first or last element change
signal_change = np.diff(detect.astype(int))

start_swr_idx = np.where(signal_change == 1)[0]
stop_swr_idx = np.where(signal_change == -1)[0] - 1

In [None]:
start_times = lfp.time[start_swr_idx]
stop_times = lfp.time[stop_swr_idx]

these_swrs = nept.Epoch([start_times, stop_times])

In [None]:
start_times[0], stop_times[0]

In [None]:
start_times = lfp.time[start_swr_idx]
stop_times = lfp.time[stop_swr_idx]

no_double = start_time[1:] - stop_time[:-1]
merge_idx = np.where(no_double < merge_thresh)[0]
start_merged = np.delete(start_time, merge_idx + 1)
stop_merged = np.delete(stop_time, merge_idx)
start_merged_idx = np.delete(start_swr_idx, merge_idx + 1)
stop_merged_idx = np.delete(stop_swr_idx, merge_idx)

In [None]:
no_double

In [None]:
len(stop_merged_idx)

In [None]:
swr_len = stop_merged - start_merged
short_idx = np.where(swr_len < min_length)[0]
start_merged = np.delete(start_merged, short_idx)
stop_merged = np.delete(stop_merged, short_idx)
start_merged_idx = np.delete(start_merged_idx, short_idx)
stop_merged_idx = np.delete(stop_merged_idx, short_idx)

In [None]:
swrs = nept.Epoch(np.array([start_merged, stop_merged]))

In [None]:
swrs.n_epochs

In [None]:
swr_lfps = []
for start, stop in zip(swrs.starts, swrs.stops):
    swr_lfps.append(lfp.time_slice(start, stop))

In [None]:
start = info.task_times["prerecord"].start
stop = info.task_times["prerecord"].stop
sliced_lfp = lfp.time_slice(start, stop)

In [None]:
plt.plot(sliced_lfp.time, sliced_lfp.data)
for swr_lfp in swr_lfps:
    plt.plot(swr_lfp.time, swr_lfp.data)
plt.show()