In [None]:
%matplotlib inline
import matplotlib.pyplot as plt
import numpy as np
import os
import nept

from loading_data import get_data
from analyze_tuning_curves import get_tuning_curves
from utils_maze import get_trial_idx, get_zones, align_to_event

In [None]:
thisdir = os.getcwd()
pickle_filepath = os.path.join(thisdir, "cache", "pickled")
output_filepath = os.path.join(thisdir, "plots", "binned_swrs")

In [None]:
import info.r063d5 as r063d5
import info.r063d6 as r063d6
infos =[r063d5, r063d6]

from run import spike_sorted_infos
# infos = spike_sorted_infos

In [None]:
def plot_binned_swr(info, binned_swr, filepath):
    plt.plot(binned_swr.time, binned_swr.data, ms=3)
    
    xtick_labels = ["prerecord", "phase1", "pauseA", "phase2", "pauseB", "phase3", "postrecord"]
    xtick_location = []
    for phase in xtick_labels:
        plt.axvline(x=info.task_times[phase].start, color="k", linestyle='--', ms=3)
        xtick_location.append(info.task_times[phase].start)
    
    plt.xticks(xtick_location, xtick_labels, rotation=75)
    plt.tight_layout()
    plt.savefig(os.path.join(filepath, info.session_id+"-binned_swr.png"))
    plt.close()
#     plt.show()

In [None]:
def get_first_shortcut(info, position, events):
    t_start = info.task_times['phase3'].start
    t_stop = info.task_times['phase3'].stop

    sliced_pos = position.time_slice(t_start, t_stop)

    feeder1_times = []
    for feeder1 in events['feeder1']:
        if t_start < feeder1 < t_stop:
            feeder1_times.append(feeder1)

    feeder2_times = []
    for feeder2 in events['feeder2']:
        if t_start < feeder2 < t_stop:
            feeder2_times.append(feeder2)

    path_pos = get_zones(info, sliced_pos)

    trials_idx, trial_epochs = get_trial_idx(path_pos['u'].time, path_pos['shortcut'].time, path_pos['novel'].time,
                                             feeder1_times, feeder2_times, t_stop)


    first_shortcut = trial_epochs[trials_idx['shortcut'][0][0]].start
    second_shortcut = trial_epochs[trials_idx['shortcut'][1][0]].start
    
    return first_shortcut

In [None]:
perievents = []

for info in infos:
    print(info.session_id)
    events, position, spikes, lfp, _ = get_data(info)

    # Remove interneurons
    max_mean_firing = 5
    interneurons = np.zeros(len(spikes), dtype=bool)
    for i, spike in enumerate(spikes):
        if len(spike.time) / info.session_length >= max_mean_firing:
            interneurons[i] = True
    spikes = spikes[~interneurons]

    # Find SWRs for the whole session
    z_thresh = 2.0
    power_thresh = 3.0
    merge_thresh = 0.02
    min_length = 0.05
    swrs = nept.detect_swr_hilbert(lfp, fs=info.fs, thresh=(140.0, 250.0), z_thresh=z_thresh,
                                   power_thresh=power_thresh, merge_thresh=merge_thresh, min_length=min_length)

    # Restrict SWRs to those with 4 or more participating neurons
    swrs = nept.find_multi_in_epochs(spikes, swrs, min_involved=4)
    
    # Find rest epochs for entire session
    epochs_of_interest = nept.rest_threshold(position, thresh=4., t_smooth=0.05)

    swrs = epochs_of_interest.overlaps(swrs)
    swrs = swrs[swrs.durations >= 0.05]

    binsize = 3.
    t_bins = np.arange(lfp.time.min(), lfp.time.max()+binsize, binsize)

    swr_counts = np.histogram(swrs.centers, bins=t_bins)[0]
    binned_swr = nept.AnalogSignal(swr_counts, t_bins[:-1])
    
    plot_binned_swr(info, binned_swr, output_filepath)
    
    first_shortcut = get_first_shortcut(info, position, events)

    perievent = align_to_event(binned_swr, first_shortcut, t_before=100, t_after=200)
    
    perievents.append(perievent)
    
    
times = perievents[0].time

datas = np.zeros((len(times), len(perievents)))
for i, perievent in enumerate(perievents):
    datas[:, i] = np.squeeze(perievent.data)
datas = np.mean(datas, axis=1)


plt.plot(times, datas)
plt.axvline(x=0, color="k", linestyle='--')
plt.tight_layout()
plt.savefig(os.path.join(output_filepath, "combined_perievent-first-shortcut.png"))
plt.close()
#     plt.show()

In [None]:
# for perievent in perievents:
#     plt.plot(perievent.time, perievent.data)
#     plt.axvline(x=0, color="k", linestyle='--')
#     plt.show()

In [None]:
# epochs_of_interest = epochs_of_interest.intersect(rest_epochs)
# epochs_of_interest.n_epochs