In [None]:
zone_labels = ["u", "shortcut", "novel", "other"]

a = {zone_label: [] for zone_label in zone_labels}

In [None]:
a["u"] = [1.,2.,3.]

In [None]:
a

In [None]:
for zone in a:
    print(zone)
    print(a[zone])

In [None]:
%matplotlib inline
import matplotlib.pyplot as plt
import numpy as np
import warnings
import random
import scipy
import pickle
import os
import nept

from loading_data import get_data
from analyze_tuning_curves import get_only_tuning_curves
from utils_maze import get_zones
from analyze_decode_swrs import (plot_summary_individual,
                                 plot_likelihood_overspace,
                                 plot_combined,
                                 plot_stacked_summary)

thisdir = os.getcwd()
pickle_filepath = os.path.join(thisdir, "cache", "pickled")
output_filepath = os.path.join(thisdir, "plots", "trials", "decoding", "shuffled")
if not os.path.exists(output_filepath):
    os.makedirs(output_filepath)

# Set random seeds
random.seed(0)
np.random.seed(0)

In [None]:
class Session:
    """A collection of LikelihoodsAtTaskTime for each session

        Parameters
        ----------
        task_times : dict of TaskTime

    """

    def __init__(self, task_labels, zones):
        for task_label in task_labels:
            setattr(self, task_label, TaskTime([], zones))


class TaskTime:
    """A set of decoded likelihoods for a given task time

        Parameters
        ----------
        likelihoods : np.array
            With shape (ntimebins, nxbins, nybins)
        zones : dict of Zones

        Attributes
        ----------
        likelihoods : np.array
            With shape (ntimebins, nxbins, nybins)

    """

    def __init__(self, likelihoods, zones):
        self.likelihoods = likelihoods
        for label in zones:
            setattr(self, label, Zone(label, zones[label], self.likelihoods))


class Zone:
    """Summary of decoded likelihoods for a given physical zone

            Parameters
            ----------
            label : str
            zone : bool

            Attributes
            ----------
            label : str
            zone : bool
            sum : np.array
            mean : np.array
            max : np.array

        """

    def __init__(self, label, zone, likelihoods):
        self.label = label
        self.zone = zone
        self.likelihoods = likelihoods

    def sum(self):
        if len(self.likelihoods) > 0:
            print("I don't get it")
            return np.nansum(self.likelihoods[:, self.zone], axis=1)
            
        else:
            print(self.likelihoods)
            return np.nan
            

    def mean(self):
        if len(self.likelihoods) > 0:
            return np.nanmean(self.likelihoods[:, self.zone], axis=1)
        else:
            return np.nan

    def max(self):
        if len(self.likelihoods) > 0:
            return np.nanmax(self.likelihoods[:, self.zone], axis=1)
        else:
            return np.nan

In [None]:
import info.r068d8 as r068d8
info = r068d8

events, position, spikes, lfp, _ = get_data(info)

# Define zones
zones = dict()
zones["u"], zones["shortcut"], zones["novel"] = get_zones(info, position, subset=True)
combined_zones = zones["u"] + zones["shortcut"] + zones["novel"]
zones["other"] = ~combined_zones

tuning_curves_fromdata = get_only_tuning_curves(info,
                                           position,
                                           spikes,
                                           info.task_times["phase3"])

In [None]:
task_labels = ["prerecord", "pauseA", "pauseB", "postrecord"]
zone_labels = ["u", "shortcut", "novel", "other"]

In [None]:
# swr params
z_thresh = 2.0
power_thresh = 3.0
merge_thresh = 0.02
min_length = 0.05
swr_thresh = (140.0, 250.0)

swrs = nept.detect_swr_hilbert(lfp, fs=info.fs, thresh=swr_thresh, z_thresh=z_thresh,
                       power_thresh=power_thresh, merge_thresh=merge_thresh, min_length=min_length)
swrs = nept.find_multi_in_epochs(spikes, swrs, min_involved=4)

rest_epochs = nept.rest_threshold(position, thresh=12., t_smooth=0.8)

# Restrict SWRs to those during epochs of interest during rest
phase_swrs = dict()

for task_time in task_labels:
    epochs_of_interest = info.task_times[task_time].intersect(rest_epochs)

    phase_swrs[task_time] = epochs_of_interest.overlaps(swrs)
    phase_swrs[task_time] = phase_swrs[task_time][phase_swrs[task_time].durations >= 0.05]

In [None]:
from analyze_classy_decode import bin_spikes

In [None]:
tc_shape = tuning_curves_fromdata.shape
tuning_curves = tuning_curves_fromdata.reshape(tc_shape[0], tc_shape[1] * tc_shape[2])

session = Session(task_labels, zones)

for i, task_label in enumerate(task_labels):
    phase_likelihoods = np.empty((phase_swrs[task_label].n_epochs, tc_shape[1], tc_shape[2]))
    for j, (start, stop) in enumerate(zip(phase_swrs[task_label].starts, phase_swrs[task_label].stops)):
        t_window = stop-start  # 0.1 for running, 0.025 for swr

        sliced_spikes = [spiketrain.time_slice(start, stop) for spiketrain in spikes]

        counts = bin_spikes(sliced_spikes, np.array([start, stop]), dt=t_window, window=t_window,
                            gaussian_std=0.0075, normalized=False)

        likelihood = nept.bayesian_prob(counts, tuning_curves, binsize=t_window, min_neurons=3, min_spikes=1)

        phase_likelihoods[j] = likelihood.reshape(tc_shape[1], tc_shape[2])

    tasktime = getattr(session, task_label)
    tasktime.likelihoods = phase_likelihoods
    
    for zone_label in zone_labels:
        zone = getattr(tasktime, zone_label)
        zone.likelihoods = phase_likelihoods

In [None]:
len(session.pauseA.likelihoods)

In [None]:
session.pauseA.shortcut.sum()

In [None]:
trythis = session.pauseA.likelihoods
trythat = session.pauseA.shortcut.zone

In [None]:
trythis.shape, trythat.shape

In [None]:
np.nansum(trythis[:, trythat], axis=1)