In [None]:
%matplotlib inline
import matplotlib.pyplot as plt
import numpy as np
import os
import nept

In [None]:
def speed_threshold(position, speed_thresh):
    """Finds times where position is above a certain speed threshold

    Parameters
    ----------
    position: nept.Position
    speed_thresh: float

    Returns
    -------
    run_epoch: nept.Epoch

    """
    speed = position.speed()
    idx = np.where(np.diff(np.squeeze(speed.data) >= speed_thresh))[0]
    
    starts = position.time[idx[::2]]
    stops = position.time[idx[1::2]]

    # missing stop
    if len(starts) != len(stops):
        assert len(starts) - len(stops) == 1
        if starts[-1] == position.time[-1]:
            starts = starts[:-1]
        else:
            stops = np.hstack([stops, position.time[-1]])
        
    return nept.Epoch(starts, stops-starts).merge(gap=0.0)

In [None]:
def rest_threshold(position, rest_thresh):
    """Finds times where position is below a certain rest threshold

    Parameters
    ----------
    position: nept.Position
    rest_thresh: float

    Returns
    -------
    rest_epoch: nept.Epoch

    """
    speed = position.speed()
    idx = np.where(np.diff(np.squeeze(speed.data) <= rest_thresh))[0]
    
    starts = position.time[idx[::2]]
    stops = position.time[idx[1::2]]
    
#     # missing start
#     if len(starts) != len(stops):
#         assert len(starts) - len(stops) == 1
#         if starts[-1] == position.time[-1]:
#             starts = starts[:-1]
#         else:
#             stops = np.hstack([stops, position.time[idx[0]+1]])
            
    print(starts,stops)
    return nept.Epoch(starts, stops-starts).merge(gap=0.0)

In [None]:
times = np.array([0.0, 1.0, 2.0, 3.0, 4.0, 5.0])
data = np.array([0.0, 0.5, 1.0, 0.7, 1.7, 2.0])

position = nept.Position(data, times)

run_epoch = speed_threshold(position, speed_thresh=0.5)
rest_epoch = rest_threshold(position, rest_thresh=0.6)

# assert np.allclose(run_epoch.starts, np.array([2., 5.]))
# assert np.allclose(run_epoch.stops, np.array([3., 6.]))

In [None]:
run_epoch.starts, run_epoch.stops

In [None]:
rest_epoch.starts, rest_epoch.stops

In [None]:
speed = position.speed()

In [None]:
speed = position.speed()

plt.plot(speed.time, speed.data, "k.")
plt.show()

In [None]:
speed.data

In [None]:
def rest_thresh(position, thresh):
       return speed_thresh(position, thresh, direction="lesser")
   
def run_thresh(position, thresh):
       return speed_thresh(position, thresh, direction="greater")
   
def speed_thresh(position, thresh, direction):
   
    speed = position.speed()
    if direction == "lesser":
        changes = np.diff(np.hstack(([0], (np.squeeze(speed.data) <= thresh).astype(int))))
    elif direction == "greater":
        changes = np.diff(np.hstack(([0], (np.squeeze(speed.data) >= thresh).astype(int))))
    else:
        raise ValueError("Must be 'lesser' or 'greater'")

    starts = np.where(changes == 1)[0]
    stops = np.where(changes == -1)[0]

    if len(starts) != len(stops):
        assert len(starts) - len(stops) == 1
        stops = np.hstack((stops, position.n_samples - 1))

    if starts[-1] == stops[-1]:
        starts = starts[:-1]
        stops = stops[:-1]

    data = np.vstack([position.time[starts], position.time[stops]]).T
    return nept.Epoch(data)