In [None]:
%matplotlib inline
import matplotlib.pyplot as plt
import numpy as np
import os
import nept

In [None]:
min_neurons = 1
min_spikes = 1
t_start = 0
t_stop = 100
normalized = False
sequence_speed = 3.
sequence_len = 1.
min_epochs=3
window=0.0125
dt=0.0125
gaussian_std=0.0075

In [None]:
# Position
test_time = np.arange(0, 10, 0.034)
test_x = np.linspace(0, 200, len(test_time))
test_y = np.linspace(0, 200, len(test_time))

test_data = np.array([test_x, test_y]).T
test_position = nept.Position(test_data, test_time)

In [None]:
plt.plot(test_position.time, test_position.x)
plt.plot(test_position.time, test_position.y)
plt.show()

In [None]:
plt.plot(test_position.x, test_position.y, 'o')
plt.show()

In [None]:
xedges, yedges = nept.get_xyedges(test_position, binsize=2)

In [None]:
def make_spikes(low_val, high_val):
    return nept.SpikeTrain(test_position.time[(low_val < test_position.y)*(test_position.y < high_val)])

In [None]:
test_spikes = np.array([make_spikes(0, 25),
                        make_spikes(25, 50),
                        make_spikes(50, 75),
                        make_spikes(75, 100),
                        make_spikes(100, 125),
                        make_spikes(125, 150),
                        make_spikes(150, 175),
                        make_spikes(175, 200)])

test_tuning_curves = nept.tuning_curve_2d(test_position, test_spikes, xedges, yedges, 
                                          occupied_thresh=1., gaussian_std=0.)

In [None]:
np.mean(np.diff(test_spikes[0].time))

In [None]:
len(test_spikes)

In [None]:
plt.plot(test_position.time, test_position.y)
for idx, spiketrain in enumerate(test_spikes):
    plt.plot(spiketrain.time, np.ones(len(spiketrain.time))+idx*30, '|', color='b', ms=20, mew=2)
plt.xlim(0, 0.2)

In [None]:
test_tuning_curves = nept.tuning_curve_2d(test_position, test_spikes, xedges, yedges)

In [None]:
xx, yy = np.meshgrid(xedges, yedges)
for idx, tc in enumerate(test_tuning_curves):
    print('Neuron:', idx + 1)
    pp = plt.pcolormesh(xx, yy, tc, vmin=0., cmap='pink_r')
    plt.colorbar(pp)
    plt.axis('off')
    plt.show()

In [None]:
test_neurons = nept.Neurons(test_spikes, test_tuning_curves)

In [None]:
sliced_spikes = test_neurons.time_slice(t_start, t_stop)
sliced_spikes = sliced_spikes.spikes

position = test_position.time_slice(t_start, t_stop)

tuning_curves = test_neurons.tuning_curves

In [None]:
epochs_interest = nept.Epoch(np.array([position.time[0], position.time[-1]]))

In [None]:
window = 0.125
dt = 0.125
gaussian_std = 0.1
normalized = True

counts = nept.bin_spikes(sliced_spikes, position.time, dt=dt, window=window,
                         gaussian_std=gaussian_std, normalized=normalized)

In [None]:
pp = plt.pcolormesh(counts.data.T, cmap='pink_r')
plt.colorbar(pp)
plt.show()

In [None]:
tc_shape = tuning_curves.shape
decoding_tc = tuning_curves.reshape(tc_shape[0], tc_shape[1] * tc_shape[2])

In [None]:
likelihood = nept.bayesian_prob(counts, decoding_tc, window, 
                                min_neurons=min_neurons, min_spikes=min_spikes)

In [None]:
xcenters = (xedges[1:] + xedges[:-1]) / 2.
ycenters = (yedges[1:] + yedges[:-1]) / 2.
xy_centers = nept.cartesian(xcenters, ycenters)

In [None]:
decoded = nept.decode_location(likelihood, xy_centers, counts.time)
nan_idx = np.logical_and(np.isnan(decoded.x), np.isnan(decoded.y))
decoded = decoded[~nan_idx]

print('decoded:', decoded.n_samples)

In [None]:
position.n_samples

In [None]:
sequences = nept.remove_teleports(decoded, speed_thresh=sequence_speed, min_length=sequence_len)
decoded_epochs = epochs_interest.intersect(sequences)
decoded_epochs = decoded_epochs.expand(0.0)

decoded = decoded[decoded_epochs]

print('decoded sequences:', decoded.n_samples)

In [None]:
actual_x = np.interp(decoded.time, position.time, position.x)
actual_y = np.interp(decoded.time, position.time, position.y)
actual_position = nept.Position(np.hstack((actual_x[..., np.newaxis],
                                           actual_y[..., np.newaxis])), decoded.time)

In [None]:
errors = actual_position.distance(decoded)

In [None]:
errors

In [None]:
print(np.mean(errors))

In [None]:
start = decoded_epochs.start
stop = decoded_epochs.stop

dec = decoded.time_slice(start, stop)
pos = test_position.time_slice(start, stop)

plt.plot(pos.x, pos.y, '.', ms=7, color='g', markerfacecolor='none')
plt.plot(dec.x, dec.y, '.', ms=7, color='r', markerfacecolor='none')

plt.show()

In [None]:
plt.plot(decoded.time, decoded.x, '.')
plt.plot(test_position.time, test_position.x, '.')
plt.show()

In [None]:
plt.plot(decoded.time, decoded.y, '.')
plt.plot(test_position.time, test_position.y, '.')
plt.show()