In [None]:
%matplotlib inline
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
from matplotlib.patches import Patch
import numpy as np
import scipy
import os
import scalebar
import nept

from loading_data import get_data
from analyze_tuning_curves import get_only_tuning_curves
from utils_plotting import plot_over_space
from utils_maze import get_zones, get_bin_centers, get_matched_trials

In [None]:
thisdir = os.getcwd()
pickle_filepath = os.path.join(thisdir, "cache", "pickled")
output_filepath = os.path.join(thisdir, "plots", "trials", "decoding")
if not os.path.exists(output_filepath):
    os.makedirs(output_filepath)

In [None]:
import info.r063d2 as r063d2
import info.r063d3 as r063d3
infos = [r063d2, r063d3]
from run import analysis_infos
# infos = analysis_infos

In [None]:
def bin_spikes(spikes, time, dt, window=None, gaussian_std=None, normalized=True):
    """Bins spikes using a sliding window.

    Parameters
    ----------
    spikes: list
        Of nept.SpikeTrain
    time: np.array
    window: float or None
        Length of the sliding window, in seconds. If None, will default to dt.
    dt: float
    gaussian_std: float or None
    normalized: boolean

    Returns
    -------
    binned_spikes: nept.AnalogSignal

    """
    if window is None:
        window = dt

    bin_edges = time

    given_n_bins = window / dt
    n_bins = int(round(given_n_bins))
    if abs(n_bins - given_n_bins) > 0.01:
        warnings.warn("dt does not divide window evenly. "
                      "Using window %g instead." % (n_bins*dt))

    if normalized:
        square_filter = np.ones(n_bins) * (1 / n_bins)
    else:
        square_filter = np.ones(n_bins)

    counts = np.zeros((len(spikes), len(bin_edges) - 1))
    for idx, spiketrain in enumerate(spikes):
        counts[idx] = np.convolve(np.histogram(spiketrain.time, bins=bin_edges)[0].astype(float),
                                  square_filter, mode="same")

    if gaussian_std is not None:
        counts = nept.gaussian_filter(counts, gaussian_std, dt=dt, normalized=normalized, axis=1)

    return nept.AnalogSignal(counts, bin_edges[:-1])

In [None]:
# shuffled_id = False
# if shuffled_id:
#     tuning_curves = np.random.permutation(tuning_curves)

In [None]:
# Average decoded likelihood for each trajectory
def get_combined_likelihoods(info, position, spikes, lfp, zones):
    tuning_curves = get_only_tuning_curves(info,
                                           position,
                                           spikes,
                                           info.task_times["phase3"])
    tc_shape = tuning_curves.shape
    decoding_tc = tuning_curves.reshape(tc_shape[0], tc_shape[1] * tc_shape[2])

    # Find SWRs for the whole session
    z_thresh = 2.0
    power_thresh = 3.0
    merge_thresh = 0.02
    min_length = 0.05
    swrs = nept.detect_swr_hilbert(lfp, fs=info.fs, thresh=(140.0, 250.0), z_thresh=z_thresh,
                                   power_thresh=power_thresh, merge_thresh=merge_thresh, min_length=min_length)
    swrs = nept.find_multi_in_epochs(spikes, swrs, min_involved=4)

    rest_epochs = nept.rest_threshold(position, thresh=12., t_smooth=0.8)

    # Restrict SWRs to those during epochs of interest during rest
    task_times = ["prerecord", "pauseA", "pauseB", "postrecord"]
    data_average = {k: {key: [] for key in zones.keys()} for k in task_times}
    data_max = {k: {key: [] for key in zones.keys()} for k in task_times}
    likelihoods = {k: [] for k in task_times}
    likelihood_array = dict()
    n_swrs = dict()

    for task_time in task_times:
        epochs_of_interest = info.task_times[task_time].intersect(rest_epochs)

        phase_swrs = epochs_of_interest.overlaps(swrs)
        phase_swrs = phase_swrs[phase_swrs.durations >= 0.05]

        n_swrs[task_time] = phase_swrs.n_epochs

        for start, stop in zip(phase_swrs.starts, phase_swrs.stops):
            sliced_spikes = [spiketrain.time_slice(start, stop) for spiketrain in spikes]

            t_window = stop-start # 0.1 for running, 0.025 for swr

            counts = bin_spikes(sliced_spikes, np.array([start, stop]), dt=t_window, window=t_window,
                                     gaussian_std=0.0075, normalized=False)

            likelihood = nept.bayesian_prob(counts, decoding_tc, binsize=t_window, min_neurons=3, min_spikes=1)

            # Remove nans from likelihood and reshape for plotting
            keep_idx = np.sum(np.isnan(likelihood), axis=1) < likelihood.shape[1]
            likelihood = likelihood[keep_idx]
            likelihoods[task_time].append(likelihood.reshape(1, tc_shape[1], tc_shape[2]))

        for swr_likelihood in likelihoods[task_time]:
            for maze_segment in zones.keys():
                data_average[task_time][maze_segment].append(np.nanmean(swr_likelihood[0][zones[maze_segment]]))
                data_max[task_time][maze_segment].append(np.nanmax(swr_likelihood[0][zones[maze_segment]]))
                
        if likelihoods[task_time]:
            likelihood_array[task_time] = np.concatenate(likelihoods[task_time])
        else:
            likelihood_array[task_time] = np.array([])
    
    return data_average, data_max, n_swrs, likelihood_array

In [None]:
# stacked barplot
def plot_decoded_stacked_summary(all_data, n_all_swrs, task_times, maze_segments, n_sessions, savefig=False):
    trajectory_means = {key: [] for key in maze_segments}
    trajectory_sems = {key: [] for key in maze_segments}

    for trajectory in maze_segments:

        tt = {key: [] for key in task_times}
        n_swrs = {key: 0 for key in task_times}

        for i, session in enumerate(all_data):
            for task_time in task_times:
                tt[task_time].extend(session[task_time][trajectory])

                n_swrs[task_time] += n_all_swrs[i][task_time]
        trajectory_means[trajectory] = [np.nanmean(tt["prerecord"]), np.nanmean(tt["pauseA"]), np.nanmean(tt["pauseB"]), np.nanmean(tt["postrecord"])]
        trajectory_sems[trajectory] = [scipy.stats.sem(tt["prerecord"]), scipy.stats.sem(tt["pauseA"]), scipy.stats.sem(tt["pauseB"]), scipy.stats.sem(tt["postrecord"])]

    fig, ax = plt.subplots(figsize=(7,5))
    n = np.arange(len(task_times))
    pu = plt.bar(n, trajectory_means["u"], yerr=trajectory_sems["u"], color="#2b8cbe")
    ps = plt.bar(n, trajectory_means["shortcut"], yerr=trajectory_sems["shortcut"],
                 bottom=trajectory_means["u"], color="#31a354")
    pn = plt.bar(n, trajectory_means["novel"], yerr=trajectory_sems["novel"],
                 bottom=np.array(trajectory_means["u"])+np.array(trajectory_means["shortcut"]), color="#d95f0e")
    po = plt.bar(n, trajectory_means["other"], yerr=trajectory_sems["other"],
                 bottom=np.array(trajectory_means["u"])+np.array(trajectory_means["shortcut"])+np.array(trajectory_means["novel"]), color="#bdbdbd")
    plt.xticks(n, task_times)
    if n_sessions == 1:
        title = info.session_id + " average posteriors during SWRs"
    else:
        title = "Average posteriors during SWRs"
    plt.title(title)
#     plt.ylabel("Proportion")

    for i, task_time in enumerate(task_times):
        ax.text(i, 0.01, str(n_swrs[task_time]), ha="center", fontsize=14)

    ax.spines['right'].set_visible(False)
    ax.spines['top'].set_visible(False)
    ax.yaxis.set_ticks_position('left')
    ax.xaxis.set_ticks_position('bottom')

    plt.tight_layout()
    if savefig:
        plt.savefig(os.path.join(output_filepath, title+"_stacked.png"))
        plt.close()
    else:
        plt.show()

In [None]:
def plot_decoded_combined_average_summary(all_data, n_all_swrs, task_times, maze_segments, n_sessions, savefig=False):
    trajectory_means = {key: [] for key in maze_segments}
    trajectory_sems = {key: [] for key in maze_segments}

    for trajectory in maze_segments:

        tt = {key: [] for key in task_times}
        n_swrs = {key: 0 for key in task_times}

        for i, session in enumerate(all_data):
            for task_time in task_times:
                tt[task_time].extend(session[task_time][trajectory])

                n_swrs[task_time] += n_all_swrs[i][task_time]
        trajectory_means[trajectory] = [np.nanmean(tt["prerecord"]), np.nanmean(tt["pauseA"]), np.nanmean(tt["pauseB"]), np.nanmean(tt["postrecord"])]
        trajectory_sems[trajectory] = [scipy.stats.sem(tt["prerecord"]), scipy.stats.sem(tt["pauseA"]), scipy.stats.sem(tt["pauseB"]), scipy.stats.sem(tt["postrecord"])]

    fig = plt.figure(figsize=(12,6))
    gs1 = gridspec.GridSpec(1, 4)
    gs1.update(wspace=0.3, hspace=0.)

    n = np.arange(len(task_times))
    ax1 = plt.subplot(gs1[0])
    ax1.bar(n, trajectory_means["u"], yerr=trajectory_sems["u"], color="#2b8cbe")
    ax2 = plt.subplot(gs1[1])
    ax2.bar(n, trajectory_means["shortcut"], yerr=trajectory_sems["shortcut"], color="#31a354")
    ax3 = plt.subplot(gs1[2])
    ax3.bar(n, trajectory_means["novel"], yerr=trajectory_sems["novel"], color="#d95f0e")
    ax4 = plt.subplot(gs1[3])
    ax4.bar(n, trajectory_means["other"], yerr=trajectory_sems["other"], color="#bdbdbd")

    for ax in [ax1, ax2, ax3, ax4]:
        ax.set_ylim([0, 0.6])
        
        ax.set_xticks(np.arange(len(task_times)))
        ax.set_xticklabels(task_times, rotation = 90)

        ax.spines['right'].set_visible(False)
        ax.spines['top'].set_visible(False)
        ax.yaxis.set_ticks_position('left')
        ax.xaxis.set_ticks_position('bottom')

        for i, task_time in enumerate(task_times):
            ax.text(i, 0.01, str(n_swrs[task_time]), ha="center", fontsize=14)

    for ax in [ax2, ax3, ax4]:
        ax.set_yticklabels([])

    if n_sessions == 1:
        title = info.session_id + " average posteriors during SWRs"
    else:
        title = "Average posteriors during SWRs"
    fig.suptitle(title, fontsize=18)
#     ax1.set_ylabel("Proportion")

    legend_elements = [Patch(facecolor='#2b8cbe', edgecolor='k', label="u"),
                       Patch(facecolor='#31a354', edgecolor='k', label="shortcut"),
                       Patch(facecolor='#d95f0e', edgecolor='k', label="novel"),
                       Patch(facecolor='#bdbdbd', edgecolor='k', label="other")]
    plt.legend(handles=legend_elements, bbox_to_anchor=(1.05, 1.0))
    
    gs1.tight_layout(fig)
    
    if savefig:
        plt.savefig(os.path.join(output_filepath, title+"_combined-average.png"))
        plt.close()
    else:
        plt.show()

In [None]:
def plot_decoded_combined_max_summary(all_data, n_all_swrs, task_times, maze_segments, n_sessions, savefig=False):
    trajectory_means = {key: [] for key in maze_segments}
    trajectory_sems = {key: [] for key in maze_segments}

    for trajectory in maze_segments:

        tt = {key: [] for key in task_times}
        n_swrs = {key: 0 for key in task_times}

        for i, session in enumerate(all_data):
            for task_time in task_times:
                tt[task_time].extend(session[task_time][trajectory])

                n_swrs[task_time] += n_all_swrs[i][task_time]
        trajectory_means[trajectory] = [np.nanmean(tt["prerecord"]), np.nanmean(tt["pauseA"]), np.nanmean(tt["pauseB"]), np.nanmean(tt["postrecord"])]
        trajectory_sems[trajectory] = [scipy.stats.sem(tt["prerecord"]), scipy.stats.sem(tt["pauseA"]), scipy.stats.sem(tt["pauseB"]), scipy.stats.sem(tt["postrecord"])]

    fig = plt.figure(figsize=(12,6))
    gs1 = gridspec.GridSpec(1, 4)
    gs1.update(wspace=0.3, hspace=0.)

    n = np.arange(len(task_times))
    ax1 = plt.subplot(gs1[0])
    ax1.bar(n, trajectory_means["u"], yerr=trajectory_sems["u"], color="#2b8cbe")
    ax2 = plt.subplot(gs1[1])
    ax2.bar(n, trajectory_means["shortcut"], yerr=trajectory_sems["shortcut"], color="#31a354")
    ax3 = plt.subplot(gs1[2])
    ax3.bar(n, trajectory_means["novel"], yerr=trajectory_sems["novel"], color="#d95f0e")
    ax4 = plt.subplot(gs1[3])
    ax4.bar(n, trajectory_means["other"], yerr=trajectory_sems["other"], color="#bdbdbd")

    for ax in [ax1, ax2, ax3, ax4]:
        ax.set_ylim([0, 0.6])
        
        ax.set_xticks(np.arange(len(task_times)))
        ax.set_xticklabels(task_times, rotation = 90)

        ax.spines['right'].set_visible(False)
        ax.spines['top'].set_visible(False)
        ax.yaxis.set_ticks_position('left')
        ax.xaxis.set_ticks_position('bottom')

        for i, task_time in enumerate(task_times):
            ax.text(i, 0.01, str(n_swrs[task_time]), ha="center", fontsize=14)

    for ax in [ax2, ax3, ax4]:
        ax.set_yticklabels([])

    if n_sessions == 1:
        title = info.session_id + " average max posteriors during SWRs"
    else:
        title = "Average posteriors during SWRs"
    fig.suptitle(title, fontsize=18)
#     ax1.set_ylabel("Proportion")

    legend_elements = [Patch(facecolor='#2b8cbe', edgecolor='k', label="u"),
                       Patch(facecolor='#31a354', edgecolor='k', label="shortcut"),
                       Patch(facecolor='#d95f0e', edgecolor='k', label="novel"),
                       Patch(facecolor='#bdbdbd', edgecolor='k', label="other")]
    plt.legend(handles=legend_elements, bbox_to_anchor=(1.05, 1.0))

    for i, task_time in enumerate(task_times):
        ax1.text(i, 0.01, str(n_swrs[task_time]), ha="center", fontsize=14)
    
    gs1.tight_layout(fig)
    
    if savefig:
        plt.savefig(os.path.join(output_filepath, title+"_combined-max.png"))
        plt.close()
    else:
        plt.show()

In [None]:
def plot_likelihood_overspace(info, position, likelihoods, zones, savefig=False):
    
    xx, yy = np.meshgrid(info.xedges, info.yedges)
    xcenters, ycenters = get_bin_centers(info)
    xxx, yyy = np.meshgrid(xcenters, ycenters)
    
    sliced_position = position.time_slice(info.task_times["phase3"].starts, info.task_times["phase3"].stops)
    plt.plot(sliced_position.x, sliced_position.y, "b.", ms=1, alpha=0.2)
    pp = plt.pcolormesh(xx, yy, np.nanmean(likelihoods[task_time], axis=0), vmax=0.2, cmap='bone_r')
    plt.contour(xxx, yyy, zones["u"], levels=0, colors='#2b8cbe', corner_mask=False)
    plt.contour(xxx, yyy, zones["shortcut"], levels=0, colors='#31a354', corner_mask=False)
    plt.contour(xxx, yyy, zones["novel"], levels=0, colors='#d95f0e', corner_mask=False)

    plt.colorbar(pp)
    plt.axis('off')
    if savefig:
        filename = info.session_id+"-average-likelihood-overspace_"+ordered_task_times[i]+".png"
        plt.savefig(os.path.join(output_filepath, filename))
        plt.close()
    else:
        plt.show()

In [None]:
task_times = ["prerecord", "pauseA", "pauseB", "postrecord"]
maze_segments = ["u", "shortcut", "novel", "other"]
ordered_task_times = ["1-prerecord", "2-pauseA", "3-pauseB", "4-postrecord"]

In [None]:
savefig = True

In [None]:
# plot individual sessions       
for info in infos:
    events, position, spikes, lfp, _ = get_data(info)

    zones = dict()
    zones["u"], zones["shortcut"], zones["novel"] = get_zones(info, position, subset=True)
    combined_zones = zones["u"] + zones["shortcut"] + zones["novel"]
    zones["other"] = ~combined_zones
    
    data_average, data_max, n_swrs, likelihoods = get_combined_likelihoods(info, position, spikes, lfp, zones)
    
    plot_decoded_stacked_summary([data_average], [n_swrs], task_times, maze_segments, n_sessions=1, savefig=savefig)
    plot_decoded_combined_average_summary([data_average], [n_swrs], task_times, maze_segments, n_sessions=1, savefig=savefig)
    plot_decoded_combined_max_summary([data_max], [n_swrs], task_times, maze_segments, n_sessions=1, savefig=savefig)
    
    for i, task_time in enumerate(task_times):
        if likelihoods[task_time].size > 0:
            plot_likelihood_overspace(info, position, likelihoods, zones, savefig=True)

In [None]:
# plot combined sessions
all_data_average = []
all_data_max = []
n_all_swrs = []

for info in infos:
    events, position, spikes, lfp, _ = get_data(info)
    
    zones = dict()
    zones["u"], zones["shortcut"], zones["novel"] = get_zones(info, position, subset=True)
    combined_zones = zones["u"] + zones["shortcut"] + zones["novel"]
    zones["other"] = ~combined_zones
    
    data_average, data_max, n_swrs, likelihoods = get_combined_likelihoods(info, position, spikes, lfp, zones)

all_data_average.append(data_average)
all_data_max.append(data_max)
n_all_swrs.append(n_swrs)

plot_decoded_stacked_summary(all_data_average, n_all_swrs, task_times, maze_segments, n_sessions=len(infos), savefig=savefig)
plot_decoded_combined_average_summary(all_data_average, n_all_swrs, task_times, maze_segments, n_sessions=len(infos), savefig=savefig)
plot_decoded_combined_max_summary(all_data_max, n_all_swrs, task_times, maze_segments, n_sessions=len(infos), savefig=savefig)

In [None]:
1/0

In [None]:
def plot_summary_mean(info, position, lfp, spikes, start, stop, likelihood, tc_shape, zones, filepath=None, savefig=False):
    buffer=0.1

    sliced_spikes = [spiketrain.time_slice(start-buffer, stop+buffer) for spiketrain in spikes]

    rows = len(sliced_spikes)
    add_rows = int(rows / 8)

    ms = 600 / rows
    mew = 0.7
    spike_loc = 1

    fig = plt.figure(figsize=(8, 8))
    gs1 = gridspec.GridSpec(3, 2)
    gs1.update(wspace=0.3, hspace=0.3)

    ax1 = plt.subplot(gs1[1:, 0])
    for idx, neuron_spikes in enumerate(sliced_spikes):
        ax1.plot(neuron_spikes.time, np.ones(len(neuron_spikes.time)) + (idx * spike_loc), '|',
                 color='k', ms=ms, mew=mew)

    ax1.axis('off')

    ax2 = plt.subplot(gs1[0, 0], sharex=ax1)

    start_idx = nept.find_nearest_idx(lfp.time, start - buffer)
    stop_idx = nept.find_nearest_idx(lfp.time, stop + buffer)
    ax2.plot(lfp.time[start_idx:stop_idx], lfp.data[start_idx:stop_idx], '#3288bd', lw=0.3)

    start_idx = nept.find_nearest_idx(lfp.time, start)
    stop_idx = nept.find_nearest_idx(lfp.time, stop)
    ax2.plot(lfp.time[start_idx:stop_idx], lfp.data[start_idx:stop_idx], 'k', lw=0.4)

    ax2.axis("off")

    scalebar.add_scalebar(ax2, matchy=False, bbox_transform=fig.transFigure,
                          bbox_to_anchor=(0.25, 0.05), units='ms')


    likelihood = likelihood.reshape(tc_shape[1], tc_shape[2])
    likelihood[np.isnan(likelihood)] = 0
    
    xx, yy = np.meshgrid(info.xedges, info.yedges)
    xcenters, ycenters = get_bin_centers(info)
    xxx, yyy = np.meshgrid(xcenters, ycenters)

    ax3 = plt.subplot(gs1[0, 1])
    sliced_position = position.time_slice(info.task_times["phase3"].starts, info.task_times["phase3"].stops)
    ax3.plot(sliced_position.x, sliced_position.y, "y.", ms=1, alpha=0.2)
    pp = ax3.pcolormesh(xx, yy, likelihood, cmap='bone_r')
    ax3.contour(xxx, yyy, zones["u"], levels=0, colors='#2b8cbe')
    ax3.contour(xxx, yyy, zones["shortcut"], levels=0, colors='#31a354')
    ax3.contour(xxx, yyy, zones["novel"], levels=0, colors='#d95f0e')

    plt.colorbar(pp)
    ax3.axis('off')

    ax4 = plt.subplot(gs1[1:, 1])
    means = [np.nanmean(likelihood[zones[trajectory]]) for trajectory in maze_segments]
    sems = [scipy.stats.sem(likelihood[zones[trajectory]]) for trajectory in maze_segments]
    n = np.arange(len(maze_segments))
    ax4.bar(n, means, yerr=sems, color=['#2b8cbe', '#31a354', '#d95f0e', '#bdbdbd'], edgecolor='k')
    ax4.set_xticks(n)
    ax4.set_xticklabels(maze_segments, rotation=90)
    ax4.set_ylim([0, 0.1])
    
    plt.tight_layout()
    
    if savefig:
        plt.savefig(filepath)
        plt.close()
    else:
        plt.show()

In [None]:
def plot_summary_max(info, position, lfp, spikes, start, stop, likelihood, tc_shape, zones, filepath=None, savefig=False):
    buffer=0.1

    sliced_spikes = [spiketrain.time_slice(start-buffer, stop+buffer) for spiketrain in spikes]

    rows = len(sliced_spikes)
    add_rows = int(rows / 8)

    ms = 600 / rows
    mew = 0.7
    spike_loc = 1

    fig = plt.figure(figsize=(8, 8))
    gs1 = gridspec.GridSpec(3, 2)
    gs1.update(wspace=0.3, hspace=0.3)

    ax1 = plt.subplot(gs1[1:, 0])
    for idx, neuron_spikes in enumerate(sliced_spikes):
        ax1.plot(neuron_spikes.time, np.ones(len(neuron_spikes.time)) + (idx * spike_loc), '|',
                 color='k', ms=ms, mew=mew)

    ax1.axis('off')

    ax2 = plt.subplot(gs1[0, 0], sharex=ax1)

    start_idx = nept.find_nearest_idx(lfp.time, start - buffer)
    stop_idx = nept.find_nearest_idx(lfp.time, stop + buffer)
    ax2.plot(lfp.time[start_idx:stop_idx], lfp.data[start_idx:stop_idx], '#3288bd', lw=0.3)

    start_idx = nept.find_nearest_idx(lfp.time, start)
    stop_idx = nept.find_nearest_idx(lfp.time, stop)
    ax2.plot(lfp.time[start_idx:stop_idx], lfp.data[start_idx:stop_idx], 'k', lw=0.4)

    ax2.axis("off")

    scalebar.add_scalebar(ax2, matchy=False, bbox_transform=fig.transFigure,
                          bbox_to_anchor=(0.25, 0.05), units='ms')

    likelihood = likelihood.reshape(tc_shape[1], tc_shape[2])
    likelihood[np.isnan(likelihood)] = 0
    
    xx, yy = np.meshgrid(info.xedges, info.yedges)
    xcenters, ycenters = get_bin_centers(info)
    xxx, yyy = np.meshgrid(xcenters, ycenters)

    ax3 = plt.subplot(gs1[0, 1])
    sliced_position = position.time_slice(info.task_times["phase3"].starts, info.task_times["phase3"].stops)
    ax3.plot(sliced_position.x, sliced_position.y, "y.", ms=1, alpha=0.2)
    pp = ax3.pcolormesh(xx, yy, likelihood, cmap='bone_r')
    ax3.contour(xxx, yyy, zones["u"], levels=0, colors='#2b8cbe')
    ax3.contour(xxx, yyy, zones["shortcut"], levels=0, colors='#31a354')
    ax3.contour(xxx, yyy, zones["novel"], levels=0, colors='#d95f0e')

    plt.colorbar(pp)
    ax3.axis('off')

    ax4 = plt.subplot(gs1[1:, 1])
    means = [np.nanmax(likelihood[zones[trajectory]]) for trajectory in maze_segments]
    n = np.arange(len(maze_segments))
    ax4.bar(n, means, color=['#2b8cbe', '#31a354', '#d95f0e', '#bdbdbd'], edgecolor='k')
    ax4.set_xticks(n)
    ax4.set_xticklabels(maze_segments, rotation=90)
    ax4.set_ylim([0, 1.])
    
    plt.tight_layout()
    
    if savefig:
        plt.savefig(filepath)
        plt.close()
    else:
        plt.show()

In [None]:
def get_likelihood(spikes, tuning_curves, start, stop):
    sliced_spikes = [spiketrain.time_slice(start, stop) for spiketrain in spikes]
    t_window = stop-start # 0.1 for running, 0.025 for swr
    counts = bin_spikes(sliced_spikes, np.array([start, stop]), dt=t_window, window=t_window,
                        gaussian_std=0.0075, normalized=False)
    likelihood = nept.bayesian_prob(counts, decoding_tc, binsize=t_window, min_neurons=3, min_spikes=1)
    
    return likelihood

In [None]:
for info in infos:
    events, position, spikes, lfp, _ = get_data(info)
    
    tuning_curves = get_only_tuning_curves(info,
                                           position,
                                           spikes,
                                           info.task_times["phase3"])

    tc_shape = tuning_curves.shape
    decoding_tc = tuning_curves.reshape(tc_shape[0], tc_shape[1] * tc_shape[2])
    
    # Define zones
    zones = dict()
    zones["u"], zones["shortcut"], zones["novel"] = get_zones(info, position, subset=True)
    combined_zones = zones["u"] + zones["shortcut"] + zones["novel"]
    zones["other"] = ~combined_zones

    # Find SWRs for the whole session
    z_thresh = 2.0
    power_thresh = 3.0
    merge_thresh = 0.02
    min_length = 0.05
    swrs = nept.detect_swr_hilbert(lfp, fs=info.fs, thresh=(140.0, 250.0), z_thresh=z_thresh,
                                   power_thresh=power_thresh, merge_thresh=merge_thresh, min_length=min_length)
    swrs = nept.find_multi_in_epochs(spikes, swrs, min_involved=4)

    rest_epochs = nept.rest_threshold(position, thresh=12., t_smooth=0.8)

    # Restrict SWRs to those during epochs of interest during rest
    task_times = ["prerecord", "pauseA", "pauseB", "postrecord"]
    
    data = {k: {key: [] for key in zones.keys()} for k in task_times}
    likelihoods = {k: [] for k in task_times}
    likelihood_array = dict()
    n_swrs = dict()

    for task_time in task_times:
        epochs_of_interest = info.task_times[task_time].intersect(rest_epochs)

        phase_swrs = epochs_of_interest.overlaps(swrs)
        phase_swrs = phase_swrs[phase_swrs.durations >= 0.05]

        n_swrs[task_time] = phase_swrs.n_epochs

        for i, (start, stop) in enumerate(zip(phase_swrs.starts[:5], phase_swrs.stops[:5])):
            likelihood = get_likelihood(spikes, tuning_curves, start, stop)
            
            likelihoods[task_time].append(likelihood.reshape(tc_shape[1], tc_shape[2]))

            filename = info.session_id + "_" + task_time + "_summary-swr" + str(i) + "_average.png"
            filepath = os.path.join(output_filepath, "swr", filename)
            plot_summary_mean(info, position, lfp, spikes, start, stop, likelihood, 
                              tc_shape, zones, filepath, savefig=True)

            filename = info.session_id + "_" + task_time + "_summary-swr" + str(i) + "_max.png"
            filepath = os.path.join(output_filepath, "swr", filename)
            plot_summary_max(info, position, lfp, spikes, start, stop, likelihood, 
                             tc_shape, zones, filepath, savefig=True)

#         # Remove nans from likelihood and reshape for plotting
#         keep_idx = np.sum(np.isnan(likelihood), axis=1) < likelihood.shape[1]
#         likelihood = likelihood[keep_idx]
#         likelihoods[task_time].append(likelihood.reshape(1, tc_shape[1], tc_shape[2]))

#     for swr_likelihood in likelihoods[task_time]:
#         for maze_segment in zones.keys():
#             data[task_time][maze_segment].append(np.nansum(swr_likelihood[0][zones[maze_segment]]))
#     if likelihoods[task_time]:
#         likelihood_array[task_time] = np.concatenate(likelihoods[task_time])
#     else:
#         likelihood_array[task_time] = np.array([])

In [None]:
[np.nanmax(likelihoods["prerecord"][zones[trajectory]], axis=0) for trajectory in maze_segments]

In [None]:
letssee = [np.nanmax(likelihoods[task_time], axis=1) for task_time in task_times]

In [None]:
np.nanmax(likelihoods["prerecord"][0][zones["shortcut"]])

In [None]:
len(likelihoods["prerecord"])

In [None]:
1/0

In [None]:
likelihood = likelihood.reshape(tc_shape[1], tc_shape[2])
likelihood[np.isnan(likelihood)] = 0

means = [np.nanmean(likelihood[zones[trajectory]]) for trajectory in maze_segments]
sems = [scipy.stats.sem(likelihood[zones[trajectory]]) for trajectory in maze_segments]
n = np.arange(len(maze_segments))
plt.bar(n, means, yerr=sems, color=['#2b8cbe', '#31a354', '#d95f0e', '#bdbdbd'], edgecolor='k')
plt.xticks(n, maze_segments)
plt.show()

In [None]:
plot_summary(info, position, lfp, spikes, start, stop, likelihood)

In [None]:
1/0

In [None]:
# Assign swr to a trajectory
def get_max_likelihoods(info):
    events, position, spikes, lfp, _ = get_data(info)

    u_zone, shortcut_zone, novel_zone = get_zones(info, position, subset=False)
    combined_zones = u_zone+shortcut_zone+novel_zone
    other_zone = ~combined_zones

    tuning_curves = get_only_tuning_curves(info,
                                           position,
                                           spikes,
                                           info.task_times["phase3"])
    tc_shape = tuning_curves.shape
    decoding_tc = tuning_curves.reshape(tc_shape[0], tc_shape[1] * tc_shape[2])

    # Find SWRs for the whole session
    z_thresh = 2.0
    power_thresh = 3.0
    merge_thresh = 0.02
    min_length = 0.05
    swrs = nept.detect_swr_hilbert(lfp, fs=info.fs, thresh=(140.0, 250.0), z_thresh=z_thresh,
                                   power_thresh=power_thresh, merge_thresh=merge_thresh, min_length=min_length)
    swrs = nept.find_multi_in_epochs(spikes, swrs, min_involved=4)

    rest_epochs = nept.rest_threshold(position, thresh=12., t_smooth=0.8)

    # Restrict SWRs to those during epochs of interest during rest
    task_times = ["prerecord", "pauseA", "pauseB", "postrecord"]
    maze_segments = ["u", "shortcut", "novel", "other"]
    data = {k: {key: 0 for key in maze_segments} for k in task_times}

    n_swrs = dict()

    for task_time in task_times:
        epochs_of_interest = info.task_times[task_time].intersect(rest_epochs)

        phase_swrs = epochs_of_interest.overlaps(swrs)
        phase_swrs = phase_swrs[phase_swrs.durations >= 0.05]

        n_swrs[task_time] = phase_swrs.n_epochs

        likelihoods = []

        for start, stop in zip(phase_swrs.starts, phase_swrs.stops):
            sliced_spikes = [spiketrain.time_slice(start, stop) for spiketrain in spikes]

            t_window = stop-start # 0.1 for running, 0.025 for swr

            counts = bin_spikes(sliced_spikes, np.array([start, stop]), dt=t_window, window=t_window,
                                     gaussian_std=0.0075, normalized=False)

            likelihood = nept.bayesian_prob(counts, decoding_tc, binsize=t_window, min_neurons=3, min_spikes=1)

            # Remove nans from likelihood and reshape for plotting
            keep_idx = np.sum(np.isnan(likelihood), axis=1) < likelihood.shape[1]
            likelihood = likelihood[keep_idx]
            likelihoods.append(likelihood.reshape(np.shape(likelihood)[0], tc_shape[1], tc_shape[2]))

        for swr_likelihood in likelihoods:
            data[task_time]["u"] += int(np.any(u_zone & (swr_likelihood == np.nanmax(swr_likelihood))))
            data[task_time]["shortcut"] += int(np.any(shortcut_zone & (swr_likelihood == np.nanmax(swr_likelihood))))
            data[task_time]["novel"] += int(np.any(novel_zone & (swr_likelihood == np.nanmax(swr_likelihood))))
            data[task_time]["other"] += int(np.any(other_zone & (swr_likelihood == np.nanmax(swr_likelihood))))
        
        if phase_swrs.n_epochs > 0:
            data[task_time]["u"] /= phase_swrs.n_epochs
            data[task_time]["shortcut"] /= phase_swrs.n_epochs
            data[task_time]["novel"] /= phase_swrs.n_epochs
            data[task_time]["other"] /= phase_swrs.n_epochs
        
    return data, n_swrs, likelihoods

In [None]:
def plot_max_decoded_summary(all_data, n_all_swrs, task_times, maze_segments, n_sessions):

    for task_time in task_times:
        us = []
        shortcuts = []
        novels = []
        others = []

        n_swrs = 0

        for i, session in enumerate(all_data):
            us.append(session[task_time]["u"])
            shortcuts.append(session[task_time]["shortcut"])
            novels.append(session[task_time]["novel"])
            others.append(session[task_time]["other"])

        means = [np.nanmean(us), np.nanmean(shortcuts), np.nanmean(novels), np.nanmean(others)]
        sems = [scipy.stats.sem(us), scipy.stats.sem(shortcuts), scipy.stats.sem(novels), scipy.stats.sem(others)]

        xx, yy = np.meshgrid(info.xedges, info.yedges)

        fig, ax = plt.subplots(figsize=(7,5))
        n = np.arange(len(maze_segments))
        plt.bar(n, means, yerr=sems, color="#99d8c9")
        plt.xticks(n, maze_segments)
        plt.text(0.95, 0.95, "n sessions: "+str(len(all_data)),
             horizontalalignment='center',
             verticalalignment='center',
             transform = ax.transAxes,
             fontsize=14)
        if n_sessions == 1:
             title = info.session_id + "decoded zone during SWRs in " + task_time
        else:
            title = "Decoded zone during SWRs in " + task_time
        plt.title(title)

        ax.spines['right'].set_visible(False)
        ax.spines['top'].set_visible(False)
        ax.yaxis.set_ticks_position('left')
        ax.xaxis.set_ticks_position('bottom')

        plt.tight_layout()

        plt.savefig(os.path.join(output_filepath, title+".png"))
        plt.close()

    #     plt.show()
    
    
    for trajectory in maze_segments:
        trajectory_means = []
        trajectory_sems = []

        tt = {key: [] for key in task_times}
        n_swrs = {key: 0 for key in task_times}

        for i, session in enumerate(all_data):
            for task_time in task_times:
                tt[task_time].append(session[task_time][trajectory])

                n_swrs[task_time] += n_all_swrs[i][task_time]
        trajectory_means = [np.nanmean(tt["prerecord"]), np.nanmean(tt["pauseA"]), np.nanmean(tt["pauseB"]), np.nanmean(tt["postrecord"])]
        trajectory_sems = [scipy.stats.sem(tt["prerecord"]), scipy.stats.sem(tt["pauseA"]), scipy.stats.sem(tt["pauseB"]), scipy.stats.sem(tt["postrecord"])]

        fig, ax = plt.subplots(figsize=(7,5))
        n = np.arange(len(task_times))
        plt.bar(n, trajectory_means, yerr=trajectory_sems, color="#41ae76")
        plt.xticks(n, task_times)
        if n_sessions == 1:
             title = info.session_id + "decoded zone during SWRs for " + trajectory
        else:
            title = "Decoded zone during SWRs for " + trajectory
        plt.title(title)

        for i, task_time in enumerate(task_times):
            ax.text(i, 0.01, str(n_swrs[task_time]), ha="center", fontsize=14)

        ax.spines['right'].set_visible(False)
        ax.spines['top'].set_visible(False)
        ax.yaxis.set_ticks_position('left')
        ax.xaxis.set_ticks_position('bottom')

        plt.tight_layout()

        plt.savefig(os.path.join(output_filepath, title+".png"))
        plt.close()

    #     plt.show()

In [None]:
# plot individual sessions
for info in infos:
    all_data = []
    n_all_swrs = []
    
    data, n_swrs, likelihoods = get_max_likelihoods(info)

    all_data.append(data)
    n_all_swrs.append(n_swrs)
    
    plot_max_decoded_summary(all_data, n_all_swrs, task_times, maze_segments, n_sessions=1)

In [None]:
# plot combined sessions
all_data = []
n_all_swrs = []

for info in infos:
    data, n_swrs, likelihoods = get_max_likelihoods(info)

all_data.append(data)
n_all_swrs.append(n_swrs)

plot_max_decoded_summary(all_data, n_all_swrs, task_times, maze_segments, n_sessions=len(infos))