In [None]:
%matplotlib inline
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
from matplotlib.patches import Patch
import numpy as np
import random
import scipy
import pickle
import os
import scalebar
import nept

from loading_data import get_data
from analyze_tuning_curves import get_only_tuning_curves
from utils_maze import get_zones, get_bin_centers
from analyze_decode_swrs import bin_spikes, plot_summary_individual, plot_likelihood_overspace, get_likelihood, plot_combined, plot_stacked_summary, get_likelihoods, save_likelihoods, pickle_likelihoods

In [None]:
thisdir = os.getcwd()
pickle_filepath = os.path.join(thisdir, "cache", "pickled")
output_filepath = os.path.join(thisdir, "plots", "trials", "decoding", "shuffled")
if not os.path.exists(output_filepath):
    os.makedirs(output_filepath)

In [None]:
# Set random seeds
random.seed(0)
np.random.seed(0)

In [None]:
import info.r063d2 as r063d2
import info.r063d3 as r063d3
infos = [r063d2, r063d3]
group = "test"
from run import (analysis_infos,
                 r063_infos, r066_infos, r067_infos, r068_infos,
                 days1234_infos, days5678_infos,
                 day1_infos, day2_infos, day3_infos, day4_infos, day5_infos, day6_infos, day7_infos, day8_infos)
# infos = analysis_infos
# group = "All"

# infos = r068_infos
# group = "R068"

update_cache = False

n_shuffles = 2
percentile_thresh = 99

colours = dict()
colours["u"] = "#2b8cbe"
colours["shortcut"] = "#31a354"
colours["novel"] = "#d95f0e"
colours["other"] = "#bdbdbd"

# swr params
z_thresh = 2.0
power_thresh = 3.0
merge_thresh = 0.02
min_length = 0.05
swr_thresh = (140.0, 250.0)

task_times = ["prerecord", "pauseA", "pauseB", "postrecord"]
maze_segments = ["u", "shortcut", "novel", "other"]

n_sessions = len(infos)
all_likelihoods_true = {task_time: {trajectory: [] for trajectory in maze_segments} for task_time in task_times}
all_likelihoods_shuff = {task_time: {trajectory: [] for trajectory in maze_segments} for task_time in task_times}
all_likelihoods_proportion = {task_time: {trajectory: [] for trajectory in maze_segments} for task_time in task_times}
all_likelihoods_true_passthresh = {task_time: {trajectory: [] for trajectory in maze_segments} for task_time in task_times}
all_likelihoods_true_passthresh_n_swr = {task_time: 0 for task_time in task_times}
all_compareshuffle = {task_time: {trajectory: 0 for trajectory in maze_segments} for task_time in task_times}

n_all_swrs = {task_time: 0 for task_time in task_times}

for info in infos:
    print(info.session_id)

    events, position, spikes, lfp, _ = get_data(info)

    # Define zones
    zones = dict()
    zones["u"], zones["shortcut"], zones["novel"] = get_zones(info, position, subset=True)
    combined_zones = zones["u"] + zones["shortcut"] + zones["novel"]
    zones["other"] = ~combined_zones

    # Find SWRs for the whole session
    swrs_path = os.path.join(pickle_filepath, info.session_id+"_swrs.pkl")

    # Remove previous pickle if update_cache
    if update_cache:
        if os.path.exists(swrs_path):
            os.remove(swrs_path)

    # Load pickle if it exists, otherwise compute and pickle
    if os.path.exists(swrs_path):
        print("Loading pickled true likelihoods...")
        with open(swrs_path, 'rb') as fileobj:
            swrs = pickle.load(fileobj)
    else:
        swrs = nept.detect_swr_hilbert(lfp, fs=info.fs, thresh=swr_thresh, z_thresh=z_thresh,
                                       power_thresh=power_thresh, merge_thresh=merge_thresh, min_length=min_length)
        swrs = nept.find_multi_in_epochs(spikes, swrs, min_involved=4)

    rest_epochs = nept.rest_threshold(position, thresh=12., t_smooth=0.8)

    # Restrict SWRs to those during epochs of interest during rest
    phase_swrs = dict()
    n_swrs = {task_time: 0 for task_time in task_times}

    for task_time in task_times:
        epochs_of_interest = info.task_times[task_time].intersect(rest_epochs)

        phase_swrs[task_time] = epochs_of_interest.overlaps(swrs)
        phase_swrs[task_time] = phase_swrs[task_time][phase_swrs[task_time].durations >= 0.05]

        n_swrs[task_time] += phase_swrs[task_time].n_epochs
        n_all_swrs[task_time] += phase_swrs[task_time].n_epochs

    raw_path_true = os.path.join(pickle_filepath, info.session_id+"_raw-likelihoods_true.pkl")
    sum_path_true = os.path.join(pickle_filepath, info.session_id+"_sum-likelihoods_true.pkl")

    # Remove previous pickle if update_cache
    if update_cache:
        if os.path.exists(raw_path_true):
            os.remove(raw_path_true)
        if os.path.exists(sum_path_true):
            os.remove(sum_path_true)

    compute_likelihoods = False

    # Load pickle if it exists, otherwise compute and pickle
    if os.path.exists(raw_path_true) and os.path.exists(sum_path_true):
        print("Loading pickled true likelihoods...")
        with open(raw_path_true, 'rb') as fileobj:
            raw_likelihoods_true = pickle.load(fileobj)
        with open(sum_path_true, 'rb') as fileobj:
            session_likelihoods_true = pickle.load(fileobj)

    combined_likelihoods_shuff = {task_time: {trajectory: [] for trajectory in maze_segments} for task_time in task_times}
    raw_likelihoods_shuffs = {task_time: [] for task_time in task_times}

    for i_shuffle in range(n_shuffles):
        raw_path_shuff = os.path.join(pickle_filepath,
                                      info.session_id+"_raw-likelihoods_shuffled-%03d.pkl" % i_shuffle)
        sum_path_shuff = os.path.join(pickle_filepath,
                                      info.session_id+"_sum-likelihoods_shuffled-%03d.pkl" % i_shuffle)

        # Remove previous pickle if update_cache
        if update_cache:
            if os.path.exists(raw_path_shuff):
                os.remove(raw_path_shuff)
            if os.path.exists(sum_path_shuff):
                os.remove(sum_path_shuff)

        # Load pickle if it exists, otherwise compute and pickle
        if os.path.exists(raw_path_shuff) and os.path.exists(sum_path_shuff):
            print("Loading pickled shuffled likelihoods...")
            with open(raw_path_shuff, 'rb') as fileobj:
                raw_likelihoods_shuff = pickle.load(fileobj)
            with open(sum_path_shuff, 'rb') as fileobj:
                session_likelihoods_shuff = pickle.load(fileobj)
        else:
            compute_likelihoods = True
            break

        for task_time in task_times:
            raw_likelihoods_shuffs[task_time].append(raw_likelihoods_shuff[task_time])
            for trajectory in maze_segments:
                combined_likelihoods_shuff[task_time][trajectory].append(np.array(session_likelihoods_shuff[task_time][trajectory]))
    else:
        compute_likelihoods = True

    if compute_likelihoods:
        session_likelihoods_true, raw_likelihoods_true, combined_likelihoods_shuff, raw_likelihoods_shuffs = save_likelihoods(info, position, spikes, phase_swrs, zones, task_times, maze_segments, n_shuffles)

    compareshuffle = {task_time: {trajectory: 0 for trajectory in maze_segments} for task_time in task_times}
    percentiles = {task_time: {trajectory: [] for trajectory in maze_segments} for task_time in task_times}
    passedshuffthresh = {task_time: {trajectory: [] for trajectory in maze_segments} for task_time in task_times}

    keep_idx = {task_time: [] for task_time in task_times}

    for task_time in task_times:
        raw_likelihoods_shuffs[task_time] = np.swapaxes(raw_likelihoods_shuffs[task_time], 0, 1)
        for trajectory in maze_segments:
            for idx, event in enumerate(range(len(session_likelihoods_true[task_time][trajectory]))):
                percentile = scipy.stats.percentileofscore(np.sort(np.array(combined_likelihoods_shuff[task_time][trajectory])[:,event]),
                                                           session_likelihoods_true[task_time][trajectory][event])
                percentiles[task_time][trajectory].append(percentile)
                if percentile >= percentile_thresh:
                    compareshuffle[task_time][trajectory] += 1
                    all_compareshuffle[task_time][trajectory] += 1
                    keep_idx[task_time].append(idx)

    morelikelythanshuffle_proportion = {task_time: {trajectory: [] for trajectory in maze_segments} for task_time in task_times}
    mean_combined_likelihoods_shuff = {task_time: {trajectory: [] for trajectory in maze_segments} for task_time in task_times}
    passedshuffthresh_n_swr = {task_time: 0 for task_time in task_times}

    for task_time in task_times:
        passedshuffthresh_n_swr[task_time] += len(np.unique(keep_idx[task_time]))
        all_likelihoods_true_passthresh_n_swr[task_time] += len(np.unique(keep_idx[task_time]))
        for trajectory in maze_segments:
            if len(np.sort(np.unique(keep_idx[task_time]))) > 0:
                passedshuffthresh[task_time][trajectory].append(np.array(session_likelihoods_true[task_time][trajectory])[np.sort(np.unique(keep_idx[task_time]))])

            morelikelythanshuffle_proportion[task_time][trajectory].append(compareshuffle[task_time][trajectory] / len(session_likelihoods_true[task_time][trajectory]))
            mean_combined_likelihoods_shuff[task_time][trajectory] = np.nanmean(combined_likelihoods_shuff[task_time][trajectory], axis=0)

            all_likelihoods_true[task_time][trajectory].extend(session_likelihoods_true[task_time][trajectory])
            all_likelihoods_true_passthresh[task_time][trajectory].append(passedshuffthresh[task_time][trajectory])
            all_likelihoods_shuff[task_time][trajectory].extend(mean_combined_likelihoods_shuff[task_time][trajectory])
            all_likelihoods_proportion[task_time][trajectory].extend(morelikelythanshuffle_proportion[task_time][trajectory])

            # plot percentiles
            fig, ax = plt.subplots()
            n = np.arange(len(percentiles[task_time][trajectory]))
            plt.bar(n, np.sort(percentiles[task_time][trajectory]), color=colours[trajectory])
            ax.axhline(percentile_thresh, ls="--", lw=1.5, color="k")
            title = info.session_id + " individual SWR percentile with shuffle" + str(n_shuffles) + " for " + task_time + " " + trajectory
            plt.title(title, fontsize=11)
            plt.tight_layout()
            plt.savefig(os.path.join(output_filepath, "percentiles", title))
            plt.close()

        filepath = os.path.join(output_filepath, info.session_id+"-average-likelihood-overspace_"+task_time+".png")
        if len(session_likelihoods_true[task_time]) > 0:
            plot_likelihood_overspace(info, position, raw_likelihoods_true[task_time],
                                      zones, colours, filepath)

    filename = info.session_id + " proportion of SWRs above "+str(percentile_thresh)+" percentile"
    plot_combined(morelikelythanshuffle_proportion, passedshuffthresh_n_swr,
                  task_times, maze_segments, n_sessions=1, colours=colours, filename=filename)

    filename = info.session_id + " average posteriors during SWRs_sum-shuffled"+str(n_shuffles)
    plot_combined(mean_combined_likelihoods_shuff, n_swrs, task_times, maze_segments,
                  n_sessions=1, colours=colours, filename=filename)

    filename = info.session_id + " average posteriors during SWRs_sum-true"
    plot_combined(session_likelihoods_true, n_swrs, task_times, maze_segments,
                  n_sessions=1, colours=colours, filename=filename)

    filename = info.session_id + " average posteriors during SWRs_sum-true_passthresh"
    plot_combined(passedshuffthresh, n_swrs, task_times, maze_segments,
                  n_sessions=1, colours=colours, filename=filename)

    for task_time in task_times:
        for idx in range(phase_swrs[task_time].n_epochs):
            filename = info.session_id + "_" + task_time + "_summary-swr" + str(idx) + ".png"
            filepath = os.path.join(output_filepath, "swr", filename)
            plot_summary_individual(info, raw_likelihoods_true[task_time][idx],
                                    raw_likelihoods_shuffs[task_time][idx],
                                    position, lfp, spikes,
                                    phase_swrs[task_time].starts[idx],
                                    phase_swrs[task_time].stops[idx],
                                    zones, maze_segments, colours, filepath, savefig=True)

n_total = {task_time: 0 for task_time in task_times}
for task_time in task_times:
    for trajectory in maze_segments:
        n_total[task_time] += all_compareshuffle[task_time][trajectory]

all_compareshuffles = {task_time: {trajectory: [] for trajectory in maze_segments} for task_time in task_times}
for task_time in task_times:
    for trajectory in maze_segments:
        all_compareshuffles[task_time][trajectory].append(all_compareshuffle[task_time][trajectory] / n_total[task_time])

filename = "Average posteriors during SWRs_sum-shuffled"+str(n_shuffles)
plot_combined(all_likelihoods_shuff, n_all_swrs, task_times, maze_segments,
              n_sessions=len(infos), colours=colours, filename=filename)

filename = "Average posteriors during SWRs_sum-true"
plot_combined(all_likelihoods_true, n_all_swrs, task_times, maze_segments,
              n_sessions=len(infos), colours=colours, filename=filename)

filename = "Average posteriors during SWRs_sum-true_passthresh"
plot_combined(all_likelihoods_true_passthresh, all_likelihoods_true_passthresh_n_swr,
              task_times, maze_segments, n_sessions=len(infos), colours=colours, filename=filename)

filename = "Average posteriors during SWRs_sum-true_passthresh-overallproportion"
plot_combined(all_compareshuffles, all_likelihoods_true_passthresh_n_swr,
              task_times, maze_segments, n_sessions=len(infos), colours=colours, filename=filename)

filename = "Average posteriors during SWRs_sum-stacked-shuffled"+str(n_shuffles)
plot_stacked_summary(all_likelihoods_shuff, n_all_swrs, task_times, maze_segments,
                     n_sessions=len(infos), colours=colours, filename=filename)

filename = "Average posteriors during SWRs_sum-stacked-true"
plot_stacked_summary(all_likelihoods_true, n_all_swrs, task_times, maze_segments,
                     n_sessions=len(infos), colours=colours, filename=filename)

filename = "Average posteriors during SWRs_sum-stacked-true_passthresh"
plot_stacked_summary(all_likelihoods_true_passthresh, n_all_swrs, task_times,
                     maze_segments, n_sessions=len(infos), colours=colours, filename=filename)

filename = "Proportion of SWRs above the "+str(percentile_thresh)+" percentile (shuffle" + str(n_shuffles) + ")"
plot_combined(all_likelihoods_proportion, n_all_swrs, task_times, maze_segments,
              n_sessions=len(infos), colours=colours, filename=filename)


In [None]:
raw_path_true = os.path.join(pickle_filepath, info.session_id+"_raw-likelihoods_true.pkl")
sum_path_true = os.path.join(pickle_filepath, info.session_id+"_sum-likelihoods_true.pkl")

In [None]:
os.path.exists(raw_path_true) and os.path.exists(sum_path_true)