In [None]:
%matplotlib inline
import matplotlib.pyplot as plt
import numpy as np
import os
import nept

from loading_data import get_data
from analyze_tuning_curves import get_tuning_curves
from utils_maze import get_trial_idx, get_zones, align_to_event

In [None]:
thisdir = os.getcwd()
pickle_filepath = os.path.join(thisdir, "cache", "pickled")
output_filepath = os.path.join(thisdir, "plots", "binned_swrs")

In [None]:
import info.r068d1 as r068d1
import info.r063d4 as r063d4
infos =[r068d1, r063d4]

from run import spike_sorted_infos
# infos = spike_sorted_infos

In [None]:
def get_perievent(epochs, event, t_bins, t_before, t_after, binsize, std):
    
    sliced_epochs = epochs.overlaps(nept.Epoch([event-t_before, event+t_after]))
    
    # Bin the swrs
    swr_counts = np.histogram(sliced_epochs.centers, bins=t_bins)[0]

    # Smooth the binned swrs with a gaussian filter
    filter_swr = nept.gaussian_filter(swr_counts, std=std, dt=binsize, axis=0)

    return nept.AnalogSignal(filter_swr, t_bins[:-1]-event)

In [None]:
def plot_binned_swr(info, binned_swr, filepath):
    plt.plot(binned_swr.time, binned_swr.data, ms=3)
    
    xtick_labels = ["prerecord", "phase1", "pauseA", "phase2", "pauseB", "phase3", "postrecord"]
    xtick_location = []
    for phase in xtick_labels:
        plt.axvline(x=info.task_times[phase].start, color="k", linestyle='--', ms=3)
        xtick_location.append(info.task_times[phase].start)
    
    plt.xticks(xtick_location, xtick_labels, rotation=75)
    plt.tight_layout()
#     plt.savefig(os.path.join(filepath, info.session_id+"-binned_swr.png"))
#     plt.close()
    plt.show()

In [None]:
def find_means(list_of_analogsignals):
    times = list_of_analogsignals[0].time
    datas = np.zeros((len(times), len(list_of_analogsignals)))
    for i, this_analogsignal in enumerate(list_of_analogsignals):
        datas[:, i] = np.squeeze(this_analogsignal.data)
    datas = np.mean(datas, axis=1)
    
    return nept.AnalogSignal(datas, times)


def plot_perievent(perievent, speed, title):
    fig, (ax1, ax2) = plt.subplots(2, sharex=True)
    ax1.plot(perievent.time, perievent.data, color="b")
    ax1.axvline(x=0, color="k", linestyle='--')
    ax2.plot(speed.time, speed.data, color="g")
    ax2.axvline(x=0, color="k", linestyle='--')
    plt.title("Speed", fontsize=16)
    fig.suptitle(title, fontsize=16)
    plt.tight_layout()
#     plt.savefig(os.path.join(output_filepath, title+".png"))
#     plt.close()
    plt.show()

In [None]:
shortcut_perievents = []
u_perievents = []
u_end_perievents = []
shortcut_end_perievents = []
shortcut_speed = []
shortcut_end_speed = []
u_speed = []
u_end_speed = []

t_before = 10
t_after = 20

for info in infos:
    print(info.session_id)
    events, position, spikes, lfp, _ = get_data(info)

    # Find SWRs for the whole session
    z_thresh = 2.0
    power_thresh = 3.0
    merge_thresh = 0.02
    min_length = 0.05
    swrs = nept.detect_swr_hilbert(lfp, fs=info.fs, thresh=(140.0, 250.0), z_thresh=z_thresh,
                                   power_thresh=power_thresh, merge_thresh=merge_thresh, min_length=min_length)

    # Restrict SWRs to those with 4 or more participating neurons
    swrs = nept.find_multi_in_epochs(spikes, swrs, min_involved=4)

    # Find rest epochs for entire session
    epochs_of_interest = nept.rest_threshold(position, thresh=12., t_smooth=0.8)

    swrs = epochs_of_interest.overlaps(swrs)
    swrs = swrs[swrs.durations >= 0.05]

    # Get times of interest
    t_start = info.task_times['phase3'].start
    t_stop = info.task_times['phase3'].stop

    sliced_pos = position.time_slice(t_start, t_stop)

    feeder1_times = []
    for feeder1 in events['feeder1']:
        if t_start < feeder1 < t_stop:
            feeder1_times.append(feeder1)

    feeder2_times = []
    for feeder2 in events['feeder2']:
        if t_start < feeder2 < t_stop:
            feeder2_times.append(feeder2)

    path_pos = get_zones(info, sliced_pos)

    trials_idx, trial_epochs = get_trial_idx(path_pos['u'].time, 
                                             path_pos['shortcut'].time, 
                                             path_pos['novel'].time,
                                             feeder1_times, 
                                             feeder2_times, 
                                             t_stop)
    
    first_shortcut = trial_epochs[trials_idx['shortcut'][0][0]].start
    shortcut_end = trial_epochs[trials_idx['shortcut'][0][0]].stop
    first_u = trial_epochs[trials_idx['u'][0][0]].start
    u_end = trial_epochs[trials_idx['u'][0][0]].stop
    
    def find_perievent(epoch, position, time_of_interest, binsize=1., std=1.5, t_smooth=1.):
        t_bins = np.arange(time_of_interest-t_before, time_of_interest+t_after+binsize, binsize)
        epoch_perievent = get_perievent(epoch, time_of_interest, t_bins, 
                                        t_before, t_after, binsize, std)

        perievent_time = t_bins[:-1]
        sliced_position = position.time_slice(time_of_interest-t_before, time_of_interest+t_after)
        perievent_x = np.interp(perievent_time, sliced_position.time, sliced_position.x)
        perievent_y = np.interp(perievent_time, sliced_position.time, sliced_position.y)
        perievent_position = nept.Position(np.hstack((perievent_x[..., np.newaxis],
                                                      perievent_y[..., np.newaxis])), 
                                           perievent_time-time_of_interest)
        speed_perievent = perievent_position.speed(t_smooth=t_smooth)

        return epoch_perievent, speed_perievent

    # Shortcut start
    epoch_perievent, speed_perievent = find_perievent(swrs, position, first_shortcut)
    shortcut_perievents.append(epoch_perievent)
    shortcut_speed.append(speed_perievent)
    
    # Shortcut end
    epoch_perievent, speed_perievent = find_perievent(swrs, position, first_shortcut)
    shortcut_end_perievents.append(epoch_perievent)
    shortcut_end_speed.append(speed_perievent)
    
    # U start
    epoch_perievent, speed_perievent = find_perievent(swrs, position, first_shortcut)
    u_perievents.append(epoch_perievent)
    u_speed.append(speed_perievent)
    
    # U end
    epoch_perievent, speed_perievent = find_perievent(swrs, position, first_shortcut)
    u_end_perievents.append(epoch_perievent)
    u_end_speed.append(speed_perievent)

shortcut_start = find_means(shortcut_perievents)
shortcut_speed_start = find_means(shortcut_speed)
plot_perievent(shortcut_start, shortcut_speed_start, "SWR perievent first shortcut start")

shortcut_end = find_means(shortcut_end_perievents)
shortcut_speed_end = find_means(shortcut_end_speed)
plot_perievent(shortcut_end, shortcut_speed_end, "SWR perievent first shortcut end")

u_start = find_means(u_perievents)
u_speed_start = find_means(u_speed)
plot_perievent(u_start, u_speed_start, "SWR perievent first U start")

u_end = find_means(u_end_perievents)
u_speed_end = find_means(u_end_speed)
plot_perievent(u_end, u_speed_end, "SWR perievent first U end")

In [None]:
print(info.session_id)
events, position, spikes, lfp, _ = get_data(info)

# Remove interneurons
max_mean_firing = 5
interneurons = np.zeros(len(spikes), dtype=bool)
for i, spike in enumerate(spikes):
    if len(spike.time) / info.session_length >= max_mean_firing:
        interneurons[i] = True
spikes = spikes[~interneurons]

# Find SWRs for the whole session
z_thresh = 2.0
power_thresh = 3.0
merge_thresh = 0.02
min_length = 0.05
swrs = nept.detect_swr_hilbert(lfp, fs=info.fs, thresh=(140.0, 250.0), z_thresh=z_thresh,
                               power_thresh=power_thresh, merge_thresh=merge_thresh, min_length=min_length)

# Restrict SWRs to those with 4 or more participating neurons
swrs = nept.find_multi_in_epochs(spikes, swrs, min_involved=4)

# Find rest epochs for entire session
epochs_of_interest = nept.rest_threshold(position, thresh=12., t_smooth=0.8)

swrs = epochs_of_interest.overlaps(swrs)
swrs = swrs[swrs.durations >= 0.05]

In [None]:
# Get times of interest
t_start = info.task_times['phase3'].start
t_stop = info.task_times['phase3'].stop

sliced_pos = position.time_slice(t_start, t_stop)

feeder1_times = []
for feeder1 in events['feeder1']:
    if t_start < feeder1 < t_stop:
        feeder1_times.append(feeder1)

feeder2_times = []
for feeder2 in events['feeder2']:
    if t_start < feeder2 < t_stop:
        feeder2_times.append(feeder2)

path_pos = get_zones(info, sliced_pos)

trials_idx, trial_epochs = get_trial_idx(path_pos['u'].time, 
                                         path_pos['shortcut'].time, 
                                         path_pos['novel'].time,
                                         feeder1_times, 
                                         feeder2_times, 
                                         t_stop)

first_shortcut = trial_epochs[trials_idx['shortcut'][0][0]].start
shortcut_end = trial_epochs[trials_idx['shortcut'][0][0]].stop
first_u = trial_epochs[trials_idx['u'][0][0]].start
u_end = trial_epochs[trials_idx['u'][0][0]].stop

In [None]:
plt.plot(swrs.centers, np.ones(swrs.n_epochs), ".")
plt.axvline(first_shortcut, color="k", linestyle="--")
plt.xlim(first_shortcut-t_before, first_shortcut+t_after)
plt.show()

In [None]:
binsize = 1.
std = 1.5
event = first_shortcut

t_bins = np.arange(event-t_before, event+t_after+binsize, binsize)

# First shortcut start
shortcut_perievent = get_perievent(swrs, first_shortcut, t_bins, t_before, t_after, binsize=binsize, std=std)

In [None]:
sliced_position = position.time_slice(first_shortcut-t_before, first_shortcut+t_after)
shortcut_position = nept.Position(sliced_position.data, sliced_position.time-first_shortcut)
shortcut_speed = shortcut_position.speed(t_smooth=1.)

In [None]:
sliced_position = position.time_slice(first_shortcut-t_before, first_shortcut+t_after)


In [None]:
sliced_position = position.time_slice(first_shortcut-t_before, first_shortcut+t_after)
perievent_time = t_bins[:-1]
perievent_x = np.interp(perievent_time, sliced_position.time, sliced_position.x)
perievent_y = np.interp(perievent_time, sliced_position.time, sliced_position.y)
perievent_position = nept.Position(np.hstack((perievent_x[..., np.newaxis],
                                              perievent_y[..., np.newaxis])), perievent_time-first_shortcut)
shortcut_speed = perievent_position.speed(t_smooth=1.)

In [None]:
perievent_time

In [None]:
actual_x = np.interp(decoded.time, exp_position.time, exp_position.x)
actual_y = np.interp(decoded.time, exp_position.time, exp_position.y)
actual_position = nept.Position(np.hstack((actual_x[..., np.newaxis],
                                           actual_y[..., np.newaxis])), decoded.time)

In [None]:
plt.plot(shortcut_perievent.time, shortcut_perievent.data)
plt.axvline(0, color="k", linestyle="--")
plt.show()

In [None]:
plot_perievent(shortcut_perievent, shortcut_speed, " d")

In [None]:
# Bin the swrs
binsize = 1.
t_bins = np.arange(lfp.time.min(), lfp.time.max()+binsize, binsize)
swr_counts = np.histogram(swrs.centers, bins=t_bins)[0]

# Smooth the binned swrs with a gaussian filter
std = 1.5
filter_swr = nept.gaussian_filter(swr_counts, std=std, dt=binsize, axis=0)
smoothed_swrs = nept.AnalogSignal(filter_swr, t_bins[:-1])

In [None]:
# Plot
plot_binned_swr(info, smoothed_swrs, output_filepath)

In [None]:
t_before = 50
t_after = 50

In [None]:
first_shortcut 

In [None]:
n_before = t_before * binsize

In [None]:
n_before

In [None]:
smoothed_swrs

In [None]:
shortcut_perievent = align_to_event(smoothed_swrs, first_shortcut, t_before, t_after)
shortcut_perievent.time, shortcut_perievent.data.shape

In [None]:
testing = get_perievent(swrs, first_shortcut, t_before, t_after)
testing.time, testing.data.shape

In [None]:
shortcut_position = align_to_event(position, first_shortcut, t_before, t_after)
shortcut_position = nept.Position(shortcut_position.data, shortcut_position.time)
shortcut_speed = shortcut_position.speed(t_smooth=1.)

In [None]:
plot_perievent(shortcut_perievent, shortcut_speed, "testing123")

In [None]:
spikes.shape

In [None]:
session_length = info.task_times["postrecord"].stop - info.task_times["prerecord"].start

In [None]:
session_length

In [None]:
# Remove neurons that have a rate greater than 5 Hz
max_rate = 5.
session_length = info.task_times["postrecord"].stop - info.task_times["prerecord"].start
spikes = [spiketrain for spiketrain in spikes if spiketrain.n_spikes/session_length < max_rate]

# Remove neurons that have fewer than 100 spikes in a session
min_spikes = 100
spikes = np.asarray([spiketrain for spiketrain in spikes if spiketrain.n_spikes > 100])

In [None]:
spikes.shape

In [None]:
spikes[0].n_spikes

In [None]:
def align_to_event(analogsignal, event, t_before, t_after):
    idx = nept.find_nearest_idx(analogsignal.time, event)
    event_of_interest = analogsignal.time[idx]

    sliced = analogsignal.time_slice(event_of_interest - t_before, event_of_interest + t_after)

    print(len(sliced.time))
    
    time = sliced.time - event_of_interest
    data = np.squeeze(sliced.data)
    
    dt = np.median(np.diff(analogsignal.time))
    print(len(np.arange(-t_before, t_after, dt)))

    return nept.AnalogSignal(data, time)

In [None]:
shortcut_position = align_to_event(position, first_shortcut, t_before, t_after)
shortcut_position = nept.Position(shortcut_position.data, shortcut_position.time)
print(shortcut_position.time.shape)

In [None]:
dt = np.median(np.diff(position.time))
np.arange(-t_before, t_after, dt)[:10]

In [None]:
first_shortcut

In [None]:
t_before, t_after

In [None]:
idx = nept.find_nearest_idx(position.time, first_shortcut)
position.time[133207:133237]

In [None]:
idx = nept.find_nearest_idx(position.time, position.time[133217] - 10)
position.time[idx]

In [None]:
mask = (position.time >= event_of_interest-t_before) & (position.time <= event_of_interest+t_after)

In [None]:
times = position.time[mask]
data = position.data[mask]

In [None]:
len(times)

In [None]:
event_of_interest = position.time[idx]

sliced = position.time_slice(event_of_interest - t_before, event_of_interest + t_after)

In [None]:
event_of_interest

In [None]:
sliced.time

In [None]:
analogsignal = position
event = first_shortcut

idx = nept.find_nearest_idx(analogsignal.time, event)
event_of_interest = analogsignal.time[idx]

sliced = analogsignal.time_slice(event_of_interest - t_before, event_of_interest + t_after)

dt = np.median(np.diff(analogsignal.time))
newtime = np.concatenate([np.linspace(-t_before, 0, num=int(t_before/dt)), 
                          np.linspace(0+dt, t_after, num=int(t_after/dt))])

if sliced.n_samples < len(newtime):
    print("need to fix this")

In [None]:
n_stop_missing = int(t_after/dt) - sliced[event_of_interest < sliced.time].n_samples

In [None]:
n_start_missing = int(t_before/dt) - sliced[event_of_interest > sliced.time].n_samples

In [None]:
n_start_missing, n_stop_missing

In [None]:
newdata = np.insert(sliced.data, 0, np.zeros(n_start_missing))
print(len(newdata))
newdata = np.concatenate([newdata, np.zeros(n_stop_missing)])

In [None]:
len(sliced.data)

In [None]:
len(newdata)