In [None]:
%matplotlib inline
import matplotlib.pyplot as plt
from matplotlib import gridspec
import matplotlib
import numpy as np
import itertools
import scipy
import pandas as pd
import pickle
import seaborn as sns
from scipy import stats
import os
import nept

from matplotlib import animation, rc
from IPython.display import HTML

from loading_data import get_data
from analyze_tuning_curves import get_only_tuning_curves
from analyze_decode_bytrial import decode_trial
from analyze_decode import get_decoded_zones
from utils_maze import find_zones, get_trials, get_zones, get_trial_idx

In [None]:
thisdir = os.getcwd()
pickle_filepath = os.path.join(thisdir, "cache", "pickled")
output_filepath = os.path.join(thisdir, "plots", "check_decode")
if not os.path.exists(output_filepath):
    os.makedirs(output_filepath)

In [None]:
import info.r063d2 as info
import info.r063d6 as r063d6
infos = [r063d6]

from run import spike_sorted_infos
# infos = spike_sorted_infos

In [None]:
def plot_tuning_curves(info, tuning_curves):
    xx, yy = np.meshgrid(info.xedges, info.yedges)
    multiple_tuning_curves = np.zeros(tuning_curves[0].shape)
    cmap = plt.cm.get_cmap('bone_r', 25)
    plt.figure()

    for tuning_curve in tuning_curves:
        multiple_tuning_curves += tuning_curve

    pp = plt.pcolormesh(xx, yy, multiple_tuning_curves, cmap=cmap)
    plt.colorbar(pp)
    plt.axis('off')
    plt.tight_layout()
    title = info.session_id + '-tuning_curve-all'
    plt.title(title)
    plt.tight_layout()
    plt.show()

In [None]:
def plot_counts(counts):
    fig = plt.figure(figsize=(6, 7))
    ax = plt.subplot(111)
    pp = plt.pcolormesh(counts.data.T, cmap='bone_r')
    plt.colorbar(pp)
    ax.set_xticks([])
    ax.spines['top'].set_visible(False)
    ax.spines['right'].set_visible(False)
    ax.spines['bottom'].set_visible(False)
    ax.get_yaxis().tick_left()
    plt.show()

In [None]:
shuffled_id = False

In [None]:
events, position, spikes, _, _ = get_data(info)

phase = info.task_times["phase3"]
trials = get_trials(events, phase)

In [None]:
all_errors = []
all_proportions = []

session_ids = []

for info in infos:
    session_ids.append(info.session_id)

    session_errors = []
    session_proportions = []

    for trial in trials[:10]:
        epoch_of_interest = phase.excludes(trial)

        tuning_curves = get_only_tuning_curves(position, 
                                               spikes, 
                                               info.xedges, 
                                               info.yedges, 
                                               epoch_of_interest)

        if shuffled_id:
            tuning_curves = np.random.permutation(tuning_curves)

    #     plot_tuning_curves(info, tuning_curves)

        sliced_position = position.time_slice(trial.start, trial.stop)
    #     print("n_times in trial:", sliced_position.n_samples)

        sliced_spikes = [spiketrain.time_slice(trial.start, 
                                               trial.stop) for spiketrain in spikes]


        # limit position and spikes to only times when the subject is running
        run_epoch = nept.run_threshold(sliced_position, thresh=8., t_smooth=0.8)
        sliced_position = sliced_position[run_epoch]
    #     print("n_times running:", sliced_position.n_samples)
    #     plt.plot(sliced_position.x, sliced_position.y, "k.")
    #     plt.show()

        n_spikes = 0
        for spiketrain in sliced_spikes:
            n_spikes += spiketrain.n_spikes
    #     print("n_spikes in trial:", n_spikes)

        sliced_spikes = [spiketrain.time_slice(run_epoch.start, 
                                               run_epoch.stop) for spiketrain in sliced_spikes]

        n_spikes = 0
        for spiketrain in sliced_spikes:
            n_spikes += spiketrain.n_spikes
    #     print("n_spikes running:", n_spikes)

    #     epochs_interest = nept.Epoch(np.array([sliced_position.time[0], sliced_position.time[-1]]))

        counts = nept.bin_spikes(sliced_spikes, sliced_position.time, dt=0.025, window=0.025,
                                 gaussian_std=0.0075, normalized=False)

    #     plot_counts(counts)

        min_neurons = 3

        tc_shape = tuning_curves.shape
        decoding_tc = tuning_curves.reshape(tc_shape[0], tc_shape[1] * tc_shape[2])

        likelihood = nept.bayesian_prob(counts, decoding_tc, binsize=0.025, min_neurons=min_neurons)
    #     print("bins with prob:", likelihood.size - np.sum(np.isnan(likelihood)))

        # Find decoded location based on max likelihood for each valid timestep
        xcenters = (info.xedges[1:] + info.xedges[:-1]) / 2.
        ycenters = (info.yedges[1:] + info.yedges[:-1]) / 2.
        xy_centers = nept.cartesian(xcenters, ycenters)

        decoded = nept.decode_location(likelihood, xy_centers, counts.time)

    #     print("n_decoded:", decoded.n_samples)
        session_proportions.append(decoded.n_samples/len(counts.time))
    #     print("Proportion decoded: %.2f" % (decoded.n_samples/len(counts.time)))

    #     session_decoded.append(decoded)

    #     # Remove nans from likelihood and reshape for plotting
    #     keep_idx = np.sum(np.isnan(likelihood), axis=1) < likelihood.shape[1]
    #     likelihood = likelihood[keep_idx]
    #     likelihood = likelihood.reshape(np.shape(likelihood)[0], tc_shape[1], tc_shape[2])

    #     session_likelihoods.append(likelihood)

    #     n_active_neurons = np.asarray([n_active if n_active >= min_neurons else 0 
    #                                    for n_active in np.sum(counts.data >= 1, axis=1)])
    #     n_active_neurons = n_active_neurons[keep_idx]
    #     session_n_active.append(n_active_neurons)

        f_xy = scipy.interpolate.interp1d(sliced_position.time, sliced_position.data.T, kind="nearest")
        counts_xy = f_xy(decoded.time)
        true_position = nept.Position(np.hstack((counts_xy[0][..., np.newaxis],
                                                 counts_xy[1][..., np.newaxis])),
                                      decoded.time)

    #     session_actual.append(true_position)

        trial_errors = true_position.distance(decoded)
        session_errors.extend(trial_errors)
    #     print("mean error: %.fcm" % np.mean(trial_errors))

#     print("Proportion decoded: %.2f" % np.mean(session_proportions))
#     print("mean error: %.fcm" % np.mean(session_errors))

    all_errors.append(session_errors)
    all_proportions.append(session_proportions)
    #     for error, x, y in zip(trial_errors, true_position.x, true_position.y):
    #         x_idx = nept.find_nearest_idx(xcenters, x)
    #         y_idx = nept.find_nearest_idx(ycenters, y)
    #         error_byactual_position[y_idx][x_idx] += error
    #         n_byactual_position[y_idx][x_idx] += 1

    #     session_errors.append(trial_errors)

In [None]:
np.mean(all_proportions, axis=1)

In [None]:
y_pos = np.arange(len(session_ids))
plt.bar(y_pos, np.mean(all_proportions, axis=1), align='center', alpha=0.7)
plt.xticks(y_pos, session_ids, rotation=90, fontsize=10)
plt.ylabel('Proportion')
plt.title("Samples decoded with %d cm bins" % 12)
plt.tight_layout()
plt.show()

In [None]:
counts.time.shape

In [None]:
likelihood.shape

In [None]:
def bayesian_prob(counts, tuning_curves, binsize, min_neurons, min_spikes=1):
    """Computes the bayesian probability of location based on spike counts.

    Parameters
    ----------
    counts : nept.AnalogSignal
        Where each inner array is the number of spikes (int) in each bin for an individual neuron.
    tuning_curves : np.array
        Where each inner array is the tuning curve (floats) for an individual neuron.
    binsize : float
        Size of the time bins.
    min_neurons : int
        Mininum number of neurons active in a given bin.
    min_spikes : int
        Mininum number of spikes in a given bin.

    Returns
    -------
    prob : np.array
        Where each inner array is the probability (floats) for an individual neuron by location bins.

    Notes
    -----
    If a bin does not meet the min_neuron/min_spikes requirement, that bin's probability
    is set to nan. To convert it to 0s instead, use : prob[np.isnan(prob)] = 0 on the output.

    """
    n_time_bins = np.shape(counts.time)[0]
    n_position_bins = np.shape(tuning_curves)[1]

    likelihood = np.empty((n_time_bins, n_position_bins)) * np.nan

    # Ignore warnings when inf created in this loop
    error_settings = np.seterr(over='ignore')
    for idx in range(n_position_bins):
        valid_idx = tuning_curves[:, idx] > 1  # log of 1 or less is negative or invalid
        if np.any(valid_idx):
            # event_rate is the lambda in this poisson distribution
            event_rate = tuning_curves[valid_idx, idx, np.newaxis].T ** counts.data[:, valid_idx]
            prior = np.exp(-binsize * np.sum(tuning_curves[valid_idx, idx]))

            # Below is the same as
            # likelihood[:, idx] = np.prod(event_rate, axis=0) * prior * (1/n_position_bins)
            # only less likely to have floating point issues, though slower
            likelihood[:, idx] = np.exp(np.sum(np.log(event_rate), axis=1)) * prior * (1/n_position_bins)
    np.seterr(**error_settings)

    # Set any inf value to be largest float
    largest_float = np.finfo(float).max
    likelihood[np.isinf(likelihood)] = largest_float
    likelihood /= np.nansum(likelihood, axis=1)[..., np.newaxis]
    
    print(likelihood.shape)

    # Remove bins with too few neurons that that are active
    # a neuron is considered active by having at least min_spikes in a bin
    n_active_neurons = np.sum(counts.data >= min_spikes, axis=1)
    likelihood[n_active_neurons < min_neurons] = np.nan

    return likelihood

In [None]:
def decode_location(likelihood, pos_centers, time_centers):
    """Finds the decoded location based on the centers of the position bins.

    Parameters
    ----------
    likelihood : np.array
        With shape(n_timebins, n_positionbins)
    pos_centers : np.array
    time_centers : np.array

    Returns
    -------
    decoded : nept.Position
        Estimate of decoded position.

    """
    keep_idx = np.sum(np.isnan(likelihood), axis=1) < likelihood.shape[1]
    likelihood = likelihood[keep_idx]

    max_decoded_idx = np.nanargmax(likelihood, axis=1)

    decoded_data = pos_centers[max_decoded_idx]

    decoded_time = time_centers[keep_idx]

    return nept.Position(decoded_data, decoded_time)


In [None]:
likelihood = bayesian_prob(counts, decoding_tc, binsize=0.025, min_neurons=min_neurons)

In [None]:
pos_centers = xy_centers
time_centers = counts.time
dl = decode_location(likelihood, pos_centers, time_centers)

In [None]:
dl.n_samples

In [None]:
keep_idx = np.sum(np.isnan(likelihood), axis=1) < likelihood.shape[1]
np.sum(keep_idx)
ll = likelihood[keep_idx]

In [None]:
decoded.n_samples

In [None]:
ll.shape

In [None]:
max_decoded_idx = np.nanargmax(ll, axis=1)

In [None]:
max_decoded_idx.shape

In [None]:
cc = nept.AnalogSignal(counts.data[10:20], counts.time[10:20])

In [None]:
fig = plt.figure(figsize=(6, 7))
ax = plt.subplot(111)
pp = plt.pcolormesh(cc.data.T, cmap='bone_r')
plt.colorbar(pp)
ax.set_xticks([])
ax.spines['top'].set_visible(False)
ax.spines['right'].set_visible(False)
ax.spines['bottom'].set_visible(False)
ax.get_yaxis().tick_left()
plt.show()

In [None]:
np.sum(np.sum(counts.data >=1, axis=1) >= 2)

In [None]:
np.sum(counts.data >=1, axis=1)

In [None]:
likelihood.shape

In [None]:
np.sum(np.sum(likelihood >0, axis=1) >= 2)