In [None]:
# %matplotlib inline
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
from matplotlib.patches import Patch
import numpy as np
import warnings
import random
import mpld3
import scipy
import pickle
import os
import nept
import scalebar

from loading_data import get_data
from analyze_tuning_curves import get_only_tuning_curves
from utils_maze import get_zones
from utils_maze import get_bin_centers

thisdir = os.getcwd()
pickle_filepath = os.path.join(thisdir, "cache", "pickled")
output_filepath = os.path.join(thisdir, "plots", "trials", "decoding", "classy", "compare_task-times")
if not os.path.exists(output_filepath):
    os.makedirs(output_filepath)

# Set random seeds
random.seed(0)
np.random.seed(0)

# mpld3.enable_notebook()

In [None]:
colours = dict()
colours["u"] = "#2b8cbe"
colours["shortcut"] = "#31a354"
colours["novel"] = "#d95f0e"
colours["other"] = "#bdbdbd"

task_labels = ["prerecord", "pauseA", "pauseB", "postrecord"]
zone_labels = ["u", "shortcut", "novel", "other"]

In [None]:
from analyze_classy_decode import Session, TaskTime

In [None]:
import info.r063d5 as info

In [None]:
events, position, spikes, lfp, lfp_theta = get_data(info)

In [None]:
print(info.session_id)
true_path = os.path.join(pickle_filepath, info.session_id + "_likelihoods_true.pkl")

with open(true_path, 'rb') as fileobj:
    true_session = pickle.load(fileobj)

passthresh_path = os.path.join(pickle_filepath, info.session_id + "_likelihoods_true_passthresh.pkl")

with open(passthresh_path, 'rb') as fileobj:
    passthresh_session = pickle.load(fileobj)

n_shuffles = 100
shuffled_path = os.path.join(pickle_filepath,
                             info.session_id + "_likelihoods_shuffled-%03d.pkl" % n_shuffles)

with open(shuffled_path, 'rb') as fileobj:
    shuffled_session = pickle.load(fileobj)

In [None]:
start = info.task_times["prerecord"].start
# stop = info.task_times["prerecord"].stop

stop = start+2
sliced_lfp = lfp.time_slice(start, stop)

In [None]:
plt.plot(sliced_lfp.time, sliced_lfp.data)
plt.show()

In [None]:
swr_lfps = []
swrs_epochs = getattr(true_session, "prerecord").swrs
for start, stop in zip(swrs_epochs.starts, swrs_epochs.stops):
    swr_lfps.append(lfp.time_slice(start, stop))

In [None]:
for swr_lfp in swr_lfps:
    plt.plot(swr_lfp.time, swr_lfp.data)
plt.show()

In [None]:
all_spikes = np.sort(np.concatenate([spiketrain.time for spiketrain in spikes]))
sliced_all_spikes = all_spikes[(start <= all_spikes) & (all_spikes <= stop)]

In [None]:
fig = plt.figure(figsize=(12,6))
plt.plot(sliced_lfp.time, sliced_lfp.data)
for swr_lfp in swr_lfps:
    plt.plot(swr_lfp.time, swr_lfp.data)
plt.plot(sliced_all_spikes, np.ones(len(sliced_all_spikes))*0.0002, ".")
plt.show()

In [None]:
from analyze_classy_decode import detect_swr_hilbert_limited_zscore

In [None]:
z_thresh = 3

swr_params = dict()
swr_params["z_thresh"] = z_thresh
swr_params["merge_thresh"] = 0.02
swr_params["min_length"] = 0.05
swr_params["swr_thresh"] = (140.0, 250.0)
swr_params["min_involved"] = 4

# swrs = detect_swr_hilbert_limited_zscore(lfp,
#                                          fs=info.fs,
#                                          thresh=swr_params["swr_thresh"],
#                                          times_for_zscore=nept.Epoch(info.task_times["pauseB"].start,
#                                                                      info.task_times["pauseB"].stop),
#                                          z_thresh=swr_params["z_thresh"],
#                                          merge_thresh=swr_params["merge_thresh"],
#                                          min_length=swr_params["min_length"])
# swrs = nept.find_multi_in_epochs(spikes, swrs, min_involved=swr_params["min_involved"])

swrs = nept.detect_swr_hilbert(lfp,
                                         fs=info.fs,
                                         z_thresh=swr_params["z_thresh"],
                                         merge_thresh=swr_params["merge_thresh"],
                                         min_length=swr_params["min_length"])
swrs = nept.find_multi_in_epochs(spikes, swrs, min_involved=swr_params["min_involved"])

In [None]:
swrs.n_epochs

In [None]:
# epochs_of_interest = info.task_times["prerecord"]
epochs_of_interest = nept.Epoch(start, stop)
phase_swrs = epochs_of_interest.overlaps(swrs)

In [None]:
phase_swrs.n_epochs

In [None]:
swr_lfps = []
for start, stop in zip(phase_swrs.starts, phase_swrs.stops):
    swr_lfps.append(lfp.time_slice(start, stop))

In [None]:
fig = plt.figure(figsize=(12,6))
plt.plot(sliced_lfp.time, sliced_lfp.data)
for swr_lfp in swr_lfps:
    plt.plot(swr_lfp.time, swr_lfp.data)
plt.plot(sliced_all_spikes, np.ones(len(sliced_all_spikes))*0.0002, ".")
plt.show()

In [None]:
swr_lfps = []
for start, stop in zip(phase_swrs.starts, phase_swrs.stops):
    swr_lfps.append(lfp.time_slice(start, stop))

In [None]:
bin_edges = nept.get_edges(sliced_all_spikes, 0.025)

n_bins = 3
square_filter = np.ones(n_bins)
shouldthisbesquare = np.convolve(np.histogram(sliced_all_spikes, bins=bin_edges)[0].astype(float), 
                                 square_filter, mode="same")
times = np.linspace(sliced_lfp.time[0], sliced_lfp.time[-1], shouldthisbesquare.shape[0])

In [None]:
scalelfpby = 20000

fig = plt.figure(figsize=(12,6))
plt.plot(times, shouldthisbesquare, "k")
plt.plot(sliced_lfp.time, sliced_lfp.data*scalelfpby, "b")
plt.plot(sliced_all_spikes, np.ones(len(sliced_all_spikes))*10, "c.")
for swr in swrs:
    plt.plot(swr.start, [1], "r.")
plt.show()

In [None]:
bin_edges = nept.get_edges(sliced_all_spikes, 0.025)

n_bins = 3
square_filter = np.ones(n_bins)
shouldthisbesquare = np.convolve(np.histogram(sliced_all_spikes, bins=bin_edges)[0].astype(float), square_filter, mode="same")

In [None]:
fig = plt.figure(figsize=(12,6))
plt.plot(sliced_lfp.time, sliced_lfp.data)
plt.plot(sliced_all_spikes, np.ones(len(sliced_all_spikes))*0.0002, ".")
plt.show()

In [None]:
times = np.linspace(sliced_lfp.time[0], sliced_lfp.time[-1], shouldthisbesquare.shape[0])

In [None]:
plt.plot(times, shouldthisbesquare)
plt.plot(sliced_all_spikes, np.ones(len(sliced_all_spikes)), ".")
plt.show()