In [None]:
%matplotlib inline
import matplotlib.pyplot as plt
import numpy as np
import os
import nept

from loading_data import get_data
from analyze_tuning_curves import get_tuning_curves
from utils_maze import get_trial_idx, get_zones, align_to_event

In [None]:
thisdir = os.getcwd()
pickle_filepath = os.path.join(thisdir, "cache", "pickled")
output_filepath = os.path.join(thisdir, "plots", "binned_swrs")

In [None]:
import info.r068d1 as r068d1
import info.r063d6 as r063d6
# infos =[r068d1, r063d6]

from run import spike_sorted_infos
infos = spike_sorted_infos

In [None]:
def plot_binned_swr(info, binned_swr, filepath):
    plt.plot(binned_swr.time, binned_swr.data, ms=3)
    
    xtick_labels = ["prerecord", "phase1", "pauseA", "phase2", "pauseB", "phase3", "postrecord"]
    xtick_location = []
    for phase in xtick_labels:
        plt.axvline(x=info.task_times[phase].start, color="k", linestyle='--', ms=3)
        xtick_location.append(info.task_times[phase].start)
    
    plt.xticks(xtick_location, xtick_labels, rotation=75)
    plt.tight_layout()
    plt.savefig(os.path.join(filepath, info.session_id+"-binned_swr.png"))
    plt.close()
#     plt.show()

In [None]:
shortcut_perievents = []
u_perievents = []
shortcut_end_perievents = []

t_before = 30
t_after = 60

for info in infos:
    print(info.session_id)
    events, position, spikes, lfp, _ = get_data(info)

    # Remove interneurons
    max_mean_firing = 5
    interneurons = np.zeros(len(spikes), dtype=bool)
    for i, spike in enumerate(spikes):
        if len(spike.time) / info.session_length >= max_mean_firing:
            interneurons[i] = True
    spikes = spikes[~interneurons]

    # Find SWRs for the whole session
    z_thresh = 2.0
    power_thresh = 3.0
    merge_thresh = 0.02
    min_length = 0.05
    swrs = nept.detect_swr_hilbert(lfp, fs=info.fs, thresh=(140.0, 250.0), z_thresh=z_thresh,
                                   power_thresh=power_thresh, merge_thresh=merge_thresh, min_length=min_length)

    # Restrict SWRs to those with 4 or more participating neurons
    swrs = nept.find_multi_in_epochs(spikes, swrs, min_involved=4)

    # Find rest epochs for entire session
    epochs_of_interest = nept.rest_threshold(position, thresh=15., t_smooth=1.0)

    swrs = epochs_of_interest.overlaps(swrs)
    swrs = swrs[swrs.durations >= 0.05]

    # Bin the swrs
    binsize = 3.
    t_bins = np.arange(lfp.time.min(), lfp.time.max()+binsize, binsize)
    swr_counts = np.histogram(swrs.centers, bins=t_bins)[0]

    # Smooth the binned swrs with a gaussian filter
    std = binsize * 2
    filter_swr = nept.gaussian_filter(swr_counts, std=std, dt=binsize, axis=0)
    smoothed_swrs = nept.AnalogSignal(filter_swr, t_bins[:-1])
    
    s_swrs = nept.AnalogSignal(swr_counts, t_bins[:-1])

    # Plot
    plot_binned_swr(info, smoothed_swrs, output_filepath)
    
    
    # Get times of interest
    t_start = info.task_times['phase3'].start
    t_stop = info.task_times['phase3'].stop

    sliced_pos = position.time_slice(t_start, t_stop)

    feeder1_times = []
    for feeder1 in events['feeder1']:
        if t_start < feeder1 < t_stop:
            feeder1_times.append(feeder1)

    feeder2_times = []
    for feeder2 in events['feeder2']:
        if t_start < feeder2 < t_stop:
            feeder2_times.append(feeder2)

    path_pos = get_zones(info, sliced_pos)

    trials_idx, trial_epochs = get_trial_idx(path_pos['u'].time, 
                                             path_pos['shortcut'].time, 
                                             path_pos['novel'].time,
                                             feeder1_times, 
                                             feeder2_times, 
                                             t_stop)

    first_shortcut = trial_epochs[trials_idx['shortcut'][0][0]].start
    shortcut_end = trial_epochs[trials_idx['shortcut'][0][0]].stop
    first_u = trial_epochs[trials_idx['u'][0][0]].start

    shortcut_perievent = align_to_event(smoothed_swrs, first_shortcut, t_before, t_after)
    shortcut_perievents.append(shortcut_perievent)
    
    shortcut_end_perievent = align_to_event(smoothed_swrs, shortcut_end, t_before, t_after)
    shortcut_end_perievents.append(shortcut_end_perievent)

    u_perievent = align_to_event(smoothed_swrs, first_u, t_before, t_after)
    u_perievents.append(u_perievent)


shortcut_position = align_to_event(position, first_shortcut, t_before, t_after)
shortcut_position = nept.Position(shortcut_position.data, shortcut_position.time)
shortcut_speed = shortcut_position.speed(t_smooth=1.)

times = shortcut_perievents[0].time
datas = np.zeros((len(times), len(shortcut_perievents)))
for i, perievent in enumerate(shortcut_perievents):
    datas[:, i] = np.squeeze(perievent.data)
datas = np.mean(datas, axis=1)

fig, (ax1, ax2) = plt.subplots(2, sharex=True)
ax1.plot(times, datas, color="b")
ax1.axvline(x=0, color="k", linestyle='--')
ax2.plot(shortcut_speed.time, shortcut_speed.data, color="g")
ax2.axvline(x=0, color="k", linestyle='--')
plt.title("Speed", fontsize=16)
fig.suptitle("SWR perievent first shortcut start", fontsize=16)
plt.tight_layout()
plt.savefig(os.path.join(output_filepath, "combined_perievent-first-shortcut.png"))
plt.close()
# plt.show()

shortcut_end_position = align_to_event(position, shortcut_end, t_before, t_after)
shortcut_end_position = nept.Position(shortcut_position.data, shortcut_position.time)
shortcut_end_speed = shortcut_position.speed(t_smooth=1.)

times = shortcut_perievents[0].time
datas = np.zeros((len(times), len(shortcut_end_perievents)))
for i, perievent in enumerate(shortcut_end_perievents):
    datas[:, i] = np.squeeze(perievent.data)
datas = np.mean(datas, axis=1)

fig, (ax1, ax2) = plt.subplots(2, sharex=True)
ax1.plot(times, datas, color="b")
ax1.axvline(x=0, color="k", linestyle='--')
ax2.plot(shortcut_end_speed.time, shortcut_end_speed.data, color="g")
ax2.axvline(x=0, color="k", linestyle='--')
plt.title("Speed", fontsize=16)
fig.suptitle("SWR perievent first shortcut end", fontsize=16)
plt.tight_layout()
plt.savefig(os.path.join(output_filepath, "combined_perievent-first-shortcut-end.png"))
plt.close()
# plt.show()


u_position = align_to_event(position, first_u, t_before, t_after)
u_position = nept.Position(u_position.data, u_position.time)
u_speed = u_position.speed(t_smooth=1.)

times = u_perievents[0].time
datas = np.zeros((len(times), len(u_perievents)))
for i, perievent in enumerate(u_perievents):
    datas[:, i] = np.squeeze(perievent.data)
datas = np.mean(datas, axis=1)

fig, (ax1, ax2) = plt.subplots(2, sharex=True)
ax1.plot(times, datas, color="b")
ax1.axvline(x=0, color="k", linestyle='--')
ax2.plot(u_speed.time, u_speed.data, color="g")
ax2.axvline(x=0, color="k", linestyle='--')
plt.title("Speed", fontsize=16)
fig.suptitle("SWR perievent first U", fontsize=16)
plt.tight_layout()
plt.savefig(os.path.join(output_filepath, "combined_perievent-first-u.png"))
plt.close()
# plt.show()

In [None]:
perievent.time

In [None]:
event = trial_epochs[trials_idx['u'][1][0]].start
analogsignal = smoothed_swrs

sliced = analogsignal.time_slice(event - t_before, event + t_after)

idx = nept.find_nearest_idx(sliced.time, event)

time = sliced.time - sliced.time[idx]
data = np.squeeze(sliced.data)

In [None]:
indices = [(analogsignal.time > event-t_before+0.1) & (analogsignal.time < event+t_after)]
indices = np.any(np.column_stack(indices), axis=1)

In [None]:
analogsignal[indices].n_samples

In [None]:
indices

In [None]:
sliced.n_samples

In [None]:
u_position = align_to_event(position, first_u, t_before, t_after)
u_position = nept.Position(u_position.data, u_position.time)
u_speed = u_position.speed(t_smooth=1.)

times = u_perievents[0].time
datas = np.zeros((len(times), len(u_perievents)))
for i, perievent in enumerate(u_perievents):
    datas[:, i] = np.squeeze(perievent.data)
datas = np.mean(datas, axis=1)

fig, (ax1, ax2) = plt.subplots(2, sharex=True)
ax1.plot(times, datas, color="b")
ax1.axvline(x=0, color="k", linestyle='--')
ax2.plot(u_speed.time, u_speed.data, color="g")
ax2.axvline(x=0, color="k", linestyle='--')
fig.suptitle("Perievent first U", fontsize=16)
plt.tight_layout()
plt.savefig(os.path.join(output_filepath, "combined_perievent-first-u.png"))
plt.close()
# plt.show()

In [None]:
datas = np.zeros((len(times), len(u_perievents)))

In [None]:
times

In [None]:
perievent.time

In [None]:
u_perievents[1].time