In [None]:
%matplotlib inline
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
from matplotlib.patches import Patch
import numpy as np
import scipy
import os
import scalebar
import nept

from loading_data import get_data
from analyze_tuning_curves import get_only_tuning_curves
from utils_plotting import plot_over_space
from utils_maze import get_zones, get_bin_centers, get_matched_trials

In [None]:
thisdir = os.getcwd()
pickle_filepath = os.path.join(thisdir, "cache", "pickled")
output_filepath = os.path.join(thisdir, "plots", "supplemental", "n_swrs")
if not os.path.exists(output_filepath):
    os.makedirs(output_filepath)

In [None]:
import info.r066d1 as r066d1
import info.r063d3 as r063d3
# infos = [r066d1]
from run import analysis_infos, day8_infos
infos = analysis_infos
# infos = day8_infos

In [None]:
# swr params
z_thresh = 2.0
merge_thresh = 0.02
min_length = 0.05
swr_thresh = (140.0, 250.0)

task_times = ["prerecord", "phase1", "pauseA", "phase2", "pauseB", "phase3", "postrecord"]
maze_segments = ["u", "shortcut", "novel", "other"]

In [None]:
savefig = True

for info in infos:
    print(info.session_id)
    events, position, spikes, lfp, _ = get_data(info)
    
    speed = position.speed()

    # Find SWRs for the whole session
    
    swrs = nept.detect_swr_hilbert(lfp, fs=info.fs, thresh=swr_thresh, z_thresh=z_thresh,
                                   merge_thresh=merge_thresh, min_length=min_length)
    swrs = nept.find_multi_in_epochs(spikes, swrs, min_involved=4)

    rest_epochs = nept.rest_threshold(position, thresh=12., t_smooth=0.8)

    # Restrict SWRs to those during epochs of interest during rest    
    phase_swrs = dict()
    n_swrs = {task_time: 0 for task_time in task_times}
    
    for task_time in task_times:
        epochs_of_interest = info.task_times[task_time].intersect(rest_epochs)

        phase_swrs[task_time] = epochs_of_interest.overlaps(swrs)
        phase_swrs[task_time] = phase_swrs[task_time][phase_swrs[task_time].durations >= 0.05]
        
        n_swrs[task_time] += phase_swrs[task_time].n_epochs

    print(n_swrs)
    
    fig, ax1 = plt.subplots()

    ax1.hist(swrs.centers, bins=50, color="k")
    for task_time in task_times:
        ax1.axvspan(info.task_times[task_time].start, info.task_times[task_time].stop, alpha=0.2)
    ax1.set_ylabel('Number of SWRs')

    ax2 = ax1.twinx()
    ax2.plot(speed.time, speed.data, color="g", alpha=0.5)
    ax2.set_ylim(0, 200)
    ax2.set_ylabel('Speed', color="g")

    ax1.set_xlabel("Time", color="tab:blue")
    plt.xticks([], [])
    title = info.session_id+" n_swrs"
    plt.title(title)

    plt.tight_layout()

    if savefig:
        plt.savefig(os.path.join(output_filepath, title+"_binned.png"))
        plt.close()
    else:
        plt.show()
    
    fig, ax1 = plt.subplots()
    ax1.plot(swrs.centers, np.ones(swrs.n_epochs), "|", color="k")
    for task_time in task_times:
        ax1.axvspan(info.task_times[task_time].start, info.task_times[task_time].stop, alpha=0.2)
    ax1.set_ylabel('Number of SWRs')

    ax2 = ax1.twinx()
    ax2.plot(speed.time, speed.data, color="g", alpha=0.5)
    ax2.set_ylim(0, 200)
    ax2.set_ylabel('Speed', color="g")

    ax1.set_xlabel("Time (s)", color="tab:blue")
    plt.xticks([], [])
    
    plt.tight_layout()

    if savefig:
        plt.savefig(os.path.join(output_filepath, title+".png"))
        plt.close()
    else:
        plt.show()

In [None]:
# fig, ax1 = plt.subplots()

# ax1.plot(swrs.centers, np.ones(swrs.n_epochs), "|", color="k")
# for task_time in task_times:
#     ax1.axvspan(info.task_times[task_time].start, info.task_times[task_time].stop, alpha=0.2)
# ax1.set_ylabel('Number of SWRs')

# ax2 = ax1.twinx()
# ax2.plot(speed.time, speed.data, color="g", alpha=0.5)
# ax2.set_ylim(0, 200)
# ax2.set_ylabel('Speed', color="g")

# ax1.set_xlabel("Time (s)", color="tab:blue")
# plt.xticks([], [])

# plt.show()

In [None]:
# fig, ax1 = plt.subplots()

# ax1.hist(swrs.centers, bins=50, color="k")
# for task_time in task_times:
#     ax1.axvspan(info.task_times[task_time].start, info.task_times[task_time].stop, alpha=0.2)
# ax1.set_ylabel('Number of SWRs')

# ax2 = ax1.twinx()
# ax2.plot(speed.time, speed.data, color="g", alpha=0.5)
# ax2.set_ylim(0, 200)
# ax2.set_ylabel('Speed', color="g")

# ax1.set_xlabel("Time", color="tab:blue")
# plt.xticks([], [])
# title = info.session_id+" n_swrs"
# plt.title(title)

# plt.tight_layout()

# plt.show()
# # plt.savefig(os.path.join(output_filepath, title+".png"))
# # plt.close()

In [None]:
speed = position.speed()

In [None]:
speed