In [None]:
%matplotlib inline
import os
import numpy as np
import matplotlib.pyplot as plt
import pickle
from scipy.interpolate import InterpolatedUnivariateSpline
import random
import seaborn as sns

import vdmlab as vdm

from load_data import get_pos, get_spikes, get_lfp
from analyze_plotting import plot_cooccur, plot_cooccur_combined, plot_cooccur_weighted_pauses

import sys
sys.path.append('E:\\code\\python-vdmlab\\projects\\emily_shortcut\\info')
import info.r063d2 as r063d2
import info.r063d3 as r063d3

In [None]:
infos = [r063d2]

In [None]:
pickle_filepath = 'C:\\Users\\Emily\\Code\\emi_shortcut\\cache\\pickled\\'
output_filepath = 'C:\\Users\\Emily\\Code\\emi_shortcut\\plots\\'
# pickle_filepath = 'E:\\code\\emi_shortcut\\cache\\pickled\\'
# output_filepath = 'E:\\code\\emi_shortcut\\plots\\'

In [None]:
def combine_cooccur_weighted(cooccurs):
    """Combines probabilities from multiple sessions, weighted by number of sharp-wave ripple events.

    Parameters
    ----------
    all_probs: list of dicts
        With u, shortcut, novel as keys,
        each a dict with expected, observed, active, shuffle, zscore as keys.
    n_epochs: list of ints

    Returns
    -------
    combined_weighted: dict

    """
    combined_weighted_mean = dict(u=dict(expected=[], observed=[], active=[], shuffle=[], zscore=[]),
                             shortcut=dict(expected=[], observed=[], active=[], shuffle=[], zscore=[]),
                             novel=dict(expected=[], observed=[], active=[], shuffle=[], zscore=[]))
    combined_weighted_std = dict(u=dict(expected=[], observed=[], active=[], shuffle=[], zscore=[]),
                             shortcut=dict(expected=[], observed=[], active=[], shuffle=[], zscore=[]),
                             novel=dict(expected=[], observed=[], active=[], shuffle=[], zscore=[]))

    for trajectory in combined_weighted_mean:
        for key in combined_weighted_mean[trajectory]:
            for probs, n_epoch in zip(cooccurs['probs'], cooccurs['n_epochs']):
                if np.sum(probs[trajectory][key]) > 0:
                    combined_weighted_mean[trajectory][key].append(np.nanmean(probs[trajectory][key]) * n_epoch)
                    combined_weighted_std[trajectory][key].append(np.sqrt(sum(n_epoch * (probs[trajectory][key]-(np.nanmean(probs[trajectory][key]) * n_epoch))**2.) / ((len(probs[trajectory][key])-1) * n_epoch/len(probs[trajectory][key]))))
                else:
                    combined_weighted_mean[trajectory][key].append(0.0)
                    combined_weighted_std[trajectory][key].append(0.0)

    return combined_weighted


def combine_cooccur(cooccurs):
    """Combines probabilities from multiple sessions.

    Parameters
    ----------
    all_probs: list of dicts
        With u, shortcut, novel as keys,
        each a dict with expected, observed, active, shuffle, zscore as keys.
    n_epochs: list of ints

    Returns
    -------
    combined: list

    """
    combined = dict(u=dict(expected=[], observed=[], active=[], shuffle=[], zscore=[]),
                    shortcut=dict(expected=[], observed=[], active=[], shuffle=[], zscore=[]),
                    novel=dict(expected=[], observed=[], active=[], shuffle=[], zscore=[]))

    for trajectory in combined:
        for key in combined[trajectory]:
            for probs in cooccurs['probs']:
                if np.sum(probs[trajectory][key]) > 0:
                    combined[trajectory][key].extend(probs[trajectory][key])
                else:
                    combined[trajectory][key].append(0.0)
    return combined

In [None]:
all_tracks_tc = False

cooccurs_a = dict(probs=[], n_epochs=[])
cooccurs_b = dict(probs=[], n_epochs=[])
experiment_time = 'pauseA'
print('getting co-occurrence', experiment_time)
for info in infos:
    if all_tracks_tc:
        cooccur_filename = info.session_id + '_cooccur-' + experiment_time + '_all-tracks.pkl'
    else:
        cooccur_filename = info.session_id + '_cooccur-' + experiment_time + '.pkl'
    pickled_cooccur = os.path.join(pickle_filepath, cooccur_filename)
    with open(pickled_cooccur, 'rb') as fileobj:
        cooccur = pickle.load(fileobj)

    cooccurs_a['probs'].append(cooccur['probs'])
    cooccurs_a['n_epochs'].append(cooccur['n_epochs'])

combined_a = combine_cooccur(cooccurs_a)
combined_weighted_a, combined_weighted_a_std = combine_cooccur_weighted(cooccurs_a)

In [None]:
combined_weighted_a['u']['expected'], combined_weighted_a_std['u']['expected']

In [None]:
np.sqrt(1)

In [None]:
country = ['usa', 'china', 'lux']
population = [309, 1350, 0.492]
gdp = [46000, 3920, 107000]

In [None]:
together = []
for i, c in enumerate(country):
    together.append(population[i] * gdp[i])
weighted = sum(together) / sum(population)
weighted

In [None]:
sqrt(sum(wi(xi-meanxw)**2) / (((n-1) * sum(wi)) / n))

In [None]:
n_epoch = np.array([1., 2., 3.])
total_weight = np.sum(n_epoch)
observation = np.array([0.1, 0.5, 1.0])
weighted_mean = np.sum(observation * n_epoch) / total_weight
n_samples = len(observation)

weighted_std = np.sqrt(np.sum(n_epoch * (observation-weighted_mean)**2) / (((n_samples-1) * total_weight) / n_samples))

In [None]:
weighted_mean

In [None]:
weighted_error = np.sqrt(np.sum(n_epoch**2 * (observation-weighted_mean)**2 )) / total_weight
weighted_error

In [None]:
means = [1.]
err = [0.1]
means2 = [1., 2., 3.]
err2 = [0.1, 0.1, 0.1]

In [None]:
fig, (ax1, ax2, ax3) = plt.subplots(1, 3, sharey=True, figsize=(4.5, 2))

ind = np.arange(1)
width = 0.5

for ax in [ax1, ax2, ax3]:
    condition1 = ax.bar(ind+width, means, width, color='r', yerr=err, ecolor='k')
    condition2 = ax.bar(ind, means, width, color='y', yerr=err, ecolor='k')

ax1.yaxis.set_ticks_position('left')

for ax in [ax2, ax3]:
    ax.spines['left'].set_visible(False)
    ax.tick_params(axis='y', which='both', length=0)

for ax, trajectory in zip([ax1, ax2, ax3], ['U', 'Shortcut', 'Novel']):
    ax.spines['right'].set_visible(False)
    ax.spines['top'].set_visible(False)
    ax.set_xlabel(trajectory)
    ax.set_xticks([ind+0.5*width, ind+width+0.5*width])
    ax.set_xticklabels(['PauseA', 'PauseB'])
    ax.xaxis.set_ticks_position('bottom')

ax1.set_ylabel('this is it')
plt.subplots_adjust(wspace=0.1)
plt.show()

In [None]:
infos = [r063d3]

field_thresh = 1.
power_thresh = 5.
z_thresh = 3.
merge_thresh = 0.02
min_length = 0.01

for info in infos:
    print(info.session_id)

    lfp = get_lfp(info.good_swr[0])
    position = get_pos(info.pos_mat, info.pxl_to_cm)
    spikes = get_spikes(info.spike_mat)

    speed = position.speed(t_smooth=0.5)
    run_idx = np.squeeze(speed.data) >= 0.1
    run_pos = position[run_idx]

    t_start_tc = info.task_times['phase3'].start
    t_stop_tc = info.task_times['phase3'].stop

    tc_pos = run_pos.time_slice(t_start_tc, t_stop_tc)

    tc_spikes = [spiketrain.time_slice(t_start_tc, t_stop_tc) for spiketrain in spikes]

    binsize = 3
    xedges = np.arange(tc_pos.x.min(), tc_pos.x.max() + binsize, binsize)
    yedges = np.arange(tc_pos.y.min(), tc_pos.y.max() + binsize, binsize)

    tuning_curves = vdm.tuning_curve_2d(tc_pos, tc_spikes, xedges, yedges, gaussian_sigma=0.1)

    zones = find_zones(info)

    fields_tunings = categorize_fields(tuning_curves, zones, xedges, yedges, field_thresh=field_thresh)

    keys = ['u', 'shortcut', 'novel']
    unique_fields = dict()
    unique_fields['u'] = get_unique_fields(fields_tunings['u'],
                                           fields_tunings['shortcut'],
                                           fields_tunings['novel'])
    unique_fields['shortcut'] = get_unique_fields(fields_tunings['shortcut'],
                                                  fields_tunings['novel'],
                                                  fields_tunings['u'])
    unique_fields['novel'] = get_unique_fields(fields_tunings['novel'],
                                               fields_tunings['u'],
                                               fields_tunings['shortcut'])

    field_spikes = dict(u=[], shortcut=[], novel=[])
    for field in unique_fields.keys():
        for key in unique_fields[field]:
            field_spikes[field].append(spikes[key])

    experiment_times = ['pauseA']
    for experiment_time in experiment_times:
        print(experiment_time)

        t_start = info.task_times[experiment_time].start
        t_stop = info.task_times[experiment_time].stop

        sliced_lfp = lfp.time_slice(t_start, t_stop)

        sliced_spikes = [spiketrain.time_slice(t_start, t_stop) for spiketrain in spikes]

        swrs = vdm.detect_swr_hilbert(sliced_lfp, fs=info.fs, thresh=(140.0, 250.0), z_thresh=z_thresh,
                                      power_thresh=power_thresh, merge_thresh=merge_thresh, min_length=min_length)
        
        multi_swrs = vdm.find_multi_in_epochs(spikes, swrs, min_involved=3)

        count_matrix = dict()
        for key in field_spikes:
            count_matrix[key] = vdm.spike_counts(field_spikes[key], multi_swrs)

        tetrode_mask = dict()
        for key in field_spikes:
            tetrode_mask[key] = vdm.get_tetrode_mask(field_spikes[key])

        probs = dict()
        for key in field_spikes:
            probs[key] = vdm.compute_cooccur(count_matrix[key], tetrode_mask[key], num_shuffles=10000)

        filename = 'testing_cooccur-' + experiment_time + '.png'
        savepath = os.path.join(output_filepath, filename)
        plot_cooccur(probs, savepath=None)

In [None]:
np.mean(np.array([np.array([1., 2., 1.]), np.array([2.])]))

In [None]:
t = []
a = np.array([np.array([1.]), np.array([2.])])
for val in a:
    t.extend(val)
print(t)