In [None]:
%matplotlib inline
import matplotlib.pyplot as plt
import numpy as np
import scipy
import os
import nept

from loading_data import get_data
from analyze_tuning_curves import get_only_tuning_curves
from utils_plotting import plot_over_space
from utils_maze import get_subset_zones, get_bin_centers

In [None]:
thisdir = os.getcwd()
pickle_filepath = os.path.join(thisdir, "cache", "pickled")
output_filepath = os.path.join(thisdir, "plots", "decode-checks")
if not os.path.exists(output_filepath):
    os.makedirs(output_filepath)

In [None]:
import info.r066d6 as r066d6
import info.r066d5 as r066d5
infos = [r066d5]
from run import analysis_infos
# infos = analysis_infos

In [None]:
def bin_spikes(spikes, time, dt, window=None, gaussian_std=None, normalized=True):
    """Bins spikes using a sliding window.

    Parameters
    ----------
    spikes: list
        Of nept.SpikeTrain
    time: np.array
    window: float or None
        Length of the sliding window, in seconds. If None, will default to dt.
    dt: float
    gaussian_std: float or None
    normalized: boolean

    Returns
    -------
    binned_spikes: nept.AnalogSignal

    """
    if window is None:
        window = dt

    bin_edges = time

    given_n_bins = window / dt
    n_bins = int(round(given_n_bins))
    if abs(n_bins - given_n_bins) > 0.01:
        warnings.warn("dt does not divide window evenly. "
                      "Using window %g instead." % (n_bins*dt))

    if normalized:
        square_filter = np.ones(n_bins) * (1 / n_bins)
    else:
        square_filter = np.ones(n_bins)

    counts = np.zeros((len(spikes), len(bin_edges) - 1))
    for idx, spiketrain in enumerate(spikes):
        counts[idx] = np.convolve(np.histogram(spiketrain.time, bins=bin_edges)[0].astype(float),
                                  square_filter, mode="same")

    if gaussian_std is not None:
        counts = nept.gaussian_filter(counts, gaussian_std, dt=dt, normalized=normalized, axis=1)

    return nept.AnalogSignal(counts, bin_edges[:-1])

In [None]:
# Average decoded likelihood for each trajectory
def get_average_likelihoods(info):
    events, position, spikes, lfp, _ = get_data(info)

    u_zone, shortcut_zone, novel_zone = get_subset_zones(info, position)
    combined_zones = u_zone+shortcut_zone+novel_zone
    other_zone = ~combined_zones

    tuning_curves = get_only_tuning_curves(info,
                                           position,
                                           spikes,
                                           info.task_times["phase3"])
    tc_shape = tuning_curves.shape
    decoding_tc = tuning_curves.reshape(tc_shape[0], tc_shape[1] * tc_shape[2])

    # Find SWRs for the whole session
    z_thresh = 2.0
    power_thresh = 3.0
    merge_thresh = 0.02
    min_length = 0.05
    swrs = nept.detect_swr_hilbert(lfp, fs=info.fs, thresh=(140.0, 250.0), z_thresh=z_thresh,
                                   power_thresh=power_thresh, merge_thresh=merge_thresh, min_length=min_length)
    swrs = nept.find_multi_in_epochs(spikes, swrs, min_involved=4)

    rest_epochs = nept.rest_threshold(position, thresh=12., t_smooth=0.8)

    # Restrict SWRs to those during epochs of interest during rest
    task_times = ["prerecord", "pauseA", "pauseB", "postrecord"]
    maze_segments = ["u", "shortcut", "novel", "other"]
    data = {k: {key: [] for key in maze_segments} for k in task_times}

    n_swrs = dict()

    for task_time in task_times:
        epochs_of_interest = info.task_times[task_time].intersect(rest_epochs)

        phase_swrs = epochs_of_interest.overlaps(swrs)
        phase_swrs = phase_swrs[phase_swrs.durations >= 0.05]

        n_swrs[task_time] = phase_swrs.n_epochs

        likelihoods = []

        for start, stop in zip(phase_swrs.starts, phase_swrs.stops):
            sliced_spikes = [spiketrain.time_slice(start, stop) for spiketrain in spikes]

            t_window = stop-start # 0.1 for running, 0.025 for swr

            counts = bin_spikes(sliced_spikes, np.array([start, stop]), dt=t_window, window=t_window,
                                     gaussian_std=0.0075, normalized=False)

            likelihood = nept.bayesian_prob(counts, decoding_tc, binsize=t_window, min_neurons=3, min_spikes=1)

            # Remove nans from likelihood and reshape for plotting
            keep_idx = np.sum(np.isnan(likelihood), axis=1) < likelihood.shape[1]
            likelihood = likelihood[keep_idx]
            likelihoods.append(likelihood.reshape(np.shape(likelihood)[0], tc_shape[1], tc_shape[2]))

        for swr_likelihood in likelihoods:    
            data[task_time]["u"].append(np.nansum(swr_likelihood[0][u_zone]))
            data[task_time]["shortcut"].append(np.nansum(swr_likelihood[0][shortcut_zone]))
            data[task_time]["novel"].append(np.nansum(swr_likelihood[0][novel_zone]))
            data[task_time]["other"].append(np.nansum(swr_likelihood[0][other_zone]))
        
    return data, n_swrs, likelihoods

In [None]:
def plot_decoded_summary(all_data, n_all_swrs, task_times, maze_segments, n_sessions):

    for task_time in task_times:
        us = []
        shortcuts = []
        novels = []
        others = []

        n_swrs = 0

        for i, session in enumerate(all_data):
            us.extend(session[task_time]["u"])
            shortcuts.extend(session[task_time]["shortcut"])
            novels.extend(session[task_time]["novel"])
            others.extend(session[task_time]["other"])

            n_swrs += n_all_swrs[i][task_time]

        means = [np.nanmean(us), np.nanmean(shortcuts), np.nanmean(novels), np.nanmean(others)]
        sems = [scipy.stats.sem(us), scipy.stats.sem(shortcuts), scipy.stats.sem(novels), scipy.stats.sem(others)]

        xx, yy = np.meshgrid(info.xedges, info.yedges)

        fig, ax = plt.subplots(figsize=(7,5))
        n = np.arange(len(maze_segments))
        plt.bar(n, means, yerr=sems, color="#a6bddb")
        plt.xticks(n, maze_segments)
        plt.text(0.95, 0.95, "n swrs: "+str(n_swrs),
             horizontalalignment='center',
             verticalalignment='center',
             transform = ax.transAxes,
             fontsize=14)
        if n_sessions == 1:
            title = info.session_id + " average posteriors during SWRs in " + task_time
        else:
            title = "Average posteriors during SWRs in " + task_time
        plt.title(title)
        plt.ylabel("Proportion")

        ax.spines['right'].set_visible(False)
        ax.spines['top'].set_visible(False)
        ax.yaxis.set_ticks_position('left')
        ax.xaxis.set_ticks_position('bottom')
        plt.ylim(0, 0.8)

        plt.tight_layout()

        plt.savefig(os.path.join(output_filepath, title+".png"))
        plt.close()

    #     plt.show()


    for trajectory in maze_segments:
        trajectory_means = []
        trajectory_sems = []

        tt = {key: [] for key in task_times}
        n_swrs = {key: 0 for key in task_times}

        for i, session in enumerate(all_data):
            for task_time in task_times:
                tt[task_time].extend(session[task_time][trajectory])

                n_swrs[task_time] += n_all_swrs[i][task_time]
        trajectory_means = [np.nanmean(tt["prerecord"]), np.nanmean(tt["pauseA"]), np.nanmean(tt["pauseB"]), np.nanmean(tt["postrecord"])]
        trajectory_sems = [scipy.stats.sem(tt["prerecord"]), scipy.stats.sem(tt["pauseA"]), scipy.stats.sem(tt["pauseB"]), scipy.stats.sem(tt["postrecord"])]

        fig, ax = plt.subplots(figsize=(7,5))
        n = np.arange(len(task_times))
        plt.bar(n, trajectory_means, yerr=trajectory_sems, color="#3690c0")
        plt.xticks(n, task_times)
        if n_sessions == 1:
             title = info.session_id + " average posteriors during SWRs for " + trajectory
        else:
            title = "Average posteriors during SWRs for " + trajectory
        plt.title(title)
        plt.ylabel("Proportion")

        for i, task_time in enumerate(task_times):
            ax.text(i, 0.01, str(n_swrs[task_time]), ha="center", fontsize=14)

        ax.spines['right'].set_visible(False)
        ax.spines['top'].set_visible(False)
        ax.yaxis.set_ticks_position('left')
        ax.xaxis.set_ticks_position('bottom')

        plt.tight_layout()

        plt.savefig(os.path.join(output_filepath, title+".png"))
        plt.close()

    #     plt.show()

In [None]:
task_times = ["prerecord", "pauseA", "pauseB", "postrecord"]
maze_segments = ["u", "shortcut", "novel", "other"]

In [None]:
# plot individual sessions
for info in infos:
    all_data = []
    n_all_swrs = []
    
    data, n_swrs, likelihoods = get_average_likelihoods(info)

    all_data.append(data)
    n_all_swrs.append(n_swrs)
    
    plot_decoded_summary(all_data, n_all_swrs, task_times, maze_segments, n_sessions=1)

In [None]:
# plot combined sessions
all_data = []
n_all_swrs = []

for info in infos:
    data, n_swrs, likelihoods = get_average_likelihoods(info)

all_data.append(data)
n_all_swrs.append(n_swrs)

plot_decoded_summary(all_data, n_all_swrs, task_times, maze_segments, n_sessions=len(infos))

In [None]:
likelihoods

In [None]:
events, position, spikes, lfp, _ = get_data(info)
plot_over_space(info, likelihoods, position, " ")

In [None]:
# Assign swr to a trajectory
def get_max_likelihoods(info):
    events, position, spikes, lfp, _ = get_data(info)

    u_zone, shortcut_zone, novel_zone = get_subset_zones(info, position)
    combined_zones = u_zone+shortcut_zone+novel_zone
    other_zone = ~combined_zones

    tuning_curves = get_only_tuning_curves(info,
                                           position,
                                           spikes,
                                           info.task_times["phase3"])
    tc_shape = tuning_curves.shape
    decoding_tc = tuning_curves.reshape(tc_shape[0], tc_shape[1] * tc_shape[2])

    # Find SWRs for the whole session
    z_thresh = 2.0
    power_thresh = 3.0
    merge_thresh = 0.02
    min_length = 0.05
    swrs = nept.detect_swr_hilbert(lfp, fs=info.fs, thresh=(140.0, 250.0), z_thresh=z_thresh,
                                   power_thresh=power_thresh, merge_thresh=merge_thresh, min_length=min_length)
    swrs = nept.find_multi_in_epochs(spikes, swrs, min_involved=4)

    rest_epochs = nept.rest_threshold(position, thresh=12., t_smooth=0.8)

    # Restrict SWRs to those during epochs of interest during rest
    task_times = ["prerecord", "pauseA", "pauseB", "postrecord"]
    maze_segments = ["u", "shortcut", "novel", "other"]
    data = {k: {key: 0 for key in maze_segments} for k in task_times}

    n_swrs = dict()

    for task_time in task_times:
        epochs_of_interest = info.task_times[task_time].intersect(rest_epochs)

        phase_swrs = epochs_of_interest.overlaps(swrs)
        phase_swrs = phase_swrs[phase_swrs.durations >= 0.05]

        n_swrs[task_time] = phase_swrs.n_epochs

        likelihoods = []

        for start, stop in zip(phase_swrs.starts, phase_swrs.stops):
            sliced_spikes = [spiketrain.time_slice(start, stop) for spiketrain in spikes]

            t_window = stop-start # 0.1 for running, 0.025 for swr

            counts = bin_spikes(sliced_spikes, np.array([start, stop]), dt=t_window, window=t_window,
                                     gaussian_std=0.0075, normalized=False)

            likelihood = nept.bayesian_prob(counts, decoding_tc, binsize=t_window, min_neurons=3, min_spikes=1)

            # Remove nans from likelihood and reshape for plotting
            keep_idx = np.sum(np.isnan(likelihood), axis=1) < likelihood.shape[1]
            likelihood = likelihood[keep_idx]
            likelihoods.append(likelihood.reshape(np.shape(likelihood)[0], tc_shape[1], tc_shape[2]))

        for swr_likelihood in likelihoods:
            data[task_time]["u"] += int(np.any(u_zone & (swr_likelihood == np.nanmax(swr_likelihood))))
            data[task_time]["shortcut"] += int(np.any(shortcut_zone & (swr_likelihood == np.nanmax(swr_likelihood))))
            data[task_time]["novel"] += int(np.any(novel_zone & (swr_likelihood == np.nanmax(swr_likelihood))))
            data[task_time]["other"] += int(np.any(other_zone & (swr_likelihood == np.nanmax(swr_likelihood))))
        
        if phase_swrs.n_epochs > 0:
            data[task_time]["u"] /= phase_swrs.n_epochs
            data[task_time]["shortcut"] /= phase_swrs.n_epochs
            data[task_time]["novel"] /= phase_swrs.n_epochs
            data[task_time]["other"] /= phase_swrs.n_epochs
        
    return data, n_swrs, likelihoods

In [None]:
def plot_max_decoded_summary(all_data, n_all_swrs, task_times, maze_segments, n_sessions):

    for task_time in task_times:
        us = []
        shortcuts = []
        novels = []
        others = []

        n_swrs = 0

        for i, session in enumerate(all_data):
            us.append(session[task_time]["u"])
            shortcuts.append(session[task_time]["shortcut"])
            novels.append(session[task_time]["novel"])
            others.append(session[task_time]["other"])

        means = [np.nanmean(us), np.nanmean(shortcuts), np.nanmean(novels), np.nanmean(others)]
        sems = [scipy.stats.sem(us), scipy.stats.sem(shortcuts), scipy.stats.sem(novels), scipy.stats.sem(others)]

        xx, yy = np.meshgrid(info.xedges, info.yedges)

        fig, ax = plt.subplots(figsize=(7,5))
        n = np.arange(len(maze_segments))
        plt.bar(n, means, yerr=sems, color="#99d8c9")
        plt.xticks(n, maze_segments)
        plt.text(0.95, 0.95, "n sessions: "+str(len(all_data)),
             horizontalalignment='center',
             verticalalignment='center',
             transform = ax.transAxes,
             fontsize=14)
        if n_sessions == 1:
             title = info.session_id + "decoded zone during SWRs in " + task_time
        else:
            title = "Decoded zone during SWRs in " + task_time
        plt.title(title)

        ax.spines['right'].set_visible(False)
        ax.spines['top'].set_visible(False)
        ax.yaxis.set_ticks_position('left')
        ax.xaxis.set_ticks_position('bottom')

        plt.tight_layout()

        plt.savefig(os.path.join(output_filepath, title+".png"))
        plt.close()

    #     plt.show()
    
    
    for trajectory in maze_segments:
        trajectory_means = []
        trajectory_sems = []

        tt = {key: [] for key in task_times}
        n_swrs = {key: 0 for key in task_times}

        for i, session in enumerate(all_data):
            for task_time in task_times:
                tt[task_time].append(session[task_time][trajectory])

                n_swrs[task_time] += n_all_swrs[i][task_time]
        trajectory_means = [np.nanmean(tt["prerecord"]), np.nanmean(tt["pauseA"]), np.nanmean(tt["pauseB"]), np.nanmean(tt["postrecord"])]
        trajectory_sems = [scipy.stats.sem(tt["prerecord"]), scipy.stats.sem(tt["pauseA"]), scipy.stats.sem(tt["pauseB"]), scipy.stats.sem(tt["postrecord"])]

        fig, ax = plt.subplots(figsize=(7,5))
        n = np.arange(len(task_times))
        plt.bar(n, trajectory_means, yerr=trajectory_sems, color="#41ae76")
        plt.xticks(n, task_times)
        if n_sessions == 1:
             title = info.session_id + "decoded zone during SWRs for " + trajectory
        else:
            title = "Decoded zone during SWRs for " + trajectory
        plt.title(title)

        for i, task_time in enumerate(task_times):
            ax.text(i, 0.01, str(n_swrs[task_time]), ha="center", fontsize=14)

        ax.spines['right'].set_visible(False)
        ax.spines['top'].set_visible(False)
        ax.yaxis.set_ticks_position('left')
        ax.xaxis.set_ticks_position('bottom')

        plt.tight_layout()

        plt.savefig(os.path.join(output_filepath, title+".png"))
        plt.close()

    #     plt.show()

In [None]:
# plot individual sessions
for info in infos:
    all_data = []
    n_all_swrs = []
    
    data, n_swrs, likelihoods = get_max_likelihoods(info)

    all_data.append(data)
    n_all_swrs.append(n_swrs)
    
    plot_max_decoded_summary(all_data, n_all_swrs, task_times, maze_segments, n_sessions=1)

In [None]:
# plot combined sessions
all_data = []
n_all_swrs = []

for info in infos:
    data, n_swrs, likelihoods = get_max_likelihoods(info)

all_data.append(data)
n_all_swrs.append(n_swrs)

plot_max_decoded_summary(all_data, n_all_swrs, task_times, maze_segments, n_sessions=len(infos))

In [None]:
average_likelihood = np.array(likelihoods[0][0].shape)
for swr_likelihood in likelihoods:
    print("huh")

In [None]:
xx, yy = np.meshgrid(info.xedges, info.yedges)
plt.plot(position.x, position.y, "r.", ms=1)
pp = plt.pcolormesh(xx, yy, likelihoods[0][0], cmap='bone_r')

plt.colorbar(pp)
plt.axis('off')

plt.show()

In [None]:
xx, yy = np.meshgrid(info.xedges, info.yedges)
plt.plot(position.x, position.y, "r.", ms=1)
plt.pcolormesh(xx, yy, other_zone, cmap='bone_r')

plt.colorbar(pp)
plt.axis('off')

plt.show()

In [None]:
np.nanmean(us)/np.sum(u_zone), np.nanmean(shortcuts)/np.sum(shortcut_zone), np.nanmean(novels)/np.sum(novel_zone), np.nanmean(others)/np.sum(other_zone)