In [None]:
%matplotlib inline
import matplotlib.pyplot as plt
import numpy as np
import os
import nept
import scipy

from loading_data import get_data
from utils_maze import get_trials

In [None]:
thisdir = os.getcwd()
dataloc = os.path.join(thisdir, 'cache', 'data')
pickle_filepath = os.path.join(thisdir, "cache", "pickled")
output_filepath = os.path.join(thisdir, "plots", "speed")

In [None]:
import info.r066d4 as info
# import info.r068d8 as info
# infos = [r066d1, r068d8]
from run import spike_sorted_infos
# infos = spike_sorted_infos

In [None]:
events, position, spikes, _, _ = get_data(info)
xedges, yedges = nept.get_xyedges(position, binsize=8)

In [None]:
trials = get_trials(events, info.task_times["phase3"])
# for trial in trials:
trial = trials[0]
sliced_position = position.time_slice(trial.start, trial.stop)
plt.plot(sliced_position.time, sliced_position.y, "k.", ms=4)
plt.show()

run_epoch = nept.run_threshold(sliced_position, thresh=10., t_smooth=0.8)
sliced_position = sliced_position[run_epoch]

plt.plot(sliced_position.time, sliced_position.y, "k.", ms=4)
plt.show()

In [None]:
trials[0].stop

In [None]:
def gaussian_filter(signal, std, dt=1.0, normalized=True, axis=-1, n_stds=3):
    """Filters a signal with a gaussian kernel.

    Parameters
    ----------
    signal : np.array
    std : float
    dt : float
        Defaults to 1.0
    normalized : bool
    axis : int
        Defaults to -1

    Returns
    -------
    Filtered signal

    """
    n_points = (n_stds * std * 2) / dt
    n_points = int(round(n_points))
    if n_points % 2 == 0:
        n_points += 1
    if n_points <= 1.0:
        warnings.warn("std is too small for given dt. Signal is unchanged.")
        return signal
    gaussian_filter = scipy.signal.gaussian(n_points, std / dt)
    if normalized:
        gaussian_filter /= np.sum(gaussian_filter)

    print(gaussian_filter.shape)
    return np.apply_along_axis(
        lambda v: scipy.signal.convolve(v, gaussian_filter, mode="same"), axis=axis, arr=signal)

In [None]:
signal = np.array([1., 3., 7.])
std = 1.0
dt = 1.0

gaussian_filter(signal, std, dt=dt)

In [None]:
(3 * 1 * 2) / 1

In [None]:
np.convolve([3, 4], [1, 1, 5, 5], mode='same')

In [None]:
[1*4+3*1,1*4+3*5,5*4+3*5]

In [None]:
def plot_run_thresh(info, thresh, t_smooth, filepath=None):
    events, position, _, _, _ = get_data(info)
    speed = position.speed(t_smooth=1.)

    s = nept.find_nearest_idx(position.time, info.task_times["pauseB"].stop-30)
    e = nept.find_nearest_idx(position.time, info.task_times["phase3"].stop)

    runs = nept.run_threshold(position, thresh=thresh, t_smooth=t_smooth)

    fig, ax = plt.subplots(figsize=(8,6))
    ax.plot(position[s:e].time, position.x[s:e], ms=3)
    ax.plot(position[s:e].time, position.y[s:e], ms=3)
    plt.plot(speed.time, speed.data, color="k")
    for start, stop in zip(runs.starts, runs.stops):
        ax.fill_between([start, stop], np.max([np.max(position.x), 
                                              np.max(position.y)]), 
                        color="k", alpha=0.2)
    plt.axhline(thresh, color="g")
    plt.xlim(info.task_times["phase3"].start+10, info.task_times["phase3"].start+150)
    plt.legend(["X", "Y", "Speed", "thresh"], bbox_to_anchor=(1.0, 1.0))
    
    plt.tight_layout()
   
    if filepath is not None:
        filename = info.session_id+"-run_thresh-"+str(thresh)+"-t_smooth-"+str(t_smooth)+".png"
        plt.savefig(os.path.join(filepath, filename))
        plt.close()
    else:
        plt.show()

In [None]:
def plot_rest_thresh(info, thresh, t_smooth, filepath=None):
    print(info.session_id)
    print("Thresh:", thresh)
    print("t_smooth:", t_smooth)
    events, position, _, _, _ = get_data(info)
    speed = position.speed(t_smooth=1.)

    s = nept.find_nearest_idx(position.time, info.task_times["pauseB"].stop-30)
    e = nept.find_nearest_idx(position.time, info.task_times["phase3"].stop)

    runs = nept.rest_threshold(position, thresh=thresh, t_smooth=t_smooth)

    fig, ax = plt.subplots(figsize=(8,6))
    ax.plot(position[s:e].time, position.x[s:e], ms=3)
    ax.plot(position[s:e].time, position.y[s:e], ms=3)
    plt.plot(speed.time, speed.data, color="k")
    for start, stop in zip(runs.starts, runs.stops):
        ax.fill_between([start, stop], np.max([np.max(position.x), 
                                              np.max(position.y)]), 
                        color="k", alpha=0.2)
    plt.axhline(thresh, color="g")
    plt.xlim(info.task_times["phase3"].start+10, info.task_times["phase3"].start+150)
    plt.legend(["X", "Y", "Speed", "thresh"], bbox_to_anchor=(1.0, 0.9))
    
    plt.tight_layout()
    
    if filepath is not None:
        filename = info.session_id+"-rest_thresh-"+str(thresh)+"-t_smooth-"+str(t_smooth)+".png"
        plt.savefig(os.path.join(filepath, filename))
        plt.close()
    else:
        plt.show()

In [None]:
for info in infos:
    plot_rest_thresh(info, thresh=12., t_smooth=0.8, filepath=output_filepath)

In [None]:
for info in infos:
    plot_run_thresh(info, thresh=10., t_smooth=0.8, filepath=output_filepath)