In [None]:
%matplotlib inline
import matplotlib.pyplot as plt
import numpy as np
import pickle
import os
import nept

from loading_data import get_data

In [None]:
import info.r066d1 as info

In [None]:
# home = os.path.expanduser("~")
home = "E:/"
emi_shortcut = os.path.join(home, "code", "emi_shortcut")
pickle_filepath = os.path.join(emi_shortcut, "cache", "pickled")
output_filepath = os.path.join(emi_shortcut, "plots")

In [None]:
events, position, spikes, lfp, lfp_theta = get_data(info)
xedges, yedges = nept.get_xyedges(position)

In [None]:
neurons_filename = info.session_id + '_neurons.pkl'
pickled_neurons = os.path.join(pickle_filepath, neurons_filename)
with open(pickled_neurons, 'rb') as fileobj:
    neurons = pickle.load(fileobj)

In [None]:
experiment_time = 'phase3'
t_start = info.task_times[experiment_time].start
t_stop = info.task_times[experiment_time].stop

args = dict(min_swr=3,
            min_neurons=2,
            min_spikes=2,
            t_start=t_start,
            t_stop=t_stop,
            neurons=neurons,
            info=info,
            normalized=False,
            sequence_speed=10.,
            sequence_len=4,
            min_epochs=3,
            window=0.025,
            dt=0.025,
            gaussian_std=0.,
            shuffle_id=False,
            run_time=True)

In [None]:
from analyze_decode import get_decoded

In [None]:
decoded, decoded_epochs, errors, actual_position = get_decoded(**args)

In [None]:
phase3 = position.time_slice(t_start, t_stop)

In [None]:
phase3.time

In [None]:
# max_idx = np.where(decoded_epochs.durations == np.max(decoded_epochs.durations))
# start = decoded_epochs[max_idx].start
# stop = decoded_epochs[max_idx].stop

start = 25000.
stop = 25100.

pos = position.time_slice(start, stop)
plt.plot(phase3.x, phase3.y, '.')
plt.plot(pos.x, pos.y, '.')
plt.show()

checkit = decoded.time_slice(start, stop)
plt.plot(phase3.x, phase3.y, '.')
plt.plot(checkit.x, checkit.y, '.')
plt.show()

In [None]:
plt.plot(pos.time, pos.y, '.')
plt.plot(checkit.time, checkit.y, '.')

In [None]:
large_idx = np.where(decoded_epochs.durations >= np.max(decoded_epochs.durations)-7.)[0]
print(len(large_idx))

In [None]:
def find_distance_adjacent(position):
    ok = position.data[:-1]
    tmi = position.data[1:]
    
    dist = np.zeros(len(ok))
    for idx in range(ok.shape[1]):
        dist += (ok[:, idx] - tmi[:, idx]) ** 2
        
    return np.sqrt(dist)

In [None]:
def printthis(idx):
    print('Idx:', idx)
    
    rr = decoded_epochs[idx]
    pp = decoded.time_slice(rr.start, rr.stop)
    dd = find_distance_adjacent(pp)

    # number of points
    print('Number of points:', len(pp.data))

    # mean distance between adjacent points
    if len(dd) > 0:
        print('Mean distance between adjacent points:', np.mean(dd))

    # total distance
    print('Total distance:', np.sum(dd))

    # time duration
    print('Time duration:', rr.durations[0])
    
    print(' ')

In [None]:
def find_idx(idx, n_points=10., max_dist=3., total_dist=20., t_duration=0.5):    
    rr = decoded_epochs[idx]
    pp = decoded.time_slice(rr.start, rr.stop)
    dd = find_distance_adjacent(pp)

    # number of points
    if len(pp.data) > n_points:
        # mean distance between adjacent points
        if np.mean(dd) < max_dist:
            # total distance
            if np.sum(dd) > total_dist:
                # time duration
                if rr.durations[0] > t_duration:
#                     print(idx)
                    return idx
    else:
        return np.nan

In [None]:
passed = []
for idx in range(decoded_epochs.n_epochs):
    this_idx = find_idx(idx)
#     print(this_idx)
    if this_idx is not None:
        if not np.isnan(this_idx):
            passed.append(idx)

In [None]:
len(passed)

In [None]:
for i in passed:
    print(i)
    
    start = decoded_epochs[i].start
    stop = decoded_epochs[i].stop

    checkit = decoded.time_slice(start, stop)
    pos = position.time_slice(start, stop)

    cmap_position = plt.get_cmap('Greys')
    cmap_decoded = plt.get_cmap('RdPu')
    colours_position = cmap_position(np.linspace(0.25, 0.75, pos.n_samples))
    colours_decoded = cmap_decoded(np.linspace(0.25, 0.75, checkit.n_samples))
    plt.plot(phase3.x, phase3.y, '.', color='#f7fbff')
    for dec_x, dec_y, pos_x, pos_y, idx in zip(checkit.x, checkit.y, pos.x, pos.y, range(checkit.n_samples)):
        plt.plot(pos_x, pos_y, '.', ms=7, color=colours_position[idx], markerfacecolor='none')
        plt.plot(dec_x, dec_y, '.', ms=7, color=colours_decoded[idx], markerfacecolor='none')
    plt.show()

In [None]:
for i in [621]:
    print(i)
    
    start = decoded_epochs[i].start
    stop = decoded_epochs[i].stop

    checkit = decoded.time_slice(start, stop)
    pos = position.time_slice(start, stop)
    
    print('decoded:', checkit.n_samples)
    print('actual:', pos.n_samples)

    cmap_position = plt.get_cmap('Greys')
    cmap_decoded = plt.get_cmap('YlGn')
    colours_position = cmap_position(np.linspace(0.25, 0.75, pos.n_samples))
    colours_decoded = cmap_decoded(np.linspace(0.25, 0.75, checkit.n_samples))
    plt.plot(phase3.x, phase3.y, '.', color='#f7fbff')
    for dec_x, dec_y, pos_x, pos_y, idx in zip(checkit.x, checkit.y, pos.x, pos.y, range(checkit.n_samples)):
        plt.plot(pos_x, pos_y, '.', ms=7, color=colours_position[idx], markerfacecolor='none')
        plt.plot(dec_x, dec_y, '.', ms=7, color=colours_decoded[idx], markerfacecolor='none')
    plt.show()