In [None]:
%matplotlib inline
import numpy as np
import matplotlib.pyplot as plt
import nept

In [None]:
spikes = np.array([nept.SpikeTrain(np.array([1., 1.1, 1.2]))])
x = np.array([6, 5, 4, 3, 2])
y = np.array([6, 5, 4, 3, 2])
time = np.array([1, 2, 3, 4, 5])
position = nept.Position([x, y], time)
xedges, yedges = nept.get_xyedges(position)
tuning_curves = nept.tuning_curve_2d(position, spikes,
                                     xedges, yedges, occupied_thresh=0., gaussian_sigma=None)
neurons = nept.Neurons(spikes, tuning_curves)

In [None]:
neurons.tuning_curves

In [None]:
window_size = 1.
window_advance = 1
time_edges = nept.get_edges(position, window_advance, lastbin=True)
counts = nept.bin_spikes(neurons.spikes, position, window_size, window_advance,
                         gaussian_std=None, normalized=True)

In [None]:
counts.data

In [None]:
min_neurons = 1
min_spikes = 1

tc_shape = tuning_curves.shape
decoding_tc = tuning_curves.reshape(tc_shape[0], tc_shape[1] * tc_shape[2])

likelihood = nept.bayesian_prob(counts, decoding_tc, window_size, min_neurons=min_neurons, min_spikes=min_spikes)

In [None]:
likelihood

In [None]:
epochs_interest = nept.Epoch(np.hstack([1., 5.]))

xcenters = (xedges[1:] + xedges[:-1]) / 2.
ycenters = (yedges[1:] + yedges[:-1]) / 2.
xy_centers = nept.cartesian(xcenters, ycenters)

decoded = nept.decode_location(likelihood, xy_centers, time_edges)

nan_idx = np.logical_and(np.isnan(decoded.x), np.isnan(decoded.y))
decoded = decoded[~nan_idx]
decoded.time

In [None]:
# sequence_speed = 0.5
# sequence_len = 4

# sequences = nept.remove_teleports(decoded, speed_thresh=sequence_speed, min_length=sequence_len)
# decoded_epochs = epochs_interest.intersect(sequences)
# decoded_epochs = decoded_epochs.expand(0.05)

# decoded = decoded[decoded_epochs]

In [None]:
actual_x = np.interp(decoded.time, position.time, position.x)
actual_y = np.interp(decoded.time, position.time, position.y)
actual_position = nept.Position(np.hstack((actual_x[..., np.newaxis],
                                           actual_y[..., np.newaxis])), decoded.time)
errors = actual_position.distance(decoded)

In [None]:
print(errors)
print(np.mean(errors))