In [None]:
%matplotlib inline
import matplotlib.pyplot as plt
import numpy as np
import os
import nept

from loading_data import get_data
from analyze_tuning_curves import get_tuning_curves
from utils_maze import get_trial_idx, get_zones, align_to_event

In [None]:
thisdir = os.getcwd()
pickle_filepath = os.path.join(thisdir, "cache", "pickled")
output_filepath = os.path.join(thisdir, "plots", "binned_swrs")
if not os.path.exists(output_filepath):
    os.makedirs(output_filepath)

In [None]:
import info.r068d2 as info
import info.r063d4 as r063d4
# infos =[info, r063d4]

from run import spike_sorted_infos
infos = spike_sorted_infos

In [None]:
def get_perievent(epochs, event, t_bins, std=1.5):
    epochs = epochs.copy()
    dt = t_bins[1] - t_bins[0]
    
    binned = np.histogram(epochs.centers, bins=t_bins)[0]
    smoothed = nept.gaussian_filter(binned, std=std, dt=dt, axis=0)
    
    rate = smoothed / dt

    return nept.AnalogSignal(rate, t_bins[:-1]-event)

In [None]:
# def plot_binned_swr(info, binned_swr, filepath):
#     plt.plot(binned_swr.time, binned_swr.data, ms=3)
    
#     xtick_labels = ["prerecord", "phase1", "pauseA", "phase2", "pauseB", "phase3", "postrecord"]
#     xtick_location = []
#     for phase in xtick_labels:
#         plt.axvline(x=info.task_times[phase].start, color="k", linestyle='--', ms=3)
#         xtick_location.append(info.task_times[phase].start)
    
#     plt.xticks(xtick_location, xtick_labels, rotation=75)
#     plt.tight_layout()
#     if filepath is not None:
#         plt.savefig(os.path.join(filepath, info.session_id+"-binned_swr.png"))
#         plt.close()
#     else:
#         plt.show()

def find_perievent(epoch, position, event, t_before, t_after, binsize, t_smooth=0.8):
    t_bins = np.arange(event-t_before, event+t_after+binsize, binsize)
    epoch_perievent = get_perievent(epoch, event, t_bins)

    perievent_time = t_bins[:-1]
    sliced_position = position.time_slice(perievent_time[0], perievent_time[-1])
    perievent_x = np.interp(perievent_time, sliced_position.time, sliced_position.x)
    perievent_y = np.interp(perievent_time, sliced_position.time, sliced_position.y)
    perievent_position = nept.Position(np.hstack((perievent_x[..., np.newaxis],
                                                  perievent_y[..., np.newaxis])), 
                                       perievent_time-event)
    speed_perievent = perievent_position.speed(t_smooth=t_smooth)

    return epoch_perievent, speed_perievent

In [None]:
def find_means(list_of_analogsignals):
    times = list_of_analogsignals[0].time
    datas = np.zeros((len(times), len(list_of_analogsignals)))
    for i, this_analogsignal in enumerate(list_of_analogsignals):
        datas[:, i] = np.squeeze(this_analogsignal.data)
    datas = np.mean(datas, axis=1)
    
    return nept.AnalogSignal(datas, times)


def plot_perievent(perievent, speed, title, savefig=True):
    fig, (ax1, ax2) = plt.subplots(2, sharex=True)
    ax1.plot(perievent.time, perievent.data, color="b")
    ax1.axvline(x=0, color="k", linestyle='--')
#     ax1.set_ylim([0, 0.05])
    ax2.plot(speed.time, speed.data, color="g")
    ax2.axvline(x=0, color="k", linestyle='--')
    plt.title("Speed", fontsize=16)
    fig.suptitle(title, fontsize=16)
    plt.tight_layout()
    if savefig:
        plt.savefig(os.path.join(output_filepath, title+".png"))
        plt.close()
    else:
        plt.show()

In [None]:
def get_event_times(info, position, events):
    t_start = info.task_times['phase3'].start
    t_stop = info.task_times['phase3'].stop

    sliced_pos = position.time_slice(t_start, t_stop)

    feeder1_times = []
    for feeder1 in events['feeder1']:
        if t_start < feeder1 < t_stop:
            feeder1_times.append(feeder1)

    feeder2_times = []
    for feeder2 in events['feeder2']:
        if t_start < feeder2 < t_stop:
            feeder2_times.append(feeder2)

    path_pos = get_zones(info, sliced_pos)

    trials_idx, trial_epochs = get_trial_idx(path_pos['u'].time, 
                                             path_pos['shortcut'].time, 
                                             path_pos['novel'].time,
                                             feeder1_times, 
                                             feeder2_times, 
                                             t_stop)

    shortcut_epochs = [trial_epochs[idx] for idx in trials_idx["shortcut"]]
    u_epochs = [trial_epochs[idx] for idx in trials_idx["u"]]
    novel_epochs = [trial_epochs[idx] for idx in trials_idx["novel"]]
    
    shortcut_starts = [("shortcut_start", i, epoch.start) for i, epoch in enumerate(shortcut_epochs)]
    shortcut_stops = [("shortcut_stop", i, epoch.stop) for i, epoch in enumerate(shortcut_epochs)]
    u_starts = [("u_start", i, epoch.start) for i, epoch in enumerate(shortcut_epochs)]
    u_stops = [("u_stop", i, epoch.stop) for i, epoch in enumerate(shortcut_epochs)]
    novel_starts = [("novel_start", i, epoch.start) for i, epoch in enumerate(shortcut_epochs)]
    novel_stops = [("novel_stop", i, epoch.stop) for i, epoch in enumerate(shortcut_epochs)]
    
    events_of_interest = shortcut_starts + shortcut_stops + u_starts + u_stops + novel_starts + novel_stops
    
    return events_of_interest

def get_swrs(info, position, spikes, lfp):
    z_thresh = 2.0
    power_thresh = 3.0
    merge_thresh = 0.02
    min_length = 0.05
    swrs = nept.detect_swr_hilbert(lfp, fs=info.fs, thresh=(140.0, 250.0), z_thresh=z_thresh,
                                   power_thresh=power_thresh, merge_thresh=merge_thresh, min_length=min_length)

    # Restrict SWRs to those with 4 or more participating neurons
    swrs = nept.find_multi_in_epochs(spikes, swrs, min_involved=4)

    # Find rest epochs for entire session
    epochs_of_interest = nept.rest_threshold(position, thresh=12., t_smooth=0.8)

    swrs = epochs_of_interest.overlaps(swrs)
    swrs = swrs[swrs.durations >= 0.05]
    
    return swrs

In [None]:
def combine_events(swrs, positions, events, t_before, t_after, binsize, title):
    perievents = []
    speeds = []

    for swr, position, event in zip(swrs, positions, events):
        epoch_perievent, speed_perievent = find_perievent(swr, position, event, t_before, t_after, binsize)
        perievents.append(epoch_perievent)
        speeds.append(speed_perievent)
    
    mean_swrs = find_means(perievents)
    mean_speed = find_means(speeds)
    plot_perievent(mean_swrs, mean_speed, title)

In [None]:
swrs = []
positions = []
events_of_interest = []

t_before = 20.
t_after = 40.
binsize = 1.

for info in infos:
    print(info.session_id)
    events, position, spikes, lfp, _ = get_data(info)
    events_of_interest.extend(get_event_times(info, position, events))
    swrs.append(get_swrs(info, position, spikes, lfp))
    positions.append(position)
    
    
for idx in range(5):
    title = "SWR rate for shortcut start trial"+str(idx+1)
    events = [event[2] for event in events_of_interest if (event[0] == "shortcut_start" and event[1] == idx)]
    combine_events(swrs, positions, events, t_before, t_after, binsize, title)
    
    title = "SWR rate for shortcut end trial"+str(idx+1)
    events = [event[2] for event in events_of_interest if (event[0] == "shortcut_stop" and event[1] == idx)]
    combine_events(swrs, positions, events, t_before, t_after, binsize, title)
    
title = "SWR rate for u start trial1"
events = [event[2] for event in events_of_interest if (event[0] == "u_start" and event[1] == 0)]
combine_events(swrs, positions, events, t_before, t_after, binsize, title)

title = "SWR rate for u end trial1"
events = [event[2] for event in events_of_interest if (event[0] == "u_stop" and event[1] == 0)]
combine_events(swrs, positions, events, t_before, t_after, binsize, title)

title = "SWR rate for u start trial1"
events = [event[2] for event in events_of_interest if (event[0] == "u_start" and event[1] == 0)]
combine_events(swrs, positions, events, t_before, t_after, binsize, title)

title = "SWR rate for u end trial1"
events = [event[2] for event in events_of_interest if (event[0] == "u_stop" and event[1] == 0)]
combine_events(swrs, positions, events, t_before, t_after, binsize, title)