In [None]:
%matplotlib inline
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
from matplotlib.patches import Patch
import numpy as np
import warnings
import random
import scipy
import pickle
import os
import nept
import scalebar

from loading_data import get_data
from analyze_tuning_curves import get_only_tuning_curves
from utils_maze import get_zones
from analyze_decode_swrs import (plot_summary_individual,
                                 plot_likelihood_overspace,
                                 plot_combined,
                                 plot_stacked_summary)

from utils_maze import get_bin_centers
from analyze_classy_decode import bin_spikes

thisdir = os.getcwd()
pickle_filepath = os.path.join(thisdir, "cache", "pickled")
output_filepath = os.path.join(thisdir, "plots", "trials", "decoding", "classy")
if not os.path.exists(output_filepath):
    os.makedirs(output_filepath)

# Set random seeds
random.seed(0)
np.random.seed(0)

In [None]:
class Session:
    """A collection of LikelihoodsAtTaskTime for each session

        Parameters
        ----------
        task_times : dict of TaskTime

    """

    def __init__(self, position, task_labels, zones):
        self.position = position
        for task_label in task_labels:
            setattr(self, task_label, TaskTime([], [], [], zones))

    def pickle(self, save_path):
        with open(save_path, 'wb') as fileobj:
            print("Saving " + save_path)
            pickle.dump(self, fileobj)

    def n_tasktimes(self):
        return len(task_labels)


class TaskTime:
    """A set of decoded likelihoods for a given task time

        Parameters
        ----------
        likelihoods : np.array
            With shape (ntimebins, nxbins, nybins)
        zones : dict of Zones

        Attributes
        ----------
        likelihoods : np.array
            With shape (ntimebins, nxbins, nybins)

    """

    def __init__(self, tuning_curves, swrs, likelihoods, zones):
        self.tuning_curves = tuning_curves
        self.swrs = swrs
        self.likelihoods = likelihoods
        self.zones = zones

    def sums(self, zone_label):
        if len(self.likelihoods) > 0:
            return np.nansum(self.likelihoods[:, :, self.zones[zone_label]], axis=2)
        else:
            return np.nan

    def means(self, zone_label):
        if len(self.likelihoods) > 0:
            return np.nanmean(self.likelihoods[:, :, self.zones[zone_label]], axis=2)
        else:
            return np.nan

    def maxs(self, zone_label):
        if len(self.likelihoods) > 0:
            return np.nanmax(self.likelihoods[:, :, self.zones[zone_label]], axis=2)
        else:
            return np.nan

In [None]:
import info.r068d7 as r068d7
info = r068d7

In [None]:
n_shuffles = 2
percentile_thresh = 99

task_labels = ["prerecord", "pauseA", "pauseB", "postrecord"]
zone_labels = ["u", "shortcut", "novel", "other"]

swr_params = dict()
swr_params["z_thresh"] = 2.0
swr_params["power_thresh"] = 3.0
swr_params["merge_thresh"] = 0.02
swr_params["min_length"] = 0.05
swr_params["swr_thresh"] = (140.0, 250.0)
swr_params["min_involved"] = 4

colours = dict()
colours["u"] = "#2b8cbe"
colours["shortcut"] = "#31a354"
colours["novel"] = "#d95f0e"
colours["other"] = "#bdbdbd"

In [None]:
def get_likelihoods(info, swr_params, task_labels, zone_labels, n_shuffles=0, save_path=None):

    _, position, spikes, lfp, _ = get_data(info)

    zones = dict()
    zones["u"], zones["shortcut"], zones["novel"] = get_zones(info, position, subset=True)
    combined_zones = zones["u"] + zones["shortcut"] + zones["novel"]
    zones["other"] = ~combined_zones

    sliced_position = position.time_slice(info.task_times["phase3"].starts, info.task_times["phase3"].stops)
    session = Session(sliced_position, task_labels, zones)

    tuning_curves_fromdata = get_only_tuning_curves(info, position, spikes, info.task_times["phase3"])

    tc_shape = tuning_curves_fromdata.shape

    swrs = nept.detect_swr_hilbert(lfp,
                                   fs=info.fs,
                                   thresh=swr_params["swr_thresh"],
                                   z_thresh=swr_params["z_thresh"],
                                   power_thresh=swr_params["power_thresh"],
                                   merge_thresh=swr_params["merge_thresh"],
                                   min_length=swr_params["min_length"])
    swrs = nept.find_multi_in_epochs(spikes, swrs, min_involved=swr_params["min_involved"])

    rest_epochs = nept.rest_threshold(position, thresh=12., t_smooth=0.8)

    if n_shuffles > 0:
        n_passes = n_shuffles
    else:
        n_passes = 1

    for task_label in task_labels:
        epochs_of_interest = info.task_times[task_label].intersect(rest_epochs)

        phase_swrs = epochs_of_interest.overlaps(swrs)
        phase_swrs = phase_swrs[phase_swrs.durations >= 0.05]

        phase_likelihoods = np.zeros((n_passes, phase_swrs.n_epochs, tc_shape[1], tc_shape[2]))
        phase_tuningcurves = np.zeros((n_passes, tc_shape[0], tc_shape[1], tc_shape[2]))
        for n_pass in range(n_passes):

            if n_shuffles > 0:
                tuning_curves = np.random.permutation(tuning_curves_fromdata)
            else:
                tuning_curves = tuning_curves_fromdata

            phase_tuningcurves[n_pass, ] = tuning_curves
            tuning_curves = tuning_curves.reshape(tc_shape[0], tc_shape[1] * tc_shape[2])

            for n_timebin, (start, stop) in enumerate(zip(phase_swrs.starts,
                                                          phase_swrs.stops)):
                t_window = stop-start  # 0.1 for running, 0.025 for swr

                sliced_spikes = [spiketrain.time_slice(start, stop) for spiketrain in spikes]

                counts = bin_spikes(sliced_spikes, np.array([start, stop]), dt=t_window, window=t_window,
                                    gaussian_std=0.0075, normalized=False)

                likelihood = nept.bayesian_prob(counts, tuning_curves, binsize=t_window,
                                                min_neurons=3, min_spikes=1)

                phase_likelihoods[n_pass, n_timebin] = likelihood.reshape(tc_shape[1], tc_shape[2])

        tasktime = getattr(session, task_label)
        tasktime.likelihoods = phase_likelihoods
        tasktime.tuning_curves = phase_tuningcurves
        tasktime.swrs = phase_swrs

    if save_path is not None:
        session.pickle(save_path)

    return session

In [None]:
plot_individual = False
update_cache = True
dont_save_pickle = False


print(info.session_id)

# Get true data
true_path = os.path.join(pickle_filepath, info.session_id+"_likelihoods_true.pkl")

# Remove previous pickle if update_cache
if update_cache:
    if os.path.exists(true_path):
        os.remove(true_path)

# Load pickle if it exists, otherwise compute and pickle
if os.path.exists(true_path):
    print("Loading pickled true likelihoods...")
    compute_likelihoods = False
    with open(true_path, 'rb') as fileobj:
        true_session = pickle.load(fileobj)
else:
    if dont_save_pickle:
        true_path = None
    true_session = get_likelihoods(info,
                                   swr_params,
                                   task_labels,
                                   zone_labels,
                                   save_path=true_path)

# Get shuffled data
shuffled_path = os.path.join(pickle_filepath,
                             info.session_id+"_likelihoods_shuffled-%03d.pkl" % n_shuffles)

# Remove previous pickle if update_cache
if update_cache:
    if os.path.exists(shuffled_path):
        os.remove(shuffled_path)

# Load pickle if it exists, otherwise compute and pickle
if os.path.exists(shuffled_path):
    print("Loading pickled shuffled likelihoods...")
    with open(shuffled_path, 'rb') as fileobj:
        shuffled_session = pickle.load(fileobj)
else:
    if dont_save_pickle:
        shuffled_path = None
    shuffled_session = get_likelihoods(info,
                                       swr_params,
                                       task_labels,
                                       zone_labels,
                                       n_shuffles=n_shuffles,
                                       save_path=shuffled_path)

In [None]:
def plot_likelihood_overspace(info, session, task_labels):
    for task_label in task_labels:
        zones = getattr(session, task_label).zones
        likelihood = np.nanmean(np.array(getattr(session, task_label).likelihoods[:]), axis=1)
        print(likelihood.shape)

        likelihood[np.isnan(likelihood)] = 0

        xx, yy = np.meshgrid(info.xedges, info.yedges)
        xcenters, ycenters = get_bin_centers(info)
        xxx, yyy = np.meshgrid(xcenters, ycenters)

        maze_highlight = "#fed976"
        plt.plot(session.position.x, session.position.y, ".", color=maze_highlight, ms=1, alpha=0.2)
        pp = plt.pcolormesh(xx, yy, likelihood[0], cmap='bone_r')
        for label in ["u", "shortcut", "novel"]:
            plt.contour(xxx, yyy, zones[label], levels=0, linewidths=2, colors=colours[label])
        plt.colorbar(pp)
        plt.axis('off')
        plt.show()

In [None]:
plot_likelihood_overspace(info, shuffled_session, task_labels)

In [None]:
1/0

In [None]:

for zone_label in zone_labels:
    sums = {task_label: [] for task_label in task_labels}
    for session in sessions:
        for task_label in task_labels:
            sums[task_label].extend(getattr(session, task_label).sums(zone_label))
    print(np.array(sums[task_label]).shape)
    
    means = [np.nanmean(sums[task_label])
             if len(sums[task_label]) > 0 else 0.0
             for task_label in task_labels]
    print(means)
    sems = [np.nanmean(scipy.stats.sem(np.array(sums[task_label]), axis=0,
                                       nan_policy="omit"))
            if len(sums[task_label]) > 1 else 0.0
            for task_label in task_labels]

In [None]:
sums

In [None]:
[np.nanmean(scipy.stats.sem(sums[task_label], axis=0,
                                           nan_policy="omit"))
                if len(sums[task_label]) > 1 else 0.0
                for task_label in task_labels]

In [None]:
[np.nanmean(sums[task_label])
                 if len(sums[task_label]) > 0 else 0.0
                 for task_label in task_labels]

In [None]:
sessions = [session, session]

In [None]:
combined_means = {zone_label: [] for zone_label in zone_labels}
trajectory_means = {zone_label: [] for zone_label in zone_labels}
for session in sessions:
    zone_label = "u"
#     for zone_label in zone_labels:
    task_label = "pauseA"
#         for task_label in task_labels:
    combined_means[zone_label].append(getattr(session, task_label).sums(zone_label))
#         print([np.nanmean(getattr(session, task_label).sums(zone_label))
#                                         if len(getattr(session, task_label).sums(zone_label)) > 0 else 0.0
#                                         for task_label in task_labels])

In [None]:
combined_means["u"]

In [None]:
session.pauseA.likelihoods.shape

In [None]:
_, position, spikes, lfp, _ = get_data(info)

In [None]:
buffer = 0.1
for start, stop in zip(session.pauseA.swrs[0].start, session.pauseA.swrs[0].stop):
    sliced_lfp = lfp.time_slice(start-buffer, stop+buffer)
    swr_trace = lfp.time_slice(start, stop)
    
    plt.plot(sliced_lfp)

In [None]:
swr_highlight = "#fc4e2a"
for swr_idx in range(session.pauseA.swrs.n_epochs):
    
    start = session.pauseA.swrs[swr_idx].start
    stop = session.pauseA.swrs[swr_idx].stop
    
    start_idx = nept.find_nearest_idx(lfp.time, start - buffer)
    stop_idx = nept.find_nearest_idx(lfp.time, stop + buffer)
    plt.plot(lfp.time[start_idx:stop_idx], lfp.data[start_idx:stop_idx], color="k", lw=0.3, alpha=0.9)

    start_idx = nept.find_nearest_idx(lfp.time, start)
    stop_idx = nept.find_nearest_idx(lfp.time, stop)
    plt.plot(lfp.time[start_idx:stop_idx], lfp.data[start_idx:stop_idx], color=swr_highlight, lw=0.6)
    plt.axis("off")
    plt.show()

In [None]:
add_rows = int(len(sliced_spikes) / 8)

ms = 600 / len(sliced_spikes)
mew = 0.7
spike_loc = 1

sliced_spikes = [spiketrain.time_slice(start-buffer, stop+buffer) for spiketrain in spikes]

for idx, neuron_spikes in enumerate(sliced_spikes):
    plt.plot(neuron_spikes.time, np.ones(len(neuron_spikes.time)) + (idx * spike_loc), '|',
             color='k', ms=ms, mew=mew)
plt.axis('off')
plt.show()

In [None]:
session.pauseA.likelihoods.shape

In [None]:
session.pauseA.swrs[0].starts

In [None]:
session.pauseA.swrs.n_epochs

In [None]:
session.pauseA.likelihoods.shape

In [None]:
session.pauseA.sums("u").shape

In [None]:
session.pauseA.zones["u"].shape

In [None]:
np.nansum(session.pauseA.likelihoods[:, :, session.pauseA.zones["u"]], axis=2)

In [None]:
for i in range(session.pauseA.swrs.n_epochs):
    a = session.pauseA.likelihoods[:, i][0]

In [None]:
xx, yy = get_bin_centers(info)
print(xx.shape, yy.shape)

In [None]:
p = plt.pcolormesh(xx, yy, a)
plt.show()

In [None]:
likelihood_shuffled = getattr(shuffled_session, task_label).likelihoods[:, swr_idx]

In [None]:
zones = getattr(session, task_label).zones

In [None]:
likelihood_shuffled.shape

In [None]:
shuffled_means = [scipy.stats.sem(np.nansum(likelihood_shuffled[:, zones[zone_label]], axis=1))
                  for zone_label in zone_labels]

In [None]:
scipy.stats.sem(np.nansum(likelihood_shuffled[:, zones["other"]], axis=1), nan_policy="omit")

In [None]:
shuffled_means

In [None]:
n = np.arange(len(means))
plt.bar(n, means,
        color=[colours["u"], colours["shortcut"], colours["novel"], colours["other"]], edgecolor='k')
plt.set_xticks(n)
plt.set_xticklabels([], rotation=90)
plt.set_ylim([0, 1.])
plt.set_title("True proportion", fontsize=14)

In [None]:
session.prerecord.likelihoods.shape

In [None]:
session.prerecord.swrs.n_epochs

In [None]:
def plot_summary_individual(info, session_true, session_shuffled, zone_labels, task_labels, colours, filepath=None):
    
    _, position, spikes, lfp, _ = get_data(info)
    
    buffer = 0.1

    for task_label in task_labels:
        print(task_label)
        swrs = getattr(session, task_label).swrs
        zones = getattr(session, task_label).zones

        for swr_idx in range(swrs.n_epochs):
            print("swr:" + str(swr_idx))
            start = swrs[swr_idx].start
            stop = swrs[swr_idx].stop

            sliced_spikes = [spiketrain.time_slice(start-buffer, stop+buffer) for spiketrain in spikes]

            add_rows = int(len(sliced_spikes) / 8)

            ms = 600 / len(sliced_spikes)
            mew = 0.7
            spike_loc = 1

            fig = plt.figure(figsize=(8, 8))
            gs1 = gridspec.GridSpec(3, 2)
            gs1.update(wspace=0.3, hspace=0.3)

            ax1 = plt.subplot(gs1[1:, 0])
            for idx, neuron_spikes in enumerate(sliced_spikes):
                ax1.plot(neuron_spikes.time, np.ones(len(neuron_spikes.time)) + (idx * spike_loc), '|',
                         color='k', ms=ms, mew=mew)
            ax1.axis('off')

            ax2 = plt.subplot(gs1[0, 0], sharex=ax1)

            swr_highlight = "#fc4e2a"
            start_idx = nept.find_nearest_idx(lfp.time, start - buffer)
            stop_idx = nept.find_nearest_idx(lfp.time, stop + buffer)
            ax2.plot(lfp.time[start_idx:stop_idx], lfp.data[start_idx:stop_idx], color="k", lw=0.3, alpha=0.9)

            start_idx = nept.find_nearest_idx(lfp.time, start)
            stop_idx = nept.find_nearest_idx(lfp.time, stop)
            ax2.plot(lfp.time[start_idx:stop_idx], lfp.data[start_idx:stop_idx], color=swr_highlight, lw=0.6)
            ax2.axis("off")

            ax1.axvline(lfp.time[start_idx], linewidth=1, color=swr_highlight)
            ax1.axvline(lfp.time[stop_idx], linewidth=1, color=swr_highlight)
            ax1.axvspan(lfp.time[start_idx], lfp.time[stop_idx], alpha=0.2, color=swr_highlight)

            scalebar.add_scalebar(ax2, matchy=False, bbox_transform=fig.transFigure,
                                  bbox_to_anchor=(0.25, 0.05), units='ms')
            
            likelihood_true = np.array(getattr(session_true, task_label).likelihoods[:, swr_idx])

            likelihood_true[np.isnan(likelihood_true)] = 0

            xx, yy = np.meshgrid(info.xedges, info.yedges)
            xcenters, ycenters = get_bin_centers(info)
            xxx, yyy = np.meshgrid(xcenters, ycenters)

            maze_highlight = "#fed976"
            ax3 = plt.subplot(gs1[0, 1])
            sliced_position = position.time_slice(info.task_times["phase3"].starts, info.task_times["phase3"].stops)
            ax3.plot(sliced_position.x, sliced_position.y, ".", color=maze_highlight, ms=1, alpha=0.2)
            pp = ax3.pcolormesh(xx, yy, likelihood_true[0], cmap='bone_r')
            for label in ["u", "shortcut", "novel"]:
                ax3.contour(xxx, yyy, zones[label], levels=0, linewidths=2, colors=colours[label])
            plt.colorbar(pp)
            ax3.axis('off')
            
            likelihood_true = getattr(session_true, task_label).likelihoods[:, swr_idx]
            
            means_true = [np.nanmean(np.nansum(likelihood_true[:, zones[zone_label]], axis=1))
                          for zone_label in zone_labels]

            ax4 = plt.subplot(gs1[1:2, 1])
            n = np.arange(len(zone_labels))
            ax4.bar(n, means_true,
                    color=[colours[zone_label] for zone_label in zone_labels], edgecolor='k')
            ax4.set_xticks(n)
            ax4.set_xticklabels([], rotation=90)
            ax4.set_ylim([0, 1.])
            ax4.set_title("True proportion", fontsize=14)
            
            likelihood_shuffled = getattr(session_shuffled, task_label).likelihoods[:, swr_idx]
            
            means_shuffled = [np.nanmean(np.nansum(likelihood_shuffled[:, zones[zone_label]], axis=1))                    
                              for zone_label in zone_labels]
            sems_shuffled = [scipy.stats.sem(np.nansum(likelihood_shuffled[:, zones[zone_label]], axis=1))                                          
                             for zone_label in zone_labels]

            ax5 = plt.subplot(gs1[2:, 1], sharey=ax4)
            n = np.arange(len(zone_labels))
            ax5.bar(n, means_shuffled,
                    yerr=sems_shuffled,
                    color=[colours[zone_label] for zone_label in zone_labels], edgecolor='k')
            ax5.set_xticks(n)                                                                                                                                            
            ax5.set_xticklabels(zone_labels, rotation=90)
            ax5.set_ylim([0, 1.])
            ax5.set_title("Shuffled proportion", fontsize=14)

            plt.tight_layout()

            if filepath is not None:
                filename = info.session_id+"_"+task_label+"_summary-swr"+str(swr_idx)+".png"
                plt.savefig(os.path.join(output_filepath, filename))
                plt.close()
            else:
                plt.show()

In [None]:
plot_summary_individual(info, true_session, shuffled_session, zone_labels, task_labels, colours)