In [None]:
%matplotlib inline
import matplotlib.pyplot as plt
import numpy as np
import os
import nept

from loading_data import get_data
from analyze_classy_decode import detect_swr_hilbert_limited_zscore

In [None]:
thisdir = os.getcwd()
pickle_filepath = os.path.join(thisdir, "cache", "pickled")
output_filepath = os.path.join(thisdir, "plots", "oksequenceless")
if not os.path.exists(output_filepath):
    os.makedirs(output_filepath)

In [None]:
import info.r063d2 as r063d2
import info.r063d3 as r063d3
infos = [r063d2, r063d3]

# from run import analysis_infos
# infos = analysis_infos

In [None]:
info = r063d2

In [None]:
# swr params
merge_thresh = 0.02
min_length = 0.05
swr_thresh = (140.0, 250.0)

task_times = ["prerecord", "pauseA", "pauseB", "postrecord"]
maze_segments = ["u", "shortcut", "novel", "other"]

In [None]:
print(info.session_id)
events, position, spikes, lfp, _ = get_data(info)

# Find SWRs for the whole session
swrs = detect_swr_hilbert_limited_zscore(info, 
                                         lfp=lfp,
                                         fs=info.fs, 
                                         thresh=swr_thresh,
                                         times_for_zscore=nept.Epoch([info.task_times["pauseB"].start,
                                                                      info.task_times["pauseB"].stop]),
                                         merge_thresh=merge_thresh, 
                                         min_length=min_length)
swrs = nept.find_multi_in_epochs(spikes, swrs, min_involved=4)

rest_epochs = nept.rest_threshold(position, thresh=12., t_smooth=0.8)

In [None]:
# Restrict SWRs to those during epochs of interest during rest    
phase_swrs = dict()
n_swrs = {task_time: 0 for task_time in task_times}

for task_time in task_times:
    epochs_of_interest = info.task_times[task_time].intersect(rest_epochs)

    phase_swrs[task_time] = epochs_of_interest.overlaps(swrs)
    phase_swrs[task_time] = phase_swrs[task_time][phase_swrs[task_time].durations >= 0.05]

    n_swrs[task_time] += phase_swrs[task_time].n_epochs

print(n_swrs)