In [None]:
%matplotlib inline
import matplotlib.pyplot as plt
import numpy as np
import scipy
import os
import nept

from loading_data import get_data
from analyze_tuning_curves import get_only_tuning_curves
from utils_maze import get_bin_centers, get_zones, get_xy_idx, get_matched_trials
from utils_plotting import plot_over_space

In [None]:
thisdir = os.getcwd()
pickle_filepath = os.path.join(thisdir, "cache", "pickled")
output_filepath = os.path.join(thisdir, "plots", "trials", "decoding")
if not os.path.exists(output_filepath):
    os.makedirs(output_filepath)

In [None]:
import info.r063d3 as r063d3
import info.r068d3 as r068d3
# infos = [r063d3, r068d3]
from run import analysis_infos
infos = analysis_infos

In [None]:
for info in infos:
    print(info.session_id)
    events, position, spikes, _, _ = get_data(info)
        
    phase = info.task_times["phase3"]
    sliced_position = position.time_slice(phase.start, phase.stop)

    # trials = get_trials(events, phase)
    trials = get_matched_trials(info, sliced_position, subset=False)
#     trials = get_matched_trials(info, sliced_position, subset=True)

    error_byactual_position = np.zeros((len(info.yedges), len(info.xedges)))
    n_byactual_position = np.ones((len(info.yedges), len(info.xedges)))

    session_n_active = []
    session_likelihoods = []
    session_decoded = []
    session_actual = []
    session_errors = []
    n_timebins = []

    for trial in trials:
        starts = [start for start in trials.starts if start != trial.start]
        stops = [stop for stop in trials.stops if stop != trial.stop]
        epoch_of_interest = nept.Epoch([starts, stops])

        tuning_curves = get_only_tuning_curves(info,
                                               position,
                                               spikes,
                                               epoch_of_interest)
        
        tc_shape = tuning_curves.shape
        decoding_tc = tuning_curves.reshape(tc_shape[0], tc_shape[1] * tc_shape[2])
        
        
        trial_position = position.time_slice(trial.start, trial.stop)

        sliced_spikes = [spiketrain.time_slice(trial.start,
                                               trial.stop) for spiketrain in spikes]

        # limit position to only times when the subject is moving faster than a certain threshold
#         run_epoch = nept.run_threshold(trial_position, thresh=10., t_smooth=0.8)
#         trial_position = trial_position[run_epoch]

#         sliced_spikes = [spiketrain.time_slice(run_epoch.start,
#                                                run_epoch.stop) for spiketrain in sliced_spikes]

        # epochs_interest = nept.Epoch(np.array([trial_position.time[0], trial_position.time[-1]]))

        t_window = 0.1  # 0.1 for running, 0.025 for swr

        counts = nept.bin_spikes(sliced_spikes, trial_position.time, dt=t_window, window=t_window,
                                 gaussian_std=0.0075, normalized=False)

        n_timebins.append(len(counts.time))
        min_neurons=3
        
        likelihood = nept.bayesian_prob(counts, decoding_tc, binsize=t_window, min_neurons=min_neurons, min_spikes=1)

        # Find decoded location based on max likelihood for each valid timestep
        xcenters, ycenters = get_bin_centers(info)
        xy_centers = nept.cartesian(xcenters, ycenters)
        decoded = nept.decode_location(likelihood, xy_centers, counts.time)

        session_decoded.append(decoded)

        # Remove nans from likelihood and reshape for plotting
        keep_idx = np.sum(np.isnan(likelihood), axis=1) < likelihood.shape[1]
        likelihood = likelihood[keep_idx]
        likelihood = likelihood.reshape(np.shape(likelihood)[0], tc_shape[1], tc_shape[2])

        session_likelihoods.append(likelihood)

        n_active_neurons = np.asarray([n_active if n_active >= min_neurons else 0
                                       for n_active in np.sum(counts.data >= 1, axis=1)])
        n_active_neurons = n_active_neurons[keep_idx]
        session_n_active.append(n_active_neurons)

        f_xy = scipy.interpolate.interp1d(trial_position.time, trial_position.data.T, kind="nearest")
        counts_xy = f_xy(decoded.time)
        true_position = nept.Position(np.hstack((counts_xy[0][..., np.newaxis],
                                                 counts_xy[1][..., np.newaxis])),
                                      decoded.time)

        session_actual.append(true_position)

        trial_errors = true_position.distance(decoded)
        session_errors.append(trial_errors)
        
    title = info.session_id+"_matched-trials_decoding-error"
    filepath = os.path.join(output_filepath, title+".png")
    errors_byactual = plot_over_space(info, sliced_position, session_errors, 
                                      session_actual, title, vmax=80., filepath=filepath)