In [None]:
%matplotlib inline
import matplotlib.pyplot as plt
from matplotlib import gridspec
import matplotlib
import numpy as np
import itertools
import scipy
import pandas as pd
import seaborn as sns
from scipy import stats
import os
import nept

from matplotlib import animation, rc
from IPython.display import HTML

from loading_data import get_data
from analyze_tuning_curves import get_only_tuning_curves
from analyze_decode_bytrial import decode_trial
from analyze_decode import get_decoded_zones
from utils_maze import find_zones, get_trials

In [None]:
thisdir = os.getcwd()
pickle_filepath = os.path.join(thisdir, "cache", "pickled")
output_filepath = os.path.join(thisdir, "plots", "decode-video")

In [None]:
from run import spike_sorted_infos
import info.r063d5 as r063d5
import info.r063d6 as r063d6
infos = [r063d5]
# infos = spike_sorted_infos

In [None]:
def get_decoded(info, position, spikes, xedges, yedges, shuffled_id):
    
    phase = info.task_times["phase3"]
    trials = get_trials(events, phase)
    
    error_byactual_position = np.zeros((len(yedges), len(xedges)))
    n_byactual_position = np.ones((len(yedges), len(xedges)))
    
    session_n_active = []
    session_likelihoods = []
    session_decoded = []
    session_actual = []
    session_errors = []
    
    for trial in trials:
        epoch_of_interest = phase.excludes(trial)

        tuning_curves = get_only_tuning_curves(position, 
                                               spikes, 
                                               xedges, 
                                               yedges, 
                                               epoch_of_interest)

        if shuffled_id:
            tuning_curves = np.random.permutation(tuning_curves)

        sliced_position = position.time_slice(trial.start, trial.stop)
        
        sliced_spikes = [spiketrain.time_slice(trial.start, 
                                               trial.stop) for spiketrain in spikes]

        # limit position to only times when the subject is moving faster than a certain threshold
        run_epoch = nept.run_threshold(sliced_position, thresh=10., t_smooth=0.8)
        sliced_position = sliced_position[run_epoch]
        sliced_spikes = [spiketrain.time_slice(run_epoch.start, 
                                               run_epoch.stop) for spiketrain in spikes]

        epochs_interest = nept.Epoch(np.array([sliced_position.time[0], sliced_position.time[-1]]))

        counts = nept.bin_spikes(sliced_spikes, sliced_position.time, dt=0.025, window=0.025,
                                 gaussian_std=0.0075, normalized=False)
        
        min_neurons = 2
        min_spikes = 2
        
        tc_shape = tuning_curves.shape
        decoding_tc = tuning_curves.reshape(tc_shape[0], tc_shape[1] * tc_shape[2])

        likelihood = nept.bayesian_prob(counts, decoding_tc, binsize=0.025, min_neurons=2, 
                                        min_spikes=min_spikes)

        # Find decoded location based on max likelihood for each valid timestep
        xcenters = (xedges[1:] + xedges[:-1]) / 2.
        ycenters = (yedges[1:] + yedges[:-1]) / 2.
        xy_centers = nept.cartesian(xcenters, ycenters)
        decoded = nept.decode_location(likelihood, xy_centers, counts.time)

        session_decoded.append(decoded)
        
        # Remove nans from likelihood and reshape for plotting
        keep_idx = np.sum(np.isnan(likelihood), axis=1) < likelihood.shape[1]
        likelihood = likelihood[keep_idx]
        likelihood = likelihood.reshape(np.shape(likelihood)[0], tc_shape[1], tc_shape[2])

        session_likelihoods.append(likelihood)
        
        n_active_neurons = np.asarray([n_active if n_active >= min_neurons else 0 
                                       for n_active in np.sum(counts.data >= 1, axis=1)])
        n_active_neurons = n_active_neurons[keep_idx]
        session_n_active.append(n_active_neurons)

        f_xy = scipy.interpolate.interp1d(sliced_position.time, sliced_position.data.T, kind="nearest")
        counts_xy = f_xy(decoded.time)
        true_position = nept.Position(np.hstack((counts_xy[0][..., np.newaxis],
                                                 counts_xy[1][..., np.newaxis])),
                                      decoded.time)

        session_actual.append(true_position)

        trial_errors = true_position.distance(decoded)

        for error, x, y in zip(trial_errors, true_position.x, true_position.y):
            x_idx = nept.find_nearest_idx(xedges, x)
            y_idx = nept.find_nearest_idx(yedges, y)
            error_byactual_position[y_idx][x_idx] += error
            n_byactual_position[y_idx][x_idx] += 1

        session_errors.append(trial_errors)
            
#     error_byactual = error_byactual_position / n_byactual_position

    return session_decoded, session_actual, session_likelihoods, session_errors, session_n_active

In [None]:
def plot_errors(all_errors, all_errors_id_shuffled, all_errors_random_shuffled, n_sessions, filename=None):
    
    all_errors = np.concatenate([np.concatenate(errors, axis=0) for errors in all_errors], axis=0)
    all_errors_id_shuffled = np.concatenate([np.concatenate(errors, axis=0) for errors in all_errors_id_shuffled], axis=0)

#     print('Actual:', np.median(all_errors))
#     print('ID shuffle:', np.median(all_errors_id_shuffled))

#     print('Actual:', np.mean(all_errors), stats.sem(all_errors))
#     print('ID shuffle:', np.mean(all_errors_id_shuffled), stats.sem(all_errors_id_shuffled))

    fliersize = 1

    decoded_dict = dict(error=all_errors, label='Decoded')
    shuffled_id_dict = dict(error=all_errors_id_shuffled, label='ID shuffled')
    decoded_errors = pd.DataFrame(decoded_dict)
    shuffled_id = pd.DataFrame(shuffled_id_dict)
    data = pd.concat([shuffled_id, decoded_errors])
    colours = ['#ffffff', '#bdbdbd']

    plt.figure(figsize=(6, 4))
    flierprops = dict(marker='o', markersize=fliersize, linestyle='none')
    # ax = sns.boxplot(x='label', y='error', data=data, palette=colours, flierprops=flierprops)
    ax = sns.boxplot(x='label', y='error', data=data, flierprops=flierprops)


    edge_colour = '#252525'
    for i, artist in enumerate(ax.artists):
        artist.set_edgecolor(edge_colour)
        artist.set_facecolor(colours[i])

        for j in range(i*6, i*6+6):
            line = ax.lines[j]
            line.set_color(edge_colour)
            line.set_mfc(edge_colour)
            line.set_mec(edge_colour)
    
    ax.text(1., 1., "N sessions: %d \nmean-error: %.1f cm \nmedian-error: %.1f cm" % (n_sessions, 
                                                                                      np.mean(all_errors), 
                                                                                      np.median(all_errors)),
            horizontalalignment='right',
            verticalalignment='top',
            transform = ax.transAxes,
            fontsize=10)

    ax.set(xlabel=' ', ylabel="Error (cm)")
    plt.xticks(fontsize=14)

    sns.despine()
    plt.tight_layout()
    
    if filename is not None:
        plt.savefig(filename)
        plt.close()
    else:
        plt.show()

In [None]:
binsize = 8
n_sessions = 0
session_ids = []
xedges = []
yedges = []

all_decoded = []
all_actual = []
all_likelihoods = []
all_n_active = []

all_errors = []
all_errors_id_shuffled = []
all_errors_random_shuffled = []

for info in infos:
    print(info.session_id)
    session_ids.append(info.session_id)
    n_sessions += 1
    events, position, spikes, _, _ = get_data(info)

    xedge, yedge = nept.get_xyedges(position, binsize=binsize)
    xedges.append(xedge)
    yedges.append(yedge)

    xx, yy = np.meshgrid(xedge, yedge)


    decoded, actual, likelihoods, errors, n_active = get_decoded(info, 
                                             position, 
                                             spikes, 
                                             xedge, 
                                             yedge, 
                                             shuffled_id=False)

    _, _, _, errors_id_shuffled, _ = get_decoded(info, 
                                             position, 
                                             spikes, 
                                             xedge, 
                                             yedge, 
                                             shuffled_id=True)

    all_decoded.append(decoded)
    all_actual.append(actual)
    all_likelihoods.append(likelihoods)
    all_n_active.append(n_active)

    all_errors.append(errors)
    all_errors_id_shuffled.append(errors_id_shuffled)

combined_errors = np.concatenate([np.concatenate(errors, axis=0) for errors in all_errors], axis=0)

filename = os.path.join(output_filepath, "combined_errors-binsize"+str(binsize)+".png")
plot_errors(all_errors, all_errors_id_shuffled, n_sessions, filename)

In [None]:
filename = os.path.join(output_filepath, "combined_errors-binsize"+str(binsize)+".png")
plot_errors(all_errors, all_errors_id_shuffled, n_sessions, filename)

In [None]:
np.mean(all_errors)

In [None]:
for trial_idx in range(2):
    for session_idx in range(n_sessions):
        decoded = all_decoded[session_idx][trial_idx]
        true_position = all_actual[session_idx][trial_idx]
        likelihoods = np.array(all_likelihoods[session_idx][trial_idx])
        n_active = all_n_active[session_idx][trial_idx]
        errors = all_errors[session_idx][trial_idx]
        xedge = xedges[session_idx]
        yedge = yedges[session_idx]

        fig = plt.figure(figsize=(12, 10))
        gs = gridspec.GridSpec(5, 4)

        xx, yy = np.meshgrid(xedge, yedge)

        ax1 = plt.subplot2grid((5, 4), (0, 0), colspan=3, rowspan=3)

        pad_amount = binsize*2
        ax1.set_xlim((np.floor(np.min(true_position.x))-pad_amount, np.ceil(np.max(true_position.x))+pad_amount))
        ax1.set_ylim((np.floor(np.min(true_position.y))-pad_amount, np.ceil(np.max(true_position.y))+pad_amount))

        n_timebins = decoded.n_samples
    #     n_timebins = 10

        n_colours = 20.
        colours = [(1., 1., 1.)]
        colours.extend(matplotlib.cm.copper_r(np.linspace(0, 1, n_colours-1)))
        cmap = matplotlib.colors.ListedColormap(colours)

        likelihoods_withnan = np.array(likelihoods)
        likelihoods[np.isnan(likelihoods)] = -0.01

        xcenters = xedge[:-1] + (xedge[1:] - xedge[:-1]) / 2
        ycenters = yedge[:-1] + (yedge[1:] - yedge[:-1]) / 2

        x_idx = [nept.find_nearest_idx(xcenters, true_position.x[timestep]) for timestep in range(true_position.n_samples)]
        y_idx = [nept.find_nearest_idx(ycenters, true_position.y[timestep]) for timestep in range(true_position.n_samples)]

        x_dec_idx = [nept.find_nearest_idx(xcenters, decoded.x[timestep]) for timestep in range(decoded.n_samples)]
        y_dec_idx = [nept.find_nearest_idx(ycenters, decoded.y[timestep]) for timestep in range(decoded.n_samples)]

    #     cmap = plt.cm.get_cmap('bone_r')
        posterior_position = ax1.pcolormesh(xx[:-1], yy[:-1], likelihoods[0], vmax=0.2, cmap=cmap)
        colorbar = fig.colorbar(posterior_position, ax=ax1)

        estimated_position, = ax1.plot([], [], "o", color="c")
        rat_position, = ax1.plot([], [], "<", color="b")

        ax2 = plt.subplot2grid((5, 4), (3, 0), colspan=3)

        binwidth = 5.
        error_bins = np.arange(-binwidth, np.max(errors)+binwidth, binwidth)

        _, _, errors_bin = ax2.hist([np.clip(errors, error_bins[0], error_bins[-1])], bins=error_bins, rwidth=0.9, color="k")
        errors_idx = np.digitize(errors, error_bins)

        fontsize = 14
        likelihood_at_actual = ax2.text(0.6, 1, [],
                 horizontalalignment='left',
                 verticalalignment='top',
                 transform = ax2.transAxes,
                 fontsize=fontsize)

        ax2.set_xlabel("Error (cm)", fontsize=fontsize)
        ax2.set_ylabel("# bins", fontsize=fontsize)
        ax2.spines['right'].set_visible(False)
        ax2.spines['top'].set_visible(False)
        ax2.yaxis.set_ticks_position('left')
        ax2.xaxis.set_ticks_position('bottom')
        # xticks = binwidth * np.arange(0, len(xlabels)-1, 2) - binwidth
        # xticks[0] = -binwidth/2.
        xticks = binwidth * np.arange(0, len(error_bins), 4)
        plt.xticks(xticks, fontsize=fontsize)
        plt.yticks(fontsize=fontsize)

        ax3 = plt.subplot2grid((5, 4), (4, 0), colspan=3)

        n_active_bins = np.arange(-0.5, np.max(n_active)+1)

        _, _, n_neurons_bin = ax3.hist(n_active, bins=n_active_bins, rwidth=0.9, color="k", align="mid")

        ax3.set_xlabel("Number of active neurons", fontsize=fontsize)
        ax3.set_ylabel("# bins", fontsize=fontsize)
        ax3.spines['right'].set_visible(False)
        ax3.spines['top'].set_visible(False)
        ax3.yaxis.set_ticks_position('left')
        ax3.xaxis.set_ticks_position('bottom')
    #     xticks = binwidth * np.arange(0, len(n_active_bins)+1)
    #     plt.xticks(xticks+binwidth/2)
    #     ax3.set_xticklabels(xticks, fontsize=fontsize)
        plt.yticks(fontsize=fontsize)

        fig.tight_layout()


        def init():
            posterior_position.set_array([])
            estimated_position.set_data([], [])
            rat_position.set_data([], [])
            likelihood_at_actual.set_text([])
            return (posterior_position, estimated_position, rat_position, likelihood_at_actual)


        def animate(i):
            posterior_position.set_array(likelihoods[i].ravel())
            estimated_position.set_data(decoded.x[i], decoded.y[i])
            rat_position.set_data(true_position.x[i], true_position.y[i])

            for patch in errors_bin:
                patch.set_fc('k')
            errors_bin[errors_idx[i]-1].set_fc('r')

            for patch in n_neurons_bin:
                patch.set_fc('k')
            n_neurons_bin[n_active[i]].set_fc('b')

            likelihood_at_actual.set_text("posterior at true position: %.3f \nposterior at decoded position: %.3f " % 
                                          (likelihoods_withnan[i][y_idx[i]][x_idx[i]], 
                                           likelihoods_withnan[i][y_dec_idx[i]][x_dec_idx[i]]))

            return (posterior_position, estimated_position, rat_position, likelihood_at_actual)

        anim = animation.FuncAnimation(fig, animate, frames=n_timebins, interval=200, 
                                       blit=False, repeat=False)


#         writer = animation.writers['ffmpeg'](fps=10)
#         dpi = 600
#         filename = session_ids[session_idx]+'_decoded_trial'+str(trial_idx)+'.mp4'
#         anim.save(os.path.join(output_filepath, filename), writer=writer, dpi=dpi)

#         plt.close()

In [None]:
print("Blue triangle is the true position; Cyan circle is the estimated location.")
HTML(anim.to_html5_video())

In [None]:
from shapely.geometry import Point, LineString

In [None]:
def expand_line(start_pt, stop_pt, line, expand_by):
    """Expands shapely line into a zone.

    Parameters
    ----------
    start_pt : shapely.Point
    stop_pt : shapely.Point
    line : shapely.LineString
    expand_by : int or float

    Returns
    -------
    zone : shapely.Polygon

    """
    line_expanded = line.buffer(expand_by)
    zone = start_pt.union(line_expanded).union(stop_pt)

    return zone


def find_zones(info, remove_feeder, expand_by=6):
    """Finds zones from ideal trajectories.

    Parameters
    ----------
    info : shortcut module
    remove_feeder: boolean
    expand_by : int or float
        Amount to expand the line.

    Returns
    -------
    zone : dict
        With shapely.Polygon as values.
        Keys are u, shortcut, novel, ushort, unovel, uped, shortped,
        novelped, pedestal.

    """
    u_line = LineString(info.u_trajectory)
    shortcut_line = LineString(info.shortcut_trajectory)
    novel_line = LineString(info.novel_trajectory)

    pedestal = Point(info.path_pts['pedestal'][0], info.path_pts['pedestal'][1]).buffer(expand_by*2.2)
    feeder1 = Point(info.path_pts['feeder1'][0], info.path_pts['feeder1'][1]).buffer(expand_by*1.2)
    feeder2 = Point(info.path_pts['feeder2'][0], info.path_pts['feeder2'][1]).buffer(expand_by*1.2)

    u_zone = expand_line(Point(info.u_trajectory[0]), 
                         Point(info.u_trajectory[-1]), 
                         u_line, expand_by)
    shortcut_zone = expand_line(Point(info.shortcut_trajectory[0]), 
                                Point(info.shortcut_trajectory[-1]), 
                                shortcut_line, expand_by)
    novel_zone = expand_line(Point(info.novel_trajectory[0]), 
                             Point(info.novel_trajectory[-1]), 
                             novel_line, expand_by)

    zone = dict()
    zone['u'] = u_zone.difference(pedestal)
    zone['shortcut'] = shortcut_zone.difference(u_zone)
    zone['shortcut'] = zone['shortcut'].difference(novel_zone)
    zone['shortcut'] = zone['shortcut'].difference(pedestal)
    zone['novel'] = novel_zone.difference(u_zone)
    zone['novel'] = zone['novel'].difference(pedestal)
    zone['pedestal'] = pedestal

    if remove_feeder:
        for feeder in [feeder1, feeder2]:
            zone['u'] = zone['u'].difference(feeder)
            zone['shortcut'] = zone['shortcut'].difference(feeder)
            zone['novel'] = zone['novel'].difference(feeder)
            zone['pedestal'] = zone['pedestal'].difference(feeder)

    return zone

In [None]:
zones = find_zones(info, remove_feeder=True, expand_by=20)

In [None]:
zones

In [None]:
zone = zones["u"]
plt.plot(zones["u"].exterior.xy[0], zones["u"].exterior.xy[1], 'b', lw=1)

In [None]:
xcenters = xedge[:-1] + (xedge[1:] - xedge[:-1]) / 2
ycenters = yedge[:-1] + (yedge[1:] - yedge[:-1]) / 2
xcenters

In [None]:
u_xbins = np.zeros(len(xcenters)).astype(bool)
shortcut_xbins = np.zeros(len(xcenters)).astype(bool)
novel_xbins = np.zeros(len(xcenters)).astype(bool)
pedestal_xbins = np.zeros(len(xcenters)).astype(bool)

u_ybins = np.zeros(len(ycenters)).astype(bool)
shortcut_ybins = np.zeros(len(ycenters)).astype(bool)
novel_ybins = np.zeros(len(ycenters)).astype(bool)
pedestal_ybins = np.zeros(len(ycenters)).astype(bool)

In [None]:
for x_idx, xbin in enumerate(xcenters):
    for y_idx, ybin in enumerate(ycenters):
        point = Point([xbin, ybin])
        if zones['u'].contains(point):
            u_xbins[x_idx] = True
            u_ybins[y_idx] = True

In [None]:
u_xbins, u_ybins

In [None]:
plt.plot(phase1_position.x, phase1_position.y, ".")
plt.plot(zones["u"].exterior.xy[0], zones["u"].exterior.xy[1], 'b', lw=1)
for intersect in zones["shortcut"]:
    plt.plot(intersect.exterior.xy[0], intersect.exterior.xy[1], 'g', lw=1)
plt.plot(zones["novel"].exterior.xy[0], zones["novel"].exterior.xy[1], 'r', lw=1)
plt.show()

In [None]:
shortcut_xbins

In [None]:
events, position, spikes, _, _ = get_data(info)
xedge, yedge = nept.get_xyedges(position, binsize=binsize)

In [None]:
phase1_position = position.time_slice(info.task_times["phase1"].start, info.task_times["phase1"].stop)
occupancy = nept.get_occupancy(phase1_position, yedge, xedge)

In [None]:
occupancy

In [None]:
u_pos = np.zeros(occupancy.shape).astype(bool)

In [None]:
u_pos[occupancy > 0] = True

In [None]:
xx, yy = np.meshgrid(xedge, yedge)

plt.figure()
pp = plt.pcolormesh(xx, yy, occupancy, vmax=10., cmap="Greys")
colourbar = plt.colorbar(pp)
plt.tight_layout()

In [None]:
def point_in_zones(position, zones):
    """Assigns points to each trajectory

    Parameters
    ----------
    position : nept.Position
    zones : dict
        With u, shortcut, novel, pedestal as keys

    Returns
    -------
    sorted_zones : dict
        With u, shortcut, novel, other as keys, each a nept.Position object

    """
    u_data = []
    u_times = []
    shortcut_data = []
    shortcut_times = []
    novel_data = []
    novel_times = []
    other_data = []
    other_times = []

    if not position.isempty:
        for x, y, time in zip(position.x, position.y, position.time):
            point = Point([x, y])
            if zones['u'].contains(point):
                u_data.append([x, y])
                u_times.append(time)
                continue
            elif zones['shortcut'].contains(point):
                shortcut_data.append([x, y])
                shortcut_times.append(time)
                continue
            elif zones['novel'].contains(point):
                novel_data.append([x, y])
                novel_times.append(time)
                continue
            else:
                other_data.append([x, y])
                other_times.append(time)

    sorted_zones = dict()
    sorted_zones['u'] = nept.Position(u_data, u_times)
    sorted_zones['shortcut'] = nept.Position(shortcut_data, shortcut_times)
    sorted_zones['novel'] = nept.Position(novel_data, novel_times)
    sorted_zones['other'] = nept.Position(other_data, other_times)

    return sorted_zones

In [None]:
trial_idx = 1
timestep = 0

xcenters = xedges[:-1] + (xedges[1:] - xedges[:-1]) / 2
ycenters = yedges[:-1] + (yedges[1:] - yedges[:-1]) / 2

x_idx = [nept.find_nearest_idx(xcenters, actual[trial_idx].x[timestep]) for timestep in range(actual[trial_idx].n_samples)]
y_idx = [nept.find_nearest_idx(ycenters, actual[trial_idx].y[timestep]) for timestep in range(actual[trial_idx].n_samples)]

print(likelihoods[timestep][y_idx[timestep]][x_idx[timestep]])

In [None]:
cmap = plt.cm.get_cmap('bone_r')
pp = plt.pcolormesh(xx[:-1], yy[:-1], likelihoods[timestep], cmap=cmap)
colorbar = fig.colorbar(pp, ax=ax1)
plt.plot(actual[trial_idx].x[timestep], actual[trial_idx].y[timestep], "r.", ms=10)
colorbar = plt.colorbar(pp)
plt.show()

In [None]:
# combined_errors = []
# binsizes = [2, 4, 6, 8, 10, 12, 14, 16, 18, 20]
# for binsize in binsizes:
#     n_sessions = 0
#     xxs = []
#     yys = []

#     all_decoded = []
#     all_actual = []
#     all_likelihoods = []
#     all_n_active = []

#     all_errors = []


#     for info in infos:
#         print(info.session_id)
#         n_sessions += 1
#         events, position, spikes, _, _ = get_data(info)
#         xedges, yedges = nept.get_xyedges(position, binsize=binsize)

#         xx, yy = np.meshgrid(xedges, yedges)
#         xxs.append(xx)
#         yys.append(yy)

#         phase = "phase3"
#         trial_epochs = get_trials(events, info.task_times[phase])

#         decoded, actual, likelihoods, errors, n_active = get_decoded(info, 
#                                                  position, 
#                                                  spikes, 
#                                                  xedges, 
#                                                  yedges, 
#                                                  shuffled_id=False, 
#                                                  random_shuffle=False)

#         all_decoded.append(decoded)
#         all_actual.append(actual)
#         all_likelihoods.append(likelihoods)
#         all_n_active.append(n_active)

#         all_errors.append(errors)

#     combined_errors.append(np.concatenate([np.concatenate(errors, axis=0) for errors in all_errors], axis=0))
    
# plt.plot(binsizes, np.mean(combined_errors, axis=1))
# plt.xticks(binsizes)
# plt.xlabel("Binsize (cm)")
# plt.show()