In [None]:
%matplotlib inline
import matplotlib.pyplot as plt
from matplotlib import gridspec
import matplotlib
import numpy as np
import itertools
import scipy
import pandas as pd
import seaborn as sns
from scipy import stats
import os
import nept

from matplotlib import animation, rc
from IPython.display import HTML

from loading_data import get_data
from analyze_tuning_curves import get_only_tuning_curves
from analyze_decode_bytrial import decode_trial
from analyze_decode import get_decoded_zones
from utils_maze import find_zones, get_trials

In [None]:
thisdir = os.getcwd()
pickle_filepath = os.path.join(thisdir, "cache", "pickled")
output_filepath = os.path.join(thisdir, "plots", "decode-video")

In [None]:
from run import spike_sorted_infos
import info.r063d5 as r063d5
import info.r063d6 as r063d6
infos = [r063d5, r063d6]
# infos = spike_sorted_infos

In [None]:
def get_decoded(info, position, spikes, xedges, yedges, shuffled_id, random_shuffle):
    
    min_n_spikes = 100
    max_n_spikes = 5000
    
    phase = info.task_times["phase3"]
    trials = get_trials(events, phase)
    
    error_byactual_position = np.zeros((len(yedges), len(xedges)))
    n_byactual_position = np.ones((len(yedges), len(xedges)))
    
    session_n_active = []
    session_likelihoods = []
    session_decoded = []
    session_actual = []
    session_errors = []
    
    for trial in trials:
        epoch_of_interest = phase.excludes(trial)

        tuning_curves = get_only_tuning_curves(position, 
                                               spikes, 
                                               xedges, 
                                               yedges, 
                                               epoch_of_interest,
                                               min_n_spikes,
                                               max_n_spikes)

        if shuffled_id:
            tuning_curves = np.random.permutation(tuning_curves)

        sliced_position = position.time_slice(trial.start, trial.stop)
        
        sliced_spikes = [spiketrain.time_slice(trial.start, 
                                               trial.stop) for spiketrain in spikes]

        # limit position to only times when the subject is moving faster than a certain threshold
        run_epoch = nept.run_threshold(sliced_position, thresh=15., t_smooth=1.)
        sliced_position = sliced_position[run_epoch]
        sliced_spikes = [spiketrain.time_slice(run_epoch.start, 
                                               run_epoch.stop) for spiketrain in spikes]

        epochs_interest = nept.Epoch(np.array([sliced_position.time[0], sliced_position.time[-1]]))

        counts = nept.bin_spikes(sliced_spikes, sliced_position.time, dt=0.025, window=0.025,
                                 gaussian_std=0.0075, normalized=False)

        session_n_active.append(np.sum(counts.data >= 1, axis=1))

        tc_shape = tuning_curves.shape
        decoding_tc = tuning_curves.reshape(tc_shape[0], tc_shape[1] * tc_shape[2])

        likelihood = nept.bayesian_prob(counts, decoding_tc, binsize=0.025, min_neurons=2, min_spikes=1)

        # Find decoded location based on max likelihood for each valid timestep
        xcenters = (xedges[1:] + xedges[:-1]) / 2.
        ycenters = (yedges[1:] + yedges[:-1]) / 2.
        xy_centers = nept.cartesian(xcenters, ycenters)
        decoded = nept.decode_location(likelihood, xy_centers, counts.time)

        if random_shuffle:
            random_x = [np.random.choice(decoded.x) for val in decoded.x]
            random_y = [np.random.choice(decoded.y) for val in decoded.y]

            decoded = nept.Position(np.array([random_x, random_y]).T, decoded.time)

        session_decoded.append(decoded)
        
        # Remove nans from likelihood and reshape for plotting
        keep_idx = np.sum(np.isnan(likelihood), axis=1) < likelihood.shape[1]
        likelihood = likelihood[keep_idx]
        likelihood = likelihood.reshape(np.shape(likelihood)[0], tc_shape[1], tc_shape[2])

        session_likelihoods.append(likelihood)

        f_xy = scipy.interpolate.interp1d(sliced_position.time, sliced_position.data.T, kind="nearest")
        counts_xy = f_xy(decoded.time)
        true_position = nept.Position(np.hstack((counts_xy[0][..., np.newaxis],
                                                 counts_xy[1][..., np.newaxis])),
                                      decoded.time)

        session_actual.append(true_position)

        trial_errors = true_position.distance(decoded)

        for error, x, y in zip(trial_errors, true_position.x, true_position.y):
            x_idx = nept.find_nearest_idx(xedges, x)
            y_idx = nept.find_nearest_idx(yedges, y)
            error_byactual_position[y_idx][x_idx] += error
            n_byactual_position[y_idx][x_idx] += 1

        session_errors.append(trial_errors)
            
#     error_byactual = error_byactual_position / n_byactual_position

    return session_decoded, session_actual, session_likelihoods, session_errors, session_n_active

In [None]:
def plot_errors(all_errors, all_errors_id_shuffled, all_errors_random_shuffled, n_sessions):
    
    all_errors = np.concatenate([np.concatenate(errors, axis=0) for errors in all_errors], axis=0)
    all_errors_id_shuffled = np.concatenate([np.concatenate(errors, axis=0) for errors in all_errors_id_shuffled], axis=0)
    all_errors_random_shuffled = np.concatenate([np.concatenate(errors, axis=0) for errors in all_errors_random_shuffled], axis=0)
    
    print('Actual:', np.median(all_errors))
    print('ID shuffle:', np.median(all_errors_id_shuffled))
    print('Random shuffle:', np.median(all_errors_random_shuffled)) 

    print('Actual:', np.mean(all_errors), stats.sem(all_errors))
    print('ID shuffle:', np.mean(all_errors_id_shuffled), stats.sem(all_errors_id_shuffled))
    print('Random shuffle:', np.mean(all_errors_random_shuffled), stats.sem(all_errors_random_shuffled)) 
    
    fliersize = 1

    decoded_dict = dict(error=all_errors, label='Decoded')
    shuffled_id_dict = dict(error=all_errors_id_shuffled, label='ID shuffled')
    shuffled_random_dict = dict(error=all_errors_random_shuffled, label='Random shuffled')
    decoded_errors = pd.DataFrame(decoded_dict)
    shuffled_id = pd.DataFrame(shuffled_id_dict)
    shuffled_random = pd.DataFrame(shuffled_random_dict)
    data = pd.concat([shuffled_id, shuffled_random, decoded_errors])
    colours = ['#ffffff', '#ffffff', '#bdbdbd']

    plt.figure(figsize=(6, 4))
    flierprops = dict(marker='o', markersize=fliersize, linestyle='none')
    # ax = sns.boxplot(x='label', y='error', data=data, palette=colours, flierprops=flierprops)
    ax = sns.boxplot(x='label', y='error', data=data, flierprops=flierprops)


    edge_colour = '#252525'
    for i, artist in enumerate(ax.artists):
        artist.set_edgecolor(edge_colour)
        artist.set_facecolor(colours[i])

        for j in range(i*6, i*6+6):
            line = ax.lines[j]
            line.set_color(edge_colour)
            line.set_mfc(edge_colour)
            line.set_mec(edge_colour)

    ax.text(1., 1., 'n_sessions = ' + str(n_sessions),
            verticalalignment='bottom',
            horizontalalignment='right',
            transform=ax.transAxes,
            color='k', fontsize=10)

    ax.set(xlabel=' ', ylabel="Error (cm)")
    plt.xticks(fontsize=14)

    sns.despine()
    plt.tight_layout()

    plt.show()

In [None]:
n_sessions = 0
xxs = []
yys = []

all_decoded = []
all_actual = []
all_likelihoods = []
all_n_active = []

all_errors = []
all_errors_id_shuffled = []
all_errors_random_shuffled = []


for info in infos:
    print(info.session_id)
    n_sessions += 1
    events, position, spikes, _, _ = get_data(info)
    xedges, yedges = nept.get_xyedges(position, binsize=8)
    
    xx, yy = np.meshgrid(xedges, yedges)
    xxs.append(xx)
    yys.append(yy)

    phase = "phase3"
    trial_epochs = get_trials(events, info.task_times[phase])
    
    decoded, actual, likelihoods, errors, n_active = get_decoded(info, 
                                             position, 
                                             spikes, 
                                             xedges, 
                                             yedges, 
                                             shuffled_id=False, 
                                             random_shuffle=False)
    
    _, _, _, errors_id_shuffled, _ = get_decoded(info, 
                                             position, 
                                             spikes, 
                                             xedges, 
                                             yedges, 
                                             shuffled_id=True, 
                                             random_shuffle=False)
    
    _, _, _, errors_random_shuffled, _ = get_decoded(info, 
                                             position, 
                                             spikes, 
                                             xedges, 
                                             yedges, 
                                             shuffled_id=False, 
                                             random_shuffle=True)
    
    all_decoded.append(decoded)
    all_actual.append(actual)
    all_likelihoods.append(likelihoods)
    all_n_active.append(n_active)
    
    all_errors.append(errors)
    all_errors_id_shuffled.append(errors_id_shuffled)
    all_errors_random_shuffled.append(errors_random_shuffled)
    
plot_errors(all_errors, all_errors_id_shuffled, all_errors_random_shuffled, n_sessions)

In [None]:
session_idx = 0
trial_idx = 1
decoded = all_decoded[session_idx][trial_idx]
true_position = all_actual[session_idx][trial_idx]
likelihoods = all_likelihoods[session_idx][trial_idx]
n_active = all_n_active[session_idx][trial_idx]
errors = all_errors[session_idx][trial_idx]
xx = xxs[session_idx]
yy = yys[session_idx]

fig = plt.figure(figsize=(12, 10))
gs = gridspec.GridSpec(5, 4) 

ax1 = plt.subplot2grid((5, 4), (0, 0), colspan=3, rowspan=3)

pad_amount = 5
ax1.set_xlim((np.floor(np.min(true_position.x))-pad_amount, np.ceil(np.max(true_position.x))+pad_amount))
ax1.set_ylim((np.floor(np.min(true_position.y))-pad_amount, np.ceil(np.max(true_position.y))+pad_amount))

# n_timebins = decoded.n_samples
n_timebins = 10

cmap = plt.cm.get_cmap('bone_r')
posterior_position = ax1.pcolormesh(xx[:-1], yy[:-1], likelihoods[0], vmin=0, vmax=0.1, cmap=cmap)
colorbar = fig.colorbar(posterior_position, ax=ax1)

estimated_position, = ax1.plot([], [], "<", color="r")
rat_position, = ax1.plot([], [], "<", color="b")

ax2 = plt.subplot2grid((5, 4), (3, 0), colspan=3)

binwidth = 5.
error_bins = np.arange(-binwidth, np.max(errors)+binwidth, binwidth)

_, _, errors_bin = ax2.hist([np.clip(errors, error_bins[0], error_bins[-1])], bins=error_bins, rwidth=0.9, color="k")
errors_idx = np.digitize(errors, error_bins)


# xlabels = [str(int(b)) for b in error_bins]
# xlabels[0] = "nan"
# ax2.set_xticklabels(xlabels)
fontsize = 14
ax2.set_xlabel("Error (cm)", fontsize=fontsize)
ax2.set_ylabel("# bins", fontsize=fontsize)
ax2.spines['right'].set_visible(False)
ax2.spines['top'].set_visible(False)
ax2.yaxis.set_ticks_position('left')
ax2.xaxis.set_ticks_position('bottom')
# xticks = binwidth * np.arange(0, len(xlabels)-1, 2) - binwidth
# xticks[0] = -binwidth/2.
xticks = binwidth * np.arange(0, len(error_bins), 4)
plt.xticks(xticks, fontsize=fontsize)
plt.yticks(fontsize=fontsize)

ax3 = plt.subplot2grid((5, 4), (4, 0), colspan=3)

binwidth = 1
n_active_bins = np.arange(-binwidth, np.max(n_active)+binwidth, binwidth)

_, _, n_neurons_bin = ax3.hist(n_active, bins=n_active_bins, rwidth=0.9, color="k")
#     n_active_idx = np.digitize(n_active, n_active_bins)

ax3.set_xlabel("Number of active neurons", fontsize=fontsize)
ax3.set_ylabel("# bins", fontsize=fontsize)
ax3.spines['right'].set_visible(False)
ax3.spines['top'].set_visible(False)
ax3.yaxis.set_ticks_position('left')
ax3.xaxis.set_ticks_position('bottom')
xticks = binwidth * np.arange(len(n_active_bins))
plt.xticks(xticks-binwidth/2)
ax3.set_xticklabels(xticks, fontsize=fontsize)
plt.yticks(fontsize=fontsize)

fig.tight_layout()


def init():
    posterior_position.set_array([])
    estimated_position.set_data([], [])
    rat_position.set_data([], [])
    return (posterior_position, estimated_position, rat_position)


def animate(i):
    posterior_position.set_array(likelihoods[i].ravel())
    estimated_position.set_data(decoded.x[i], decoded.y[i])
    rat_position.set_data(true_position.x[i], true_position.y[i])

    for patch in errors_bin:
        patch.set_fc('k')
    errors_bin[errors_idx[i]-1].set_fc('r')

    for patch in n_neurons_bin:
        patch.set_fc('k')
    n_neurons_bin[n_active[i]].set_fc('r')

    return (posterior_position, estimated_position, rat_position)

anim = animation.FuncAnimation(fig, animate, frames=n_timebins, interval=100, 
                               blit=False, repeat=False)


#     writer = animation.writers['ffmpeg'](fps=18)
#     dpi = 600
#     filename = '/errors_'+info.session_id+'_trial'+str(trial_idx)+'.mp4'
#     anim.save(output_filepath+filename, writer=writer, dpi=dpi)

#     plt.close()

In [None]:
print("Blue is true position; Red is estimated location.")
HTML(anim.to_html5_video())

In [None]:
# tuning_curves = tc

xx, yy = np.meshgrid(xedges, yedges)

n_colours = 9.
colours = [(1., 1., 1.)]
colours.extend(matplotlib.cm.copper_r(np.linspace(0, 1, n_colours-1)))
cmap = matplotlib.colors.ListedColormap(colours)

# Plot individual tuning curves
# for i, tuning_curve in enumerate(tuning_curves):    
for i, tuning_curve in enumerate(tuning_curves[:3]): 
    tuning_curve = np.array(tuning_curve)
    tuning_curve[np.isnan(tuning_curve)] = -np.nanmax(tuning_curve) / n_colours

    plt.figure()
    pp = plt.pcolormesh(xx, yy, tuning_curve, cmap=cmap)

    colourbar = plt.colorbar(pp)
    plt.tight_layout()
    plt.show()

In [None]:
# writer = animation.writers['ffmpeg'](fps=18)
# dpi = 600
# filename = '/errors_'+info.session_id+'_trial'+str(trial_idx)+'.mp4'
# anim.save(output_filepath+filename, writer=writer, dpi=dpi)