In [None]:
%matplotlib inline
import matplotlib.pyplot as plt
from matplotlib import gridspec
import numpy as np
import itertools
import scipy
import os
import nept

from matplotlib import animation, rc
from IPython.display import HTML

from loading_data import get_data
from analyze_tuning_curves import get_only_tuning_curves
from analyze_decode_bytrial import decode_trial
from analyze_decode import get_decoded_zones
from utils_maze import find_zones, get_trials

In [None]:
thisdir = os.getcwd()
pickle_filepath = os.path.join(thisdir, "cache", "pickled")
output_filepath = os.path.join(thisdir, "plots", "decode-video")

In [None]:
import info.r068d6 as info

In [None]:
events, position, spikes, lfp, _ = get_data(info)

phase = "phase3"

position_initial = position.time_slice(info.task_times[phase].start, info.task_times[phase].stop)
spikes = [spiketrain.time_slice(info.task_times[phase].start, info.task_times[phase].stop) for spiketrain in spikes]

In [None]:
xedges, yedges = nept.get_xyedges(position_initial, binsize=4)

trial_epochs = get_trials(events, info.task_times[phase])

In [None]:
# for trial_idx in range(trial_epochs.n_epochs):
for trial_idx in [18]:
    trial_start = trial_epochs.starts[trial_idx]
    trial_stop = trial_epochs.stops[trial_idx]

    trial_times = nept.Epoch([trial_start, trial_stop])
    sliced_spikes, tuning_curves = get_only_tuning_curves(info, position, spikes, xedges, yedges, phase="phase3")

    decoding_times = trial_times
    shuffle_id = False
    speed_limit = 0.167
    t_smooth = 0.5
    dt = 0.025
    window = 0.025
    gaussian_std = 0.0075
    normalized = False
    min_neurons = 2
    min_spikes = 1

    position = position_initial.time_slice(decoding_times.start, decoding_times.stop)
    
    # limit position to only times when the subject is moving faster than a certain threshold
    run_epoch = nept.run_threshold(position, thresh=speed_limit, t_smooth=t_smooth)
    position = position[run_epoch]

    epochs_interest = nept.Epoch(np.array([position.time[0], position.time[-1]]))

    counts = nept.bin_spikes(sliced_spikes, position.time, dt=dt, window=window,
                             gaussian_std=gaussian_std, normalized=normalized)
    
    n_active_neurons = np.sum(counts.data >= 1, axis=1)

    tc_shape = tuning_curves.shape
    decoding_tc = tuning_curves.reshape(tc_shape[0], tc_shape[1] * tc_shape[2])

    likelihood = nept.bayesian_prob(counts, decoding_tc, window, min_neurons=min_neurons, min_spikes=min_spikes)

In [None]:
likelihood

In [None]:
trial_epochs.n_epochs

In [None]:
# for trial_idx in range(trial_epochs.n_epochs):
for trial_idx in [18]:
    trial_start = trial_epochs.starts[trial_idx]
    trial_stop = trial_epochs.stops[trial_idx]

    trial_times = nept.Epoch([trial_start, trial_stop])
    sliced_spikes, tuning_curves = get_only_tuning_curves(info, position, spikes, xedges, yedges, phase="phase3")

    decoding_times = trial_times
    shuffle_id = False
    speed_limit = 4.
    t_smooth = 0.5
    dt = 0.025
    window = 0.025
    gaussian_std = 0.0075
    normalized = False
    min_neurons = 2
    min_spikes = 1

    position = position_initial.time_slice(decoding_times.start, decoding_times.stop)
    
    # limit position to only times when the subject is moving faster than a certain threshold
    run_epoch = nept.run_threshold(position, thresh=speed_limit, t_smooth=t_smooth)
    position = position[run_epoch]

    epochs_interest = nept.Epoch(np.array([position.time[0], position.time[-1]]))

    counts = nept.bin_spikes(sliced_spikes, position.time, dt=dt, window=window,
                             gaussian_std=gaussian_std, normalized=normalized)
    
    n_active_neurons = np.sum(counts.data >= 1, axis=1)

    tc_shape = tuning_curves.shape
    decoding_tc = tuning_curves.reshape(tc_shape[0], tc_shape[1] * tc_shape[2])

    likelihood = nept.bayesian_prob(counts, decoding_tc, window, min_neurons=min_neurons, min_spikes=min_spikes)
    
    xcenters = (xedges[1:] + xedges[:-1]) / 2.
    ycenters = (yedges[1:] + yedges[:-1]) / 2.
    xy_centers = nept.cartesian(xcenters, ycenters)
    decoded_position = nept.decode_location(likelihood, xy_centers, counts.time)
    
    likelihood = likelihood.reshape(np.shape(likelihood)[0], tc_shape[1], tc_shape[2])
#     likelihood[np.isnan(likelihood)] = 0.

    f_xy = scipy.interpolate.interp1d(position.time, position.data.T, kind="nearest")
    counts_xy = f_xy(decoded_position.time)
    true_position = nept.Position(np.hstack((counts_xy[0][..., np.newaxis],
                                             counts_xy[1][..., np.newaxis])),
                                  decoded_position.time)
    
    errors = true_position.distance(decoded_position)
    print(np.nanmean(errors))
    

    fig = plt.figure(figsize=(12, 10))
    gs = gridspec.GridSpec(5, 4) 
    
    ax1 = plt.subplot2grid((5, 4), (0, 0), colspan=3, rowspan=3)

    xx, yy = np.meshgrid(xedges, yedges)

    pad_amount = 5
    ax1.set_xlim((np.floor(np.min(true_position.x))-pad_amount, np.ceil(np.max(true_position.x))+pad_amount))
    ax1.set_ylim((np.floor(np.min(true_position.y))-pad_amount, np.ceil(np.max(true_position.y))+pad_amount))

    n_timebins = decoded_position.n_samples

    cmap = plt.cm.get_cmap('bone_r')
    posterior_position = ax1.pcolormesh(xx[:-1], yy[:-1], likelihood[0], vmin=0, vmax=0.1, cmap=cmap)
    colorbar = fig.colorbar(posterior_position, ax=ax1)

    estimated_position, = ax1.plot([], [], "<", color="r")
    rat_position, = ax1.plot([], [], "<", color="b")
    
    ax2 = plt.subplot2grid((5, 4), (3, 0), colspan=3)

    binwidth = 5.
    error_bins = np.arange(-binwidth, np.max(errors)+binwidth, binwidth)
    
    _, _, errors_bin = ax2.hist([np.clip(errors, error_bins[0], error_bins[-1])], bins=error_bins, rwidth=0.9, color="k")
    errors_idx = np.digitize(errors, error_bins)

    
    xlabels = [str(int(b)) for b in error_bins]
    xlabels[0] = "nan"
    ax2.set_xticklabels(xlabels)
    fontsize = 14
    ax2.set_xlabel("Error (cm)", fontsize=fontsize)
    ax2.set_ylabel("# bins", fontsize=fontsize)
    ax2.spines['right'].set_visible(False)
    ax2.spines['top'].set_visible(False)
    ax2.yaxis.set_ticks_position('left')
    ax2.xaxis.set_ticks_position('bottom')
    xticks = binwidth * np.arange(0, len(xlabels)-1, 2) - binwidth
    xticks[0] = -binwidth/2.
    plt.xticks(xticks, fontsize=fontsize)
    plt.yticks(fontsize=fontsize)
    
    ax3 = plt.subplot2grid((5, 4), (4, 0), colspan=3)
    
    binwidth = 1
    n_active_bins = np.arange(-binwidth, np.max(n_active_neurons)+binwidth, binwidth)

    _, _, n_neurons_bin = ax3.hist(n_active_neurons, bins=n_active_bins, rwidth=0.9, color="k")
#     n_active_idx = np.digitize(n_active_neurons, n_active_bins)
    
    ax3.set_xlabel("Number of active neurons", fontsize=fontsize)
    ax3.set_ylabel("# bins", fontsize=fontsize)
    ax3.spines['right'].set_visible(False)
    ax3.spines['top'].set_visible(False)
    ax3.yaxis.set_ticks_position('left')
    ax3.xaxis.set_ticks_position('bottom')
    xticks = binwidth * np.arange(len(n_active_bins))
    plt.xticks(xticks-binwidth/2)
    ax3.set_xticklabels(xticks, fontsize=fontsize)
    plt.yticks(fontsize=fontsize)
    
    fig.tight_layout()


    def init():
        posterior_position.set_array([])
        estimated_position.set_data([], [])
        rat_position.set_data([], [])
        return (posterior_position, estimated_position, rat_position)


    def animate(i):
        posterior_position.set_array(likelihood[i].ravel())
        estimated_position.set_data(decoded_position.x[i], decoded_position.y[i])
        rat_position.set_data(true_position.x[i], true_position.y[i])
        
        for patch in errors_bin:
            patch.set_fc('k')
        errors_bin[errors_idx[i]-1].set_fc('r')
        
        for patch in n_neurons_bin:
            patch.set_fc('k')
        n_neurons_bin[n_active_neurons[i]].set_fc('r')
    
        return (posterior_position, estimated_position, rat_position, errors_bin, n_neurons_bin)

    anim = animation.FuncAnimation(fig, animate, frames=n_timebins, interval=80, 
                                   blit=False, repeat=False)


#     writer = animation.writers['ffmpeg'](fps=18)
#     dpi = 600
#     filename = '/errors_'+info.session_id+'_trial'+str(trial_idx)+'.mp4'
#     anim.save(output_filepath+filename, writer=writer, dpi=dpi)
    
#     plt.close()

In [None]:
true_position.n_samples, decoded_position.n_samples

In [None]:
likelihood.shape

In [None]:
# writer = animation.writers['ffmpeg'](fps=18)
# dpi = 600
# filename = '/errors_'+info.session_id+'_trial'+str(trial_idx)+'.mp4'
# anim.save(output_filepath+filename, writer=writer, dpi=dpi)

In [None]:
print("Blue is true position; Red is estimated location.")
HTML(anim.to_html5_video())

In [None]:
true_position.distance(decoded_position)

In [None]:
true_position.x[:20]

In [None]:
decoded_position.x[:20]

In [None]:
likelihood

In [None]:
tuning_curves = decoding_tc
binsize = window


n_time_bins = np.shape(counts.time)[0]
n_position_bins = np.shape(tuning_curves)[1]

likelihood = np.empty((n_time_bins, n_position_bins)) * np.nan

# Ignore warnings when inf created in this loop
error_settings = np.seterr(over='ignore')
for idx in range(n_position_bins):
    valid_idx = tuning_curves[:, idx] > 1  # log of 1 or less is negative or invalid
    if np.any(valid_idx):
        # event_rate is the lambda in this poisson distribution
        event_rate = tuning_curves[valid_idx, idx, np.newaxis].T ** counts.data[:, valid_idx]
        prior = np.exp(-binsize * np.nansum(tuning_curves[valid_idx, idx]))

        # Below is the same as
        # likelihood[:, idx] = np.prod(event_rate, axis=0) * prior * (1/n_position_bins)
        # only less likely to have floating point issues, though slower
        likelihood[:, idx] = np.exp(np.nansum(np.log(event_rate), axis=1)) * prior * (1/n_position_bins)
np.seterr(**error_settings)

print(likelihood)

# Set any inf value to be largest float
largest_float = np.finfo(float).max
likelihood[np.isinf(likelihood)] = largest_float
likelihood /= np.nansum(likelihood, axis=1)[..., np.newaxis]

print(likelihood)

# Remove bins with too few neurons that that are active
# a neuron is considered active by having at least min_spikes in a bin
n_active_neurons = np.sum(counts.data >= min_spikes, axis=1)
likelihood[n_active_neurons < min_neurons] = np.nan

In [None]:
n_position_bins

In [None]:

for idx in range(n_position_bins):
    if np.nansum(tuning_curves[:, idx]) > 1:
        print(idx)

In [None]:
tuning_curves[:, 1113]

In [None]:
valid_idx

In [None]:
tuning_curves[valid_idx, idx]

In [None]:
counts.data

In [None]:
n_active_neurons.shape

In [None]:
trial_idx = 18
trial_start = trial_epochs.starts[trial_idx]
trial_stop = trial_epochs.stops[trial_idx]

trial_times = nept.Epoch([trial_start, trial_stop])
sliced_spikes, tuning_curves = get_only_tuning_curves(info, position, spikes, xedges, yedges, phase="phase3")

decoding_times = trial_times
shuffle_id = False
speed_limit = 4.
t_smooth = 0.5
dt = 0.025
window = 0.025
gaussian_std = 0.0075
normalized = False
min_neurons = 2
min_spikes = 1

position = position_initial.time_slice(decoding_times.start, decoding_times.stop)

# limit position to only times when the subject is moving faster than a certain threshold
run_epoch = nept.run_threshold(position, thresh=speed_limit, t_smooth=t_smooth)
position = position[run_epoch]

epochs_interest = nept.Epoch(np.array([position.time[0], position.time[-1]]))

counts = nept.bin_spikes(sliced_spikes, position.time, dt=dt, window=window,
                         gaussian_std=gaussian_std, normalized=normalized)

n_active_neurons = np.sum(counts.data >= 1, axis=1)

tc_shape = tuning_curves.shape
decoding_tc = tuning_curves.reshape(tc_shape[0], tc_shape[1] * tc_shape[2])

likelihood = nept.bayesian_prob(counts, decoding_tc, window, min_neurons=min_neurons, min_spikes=min_spikes)

In [None]:
xcenters = (xedges[1:] + xedges[:-1]) / 2.
ycenters = (yedges[1:] + yedges[:-1]) / 2.
xy_centers = nept.cartesian(xcenters, ycenters)
decoded_position = nept.decode_location(likelihood, xy_centers, counts.time)

likelihood = likelihood.reshape(np.shape(likelihood)[0], tc_shape[1], tc_shape[2])
#     likelihood[np.isnan(likelihood)] = 0.

f_xy = scipy.interpolate.interp1d(position.time, position.data.T, kind="nearest")
counts_xy = f_xy(decoded_position.time)
true_position = nept.Position(np.hstack((counts_xy[0][..., np.newaxis],
                                         counts_xy[1][..., np.newaxis])),
                              decoded_position.time)

errors = true_position.distance(decoded_position)
print(np.nanmean(errors))

In [None]:
sum_likelihood = np.nansum(likelihood, axis=0)

In [None]:
fig, ax = plt.subplots()

xx, yy = np.meshgrid(xedges, yedges)

pad_amount = 5
ax.set_xlim((np.floor(np.min(true_position.x))-pad_amount, np.ceil(np.max(true_position.x))+pad_amount))
ax.set_ylim((np.floor(np.min(true_position.y))-pad_amount, np.ceil(np.max(true_position.y))+pad_amount))

n_timebins = len(likelihood)

cmap = plt.cm.get_cmap('bone_r')
posterior_position = ax.pcolormesh(xx[:-1], yy[:-1], sum_likelihood, vmin=0, vmax=0.1, cmap=cmap)
colorbar = plt.colorbar(posterior_position, ax=ax)