In [None]:
%matplotlib inline
import matplotlib.pyplot as plt
from matplotlib import gridspec
import matplotlib
import numpy as np
import itertools
import scipy
import pandas as pd
import seaborn as sns
from scipy import stats
import os
import nept

from matplotlib import animation, rc
from IPython.display import HTML

from loading_data import get_data
from analyze_tuning_curves import get_only_tuning_curves
from analyze_decode_bytrial import decode_trial
from analyze_decode import get_decoded_zones
from utils_maze import find_zones, get_trials

In [None]:
thisdir = os.getcwd()
pickle_filepath = os.path.join(thisdir, "cache", "pickled")
output_filepath = os.path.join(thisdir, "plots", "decode-video")

In [None]:
from run import spike_sorted_infos
import info.r063d5 as r063d5
import info.r063d6 as r063d6
# infos = [r063d5, r063d6]
infos = spike_sorted_infos

In [None]:
def get_decoded(info, position, spikes, xedges, yedges, shuffled_id, random_shuffle):
    sliced_spikes, tuning_curves = get_only_tuning_curves(info, 
                                                          position, 
                                                          spikes, 
                                                          xedges, 
                                                          yedges, 
                                                          phase="phase3")

    if shuffled_id:
        tuning_curves = np.random.permutation(tuning_curves)

    decoding_times = info.task_times["phase3"]
    shuffle_id = False

    position = position.time_slice(decoding_times.start, decoding_times.stop)

    # limit position to only times when the subject is moving faster than a certain threshold
    run_epoch = nept.run_threshold(position, thresh=15., t_smooth=1.)
    position = position[run_epoch]

    epochs_interest = nept.Epoch(np.array([position.time[0], position.time[-1]]))

    counts = nept.bin_spikes(sliced_spikes, position.time, dt=0.025, window=0.025,
                             gaussian_std=0.0075, normalized=False)

    n_active_neurons = np.sum(counts.data >= 1, axis=1)

    tc_shape = tuning_curves.shape
    decoding_tc = tuning_curves.reshape(tc_shape[0], tc_shape[1] * tc_shape[2])

    likelihood = nept.bayesian_prob(counts, decoding_tc, window, min_neurons=2, min_spikes=1)

    xcenters = (xedges[1:] + xedges[:-1]) / 2.
    ycenters = (yedges[1:] + yedges[:-1]) / 2.
    xy_centers = nept.cartesian(xcenters, ycenters)
    decoded = nept.decode_location(likelihood, xy_centers, counts.time)

    if random_shuffle:
        random_x = [np.random.choice(decoded.x) for val in decoded.x]
        random_y = [np.random.choice(decoded.y) for val in decoded.y]

        decoded = nept.Position(np.array([random_x, random_y]).T, decoded.time)

    likelihood = likelihood.reshape(np.shape(likelihood)[0], tc_shape[1], tc_shape[2])

    f_xy = scipy.interpolate.interp1d(position.time, position.data.T, kind="nearest")
    counts_xy = f_xy(decoded.time)
    true_position = nept.Position(np.hstack((counts_xy[0][..., np.newaxis],
                                             counts_xy[1][..., np.newaxis])),
                                  decoded.time)

    errors = true_position.distance(decoded)

    return decoded, true_position, errors, likelihood

In [None]:
all_errors = []
all_errors_id_shuffled = []
all_errors_random_shuffled = []
n_sessions = 0

for info in infos:
    print(info.session_id)
    n_sessions += 1
    events, position, spikes, _, _ = get_data(info)
    xedges, yedges = nept.get_xyedges(position, binsize=8)

    phase = "phase3"
    trial_epochs = get_trials(events, info.task_times[phase])
    
    decoded, true_position, errors = get_decoded(info, 
                                             position, 
                                             spikes, 
                                             xedges, 
                                             yedges, 
                                             shuffled_id=False, 
                                             random_shuffle=False)
    
    decoded_id_shuffled, true_position_id_shuffled, errors_id_shuffled = get_decoded(info, 
                                             position, 
                                             spikes, 
                                             xedges, 
                                             yedges, 
                                             shuffled_id=True, 
                                             random_shuffle=False)
    
    decoded_random_shuffled, true_position_random_shuffled, errors_random_shuffled = get_decoded(info, 
                                             position, 
                                             spikes, 
                                             xedges, 
                                             yedges, 
                                             shuffled_id=False, 
                                             random_shuffle=True)
    
    all_errors.extend(errors)
    all_errors_id_shuffled.extend(errors_id_shuffled)
    all_errors_random_shuffled.extend(errors_random_shuffled)   
    
fliersize = 1

decoded_dict = dict(error=all_errors, label='Decoded')
shuffled_id_dict = dict(error=all_errors_id_shuffled, label='ID shuffled')
shuffled_random_dict = dict(error=all_errors_random_shuffled, label='Random shuffled')
decoded_errors = pd.DataFrame(decoded_dict)
shuffled_id = pd.DataFrame(shuffled_id_dict)
shuffled_random = pd.DataFrame(shuffled_random_dict)
data = pd.concat([shuffled_id, shuffled_random, decoded_errors])
colours = ['#ffffff', '#ffffff', '#bdbdbd']

plt.figure(figsize=(6, 4))
flierprops = dict(marker='o', markersize=fliersize, linestyle='none')
# ax = sns.boxplot(x='label', y='error', data=data, palette=colours, flierprops=flierprops)
ax = sns.boxplot(x='label', y='error', data=data, flierprops=flierprops)


edge_colour = '#252525'
for i, artist in enumerate(ax.artists):
    artist.set_edgecolor(edge_colour)
    artist.set_facecolor(colours[i])

    for j in range(i*6, i*6+6):
        line = ax.lines[j]
        line.set_color(edge_colour)
        line.set_mfc(edge_colour)
        line.set_mec(edge_colour)

ax.text(1., 1., 'n_sessions = ' + str(n_sessions),
        verticalalignment='bottom',
        horizontalalignment='right',
        transform=ax.transAxes,
        color='k', fontsize=10)

ax.set(xlabel=' ', ylabel="Error (cm)")
plt.xticks(fontsize=14)

sns.despine()
plt.tight_layout()

plt.show()

In [None]:
print('Actual:', np.median(all_errors))
print('ID shuffle:', np.median(all_errors_id_shuffled))
print('Random shuffle:', np.median(all_errors_random_shuffled)) 

In [None]:
print('Actual:', np.mean(all_errors), stats.sem(all_errors))
print('ID shuffle:', np.mean(all_errors_id_shuffled), stats.sem(all_errors_id_shuffled))
print('Random shuffle:', np.mean(all_errors_random_shuffled), stats.sem(all_errors_random_shuffled)) 

In [None]:
events, position, spikes, _, _ = get_data(info)
xedges, yedges = nept.get_xyedges(position, binsize=8)

phase = "phase3"
trial_epochs = get_trials(events, info.task_times[phase])

In [None]:
decoded, true_position, errors = get_decoded(info, 
                                             position, 
                                             spikes, 
                                             xedges, 
                                             yedges, 
                                             shuffled_id=False, 
                                             random_shuffle=False)

In [None]:
decoded_id_shuffled, true_position_id_shuffled, errors_id_shuffled = get_decoded(info, 
                                             position, 
                                             spikes, 
                                             xedges, 
                                             yedges, 
                                             shuffled_id=True, 
                                             random_shuffle=False)

In [None]:
decoded_random_shuffled, true_position_random_shuffled, errors_random_shuffled = get_decoded(info, 
                                             position, 
                                             spikes, 
                                             xedges, 
                                             yedges, 
                                             shuffled_id=False, 
                                             random_shuffle=True)

In [None]:
sliced_spikes, tuning_curves = get_only_tuning_curves(info, 
                                                      position, 
                                                      spikes, 
                                                      xedges, 
                                                      yedges, 
                                                      phase="phase3")

decoding_times = info.task_times["phase3"]
shuffle_id = False
dt = 0.025
window = 0.025
gaussian_std = 0.0075
normalized = False
min_neurons = 2
min_spikes = 1

position = position.time_slice(decoding_times.start, decoding_times.stop)

# limit position to only times when the subject is moving faster than a certain threshold
run_epoch = nept.run_threshold(position, thresh=15., t_smooth=1.)
position = position[run_epoch]

epochs_interest = nept.Epoch(np.array([position.time[0], position.time[-1]]))

counts = nept.bin_spikes(sliced_spikes, position.time, dt=dt, window=window,
                         gaussian_std=gaussian_std, normalized=normalized)

n_active_neurons = np.sum(counts.data >= 1, axis=1)

tc_shape = tuning_curves.shape
decoding_tc = tuning_curves.reshape(tc_shape[0], tc_shape[1] * tc_shape[2])

likelihood = nept.bayesian_prob(counts, decoding_tc, window, min_neurons=min_neurons, min_spikes=min_spikes)

xcenters = (xedges[1:] + xedges[:-1]) / 2.
ycenters = (yedges[1:] + yedges[:-1]) / 2.
xy_centers = nept.cartesian(xcenters, ycenters)
decoded = nept.decode_location(likelihood, xy_centers, counts.time)

likelihood = likelihood.reshape(np.shape(likelihood)[0], tc_shape[1], tc_shape[2])

f_xy = scipy.interpolate.interp1d(position.time, position.data.T, kind="nearest")
counts_xy = f_xy(decoded.time)
true_position = nept.Position(np.hstack((counts_xy[0][..., np.newaxis],
                                         counts_xy[1][..., np.newaxis])),
                              decoded.time)

errors = true_position.distance(decoded)

fig = plt.figure(figsize=(12, 10))
gs = gridspec.GridSpec(5, 4) 

ax1 = plt.subplot2grid((5, 4), (0, 0), colspan=3, rowspan=3)

xx, yy = np.meshgrid(xedges, yedges)

pad_amount = 5
ax1.set_xlim((np.floor(np.min(true_position.x))-pad_amount, np.ceil(np.max(true_position.x))+pad_amount))
ax1.set_ylim((np.floor(np.min(true_position.y))-pad_amount, np.ceil(np.max(true_position.y))+pad_amount))

n_timebins = decoded.n_samples

cmap = plt.cm.get_cmap('bone_r')
posterior_position = ax1.pcolormesh(xx[:-1], yy[:-1], likelihood[0], vmin=0, vmax=0.1, cmap=cmap)
colorbar = fig.colorbar(posterior_position, ax=ax1)

estimated_position, = ax1.plot([], [], "<", color="r")
rat_position, = ax1.plot([], [], "<", color="b")

ax2 = plt.subplot2grid((5, 4), (3, 0), colspan=3)

binwidth = 5.
error_bins = np.arange(-binwidth, np.max(errors)+binwidth, binwidth)

_, _, errors_bin = ax2.hist([np.clip(errors, error_bins[0], error_bins[-1])], bins=error_bins, rwidth=0.9, color="k")
errors_idx = np.digitize(errors, error_bins)


xlabels = [str(int(b)) for b in error_bins]
xlabels[0] = "nan"
ax2.set_xticklabels(xlabels)
fontsize = 14
ax2.set_xlabel("Error (cm)", fontsize=fontsize)
ax2.set_ylabel("# bins", fontsize=fontsize)
ax2.spines['right'].set_visible(False)
ax2.spines['top'].set_visible(False)
ax2.yaxis.set_ticks_position('left')
ax2.xaxis.set_ticks_position('bottom')
xticks = binwidth * np.arange(0, len(xlabels)-1, 2) - binwidth
xticks[0] = -binwidth/2.
plt.xticks(xticks, fontsize=fontsize)
plt.yticks(fontsize=fontsize)

ax3 = plt.subplot2grid((5, 4), (4, 0), colspan=3)

binwidth = 1
n_active_bins = np.arange(-binwidth, np.max(n_active_neurons)+binwidth, binwidth)

_, _, n_neurons_bin = ax3.hist(n_active_neurons, bins=n_active_bins, rwidth=0.9, color="k")
#     n_active_idx = np.digitize(n_active_neurons, n_active_bins)

ax3.set_xlabel("Number of active neurons", fontsize=fontsize)
ax3.set_ylabel("# bins", fontsize=fontsize)
ax3.spines['right'].set_visible(False)
ax3.spines['top'].set_visible(False)
ax3.yaxis.set_ticks_position('left')
ax3.xaxis.set_ticks_position('bottom')
xticks = binwidth * np.arange(len(n_active_bins))
plt.xticks(xticks-binwidth/2)
ax3.set_xticklabels(xticks, fontsize=fontsize)
plt.yticks(fontsize=fontsize)

fig.tight_layout()


def init():
    posterior_position.set_array([])
    estimated_position.set_data([], [])
    rat_position.set_data([], [])
    return (posterior_position, estimated_position, rat_position)


def animate(i):
    posterior_position.set_array(likelihood[i].ravel())
    estimated_position.set_data(decoded.x[i], decoded.y[i])
    rat_position.set_data(true_position.x[i], true_position.y[i])

    for patch in errors_bin:
        patch.set_fc('k')
    errors_bin[errors_idx[i]-1].set_fc('r')

    for patch in n_neurons_bin:
        patch.set_fc('k')
    n_neurons_bin[n_active_neurons[i]].set_fc('r')

    return (posterior_position, estimated_position, rat_position, errors_bin, n_neurons_bin)

anim = animation.FuncAnimation(fig, animate, frames=n_timebins, interval=80, 
                               blit=False, repeat=False)


#     writer = animation.writers['ffmpeg'](fps=18)
#     dpi = 600
#     filename = '/errors_'+info.session_id+'_trial'+str(trial_idx)+'.mp4'
#     anim.save(output_filepath+filename, writer=writer, dpi=dpi)

#     plt.close()

In [None]:
print("Blue is true position; Red is estimated location.")
HTML(anim.to_html5_video())

In [None]:
true_position.n_samples, decoded.n_samples

In [None]:
sliced_spikes, tuning_curves = get_only_tuning_curves(info, 
                                                          position, 
                                                          spikes, 
                                                          xedges, 
                                                          yedges, 
                                                          phase="phase3")

In [None]:
# tuning_curves = tc

xx, yy = np.meshgrid(xedges, yedges)

n_colours = 9.
colours = [(1., 1., 1.)]
colours.extend(matplotlib.cm.copper_r(np.linspace(0, 1, n_colours-1)))
cmap = matplotlib.colors.ListedColormap(colours)

# Plot individual tuning curves
# for i, tuning_curve in enumerate(tuning_curves):    
for i, tuning_curve in enumerate(tuning_curves[:3]): 
    tuning_curve = np.array(tuning_curve)
    tuning_curve[np.isnan(tuning_curve)] = -np.nanmax(tuning_curve) / n_colours

    plt.figure()
    pp = plt.pcolormesh(xx, yy, tuning_curve, cmap=cmap)

    colourbar = plt.colorbar(pp)
    plt.tight_layout()
    plt.show()

In [None]:
# writer = animation.writers['ffmpeg'](fps=18)
# dpi = 600
# filename = '/errors_'+info.session_id+'_trial'+str(trial_idx)+'.mp4'
# anim.save(output_filepath+filename, writer=writer, dpi=dpi)

In [None]:
def get_only_tuning_curves(info, position, spikes, xedges, yedges, phase):
    sliced_position = position.time_slice(info.task_times[phase].start, info.task_times[phase].stop)
    sliced_spikes = [spiketrain.time_slice(info.task_times[phase].start, info.task_times[phase].stop) for spiketrain in spikes]

    # Limit position and spikes to only running times
    run_epoch = nept.run_threshold(sliced_position, thresh=15., t_smooth=1.)
    run_position = sliced_position[run_epoch]
    track_spikes = np.asarray([spiketrain.time_slice(run_epoch.starts, run_epoch.stops) for spiketrain in sliced_spikes])

    # Remove neurons with too few or too many spikes
    len_epochs = np.sum(run_epoch.durations)
    min_n_spikes = 0.4 * len_epochs
    max_n_spikes = 5 * len_epochs

    keep_idx = np.zeros(len(track_spikes), dtype=bool)
    for i, spiketrain in enumerate(track_spikes):
        if len(spiketrain.time) >= min_n_spikes and len(spiketrain.time) <= max_n_spikes:
            keep_idx[i] = True
    tuning_spikes = track_spikes[keep_idx]

    tuning_curves = nept.tuning_curve_2d(run_position, tuning_spikes, xedges, yedges, occupied_thresh=0.5, gaussian_std=0.3)

    return tuning_spikes, tuning_curves

In [None]:
phase = "phase3"
sliced_position = position.time_slice(info.task_times[phase].start, info.task_times[phase].stop)
sliced_spikes = [spiketrain.time_slice(info.task_times[phase].start, info.task_times[phase].stop) for spiketrain in spikes]

# Limit position and spikes to only running times
run_epoch = nept.run_threshold(sliced_position, thresh=15., t_smooth=1.)
run_position = sliced_position[run_epoch]
track_spikes = np.asarray([spiketrain.time_slice(run_epoch.starts, run_epoch.stops) for spiketrain in sliced_spikes])

# Remove neurons with too few or too many spikes
len_epochs = np.sum(run_epoch.durations)
min_n_spikes = 0.4 * len_epochs
max_n_spikes = 5 * len_epochs

keep_idx = np.zeros(len(track_spikes), dtype=bool)
for i, spiketrain in enumerate(track_spikes):
    if len(spiketrain.time) >= min_n_spikes and len(spiketrain.time) <= max_n_spikes:
        keep_idx[i] = True
tuning_spikes = track_spikes[keep_idx]

tuning_curves = nept.tuning_curve_2d(run_position, tuning_spikes, xedges, yedges, occupied_thresh=0.5, gaussian_std=0.3)


In [None]:
_, position, spikes, _, _ = get_data(info)
xedges, yedges = nept.get_xyedges(position, binsize=8)

In [None]:
sliced_spikes, tuning_curves = get_only_tuning_curves(info, position, spikes, xedges, yedges, phase="phase3")

In [None]:
xx, yy = np.meshgrid(xedges, yedges)

n_colours = 9.
colours = [(1., 1., 1.)]
colours.extend(matplotlib.cm.copper_r(np.linspace(0, 1, n_colours-1)))
cmap = matplotlib.colors.ListedColormap(colours)

# Plot individual tuning curves
# for i, tuning_curve in enumerate(tuning_curves):    
for i, tuning_curve in enumerate(tuning_curves[:3]): 
    tuning_curve = np.array(tuning_curve)
    tuning_curve[np.isnan(tuning_curve)] = -np.nanmax(tuning_curve) / n_colours

    plt.figure()
    pp = plt.pcolormesh(xx, yy, tuning_curve, cmap=cmap)

    colourbar = plt.colorbar(pp)
    plt.tight_layout()
    plt.show()