In [None]:
%matplotlib inline
import matplotlib.pyplot as plt
import numpy as np
import scipy
import pickle
import os
import nept

from loading_data import get_data
from analyze_sequenceless import Session, TaskTime

In [None]:
thisdir = os.getcwd()
pickle_filepath = os.path.join(thisdir, "cache", "pickled")
output_filepath = os.path.join(thisdir, "plots", "decoding", "sequenceless", "stats")
if not os.path.exists(output_filepath):
    os.makedirs(output_filepath)

In [None]:
import info.r066d1 as r066d1
import info.r066d2 as r066d2
infos = [r066d1, r066d2]
info = r066d1
# from run import analysis_infos
# infos = analysis_infos

In [None]:
# for info in infos:
print(info.session_id)
events, position, spikes, lfp, _ = get_data(info)

In [None]:
passthresh_path = os.path.join(pickle_filepath, info.session_id + "_likelihoods_true_passthresh.pkl")

if os.path.exists(passthresh_path):
    print("Loading pickled passthresh likelihoods...")
    with open(passthresh_path, 'rb') as fileobj:
        passthresh_session = pickle.load(fileobj)

In [None]:
passthresh_session.pauseA.swrs.n_epochs

In [None]:
session = passthresh_session
task_labels = ["prerecord", "pauseA"]
zone_labels = ["u", "shortcut"]
n_swrs = {task_label: 0 for task_label in task_labels}
for zone_label in zone_labels:
    for task_label in task_labels:
        zone_sums = getattr(session, task_label).sums(zone_label)
        n_swrs[task_label] += getattr(session, task_label).swrs.n_epochs
    print(zone_sums)

In [None]:
n_swrs

In [None]:
np.mean(passthresh_session.pauseA.sums("shortcut"))

In [None]:
observed = np.array([[100, 150, 200], 
                     [50, 100, 150]])
print(observed)

In [None]:
expected_equal = np.ones(observed.shape) * np.mean(observed)
expected_equal

In [None]:
expected_bypath = np.ones(observed.shape) * np.mean(observed, axis=0)
expected_bypath

In [None]:
expected_byphase = np.repeat(np.mean(observed, axis=1), 3).reshape(observed.shape)
expected_byphase

In [None]:
chisq, p = scipy.stats.chisquare(observed, f_exp=expected_equal, axis=None)
print(chisq, p, p<0.05)

In [None]:
chi_squared_stat = (((observed-expected_equal)**2)/expected_equal).sum()
chi_squared_stat

In [None]:
chisq, p = scipy.stats.chisquare(observed, f_exp=expected_bypath, axis=None)
print(chisq, p, p<0.05)

In [None]:
chisq, p = scipy.stats.chisquare(observed, f_exp=expected_byphase, axis=None)
print(chisq, p, p<0.05)

In [None]:
p_value = 1 - scipy.stats.chi2.cdf(x=chi_squared_stat, df=1)
p_value

In [None]:
import copy

def limit_by_n_swr(session, task_labels, n_swr_thresh, zone_label="u"):
    session_copy = copy.deepcopy(session)

    for task_label in task_labels:
        if getattr(session_copy, task_label).swrs.n_epochs < n_swr_thresh:
            zone_shape = getattr(session_copy, task_label).zones[zone_label].shape
            getattr(session_copy, task_label).likelihoods = np.ones((1, 1, zone_shape[0], zone_shape[1])) * np.nan

    return session_copy

update_cache=False
n_swr_thresh=10
n_shuffles=100

In [None]:
dont_save_pickle = False
plot_individual = False
plot_individual_passthresh = False
plot_overspace = False
plot_summary = True

percentile_thresh = 95

colours = dict()
colours["u"] = "#2b8cbe"
colours["shortcut"] = "#31a354"
colours["novel"] = "#d95f0e"
colours["other"] = "#bdbdbd"

# swr params
swr_params = dict()
swr_params["merge_thresh"] = 0.02
swr_params["min_length"] = 0.05
swr_params["swr_thresh"] = (140.0, 250.0)
swr_params["min_involved"] = 4

task_labels = ["prerecord", "pauseA", "pauseB", "postrecord"]
zone_labels = ["u", "shortcut", "novel", "other"]

true_sessions = []
shuffled_sessions = []
passthresh_sessions = []
passthresh_counts = []
combined_passthresh_count = {task_label: {zone_label: 0 for zone_label in zone_labels} for task_label in task_labels}

for info in infos:
    print(info.session_id)

    # Get true data
    true_path = os.path.join(pickle_filepath, info.session_id + "_likelihoods_true.pkl")

    # Remove previous pickle if update_cache
    if update_cache:
        if os.path.exists(true_path):
            os.remove(true_path)

    # Load pickle if it exists, otherwise compute and pickle
    if os.path.exists(true_path):
        print("Loading pickled true likelihoods...")
        compute_likelihoods = False
        with open(true_path, 'rb') as fileobj:
            true_session = pickle.load(fileobj)
    else:
        if dont_save_pickle:
            true_path = None
        true_session = get_likelihoods(info,
                                       swr_params,
                                       task_labels,
                                       n_shuffles,
                                       save_path=true_path)

    true_sessions.append(true_session)

    sessions_copy = []
    for session in true_sessions:
        session_copy = limit_by_n_swr(session, task_labels, n_swr_thresh)
        sessions_copy.append(session_copy)
    true_sessions = sessions_copy

    # Get shuffled data
    shuffled_path = os.path.join(pickle_filepath,
                                 info.session_id + "_likelihoods_shuffled-%03d.pkl" % n_shuffles)

    # Remove previous pickle if update_cache
    if update_cache:
        if os.path.exists(shuffled_path):
            os.remove(shuffled_path)

    # Load pickle if it exists, otherwise compute and pickle
    if os.path.exists(shuffled_path):
        print("Loading pickled shuffled likelihoods...")
        with open(shuffled_path, 'rb') as fileobj:
            shuffled_session = pickle.load(fileobj)
    else:
        if dont_save_pickle:
            shuffled_path = None
        shuffled_session = get_likelihoods(info,
                                           swr_params,
                                           task_labels,
                                           n_shuffles=n_shuffles,
                                           save_path=shuffled_path)

    shuffled_sessions.append(shuffled_session)
    sessions_copy = []
    for session in true_sessions:
        session_copy = limit_by_n_swr(session, task_labels, n_swr_thresh)
        sessions_copy.append(session_copy)
    shuffled_sessions = sessions_copy

    if plot_individual:
        filepath = os.path.join(output_filepath, "individual")
        if not os.path.exists(filepath):
            os.makedirs(filepath)
        plot_summary_individual(info, true_session, shuffled_session,
                                zone_labels, task_labels, colours, filepath)

    if plot_overspace:
        filepath = os.path.join(output_filepath, "overspace")
        if not os.path.exists(filepath):
            os.makedirs(filepath)
        plot_likelihood_overspace(info, true_session, task_labels, colours, filepath)

    keep_idx = {task_label: [] for task_label in task_labels}
    passthresh_count = {task_label: {zone_label: 0 for zone_label in zone_labels} for task_label in task_labels}

    for task_label in task_labels:
        for zone_label in zone_labels:
            zones = getattr(true_session, task_label).zones
            true_sums = np.array(getattr(true_session, task_label).sums(zone_label))
            shuffled_sums = np.array(getattr(shuffled_session, task_label).sums(zone_label))
            if true_sums.size <= 1 and np.isnan(true_sums).all():
                continue
            elif getattr(true_session, task_label).swrs.n_epochs == 0:
                continue
            else:
                for idx in range(true_sums.shape[1]):
                    percentile = scipy.stats.percentileofscore(np.sort(shuffled_sums[:, idx]), true_sums[:, idx][0])
                    if percentile >= percentile_thresh:
                        keep_idx[task_label].append(idx)
                        passthresh_count[task_label][zone_label] += 1
                        combined_passthresh_count[task_label][zone_label] += 1

    passthresh_counts.append(passthresh_count)

In [None]:
combined_passthresh_count

In [None]:
passthresh_counts

In [None]:
shortcut_n = []
u_n = []
novel_n = []
other_n = []

for task_label in task_labels:
    shortcut_n.append(passthresh_counts[0][task_label]["shortcut"])
    u_n.append(passthresh_counts[0][task_label]["u"])

In [None]:
np.array([shortcut_n, u_n])

In [None]:
observed = np.array([shortcut_n, u_n])
print(observed)

expected_equal = np.ones(observed.shape) * np.mean(observed)
print(expected_equal)
chisq, p = scipy.stats.chisquare(observed, f_exp=expected_equal, axis=None)
print(chisq, p, p<0.05)

expected_bypath = np.ones(observed.shape) * np.mean(observed, axis=0)
print(expected_bypath)
chisq, p = scipy.stats.chisquare(observed, f_exp=expected_bypath, axis=None)
print(chisq, p, p<0.05)

expected_byphase = np.repeat(np.mean(observed, axis=1), len(task_labels)).reshape(observed.shape)
print(expected_byphase)
chisq, p = scipy.stats.chisquare(observed, f_exp=expected_byphase, axis=None)
print(chisq, p, p<0.05)

In [None]:
print(observed)
expected = np.ones(observed.shape) * np.mean(observed, axis=0)
print(expected)

In [None]:
chi_squared_stat = np.nansum((((observed-expected)**2)/expected))
chi_squared_stat

In [None]:
(observed-expected)**2 / expected

In [None]:
expected