In [None]:
%matplotlib inline
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
from matplotlib.patches import Patch
import numpy as np
import warnings
import random
import scipy
import pickle
import os
import nept

from loading_data import get_data
from analyze_tuning_curves import get_only_tuning_curves
from utils_maze import get_zones
from analyze_decode_swrs import (plot_summary_individual,
                                 plot_likelihood_overspace,
                                 plot_combined,
                                 plot_stacked_summary)

from utils_maze import get_bin_centers
from analyze_classy_decode import bin_spikes

thisdir = os.getcwd()
pickle_filepath = os.path.join(thisdir, "cache", "pickled")
output_filepath = os.path.join(thisdir, "plots", "trials", "decoding", "shuffled")
if not os.path.exists(output_filepath):
    os.makedirs(output_filepath)

# Set random seeds
random.seed(0)
np.random.seed(0)

In [None]:
class Session:
    """A collection of LikelihoodsAtTaskTime for each session

        Parameters
        ----------
        task_times : dict of TaskTime

    """

    def __init__(self, task_labels, zones):
        for task_label in task_labels:
            setattr(self, task_label, TaskTime([], [], [], zones))

    def pickle(self, save_path):
        with open(save_path, 'wb') as fileobj:
            print("Saving " + save_path)
            pickle.dump(self, fileobj)

    def n_tasktimes(self):
        return len(task_labels)


class TaskTime:
    """A set of decoded likelihoods for a given task time

        Parameters
        ----------
        likelihoods : np.array
            With shape (ntimebins, nxbins, nybins)
        zones : dict of Zones

        Attributes
        ----------
        likelihoods : np.array
            With shape (ntimebins, nxbins, nybins)

    """

    def __init__(self, tuning_curves, swrs, likelihoods, zones):
        self.tuning_curves = tuning_curves
        self.swrs = swrs
        self.likelihoods = likelihoods
        self.zones = zones

    def sums(self, zone_label):
        if len(self.likelihoods) > 0:
            return np.nansum(self.likelihoods[:, :, self.zones[zone_label]], axis=2)
        else:
            return np.nan

    def means(self, zone_label):
        if len(self.likelihoods) > 0:
            return np.nanmean(self.likelihoods[:, :, self.zones[zone_label]], axis=2)
        else:
            return np.nan

    def maxs(self, zone_label):
        if len(self.likelihoods) > 0:
            return np.nanmax(self.likelihoods[:, :, self.zones[zone_label]], axis=2)
        else:
            return np.nan

In [None]:
import info.r068d7 as r068d7
info = r068d7

In [None]:
n_shuffles = 2
percentile_thresh = 99

task_labels = ["prerecord", "pauseA", "pauseB", "postrecord"]
zone_labels = ["u", "shortcut", "novel", "other"]

swr_params = dict()
swr_params["z_thresh"] = 2.0
swr_params["power_thresh"] = 3.0
swr_params["merge_thresh"] = 0.02
swr_params["min_length"] = 0.05
swr_params["swr_thresh"] = (140.0, 250.0)
swr_params["min_involved"] = 4

colours = dict()
colours["u"] = "#2b8cbe"
colours["shortcut"] = "#31a354"
colours["novel"] = "#d95f0e"
colours["other"] = "#bdbdbd"

In [None]:
def get_likelihoods(info, swr_params, task_labels, zone_labels, n_shuffles=0, save_path=None):

    _, position, spikes, lfp, _ = get_data(info)

    zones = dict()
    zones["u"], zones["shortcut"], zones["novel"] = get_zones(info, position, subset=True)
    combined_zones = zones["u"] + zones["shortcut"] + zones["novel"]
    zones["other"] = ~combined_zones

    session = Session(task_labels, zones)

    tuning_curves_fromdata = get_only_tuning_curves(info, position, spikes, info.task_times["phase3"])

    tc_shape = tuning_curves_fromdata.shape

    swrs = nept.detect_swr_hilbert(lfp,
                                   fs=info.fs,
                                   thresh=swr_params["swr_thresh"],
                                   z_thresh=swr_params["z_thresh"],
                                   power_thresh=swr_params["power_thresh"],
                                   merge_thresh=swr_params["merge_thresh"],
                                   min_length=swr_params["min_length"])
    swrs = nept.find_multi_in_epochs(spikes, swrs, min_involved=swr_params["min_involved"])

    rest_epochs = nept.rest_threshold(position, thresh=12., t_smooth=0.8)

    if n_shuffles > 0:
        n_passes = n_shuffles
    else:
        n_passes = 1

    for task_label in task_labels:
        epochs_of_interest = info.task_times[task_label].intersect(rest_epochs)

        phase_swrs = epochs_of_interest.overlaps(swrs)
        phase_swrs = phase_swrs[phase_swrs.durations >= 0.05]

        phase_likelihoods = np.zeros((n_passes, phase_swrs.n_epochs, tc_shape[1], tc_shape[2]))
        phase_tuningcurves = np.zeros((n_passes, tc_shape[0], tc_shape[1], tc_shape[2]))
        for n_pass in range(n_passes):

            if n_shuffles > 0:
                tuning_curves = np.random.permutation(tuning_curves_fromdata)
            else:
                tuning_curves = tuning_curves_fromdata

            phase_tuningcurves[n_pass, ] = tuning_curves
            tuning_curves = tuning_curves.reshape(tc_shape[0], tc_shape[1] * tc_shape[2])

            for n_timebin, (start, stop) in enumerate(zip(phase_swrs.starts,
                                                          phase_swrs.stops)):
                t_window = stop-start  # 0.1 for running, 0.025 for swr

                sliced_spikes = [spiketrain.time_slice(start, stop) for spiketrain in spikes]

                counts = bin_spikes(sliced_spikes, np.array([start, stop]), dt=t_window, window=t_window,
                                    gaussian_std=0.0075, normalized=False)

                likelihood = nept.bayesian_prob(counts, tuning_curves, binsize=t_window,
                                                min_neurons=3, min_spikes=1)

                phase_likelihoods[n_pass, n_timebin] = likelihood.reshape(tc_shape[1], tc_shape[2])

        tasktime = getattr(session, task_label)
        tasktime.likelihoods = phase_likelihoods
        tasktime.tuning_curves = phase_tuningcurves
        tasktime.swrs = phase_swrs

    if save_path is not None:
        session.pickle(save_path)

    return session

In [None]:
plot_individual = False
update_cache = True
dont_save_pickle = False


print(info.session_id)

# Get true data
true_path = os.path.join(pickle_filepath, info.session_id+"_likelihoods_true.pkl")

# Remove previous pickle if update_cache
if update_cache:
    if os.path.exists(true_path):
        os.remove(true_path)

# Load pickle if it exists, otherwise compute and pickle
if os.path.exists(true_path):
    print("Loading pickled true likelihoods...")
    compute_likelihoods = False
    with open(true_path, 'rb') as fileobj:
        true_session = pickle.load(fileobj)
else:
    if dont_save_pickle:
        true_path = None
    true_session = get_likelihoods(info,
                                   swr_params,
                                   task_labels,
                                   zone_labels,
                                   save_path=true_path)

# Get shuffled data
shuffled_path = os.path.join(pickle_filepath,
                             info.session_id+"_likelihoods_shuffled-%03d.pkl" % n_shuffles)

# Remove previous pickle if update_cache
if update_cache:
    if os.path.exists(shuffled_path):
        os.remove(shuffled_path)

# Load pickle if it exists, otherwise compute and pickle
if os.path.exists(shuffled_path):
    print("Loading pickled shuffled likelihoods...")
    with open(shuffled_path, 'rb') as fileobj:
        shuffled_session = pickle.load(fileobj)
else:
    if dont_save_pickle:
        shuffled_path = None
    shuffled_session = get_likelihoods(info,
                                       swr_params,
                                       task_labels,
                                       zone_labels,
                                       n_shuffles=n_shuffles,
                                       save_path=shuffled_path)

In [None]:
xcenters, ycenters = get_bin_centers(info)
xx, yy = np.meshgrid(xcenters, ycenters)
    
multiple_tuning_curves = np.zeros((tc_shape[1], tc_shape[2]))
cmap = plt.cm.get_cmap('bone_r', 25)
plt.figure()

for n_pass in range(n_passes):
    for i in range(tc_shape[0]):
        multiple_tuning_curves += session.pauseA.tuning_curves[n_pass, i]

    pp = plt.pcolormesh(xx, yy, multiple_tuning_curves, cmap=cmap)
    plt.colorbar(pp)
    plt.axis('off')
    plt.tight_layout()
    plt.show()

In [None]:
def plot_session(session, n_sessions, title, filepath=None):

    trajectory_means = {zone_label: [] for zone_label in zone_labels}
    trajectory_sems = {zone_label: [] for zone_label in zone_labels}

    for zone_label in zone_labels:
        trajectory_means[zone_label] = [np.nanmean(getattr(session, task_label).sums(zone_label))
                                        if len(getattr(session, task_label).sums(zone_label)) > 0 else 0.0
                                        for task_label in task_labels]

        trajectory_sems[zone_label] = [np.mean(scipy.stats.sem(getattr(session, task_label).sums(zone_label), 
                                                               axis=1,
                                                               nan_policy="omit"))
                                       if getattr(session, task_label).sums(zone_label).shape[1] > 1 else 0.0
                                       for task_label in task_labels]

    fig = plt.figure(figsize=(12, 6))

    gs1 = gridspec.GridSpec(1, 4)
    gs1.update(wspace=0.3, hspace=0.)

    n = np.arange(session.n_tasktimes())
    for i, zone_label in enumerate(zone_labels):
        ax = plt.subplot(gs1[i])
        ax.bar(n, trajectory_means[zone_label], yerr=trajectory_sems[zone_label], color=colours[zone_label])

        ax.set_ylim([0, 1.])

        ax.set_xticks(np.arange(session.n_tasktimes()))
        ax.set_xticklabels(task_labels, rotation = 90)

        ax.spines['right'].set_visible(False)
        ax.spines['top'].set_visible(False)
        ax.yaxis.set_ticks_position('left')
        ax.xaxis.set_ticks_position('bottom')

        for n_tasktimes, task_label in enumerate(task_labels):
            ax.text(n_tasktimes, 0.01, str(getattr(session, task_label).likelihoods.shape[1]), 
                    ha="center", fontsize=14)

        if i > 0:
            ax.set_yticklabels([])

        if i == 0:
            ax.set_ylabel("Proportion")

    plt.text(1., 1., "n sessions: "+ str(n_sessions), horizontalalignment='left',
             verticalalignment='top', fontsize=14)

    fig.suptitle(title, fontsize=16)

    legend_elements = [Patch(facecolor=colours[zone_label], edgecolor='k', label=zone_label) 
                       for zone_label in zone_labels]

    plt.legend(handles=legend_elements, bbox_to_anchor=(2.1, 0.95))

    gs1.tight_layout(fig)

    if filepath is not None:
        plt.savefig(filepath)
    else:
        plt.show()

In [None]:
n_sessions = 1
title = "testingthis"
filepath = os.path.join(output_filepath, title+".png")
session = true_session
# session = shuffled_session

In [None]:
plot_session(session, n_sessions, title)

In [None]:
sessions = [session, session]

In [None]:
combined_means = {zone_label: [] for zone_label in zone_labels}
trajectory_means = {zone_label: [] for zone_label in zone_labels}
for session in sessions:
    zone_label = "u"
#     for zone_label in zone_labels:
    task_label = "pauseA"
#         for task_label in task_labels:
    combined_means[zone_label].append(getattr(session, task_label).sums(zone_label))
#         print([np.nanmean(getattr(session, task_label).sums(zone_label))
#                                         if len(getattr(session, task_label).sums(zone_label)) > 0 else 0.0
#                                         for task_label in task_labels])

In [None]:
combined_means["u"]

In [None]:
session = true_session

In [None]:
session.pauseA.likelihoods.shape

In [None]:
session.pauseA.swrs[0].starts

In [None]:
session.pauseA.swrs.n_epochs

In [None]:
session.pauseA.likelihoods.shape

In [None]:
session.pauseA.sums("u").shape

In [None]:
session.pauseA.zones["u"].shape

In [None]:
np.nansum(session.pauseA.likelihoods[:, :, session.pauseA.zones["u"]], axis=2)

In [None]:
for i in range(session.pauseA.swrs.n_epochs):
    a = session.pauseA.likelihoods[:, i][0]

In [None]:
xx, yy = get_bin_centers(info)
print(xx.shape, yy.shape)

In [None]:
p = plt.pcolormesh(xx, yy, a)
plt.show()

In [None]:
def plot_summary_individual(info, session_true, session_shuffled, zone_labels, task_labels, colours, filepath=None):
    
    _, position, spikes, lfp, _ = get_data(info)
    
    buffer = 0.1

    for task_label in task_labels:
        getattr(session, task_label).likelihoods[i][zone_label]
        means = [ ]

    means_shuff = [np.nanmean(np.nansum(likelihood_shuff[:, zones[trajectory]], axis=1)) for trajectory in maze_segments]
    sems_shuff = [scipy.stats.sem(np.nansum(likelihood_shuff[:, zones[trajectory]], axis=1), nan_policy="omit") for trajectory in maze_segments]

    sliced_spikes = [spiketrain.time_slice(start-buffer, stop+buffer) for spiketrain in spikes]

    rows = len(sliced_spikes)
    add_rows = int(rows / 8)

    ms = 600 / rows
    mew = 0.7
    spike_loc = 1

    fig = plt.figure(figsize=(8, 8))
    gs1 = gridspec.GridSpec(3, 2)
    gs1.update(wspace=0.3, hspace=0.3)

    ax1 = plt.subplot(gs1[1:, 0])
    for idx, neuron_spikes in enumerate(sliced_spikes):
        ax1.plot(neuron_spikes.time, np.ones(len(neuron_spikes.time)) + (idx * spike_loc), '|',
                 color='k', ms=ms, mew=mew)
    ax1.axis('off')

    ax2 = plt.subplot(gs1[0, 0], sharex=ax1)

    swr_highlight = "#fc4e2a"
    start_idx = nept.find_nearest_idx(lfp.time, start - buffer)
    stop_idx = nept.find_nearest_idx(lfp.time, stop + buffer)
    ax2.plot(lfp.time[start_idx:stop_idx], lfp.data[start_idx:stop_idx], color="k", lw=0.3, alpha=0.9)

    start_idx = nept.find_nearest_idx(lfp.time, start)
    stop_idx = nept.find_nearest_idx(lfp.time, stop)
    ax2.plot(lfp.time[start_idx:stop_idx], lfp.data[start_idx:stop_idx], color=swr_highlight, lw=0.6)
    ax2.axis("off")

    ax1.axvline(lfp.time[start_idx], linewidth=1, color=swr_highlight)
    ax1.axvline(lfp.time[stop_idx], linewidth=1, color=swr_highlight)
    ax1.axvspan(lfp.time[start_idx], lfp.time[stop_idx], alpha=0.2, color=swr_highlight)

    scalebar.add_scalebar(ax2, matchy=False, bbox_transform=fig.transFigure,
                          bbox_to_anchor=(0.25, 0.05), units='ms')

    likelihood_true[np.isnan(likelihood_true)] = 0

    xx, yy = np.meshgrid(info.xedges, info.yedges)
    xcenters, ycenters = get_bin_centers(info)
    xxx, yyy = np.meshgrid(xcenters, ycenters)

    maze_highlight = "#fed976"
    ax3 = plt.subplot(gs1[0, 1])
    sliced_position = position.time_slice(info.task_times["phase3"].starts, info.task_times["phase3"].stops)
    ax3.plot(sliced_position.x, sliced_position.y, ".", color=maze_highlight, ms=1, alpha=0.2)
    pp = ax3.pcolormesh(xx, yy, likelihood_true, cmap='bone_r')
    ax3.contour(xxx, yyy, zones["u"], levels=0, colors=colours["u"])
    ax3.contour(xxx, yyy, zones["shortcut"], levels=0, colors=colours["shortcut"])
    ax3.contour(xxx, yyy, zones["novel"], levels=0, colors=colours["novel"])
    plt.colorbar(pp)
    ax3.axis('off')

    ax4 = plt.subplot(gs1[1:2, 1])
    n = np.arange(len(maze_segments))
    ax4.bar(n, means,
            color=[colours["u"], colours["shortcut"], colours["novel"], colours["other"]], edgecolor='k')
    ax4.set_xticks(n)
    ax4.set_xticklabels([], rotation=90)
    ax4.set_ylim([0, 1.])
    ax4.set_title("True proportion", fontsize=14)

    ax5 = plt.subplot(gs1[2:, 1], sharey=ax4)
    n = np.arange(len(maze_segments))
    ax5.bar(n, means_shuff,
            yerr=sems_shuff,
            color=[colours["u"], colours["shortcut"], colours["novel"], colours["other"]], edgecolor='k')
    ax5.set_xticks(n)
    ax5.set_xticklabels(maze_segments, rotation=90)
    ax5.set_ylim([0, 1.])
    ax5.set_title("Shuffled proportion", fontsize=14)

    plt.tight_layout()

    if filepath is not None:
        plt.savefig(filepath)
        plt.close()
    else:
        plt.show()

In [None]:
for task_label in task_labels:
    for idx in range(getattr(session, task_label).swrs.n_epochs):
        filename = info.session_id + "_" + task_label + "_summary-swr" + str(idx) + ".png"
        filepath = os.path.join(output_filepath, "swr", filename)
        plot_summary_individual(info, raw_likelihoods_true[task_time][idx],
                                raw_likelihoods_shuffs[task_time][idx],
                                position, lfp, spikes,
                                phase_swrs[task_time].starts[idx],
                                phase_swrs[task_time].stops[idx],
                                zones, maze_segments, colours, filepath=None)