In [None]:
%matplotlib inline
import matplotlib.pyplot as plt
from matplotlib import gridspec
import matplotlib
import numpy as np
import itertools
import scipy
import pandas as pd
import pickle
import seaborn as sns
from scipy import stats
import os
import nept

from matplotlib import animation, rc
from IPython.display import HTML

from loading_data import get_data
from analyze_tuning_curves import get_only_tuning_curves
from analyze_decode_bytrial import decode_trial
from analyze_decode import get_decoded_zones
from utils_maze import find_zones, get_trials, get_zones, get_trial_idx

In [None]:
thisdir = os.getcwd()
pickle_filepath = os.path.join(thisdir, "cache", "pickled")
output_filepath = os.path.join(thisdir, "plots", "decode-video")

In [None]:
import info.r063d2 as info
import info.r063d6 as r063d6
# infos = [r063d6]

from run import spike_sorted_infos
infos = spike_sorted_infos

In [None]:
def plot_errors(all_errors, all_errors_id_shuffled, n_sessions, filename=None):
    
    all_errors = np.concatenate([np.concatenate(errors, axis=0) for errors in all_errors], axis=0)
    all_errors_id_shuffled = np.concatenate([np.concatenate(errors, axis=0) for errors in all_errors_id_shuffled], axis=0)

    fliersize = 1

    decoded_dict = dict(error=all_errors, label='Decoded')
    shuffled_id_dict = dict(error=all_errors_id_shuffled, label='ID shuffled')
    decoded_errors = pd.DataFrame(decoded_dict)
    shuffled_id = pd.DataFrame(shuffled_id_dict)
    data = pd.concat([shuffled_id, decoded_errors])
    colours = ['#ffffff', '#bdbdbd']

    plt.figure(figsize=(6, 4))
    flierprops = dict(marker='o', markersize=fliersize, linestyle='none')
    ax = sns.boxplot(x='label', y='error', data=data, flierprops=flierprops)


    edge_colour = '#252525'
    for i, artist in enumerate(ax.artists):
        artist.set_edgecolor(edge_colour)
        artist.set_facecolor(colours[i])

        for j in range(i*6, i*6+6):
            line = ax.lines[j]
            line.set_color(edge_colour)
            line.set_mfc(edge_colour)
            line.set_mec(edge_colour)
    
    ax.text(1., 1., "N sessions: %d \nmean-error: %.1f cm \nmedian-error: %.1f cm" % (n_sessions, 
                                                                                      np.mean(all_errors), 
                                                                                      np.median(all_errors)),
            horizontalalignment='right',
            verticalalignment='top',
            transform = ax.transAxes,
            fontsize=10)

    ax.set(xlabel=' ', ylabel="Error (cm)")
    plt.xticks(fontsize=14)

    sns.despine()
    plt.tight_layout()
    
    if filename is not None:
        plt.savefig(filename)
        plt.close()
    else:
        plt.show()

In [None]:
def plot_over_space(values, positions, xedges, yedges):
    xcenters = xedges[:-1] + (xedges[1:] - xedges[:-1]) / 2
    ycenters = yedges[:-1] + (yedges[1:] - yedges[:-1]) / 2

    count_position = np.zeros((len(yedges), len(xedges)))
    n_position = np.ones((len(yedges), len(xedges)))

    for trial_values, trial_positions in zip(values, positions):
        for these_values, x, y in zip(trial_values, trial_positions.x, trial_positions.y):
            x_idx = nept.find_nearest_idx(xcenters, x)
            y_idx = nept.find_nearest_idx(ycenters, y)
            if np.isscalar(these_values):
                count_position[y_idx][x_idx] += these_values
            else:
                count_position[y_idx][x_idx] += these_values[y_idx][x_idx]
            n_position[y_idx][x_idx] += 1

    return count_position / n_position

In [None]:
def make_animation(session_id, decoded, trial_idx, xedge, yedge, filepath):
    decoded_position = decoded["decoded"][trial_idx]
    true_position = decoded["actual"][trial_idx]
    likelihoods = np.array(decoded["likelihoods"][trial_idx])
    n_active = decoded["n_active"][trial_idx]
    errors = decoded["errors"][trial_idx]

    fig = plt.figure(figsize=(12, 10))
    gs = gridspec.GridSpec(5, 4)

    xx, yy = np.meshgrid(xedge, yedge)

    ax1 = plt.subplot2grid((5, 4), (0, 0), colspan=3, rowspan=3)

    pad_amount = binsize*2
    ax1.set_xlim((np.floor(np.min(true_position.x))-pad_amount, np.ceil(np.max(true_position.x))+pad_amount))
    ax1.set_ylim((np.floor(np.min(true_position.y))-pad_amount, np.ceil(np.max(true_position.y))+pad_amount))

    n_timebins = decoded_position.n_samples
#     n_timebins = 10

    n_colours = 20.
    colours = [(1., 1., 1.)]
    colours.extend(matplotlib.cm.copper_r(np.linspace(0, 1, n_colours-1)))
    cmap = matplotlib.colors.ListedColormap(colours)

    likelihoods_withnan = np.array(likelihoods)
    likelihoods[np.isnan(likelihoods)] = -0.01

    xcenters = xedge[:-1] + (xedge[1:] - xedge[:-1]) / 2
    ycenters = yedge[:-1] + (yedge[1:] - yedge[:-1]) / 2

    x_idx = [nept.find_nearest_idx(xcenters, true_position.x[timestep]) for timestep in range(true_position.n_samples)]
    y_idx = [nept.find_nearest_idx(ycenters, true_position.y[timestep]) for timestep in range(true_position.n_samples)]

    x_dec_idx = [nept.find_nearest_idx(xcenters, decoded_position.x[timestep]) for timestep in range(decoded_position.n_samples)]
    y_dec_idx = [nept.find_nearest_idx(ycenters, decoded_position.y[timestep]) for timestep in range(decoded_position.n_samples)]

    posterior_position = ax1.pcolormesh(xx[:-1], yy[:-1], likelihoods[0], vmax=0.2, cmap=cmap)
    colorbar = fig.colorbar(posterior_position, ax=ax1)

    estimated_position, = ax1.plot([], [], "o", color="c")
    rat_position, = ax1.plot([], [], "<", color="b")

    ax2 = plt.subplot2grid((5, 4), (3, 0), colspan=3)

    binwidth = 5.
    error_bins = np.arange(-binwidth, np.max(errors)+binwidth, binwidth)

    _, _, errors_bin = ax2.hist([np.clip(errors, error_bins[0], error_bins[-1])], bins=error_bins, rwidth=0.9, color="k")
    errors_idx = np.digitize(errors, error_bins)

    fontsize = 14
    likelihood_at_actual = ax2.text(0.6, 1, [],
             horizontalalignment='left',
             verticalalignment='top',
             transform = ax2.transAxes,
             fontsize=fontsize)

    ax2.set_xlabel("Error (cm)", fontsize=fontsize)
    ax2.set_ylabel("# bins", fontsize=fontsize)
    ax2.spines['right'].set_visible(False)
    ax2.spines['top'].set_visible(False)
    ax2.yaxis.set_ticks_position('left')
    ax2.xaxis.set_ticks_position('bottom')
    xticks = binwidth * np.arange(0, len(error_bins), 4)
    plt.xticks(xticks, fontsize=fontsize)
    plt.yticks(fontsize=fontsize)

    ax3 = plt.subplot2grid((5, 4), (4, 0), colspan=3)

    n_active_bins = np.arange(-0.5, np.max(n_active)+1)

    _, _, n_neurons_bin = ax3.hist(n_active, bins=n_active_bins, rwidth=0.9, color="k", align="mid")

    ax3.set_xlabel("Number of active neurons", fontsize=fontsize)
    ax3.set_ylabel("# bins", fontsize=fontsize)
    ax3.spines['right'].set_visible(False)
    ax3.spines['top'].set_visible(False)
    ax3.yaxis.set_ticks_position('left')
    ax3.xaxis.set_ticks_position('bottom')
    plt.yticks(fontsize=fontsize)

    fig.tight_layout()


    def init():
        posterior_position.set_array([])
        estimated_position.set_data([], [])
        rat_position.set_data([], [])
        likelihood_at_actual.set_text([])
        return (posterior_position, estimated_position, rat_position, likelihood_at_actual)


    def animate(i):
        posterior_position.set_array(likelihoods[i].ravel())
        estimated_position.set_data(decoded_position.x[i], decoded_position.y[i])
        rat_position.set_data(true_position.x[i], true_position.y[i])

        for patch in errors_bin:
            patch.set_fc('k')
        errors_bin[errors_idx[i]-1].set_fc('r')

        for patch in n_neurons_bin:
            patch.set_fc('k')
        n_neurons_bin[n_active[i]].set_fc('b')

        likelihood_at_actual.set_text("posterior at true position: %.3f \nposterior at decoded position: %.3f " % 
                                      (likelihoods_withnan[i][y_idx[i]][x_idx[i]], 
                                       likelihoods_withnan[i][y_dec_idx[i]][x_dec_idx[i]]))

        return (posterior_position, estimated_position, rat_position, likelihood_at_actual)

    anim = animation.FuncAnimation(fig, animate, frames=n_timebins, interval=200, 
                                   blit=False, repeat=False)


    writer = animation.writers['ffmpeg'](fps=10)
    dpi = 600
    filename = session_id+'_decoded_trial'+str(trial_idx)+'.mp4'
    anim.save(os.path.join(filepath, filename), writer=writer, dpi=dpi)
    plt.close()

In [None]:
binsizes = [12]

In [None]:
# for info in infos:
_, position, _, _, _ = get_data(info)
for binsize in binsizes:
    huh = 0
    decoded_filename = info.session_id + '_decoded_binsize' + str(binsize) + 'cm.pkl'
    pickled_decoded = os.path.join(pickle_filepath, decoded_filename)
    with open(pickled_decoded, 'rb') as fileobj:
        decoded = pickle.load(fileobj)
    print(decoded["decoded"][0].n_samples)
    for trial in decoded["decoded"]:
        huh += trial.n_samples

    print(binsize, huh)

In [None]:
info.session_id

In [None]:
letstrythis = decoded["actual"][0]

In [None]:
middlex = (np.max(position.x) - np.min(position.x)) / 2
middley = (np.max(position.y) - np.min(position.y)) / 2

data = np.array([np.ones(position.n_samples) * middlex,
                 np.ones(position.n_samples) * middley])

middle_position = nept.Position(data, position.time)

In [None]:
plt.plot(position.x, position.y, ".")
plt.plot(middlex, middley, "r.")
plt.show()

In [None]:
np.mean(position.distance(middle_position))

In [None]:
# Individual errors
for info in infos:
    for binsize in binsizes:
        decoded_filename = info.session_id + '_decoded_binsize' + str(binsize) + 'cm.pkl'
        pickled_decoded = os.path.join(pickle_filepath, decoded_filename)
        with open(pickled_decoded, 'rb') as fileobj:
            decoded = pickle.load(fileobj)

        shuffled_filename = info.session_id + '_decoded-shuffled_binsize' + str(binsize) + 'cm.pkl'
        shuffled_decoded = os.path.join(pickle_filepath, shuffled_filename)
        with open(shuffled_decoded, 'rb') as fileobj:
            shuffled = pickle.load(fileobj)

        filename = info.session_id+"-errors_binsize"+str(binsize)+"cm.png"
        filepath = os.path.join(output_filepath, "errors")
        if not os.path.exists(filepath):
            os.makedirs(filepath)

        plot_errors([decoded["errors"]], [shuffled["errors"]], n_sessions=1, 
                    filename=os.path.join(filepath, filename))

In [None]:
# Combined errors
for binsize in binsizes:
    n_sessions = 0
    combined_decoded_errors = []
    combined_shuffled_errors = []
    for info in infos:
        n_sessions += 1
        decoded_filename = info.session_id + '_decoded_binsize' + str(binsize) + 'cm.pkl'
        pickled_decoded = os.path.join(pickle_filepath, decoded_filename)
        with open(pickled_decoded, 'rb') as fileobj:
            decoded = pickle.load(fileobj)
        combined_decoded_errors.append(decoded["errors"])

        shuffled_filename = info.session_id + '_decoded-shuffled_binsize' + str(binsize) + 'cm.pkl'
        shuffled_decoded = os.path.join(pickle_filepath, shuffled_filename)
        with open(shuffled_decoded, 'rb') as fileobj:
            shuffled = pickle.load(fileobj)
        combined_shuffled_errors.append(shuffled["errors"])
        
    filename = "combined-errors_binsize"+str(binsize)+"cm.png"
    filepath = os.path.join(output_filepath, "errors")
    if not os.path.exists(filepath):
        os.makedirs(filepath)
        
    plot_errors(combined_decoded_errors, combined_shuffled_errors, n_sessions=n_sessions, 
                filename=os.path.join(filepath, filename))

In [None]:
# Individual mean/median errors
for info in infos:
    mean_errors = []
    median_errors = []
    
    mean_errors_shuffled = []
    median_errors_shuffled = []
    
    for binsize in binsizes:
        combine_errors = []
        combine_errors_shuffled = []
        
        decoded_filename = info.session_id + '_decoded_binsize' + str(binsize) + 'cm.pkl'
        pickled_decoded = os.path.join(pickle_filepath, decoded_filename)
        with open(pickled_decoded, 'rb') as fileobj:
            decoded = pickle.load(fileobj)

        shuffled_filename = info.session_id + '_decoded-shuffled_binsize' + str(binsize) + 'cm.pkl'
        shuffled_decoded = os.path.join(pickle_filepath, shuffled_filename)
        with open(shuffled_decoded, 'rb') as fileobj:
            shuffled = pickle.load(fileobj)
            
        for error in decoded["errors"]:
            combine_errors.extend(error)
        mean_errors.append(np.mean(combine_errors))
        median_errors.append(np.median(combine_errors))
        
        for error in shuffled["errors"]:
            combine_errors_shuffled.extend(error)
        mean_errors_shuffled.append(np.mean(combine_errors_shuffled))
        median_errors_shuffled.append(np.median(combine_errors_shuffled))
        
    filename = info.session_id+"-mean-errors.png"
    filepath = os.path.join(output_filepath, "errors", "average")
    if not os.path.exists(filepath):
        os.makedirs(filepath)
        
    plt.plot(binsizes, mean_errors)
    plt.xlabel("Binsize (cm)")
    plt.ylabel("Mean error")
    plt.title(info.session_id+" mean decode error (cm)")
    plt.savefig(os.path.join(filepath, filename))
    plt.close()
    
    filename = info.session_id+"-median-errors.png"
    filepath = os.path.join(output_filepath, "errors", "average")
    if not os.path.exists(filepath):
        os.makedirs(filepath)
    
    plt.plot(binsizes, median_errors)
    plt.xlabel("Binsize (cm)")
    plt.ylabel("Median error")
    plt.title(info.session_id+" median decode error (cm)")
    plt.savefig(os.path.join(filepath, filename))
    plt.close()

In [None]:
# All mean/median errors
all_mean_errors = []
all_median_errors = []
all_mean_errors_shuffled = []
all_median_errors_shuffled = []

for binsize in binsizes:
    mean_errors = []
    median_errors = []

    mean_errors_shuffled = []
    median_errors_shuffled = []
    for info in infos:
        combine_errors = []
        combine_errors_shuffled = []
        
        decoded_filename = info.session_id + '_decoded_binsize' + str(binsize) + 'cm.pkl'
        pickled_decoded = os.path.join(pickle_filepath, decoded_filename)
        with open(pickled_decoded, 'rb') as fileobj:
            decoded = pickle.load(fileobj)

        shuffled_filename = info.session_id + '_decoded-shuffled_binsize' + str(binsize) + 'cm.pkl'
        shuffled_decoded = os.path.join(pickle_filepath, shuffled_filename)
        with open(shuffled_decoded, 'rb') as fileobj:
            shuffled = pickle.load(fileobj)
            
        for error in decoded["errors"]:
            combine_errors.extend(error)
        mean_errors.append(np.mean(combine_errors))
        median_errors.append(np.median(combine_errors))
        
        for error in shuffled["errors"]:
            combine_errors_shuffled.extend(error)
        mean_errors_shuffled.append(np.mean(combine_errors_shuffled))
        median_errors_shuffled.append(np.median(combine_errors_shuffled))
        
    all_mean_errors.append(np.mean(mean_errors))
    all_median_errors.append(np.mean(median_errors))
    all_mean_errors_shuffled.append(np.mean(mean_errors_shuffled))
    all_median_errors_shuffled.append(np.mean(median_errors_shuffled))
    
    
filename = "all-mean-errors.png"
filepath = os.path.join(output_filepath, "errors", "average")
if not os.path.exists(filepath):
    os.makedirs(filepath)

plt.plot(binsizes, all_mean_errors)
plt.xlabel("Binsize (cm)")
plt.ylabel("Mean error")
plt.title("All mean decode error (cm)")
plt.savefig(os.path.join(filepath, filename))
plt.close()

filename = "all-median-errors.png"
filepath = os.path.join(output_filepath, "errors")
if not os.path.exists(filepath):
    os.makedirs(filepath)

plt.plot(binsizes, all_median_errors)
plt.xlabel("Binsize (cm)")
plt.ylabel("Median error")
plt.title("All median decode error (cm)")
plt.savefig(os.path.join(filepath, filename))
plt.close()

In [None]:
# proportion decoded
for binsize in binsizes:
    proportion_decoded = []
    session_ids = []
    for info in infos:
        session_ids.append(info.session_id)
        decoded_filename = info.session_id + '_decoded_binsize' + str(binsize) + 'cm.pkl'
        pickled_decoded = os.path.join(pickle_filepath, decoded_filename)
        with open(pickled_decoded, 'rb') as fileobj:
            decoded = pickle.load(fileobj)

        n_decoded = 0
        for trial in decoded["decoded"]:
            n_decoded += trial.n_samples
        proportion_decoded.append(n_decoded/decoded["session_n_running"])

    filename = "proportion-decoded_binsize"+str(binsize)+"cm.png"
    filepath = os.path.join(output_filepath, "proportion")
    if not os.path.exists(filepath):
        os.makedirs(filepath)

    y_pos = np.arange(n_sessions)
    plt.bar(y_pos, proportion_decoded, align='center', alpha=0.7)
    plt.xticks(y_pos, session_ids, rotation=90, fontsize=10)
    plt.ylabel('Proportion')
    plt.title("Samples decoded with %d cm bins" % binsize)
    plt.tight_layout()
    plt.savefig(os.path.join(filepath, filename))
    plt.close()

In [None]:
# Individual likelihoods/errors over space
for info in infos:
    events, position, spikes, _, _ = get_data(info)
    
    for binsize in binsizes:
        xedge, yedge = nept.get_xyedges(position, binsize=binsize)
        xx, yy = np.meshgrid(xedge, yedge)
        
        decoded_filename = info.session_id + '_decoded_binsize' + str(binsize) + 'cm.pkl'
        pickled_decoded = os.path.join(pickle_filepath, decoded_filename)
        with open(pickled_decoded, 'rb') as fileobj:
            decoded = pickle.load(fileobj)
            
        filename = info.session_id+"-likelihoods_byactual-"+str(binsize)+"cm.png"
        filepath = os.path.join(output_filepath, "likelihoods")
        if not os.path.exists(filepath):
            os.makedirs(filepath)

        likelihood_byactual = plot_over_space(decoded["likelihoods"], decoded["actual"], xedge, yedge)
        pp = plt.pcolormesh(xx, yy, likelihood_byactual, vmin=0., cmap='bone_r')
        plt.colorbar(pp)
        title = info.session_id+" posterior"
        plt.title(title)
        plt.axis('off')
        plt.savefig(os.path.join(filepath, filename))
        plt.close()
#         plt.show()
        
        filename = info.session_id+"-errors_byactual-"+str(binsize)+"cm.png"
        filepath = os.path.join(output_filepath, "errors")
        if not os.path.exists(filepath):
            os.makedirs(filepath)

        errors_byactual = plot_over_space(decoded["errors"], decoded["actual"], xedge, yedge)
        pp = plt.pcolormesh(xx, yy, errors_byactual, vmin=0., cmap='bone_r')
        plt.colorbar(pp)
        title = info.session_id+" decoding error (cm)"
        plt.title(title)
        plt.axis('off')
        plt.savefig(os.path.join(filepath, filename))
        plt.close()
#         plt.show()

In [None]:
# animations
binsize = 8
for info in infos:
    
    decoded_filename = info.session_id + '_decoded_binsize' + str(binsize) + 'cm.pkl'
    pickled_decoded = os.path.join(pickle_filepath, decoded_filename)
    with open(pickled_decoded, 'rb') as fileobj:
        decoded = pickle.load(fileobj)

    _, position, _, _, _ = get_data(info)
    xedge, yedge = nept.get_xyedges(position, binsize=binsize)

    filepath = os.path.join(output_filepath, "animations")
    if not os.path.exists(filepath):
        os.makedirs(filepath)
    for trial_idx in [1]:
        make_animation(info.session_id, decoded, trial_idx=trial_idx, xedge=xedge, yedge=yedge, filepath=filepath)

In [None]:
1/0

In [None]:
def get_only_tuning_curves(position, spikes, xedges, yedges, epoch_of_interest):
    sliced_position = position.time_slice(epoch_of_interest.start, epoch_of_interest.stop)
    sliced_spikes = [spiketrain.time_slice(epoch_of_interest.start, epoch_of_interest.stop) for spiketrain in spikes]

    # Limit position and spikes to only running times
    run_epoch = nept.run_threshold(sliced_position, thresh=10., t_smooth=0.8)
    run_position = sliced_position[run_epoch]
    tuning_spikes = [spiketrain.time_slice(run_epoch.starts, run_epoch.stops) for spiketrain in sliced_spikes]

    tuning_curves = nept.tuning_curve_2d(run_position, tuning_spikes, xedges, yedges, occupied_thresh=0.5, gaussian_std=0.3)

    return tuning_curves

In [None]:
for binsize in binsizes:
    events, position, spikes, _, _ = get_data(info)
    xedge, yedge = nept.get_xyedges(position, binsize=binsize)
    
    xx, yy = np.meshgrid(xedge, yedge)
    
    epoch_of_interest = info.task_times["phase3"]
    
    tuning_curves = get_only_tuning_curves(position, spikes, xedge, yedge, epoch_of_interest)
    
    for i, tuning_curve in enumerate(tuning_curves[:3]):    
        tuning_curve = np.array(tuning_curve)
        tuning_curve[np.isnan(tuning_curve)] = -np.nanmax(tuning_curve) / n_colours

        plt.figure()
        pp = plt.pcolormesh(xx, yy, tuning_curve, cmap=cmap)

        colourbar = plt.colorbar(pp)
        plt.tight_layout()
        plt.show()

In [None]:
n_colours = 15
colours = [(1., 1., 1.)]
colours.extend(matplotlib.cm.copper_r(np.linspace(0, 1, n_colours-1)))
cmap = matplotlib.colors.ListedColormap(colours)

multiple_tuning_curves = np.zeros((tuning_curves.shape[1], tuning_curves.shape[2]))
for tuning_curve in tuning_curves:
    multiple_tuning_curves += tuning_curve

multiple_tuning_curves = multiple_tuning_curves / np.nansum(multiple_tuning_curves)
multiple_tuning_curves = np.array(multiple_tuning_curves)
multiple_tuning_curves[np.isnan(multiple_tuning_curves)] = -np.nanmax(multiple_tuning_curves) / n_colours

plt.figure()
pp = plt.pcolormesh(xx, yy, multiple_tuning_curves, cmap=cmap)
plt.colorbar(pp)
plt.title(info.session_id + " tuning curves (normalized)")
plt.tight_layout()
plt.show()

In [None]:
def get_decoded(info, position, spikes, xedges, yedges, shuffled_id):
    
    phase = info.task_times["phase3"]
    trials = get_trials(events, phase)
    
    error_byactual_position = np.zeros((len(yedges), len(xedges)))
    n_byactual_position = np.ones((len(yedges), len(xedges)))
    
    session_n_active = []
    session_likelihoods = []
    session_decoded = []
    session_actual = []
    session_errors = []
    session_n_running = 0
    
    for trial in trials:
        epoch_of_interest = phase.excludes(trial)

        tuning_curves = get_only_tuning_curves(position, 
                                               spikes, 
                                               xedges, 
                                               yedges, 
                                               epoch_of_interest)

        if shuffled_id:
            tuning_curves = np.random.permutation(tuning_curves)

        sliced_position = position.time_slice(trial.start, trial.stop)
        
        sliced_spikes = [spiketrain.time_slice(trial.start, 
                                               trial.stop) for spiketrain in spikes]

        # limit position to only times when the subject is moving faster than a certain threshold
        run_epoch = nept.run_threshold(sliced_position, thresh=8., t_smooth=0.8)
        sliced_position = sliced_position[run_epoch]
        
        session_n_running += sliced_position.n_samples
        
        sliced_spikes = [spiketrain.time_slice(run_epoch.start, 
                                               run_epoch.stop) for spiketrain in sliced_spikes]

        epochs_interest = nept.Epoch(np.array([sliced_position.time[0], sliced_position.time[-1]]))

        counts = nept.bin_spikes(sliced_spikes, sliced_position.time, dt=0.025, window=0.025,
                                 gaussian_std=0.0075, normalized=False)
        
        min_neurons = 2
        min_spikes = 2
        
        tc_shape = tuning_curves.shape
        decoding_tc = tuning_curves.reshape(tc_shape[0], tc_shape[1] * tc_shape[2])

        likelihood = nept.bayesian_prob(counts, decoding_tc, binsize=0.025, min_neurons=2, 
                                        min_spikes=min_spikes)

        # Find decoded location based on max likelihood for each valid timestep
        xcenters = (xedges[1:] + xedges[:-1]) / 2.
        ycenters = (yedges[1:] + yedges[:-1]) / 2.
        xy_centers = nept.cartesian(xcenters, ycenters)
        decoded = nept.decode_location(likelihood, xy_centers, counts.time)

        session_decoded.append(decoded)
        
        # Remove nans from likelihood and reshape for plotting
        keep_idx = np.sum(np.isnan(likelihood), axis=1) < likelihood.shape[1]
        likelihood = likelihood[keep_idx]
        likelihood = likelihood.reshape(np.shape(likelihood)[0], tc_shape[1], tc_shape[2])

        session_likelihoods.append(likelihood)
        
        n_active_neurons = np.asarray([n_active if n_active >= min_neurons else 0 
                                       for n_active in np.sum(counts.data >= 1, axis=1)])
        n_active_neurons = n_active_neurons[keep_idx]
        session_n_active.append(n_active_neurons)

        f_xy = scipy.interpolate.interp1d(sliced_position.time, sliced_position.data.T, kind="nearest")
        counts_xy = f_xy(decoded.time)
        true_position = nept.Position(np.hstack((counts_xy[0][..., np.newaxis],
                                                 counts_xy[1][..., np.newaxis])),
                                      decoded.time)

        session_actual.append(true_position)

        trial_errors = true_position.distance(decoded)

        for error, x, y in zip(trial_errors, true_position.x, true_position.y):
            x_idx = nept.find_nearest_idx(xcenters, x)
            y_idx = nept.find_nearest_idx(ycenters, y)
            error_byactual_position[y_idx][x_idx] += error
            n_byactual_position[y_idx][x_idx] += 1

        session_errors.append(trial_errors)
            
#     error_byactual = error_byactual_position / n_byactual_position

    return session_decoded, session_actual, session_likelihoods, session_errors, session_n_active, session_n_running

In [None]:
def plot_over_space(values, positions, xedges, yedges):
    xcenters = xedges[:-1] + (xedges[1:] - xedges[:-1]) / 2
    ycenters = yedges[:-1] + (yedges[1:] - yedges[:-1]) / 2

    count_position = np.zeros((len(yedges), len(xedges)))
    n_position = np.ones((len(yedges), len(xedges)))

    for trial_values, trial_positions in zip(values, positions):
        for these_values, x, y in zip(trial_values, trial_positions.x, trial_positions.y):
            x_idx = nept.find_nearest_idx(xcenters, x)
            y_idx = nept.find_nearest_idx(ycenters, y)
            if np.isscalar(these_values):
                count_position[y_idx][x_idx] += these_values
            else:
                count_position[y_idx][x_idx] += these_values[y_idx][x_idx]
            n_position[y_idx][x_idx] += 1

    return count_position / n_position

In [None]:
binsize = 8
n_sessions = 0
session_ids = []
xedges = []
yedges = []

all_decoded = []
all_actual = []
all_likelihoods = []
all_n_active = []

all_errors = []
all_errors_id_shuffled = []
all_errors_random_shuffled = []

for info in infos:
    print(info.session_id)
    session_ids.append(info.session_id)
    n_sessions += 1
    events, position, spikes, _, _ = get_data(info)

    xedge, yedge = nept.get_xyedges(position, binsize=binsize)
    xedges.append(xedge)
    yedges.append(yedge)

    xx, yy = np.meshgrid(xedge, yedge)


    decoded, actual, likelihoods, errors, n_active, _= get_decoded(info, 
                                             position, 
                                             spikes, 
                                             xedge, 
                                             yedge, 
                                             shuffled_id=False)

    _, _, _, errors_id_shuffled, _, _ = get_decoded(info, 
                                             position, 
                                             spikes, 
                                             xedge, 
                                             yedge, 
                                             shuffled_id=True)

    all_decoded.append(decoded)
    all_actual.append(actual)
    all_likelihoods.append(likelihoods)
    all_n_active.append(n_active)

    all_errors.append(errors)
    all_errors_id_shuffled.append(errors_id_shuffled)
    
    filename = os.path.join(output_filepath, info.session_id+"_errors-binsize"+str(binsize)+".png")
    plot_errors(errors, errors_id_shuffled, n_sessions=1, filename=filename)

combined_errors = np.concatenate([np.concatenate(errors, axis=0) for errors in all_errors], axis=0)

filename = os.path.join(output_filepath, "combined_errors-binsize"+str(binsize)+".png")
plot_errors(all_errors, all_errors_id_shuffled, n_sessions, filename)

In [None]:
# proportion decoded & likelihood/errors over space
binsize = 8
    
n_sessions = 0
session_ids = []
xedges = []
yedges = []

all_decoded = []
all_actual = []
all_likelihoods = []
all_n_active = []

all_errors = []
all_errors_id_shuffled = []
all_errors_random_shuffled = []

proportion_decoded = []

for info in infos:
    print(info.session_id)
    session_ids.append(info.session_id)
    n_sessions += 1
    events, position, spikes, _, _ = get_data(info)

    xedge, yedge = nept.get_xyedges(position, binsize=binsize)
    xedges.append(xedge)
    yedges.append(yedge)

    xx, yy = np.meshgrid(xedge, yedge)


    decoded, actual, likelihoods, errors, n_active, session_n_running = get_decoded(info, 
                                             position, 
                                             spikes, 
                                             xedge, 
                                             yedge, 
                                             shuffled_id=False)
    
#     likelihood_byactual = plot_over_space(likelihoods, actual, xedge, yedge)
#     pp = plt.pcolormesh(xx, yy, likelihood_byactual, vmin=0., cmap='bone_r')
#     plt.colorbar(pp)
#     title = info.session_id+" posterior"
#     plt.title(title)
#     plt.axis('off')
#     plt.savefig(os.path.join(output_filepath, info.session_id+"_posterior-byactual.png"))
#     plt.close()

#     errors_byactual = plot_over_space(errors, actual, xedge, yedge)
#     pp = plt.pcolormesh(xx, yy, errors_byactual, vmin=0., cmap='bone_r')
#     plt.colorbar(pp)
#     title = info.session_id+" decoding error (cm)"
#     plt.title(title)
#     plt.axis('off')
#     plt.savefig(os.path.join(output_filepath, info.session_id+"_errors-byactual.png"))
#     plt.close()

    n_decoded = 0
    for trial in decoded:
        n_decoded += trial.n_samples
    proportion_decoded.append(n_decoded/session_n_running)

print(proportion_decoded)
y_pos = np.arange(n_sessions)
plt.bar(y_pos, proportion_decoded, align='center', alpha=0.7)
plt.xticks(y_pos, session_ids, rotation=90, fontsize=10)
plt.ylabel('Proportion')
plt.title("Samples decoded with %d cm bins" % binsize)
plt.tight_layout()
plt.show()
# plt.savefig(os.path.join(output_filepath, "proportion-decoded_"+str(binsize)+"cm.png"))
# plt.close()

In [None]:
# individual session errors by binsize

binsizes = [2, 5, 8, 10, 12, 15, 20, 50, 100]

for info in infos:
    
    print(info.session_id)
    combined_errors = []
    mean_errors = []
    median_errors = []
    
    for binsize in binsizes:
        events, position, spikes, _, _ = get_data(info)

        xedge, yedge = nept.get_xyedges(position, binsize=binsize)

        decoded, actual, likelihoods, errors, n_active, session_n_running = get_decoded(info, 
                                                 position, 
                                                 spikes, 
                                                 xedge, 
                                                 yedge, 
                                                 shuffled_id=False)
        for error in errors:
            combined_errors.extend(error)
        mean_errors.append(np.mean(combined_errors))
        median_errors.append(np.median(combined_errors))
        
    plt.plot(binsizes, mean_errors)
    plt.xlabel("Binsize (cm)")
    plt.ylabel("Mean error")
    plt.title(info.session_id+" mean decode error (cm)")
    plt.savefig(os.path.join(output_filepath, info.session_id+"_mean-error.png"))
    plt.close()
    
    plt.plot(binsizes, median_errors)
    plt.xlabel("Binsize (cm)")
    plt.ylabel("Median error")
    plt.title(info.session_id+" median decode error (cm)")
    plt.savefig(os.path.join(output_filepath, info.session_id+"_median-error.png"))
    plt.close()

In [None]:
mean_errors # 100

In [None]:
np.amin(mean_errors)

In [None]:
plt.plot(position.x, position.y)

In [None]:
for trial_idx in range(2):
    for session_idx in range(n_sessions):
        decoded = all_decoded[session_idx][trial_idx]
        true_position = all_actual[session_idx][trial_idx]
        likelihoods = np.array(all_likelihoods[session_idx][trial_idx])
        n_active = all_n_active[session_idx][trial_idx]
        errors = all_errors[session_idx][trial_idx]
        xedge = xedges[session_idx]
        yedge = yedges[session_idx]

        fig = plt.figure(figsize=(12, 10))
        gs = gridspec.GridSpec(5, 4)

        xx, yy = np.meshgrid(xedge, yedge)

        ax1 = plt.subplot2grid((5, 4), (0, 0), colspan=3, rowspan=3)

        pad_amount = binsize*2
        ax1.set_xlim((np.floor(np.min(true_position.x))-pad_amount, np.ceil(np.max(true_position.x))+pad_amount))
        ax1.set_ylim((np.floor(np.min(true_position.y))-pad_amount, np.ceil(np.max(true_position.y))+pad_amount))

        n_timebins = decoded.n_samples
    #     n_timebins = 10

        n_colours = 20.
        colours = [(1., 1., 1.)]
        colours.extend(matplotlib.cm.copper_r(np.linspace(0, 1, n_colours-1)))
        cmap = matplotlib.colors.ListedColormap(colours)

        likelihoods_withnan = np.array(likelihoods)
        likelihoods[np.isnan(likelihoods)] = -0.01

        xcenters = xedge[:-1] + (xedge[1:] - xedge[:-1]) / 2
        ycenters = yedge[:-1] + (yedge[1:] - yedge[:-1]) / 2

        x_idx = [nept.find_nearest_idx(xcenters, true_position.x[timestep]) for timestep in range(true_position.n_samples)]
        y_idx = [nept.find_nearest_idx(ycenters, true_position.y[timestep]) for timestep in range(true_position.n_samples)]

        x_dec_idx = [nept.find_nearest_idx(xcenters, decoded.x[timestep]) for timestep in range(decoded.n_samples)]
        y_dec_idx = [nept.find_nearest_idx(ycenters, decoded.y[timestep]) for timestep in range(decoded.n_samples)]

    #     cmap = plt.cm.get_cmap('bone_r')
        posterior_position = ax1.pcolormesh(xx[:-1], yy[:-1], likelihoods[0], vmax=0.2, cmap=cmap)
        colorbar = fig.colorbar(posterior_position, ax=ax1)

        estimated_position, = ax1.plot([], [], "o", color="c")
        rat_position, = ax1.plot([], [], "<", color="b")

        ax2 = plt.subplot2grid((5, 4), (3, 0), colspan=3)

        binwidth = 5.
        error_bins = np.arange(-binwidth, np.max(errors)+binwidth, binwidth)

        _, _, errors_bin = ax2.hist([np.clip(errors, error_bins[0], error_bins[-1])], bins=error_bins, rwidth=0.9, color="k")
        errors_idx = np.digitize(errors, error_bins)

        fontsize = 14
        likelihood_at_actual = ax2.text(0.6, 1, [],
                 horizontalalignment='left',
                 verticalalignment='top',
                 transform = ax2.transAxes,
                 fontsize=fontsize)

        ax2.set_xlabel("Error (cm)", fontsize=fontsize)
        ax2.set_ylabel("# bins", fontsize=fontsize)
        ax2.spines['right'].set_visible(False)
        ax2.spines['top'].set_visible(False)
        ax2.yaxis.set_ticks_position('left')
        ax2.xaxis.set_ticks_position('bottom')
        # xticks = binwidth * np.arange(0, len(xlabels)-1, 2) - binwidth
        # xticks[0] = -binwidth/2.
        xticks = binwidth * np.arange(0, len(error_bins), 4)
        plt.xticks(xticks, fontsize=fontsize)
        plt.yticks(fontsize=fontsize)

        ax3 = plt.subplot2grid((5, 4), (4, 0), colspan=3)

        n_active_bins = np.arange(-0.5, np.max(n_active)+1)

        _, _, n_neurons_bin = ax3.hist(n_active, bins=n_active_bins, rwidth=0.9, color="k", align="mid")

        ax3.set_xlabel("Number of active neurons", fontsize=fontsize)
        ax3.set_ylabel("# bins", fontsize=fontsize)
        ax3.spines['right'].set_visible(False)
        ax3.spines['top'].set_visible(False)
        ax3.yaxis.set_ticks_position('left')
        ax3.xaxis.set_ticks_position('bottom')
    #     xticks = binwidth * np.arange(0, len(n_active_bins)+1)
    #     plt.xticks(xticks+binwidth/2)
    #     ax3.set_xticklabels(xticks, fontsize=fontsize)
        plt.yticks(fontsize=fontsize)

        fig.tight_layout()


        def init():
            posterior_position.set_array([])
            estimated_position.set_data([], [])
            rat_position.set_data([], [])
            likelihood_at_actual.set_text([])
            return (posterior_position, estimated_position, rat_position, likelihood_at_actual)


        def animate(i):
            posterior_position.set_array(likelihoods[i].ravel())
            estimated_position.set_data(decoded.x[i], decoded.y[i])
            rat_position.set_data(true_position.x[i], true_position.y[i])

            for patch in errors_bin:
                patch.set_fc('k')
            errors_bin[errors_idx[i]-1].set_fc('r')

            for patch in n_neurons_bin:
                patch.set_fc('k')
            n_neurons_bin[n_active[i]].set_fc('b')

            likelihood_at_actual.set_text("posterior at true position: %.3f \nposterior at decoded position: %.3f " % 
                                          (likelihoods_withnan[i][y_idx[i]][x_idx[i]], 
                                           likelihoods_withnan[i][y_dec_idx[i]][x_dec_idx[i]]))

            return (posterior_position, estimated_position, rat_position, likelihood_at_actual)

        anim = animation.FuncAnimation(fig, animate, frames=n_timebins, interval=200, 
                                       blit=False, repeat=False)


#         writer = animation.writers['ffmpeg'](fps=10)
#         dpi = 600
#         filename = session_ids[session_idx]+'_decoded_trial'+str(trial_idx)+'.mp4'
#         anim.save(os.path.join(output_filepath, filename), writer=writer, dpi=dpi)

#         plt.close()

In [None]:
print("Blue triangle is the true position; Cyan circle is the estimated location.")
HTML(anim.to_html5_video())

In [None]:
events, position, spikes, _, _ = get_data(info)
xedge, yedge = nept.get_xyedges(position, binsize=8)

In [None]:
phase1_position = position.time_slice(info.task_times["phase1"].start, info.task_times["phase1"].stop)
phase1_occupancy = nept.get_occupancy(phase1_position, yedge, xedge)

In [None]:
phase3_position = position.time_slice(info.task_times["phase3"].start, info.task_times["phase3"].stop)
phase3_occupancy = nept.get_occupancy(phase3_position, yedge, xedge)

In [None]:
phase3_occupancy.shape, phase1_occupancy.shape

In [None]:
temp = phase3_occupancy[phase3_occupancy>0]

In [None]:
np.min(temp)

In [None]:
plt.plot(phase1_position.x, phase1_position.y, ".")
plt.plot(zones["u"].exterior.xy[0], zones["u"].exterior.xy[1], 'b', lw=1)
# for intersect in zones["shortcut"]:
plt.plot(zones["shortcut"].exterior.xy[0], zones["shortcut"].exterior.xy[1], 'g', lw=1)
plt.plot(zones["novel"].exterior.xy[0], zones["novel"].exterior.xy[1], 'r', lw=1)
plt.show()

In [None]:
binned_maze_shape = phase1_occupancy.shape
u_pos = np.zeros(binned_maze_shape).astype(bool)
novel_pos = np.zeros(binned_maze_shape).astype(bool)

In [None]:
u_pos[phase1_occupancy > 0] = True
novel_pos[(phase3_occupancy > 4.) & (~u_pos)] = True

In [None]:
xx, yy = np.meshgrid(xedge, yedge)

plt.figure()
pp = plt.pcolormesh(xx, yy, phase1_occupancy, vmax=10., cmap="Greys")
colourbar = plt.colorbar(pp)
plt.tight_layout()

In [None]:
xx, yy = np.meshgrid(xedge, yedge)

plt.figure()
pp = plt.pcolormesh(xx, yy, u_pos, cmap="Greys")
colourbar = plt.colorbar(pp)
plt.tight_layout()

In [None]:
xx, yy = np.meshgrid(xedge, yedge)

plt.figure()
pp = plt.pcolormesh(xx, yy, phase3_occupancy, vmax=10., cmap="Greys")
colourbar = plt.colorbar(pp)
plt.tight_layout()

In [None]:
phase = "phase3"
t_start = info.task_times[phase].start
t_stop = info.task_times[phase].stop

sliced_pos = position.time_slice(t_start, t_stop)

feeder1_times = []
for feeder1 in events['feeder1']:
    if t_start < feeder1 < t_stop:
        feeder1_times.append(feeder1)

feeder2_times = []
for feeder2 in events['feeder2']:
    if t_start < feeder2 < t_stop:
        feeder2_times.append(feeder2)

path_pos = get_zones(info, sliced_pos)

trials_idx, trial_epochs = get_trial_idx(path_pos['u'].time, 
                                         path_pos['shortcut'].time, 
                                         path_pos['novel'].time,
                                         feeder1_times, 
                                         feeder2_times, 
                                         t_stop)

shortcut_epochs = [trial_epochs[idx] for idx in trials_idx["shortcut"]]
u_epochs = [trial_epochs[idx] for idx in trials_idx["u"]]
novel_epochs = [trial_epochs[idx] for idx in trials_idx["novel"]]

In [None]:
def point_in_zones(position, zones):
    """Assigns points to each trajectory

    Parameters
    ----------
    position : nept.Position
    zones : dict
        With u, shortcut, novel, pedestal as keys

    Returns
    -------
    sorted_zones : dict
        With u, shortcut, novel, other as keys, each a nept.Position object

    """
    u_data = []
    u_times = []
    shortcut_data = []
    shortcut_times = []
    novel_data = []
    novel_times = []
    other_data = []
    other_times = []

    if not position.isempty:
        for x, y, time in zip(position.x, position.y, position.time):
            point = Point([x, y])
            if zones['u'].contains(point):
                u_data.append([x, y])
                u_times.append(time)
                continue
            elif zones['shortcut'].contains(point):
                shortcut_data.append([x, y])
                shortcut_times.append(time)
                continue
            elif zones['novel'].contains(point):
                novel_data.append([x, y])
                novel_times.append(time)
                continue
            else:
                other_data.append([x, y])
                other_times.append(time)

    sorted_zones = dict()
    sorted_zones['u'] = nept.Position(u_data, u_times)
    sorted_zones['shortcut'] = nept.Position(shortcut_data, shortcut_times)
    sorted_zones['novel'] = nept.Position(novel_data, novel_times)
    sorted_zones['other'] = nept.Position(other_data, other_times)

    return sorted_zones

In [None]:
trial_idx = 1
timestep = 0

xcenters = xedges[:-1] + (xedges[1:] - xedges[:-1]) / 2
ycenters = yedges[:-1] + (yedges[1:] - yedges[:-1]) / 2

x_idx = [nept.find_nearest_idx(xcenters, actual[trial_idx].x[timestep]) for timestep in range(actual[trial_idx].n_samples)]
y_idx = [nept.find_nearest_idx(ycenters, actual[trial_idx].y[timestep]) for timestep in range(actual[trial_idx].n_samples)]

print(likelihoods[timestep][y_idx[timestep]][x_idx[timestep]])

In [None]:
cmap = plt.cm.get_cmap('bone_r')
pp = plt.pcolormesh(xx[:-1], yy[:-1], likelihoods[timestep], cmap=cmap)
colorbar = fig.colorbar(pp, ax=ax1)
plt.plot(actual[trial_idx].x[timestep], actual[trial_idx].y[timestep], "r.", ms=10)
colorbar = plt.colorbar(pp)
plt.show()

In [None]:
# combined_errors = []
# binsizes = [2, 4, 6, 8, 10, 12, 14, 16, 18, 20]
# for binsize in binsizes:
#     n_sessions = 0
#     xxs = []
#     yys = []

#     all_decoded = []
#     all_actual = []
#     all_likelihoods = []
#     all_n_active = []

#     all_errors = []


#     for info in infos:
#         print(info.session_id)
#         n_sessions += 1
#         events, position, spikes, _, _ = get_data(info)
#         xedges, yedges = nept.get_xyedges(position, binsize=binsize)

#         xx, yy = np.meshgrid(xedges, yedges)
#         xxs.append(xx)
#         yys.append(yy)

#         phase = "phase3"
#         trial_epochs = get_trials(events, info.task_times[phase])

#         decoded, actual, likelihoods, errors, n_active = get_decoded(info, 
#                                                  position, 
#                                                  spikes, 
#                                                  xedges, 
#                                                  yedges, 
#                                                  shuffled_id=False, 
#                                                  random_shuffle=False)

#         all_decoded.append(decoded)
#         all_actual.append(actual)
#         all_likelihoods.append(likelihoods)
#         all_n_active.append(n_active)

#         all_errors.append(errors)

#     combined_errors.append(np.concatenate([np.concatenate(errors, axis=0) for errors in all_errors], axis=0))
    
# plt.plot(binsizes, np.mean(combined_errors, axis=1))
# plt.xticks(binsizes)
# plt.xlabel("Binsize (cm)")
# plt.show()