In [None]:
%matplotlib inline
import os
import numpy as np
import matplotlib.pyplot as plt
import pickle
from scipy.interpolate import InterpolatedUnivariateSpline
import random
import seaborn as sns
from shapely.geometry import Point

import vdmlab as vdm

from load_data import get_pos, get_spikes, get_lfp
from field_functions import get_unique_fields, categorize_fields
from maze_functions import trajectory_fields, find_zones
from plotting_functions import plot_cooccur

import info.R063d2_info as r063d2
import info.R063d3_info as r063d3

In [None]:
# pickle_filepath = 'C:\\Users\\Emily\\Code\\emi_shortcut\\cache\\pickled\\'
# output_filepath = 'C:\\Users\\Emily\\Code\\emi_shortcut\\plots\\'
output_filepath = 'E:\\code\\emi_shortcut\\cache\\pickled\\'
output_filepath = 'E:\\code\\emi_shortcut\\plots\\'

In [None]:
infos = [r063d3]

field_thresh = 1.
power_thresh = 5.
z_thresh = 3.
merge_thresh = 0.02
min_length = 0.01

for info in infos:
    print(info.session_id)

    lfp = get_lfp(info.good_swr[0])
    position = get_pos(info.pos_mat, info.pxl_to_cm)
    spikes = get_spikes(info.spike_mat)

    speed = position.speed(t_smooth=0.5)
    run_idx = np.squeeze(speed.data) >= info.run_threshold
    run_pos = position[run_idx]

    t_start_tc = info.task_times['phase3'].start
    t_stop_tc = info.task_times['phase3'].stop

    tc_pos = run_pos.time_slice(t_start_tc, t_stop_tc)

    tc_spikes = [spiketrain.time_slice(t_start_tc, t_stop_tc) for spiketrain in spikes]

    binsize = 3
    xedges = np.arange(tc_pos.x.min(), tc_pos.x.max() + binsize, binsize)
    yedges = np.arange(tc_pos.y.min(), tc_pos.y.max() + binsize, binsize)

    tuning_curves = vdm.tuning_curve_2d(tc_pos, tc_spikes, xedges, yedges, gaussian_sigma=0.1)

    zones = find_zones(info)

    fields_tunings = categorize_fields(tuning_curves, zones, xedges, yedges, field_thresh=field_thresh)

    keys = ['u', 'shortcut', 'novel']
    unique_fields = dict()
    unique_fields['u'] = get_unique_fields(fields_tunings['u'],
                                           fields_tunings['shortcut'],
                                           fields_tunings['novel'])
    unique_fields['shortcut'] = get_unique_fields(fields_tunings['shortcut'],
                                                  fields_tunings['novel'],
                                                  fields_tunings['u'])
    unique_fields['novel'] = get_unique_fields(fields_tunings['novel'],
                                               fields_tunings['u'],
                                               fields_tunings['shortcut'])

    field_spikes = dict(u=[], shortcut=[], novel=[])
    for field in unique_fields.keys():
        for key in unique_fields[field]:
            field_spikes[field].append(spikes[key])

    experiment_times = ['pauseA']
    for experiment_time in experiment_times:
        print(experiment_time)

        t_start = info.task_times[experiment_time].start
        t_stop = info.task_times[experiment_time].stop

        sliced_lfp = lfp.time_slice(t_start, t_stop)

        sliced_spikes = [spiketrain.time_slice(t_start, t_stop) for spiketrain in spikes]

        swrs = vdm.detect_swr_hilbert(sliced_lfp, fs=info.fs, thresh=(140.0, 250.0), z_thresh=z_thresh,
                                      power_thresh=power_thresh, merge_thresh=merge_thresh, min_length=min_length)
        
        multi_swrs = vdm.find_multi_in_epochs(spikes, swrs, min_involved=3)

        count_matrix = dict()
        for key in field_spikes:
            count_matrix[key] = vdm.spike_counts(field_spikes[key], multi_swrs)

        tetrode_mask = dict()
        for key in field_spikes:
            tetrode_mask[key] = vdm.get_tetrode_mask(field_spikes[key])

        probs = dict()
        for key in field_spikes:
            probs[key] = vdm.compute_cooccur(count_matrix[key], tetrode_mask[key], num_shuffles=10000)

        filename = 'testing_cooccur-' + experiment_time + '.png'
        savepath = os.path.join(output_filepath, filename)
        plot_cooccur(probs, savepath=None)