In [None]:
%matplotlib inline
import os
import numpy as np
import matplotlib.pyplot as plt
import pickle
from scipy.interpolate import InterpolatedUnivariateSpline
import random
import seaborn as sns
from shapely.geometry import Point

import vdmlab as vdm

from load_data import get_pos, get_spikes, get_lfp
# from tuning_curves_functions import get_tc_1d
# from field_functions import unique_fields
from maze_functions import trajectory_fields, find_zones
from plotting_functions import plot_cooccur

import info.R063d2_info as r063d2
import info.R063d3_info as r063d3

In [None]:
# pickle_filepath = 'C:\\Users\\Emily\\Code\\emi_shortcut\\cache\\pickled\\'
# output_filepath = 'C:\\Users\\Emily\\Code\\emi_shortcut\\plots\\'
output_filepath = 'E:\\code\\emi_shortcut\\cache\\pickled\\'
output_filepath = 'E:\\code\\emi_shortcut\\plots\\'

In [None]:
def trajectory_fields(tuning_curves, spikes, zone, xedges, yedges, field_thresh):
    """Finds track tuning curves that have firing above field_thresh.

    Parameters
    ----------
    tuning_curves : list of np.arrays
    spikes: list of vdmlab.SpikeTrain objects
    zone : shapely.Polygon
    xedges : np.array
    yedges : np.array
    field_thresh : float
        Threshold (in Hz) that determines whether the neuron has a field.

    Returns
    -------
    fields_tc : dict
        With u, shortcut, novel, pedestal as keys. Values are np.arrays.
    """
    xcenters = (xedges[1:] + xedges[:-1]) / 2.
    ycenters = (yedges[1:] + yedges[:-1]) / 2.

    xy_centers = vdm.cartesian(xcenters, ycenters)

    in_u = []
    in_shortcut = []
    in_novel = []
    in_pedestal = []

    fields_tc = dict(u=[], shortcut=[], novel=[], pedestal=[])
    fields_neuron = dict(u=[], shortcut=[], novel=[], pedestal=[])
    for i, neuron_tc in enumerate(tuning_curves):
        field_idx = neuron_tc.flatten() > field_thresh
        field = xy_centers[field_idx]
        for pt in field:
            point = Point([pt[0], pt[1]])
            if zone['u'].contains(point) or zone['ushort'].contains(point) or zone['unovel'].contains(point):
                if i not in in_u:
                    in_u.append(i)
                    fields_tc['u'].append(neuron_tc)
                    fields_neuron['u'].append(spikes[i])
            if zone['shortcut'].contains(point) or zone['shortped'].contains(point):
                if i not in in_shortcut:
                    in_shortcut.append(i)
                    fields_tc['shortcut'].append(neuron_tc)
                    fields_neuron['shortcut'].append(spikes[i])
            if zone['novel'].contains(point) or zone['novelped'].contains(point):
                if i not in in_novel:
                    in_novel.append(i)
                    fields_tc['novel'].append(neuron_tc)
                    fields_neuron['novel'].append(spikes[i])
            if zone['pedestal'].contains(point):
                if i not in in_pedestal:
                    in_pedestal.append(i)
                    fields_tc['pedestal'].append(neuron_tc)
                    fields_neuron['pedestal'].append(spikes[i])

    return fields_tc, fields_neuron

In [None]:
infos = [r063d3]
experiment_times = ['pauseA']

for info in infos:
    print(info.session_id)
    for experiment_time in experiment_times:
        print(experiment_time)
        
        lfp = get_lfp(info.good_swr[0])
        position = get_pos(info.pos_mat, info.pxl_to_cm)
        spikes = get_spikes(info.spike_mat)

        speed = position.speed(t_smooth=0.5)
        run_idx = np.squeeze(speed.data) >= info.run_threshold
        run_pos = position[run_idx]

        t_start = info.task_times[experiment_time].start
        t_stop = info.task_times[experiment_time].stop
        
        sliced_lfp = lfp.time_slice(t_start, t_stop)
        
        sliced_spikes = [spiketrain.time_slice(t_start, t_stop) for spiketrain in spikes]
        
        t_start_tc = info.task_times['phase3'].start
        t_stop_tc = info.task_times['phase3'].stop
        
        tc_pos = run_pos.time_slice(t_start_tc, t_stop_tc)

        tc_spikes = [spiketrain.time_slice(t_start_tc, t_stop_tc) for spiketrain in spikes]

        binsize = 3
        xedges = np.arange(tc_pos.x.min(), tc_pos.x.max() + binsize, binsize)
        yedges = np.arange(tc_pos.y.min(), tc_pos.y.max() + binsize, binsize)

        tuning_curves = vdm.tuning_curve_2d(tc_pos, tc_spikes, xedges, yedges, gaussian_sigma=0.1)
        
        zones = find_zones(info)
        
        fields_tc, fields_spikes = trajectory_fields(tuning_curves, tc_spikes, zones, xedges, yedges, field_thresh=1.)
        
#         swrs = vdm.detect_swr_hilbert(sliced_lfp, fs=info.fs, thresh=(140.0, 250.0), power_thres=5)
        
#         swr_intervals = []
#         for swr in swrs:
#             swr_intervals.append([swr.time[0], swr.time[-1]])
#         swr_intervals = np.array(swr_intervals).T
        
#         count_matrix = dict()
#         for key in fields_spikes:
#             count_matrix[key] = vdm.spike_counts(fields_spikes[key], swr_intervals, window=0.1)
            
#         tetrode_mask = dict()
#         for key in fields_spikes:
#             tetrode_mask[key] = vdm.get_tetrode_mask(fields_spikes[key])
        
#         probs = dict()
#         for key in fields_spikes:
#             probs[key] = vdm.compute_cooccur(count_matrix[key], tetrode_mask[key], num_shuffles=10000)
            
#         filename = 'testing_cooccur-' + experiment_time + '.png'
#         savepath = os.path.join(output_filepath, filename)
#         plot_cooccur(probs, savepath)

In [None]:
len(fields_spikes['u']), len(fields_spikes['shortcut']), len(fields_spikes['novel'])

In [None]:
field_thresh = 1.
zone = zones

xcenters = (xedges[1:] + xedges[:-1]) / 2.
ycenters = (yedges[1:] + yedges[:-1]) / 2.

xy_centers = vdm.cartesian(xcenters, ycenters)

in_u = []
in_shortcut = []
in_novel = []
in_pedestal = []

fields_tc = dict(u=[], shortcut=[], novel=[], pedestal=[])
fields_neuron = dict(u=[], shortcut=[], novel=[], pedestal=[])
for i, neuron_tc in enumerate(tuning_curves):
    field_idx = np.ravel(neuron_tc) > field_thresh
    field = xy_centers[field_idx]
    for pt in field:
        point = Point([pt[0], pt[1]])
        if zone['novel'].contains(point) or zone['novelped'].contains(point):
            if i not in in_novel:
                in_novel.append(i)
                fields_tc['novel'].append(neuron_tc)
                fields_neuron['novel'].append(spikes[i])
        elif zone['shortcut'].contains(point) or zone['shortped'].contains(point):
            if i not in in_shortcut:
                in_shortcut.append(i)
                fields_tc['shortcut'].append(neuron_tc)
                fields_neuron['shortcut'].append(spikes[i])
        elif zone['u'].contains(point) or zone['ushort'].contains(point) or zone['unovel'].contains(point):
            if i not in in_u:
                in_u.append(i)
                fields_tc['u'].append(neuron_tc)
                fields_neuron['u'].append(spikes[i])
        elif zone['pedestal'].contains(point):
            if i not in in_pedestal:
                in_pedestal.append(i)
                fields_tc['pedestal'].append(neuron_tc)
                fields_neuron['pedestal'].append(spikes[i])

In [None]:
len(fields_neuron['u']), len(fields_neuron['shortcut']), len(fields_neuron['novel']), len(fields_neuron['pedestal'])

In [None]:
from plotting_functions import plot_intersects, plot_zone

for zone in zones:
    if zones[zone].geom_type == 'MultiPolygon':
        plot_intersects(zones[zone])
    elif zones[zone].geom_type == 'Polygon':
        plot_zone(zones[zone])
    else:
        continue

plt.plot(77.56975984316288, 66.92093904767174, '.', ms=20)

In [None]:
from scipy.ndimage import filters
from scipy.ndimage.morphology import generate_binary_structure, binary_erosion
def find_fields_2d(tuning_curves, neighborhood_size, threshold):
    fields_mask = []
    for tuning_curve in tuning_curves:
        max_points = filters.maximum_filter(tuning_curve, neighborhood_size)
        maxima = (tuning_curve == max_points)
        min_points = filters.minimum_filter(tuning_curve, neighborhood_size)
        diff = ((max_points - min_points) > threshold)
        maxima[diff == 0] = 0
        fields_mask.append(maxima)
    return fields_mask

In [None]:
neighborhood_size = 100
threshold = 5
fields_mask = find_fields_2d(tuning_curves, neighborhood_size, threshold)

In [None]:
fields_mask[9]