In [None]:
%matplotlib inline
import matplotlib.pyplot as plt
import numpy as np
from shapely.geometry import Point, LineString
# import itertools
from scipy.interpolate import InterpolatedUnivariateSpline

from load_data import get_pos, get_spikes
from maze_functions import find_zones, trajectory_fields
from plotting_functions import plot_intersects, plot_zone
from decode_functions import get_edges

import vdmlab as vdm

In [None]:
xy = np.array([[2, 7],
               [4, 5],
               [6, 3],
               [8, 1],
               [2, 4]])
time = np.array([0., 1., 2., 3., 4.])
position = vdm.Position(xy, time)

In [None]:
plt.plot(position.x, position.y, 'b.', ms=10)
plt.xlim(0.5, 8.5)
plt.ylim(0.5, 7.5)
plt.show()

In [None]:
spikes = [vdm.SpikeTrain(np.array([3.6, 3.9])), 
          vdm.SpikeTrain(np.array([0., 0., 2.])),
          vdm.SpikeTrain(np.array([2., 2.4]))]

In [None]:
binsize = 2
xedges = np.arange(position.x.min(), position.x.max()+binsize, binsize)
yedges = np.arange(position.y.min(), position.y.max()+binsize, binsize)

tuning_curves = vdm.tuning_curve_2d(position, spikes, xedges, yedges, sampling_rate=1.)

In [None]:
plt.figure()
xx, yy = np.meshgrid(xedges, yedges)
for tuning_curve in tuning_curves:
    pp = plt.pcolormesh(xx, yy, tuning_curve, cmap='YlGn')
    plt.colorbar(pp)
    plt.axis('off')
    plt.show()

In [None]:
counts_binsize = 0.5

time_edges = get_edges(position, counts_binsize, lastbin=True)
counts = vdm.get_counts(spikes, time_edges, apply_filter=False)

In [None]:
print(time_edges)

In [None]:
print(counts)

In [None]:
decoding_tc = []
for tuning_curve in tuning_curves:
    decoding_tc.append(np.ravel(tuning_curve))
decoding_tc = np.array(decoding_tc)

In [None]:
shape = tuning_curves[0].shape

In [None]:
tuning_curves

In [None]:
decoding_tc

In [None]:
likelihood = vdm.bayesian_prob(counts, decoding_tc, counts_binsize)

In [None]:
likelihood

In [None]:
xcenters = (xedges[1:] + xedges[:-1]) / 2.
ycenters = (yedges[1:] + yedges[:-1]) / 2.
xy_centers = vdm.cartesian(xcenters, ycenters)

time_centers = (time_edges[1:] + time_edges[:-1]) / 2.

In [None]:
time_centers

In [None]:
decoded = vdm.decode_location(likelihood, xy_centers, time_centers)

In [None]:
decoded.x, decoded.y, decoded.time

In [None]:
nan_idx = np.logical_and(np.isnan(decoded.x), np.isnan(decoded.y))
decoded = decoded[~nan_idx]

In [None]:
decoded.x, decoded.y, decoded.time

In [None]:
x_spline = InterpolatedUnivariateSpline(position.time, position.x)
y_spline = InterpolatedUnivariateSpline(position.time, position.y)
actual_position = vdm.Position(np.hstack((x_spline(time_centers)[..., np.newaxis],
                                         (y_spline(time_centers)[..., np.newaxis]))), time_centers)

In [None]:
actual_position.x, actual_position.y, actual_position.time

In [None]:
error = np.abs(decoded.data - actual_position.data)

In [None]:
avg_error = np.nanmean(error)
avg_error

## How does the 1D decoding work?

In [None]:
x = np.array([2, 4, 6, 8, 3])
time = np.array([0., 1., 2., 3., 4.])
position = vdm.Position(x, time)

In [None]:
spikes = [vdm.SpikeTrain(np.array([3.6, 3.9])), vdm.SpikeTrain(np.array([2.2, 2.43]))]

In [None]:
pos_binsize = 1
tuning_curves = vdm.tuning_curve(position, spikes, binsize=pos_binsize, sampling_rate=1., gaussian_std=None)

In [None]:
tuning_curves

In [None]:
plt.plot(tuning_curves[0], 'b')
plt.plot(tuning_curves[1], 'm')
plt.show()

In [None]:
counts_binsize = 0.5
time_edges = get_edges(position, counts_binsize, lastbin=True)
counts = vdm.get_counts(spikes, time_edges)

In [None]:
time_edges.shape

In [None]:
counts

In [None]:
likelihood = vdm.bayesian_prob(counts, tuning_curves, counts_binsize)

In [None]:
likelihood

In [None]:
pos_edges = vdm.binned_position(position, pos_binsize)
x_centers = (pos_edges[1:] + pos_edges[:-1]) / 2.
x_centers = x_centers[..., np.newaxis]

time_centers = (time_edges[1:] + time_edges[:-1]) / 2.

decoded = vdm.decode_location(likelihood, x_centers, time_centers)
decoded

In [None]:
nan_idx = np.isnan(decoded.x)
decoded = decoded[~nan_idx]

In [None]:
decoded.x, decoded.time

In [None]:
spline = InterpolatedUnivariateSpline(position.time, position.x)
actual_position = vdm.Position(spline(decoded.time), decoded.time)

In [None]:
actual_position.x, actual_position.time

In [None]:
error = np.abs(decoded.x - actual_position.x)

In [None]:
decoded.x, actual_position.x

In [None]:
avg_error = np.mean(error)
avg_error

# check velocity 1D

In [None]:
x = np.array([2, 4, 6, 8, 3])
time = np.array([0., 1., 2., 3., 4.])
position = vdm.Position(x, time)

spikes = [vdm.SpikeTrain(np.array([3.6, 3.9])), 
          vdm.SpikeTrain(np.array([2.2, 2.4])),
          vdm.SpikeTrain(np.array([0.6, 0.9])),
          vdm.SpikeTrain(np.array([1., 1.1])), 
          vdm.SpikeTrain(np.array([1.7, 1.9]))]

pos_binsize = 1
tuning_curves = vdm.tuning_curve(position, spikes, binsize=pos_binsize, sampling_rate=1., gaussian_std=None)

counts_binsize = 0.5
time_edges = get_edges(position, counts_binsize, lastbin=True)
counts = vdm.get_counts(spikes, time_edges)

likelihood = vdm.bayesian_prob(counts, tuning_curves, counts_binsize)

pos_edges = vdm.binned_position(position, pos_binsize)
x_centers = (pos_edges[1:] + pos_edges[:-1]) / 2.
x_centers = x_centers[..., np.newaxis]

time_centers = (time_edges[1:] + time_edges[:-1]) / 2.

decoded = vdm.decode_location(likelihood, x_centers, time_centers)

nan_idx = np.isnan(decoded.x)
decoded = decoded[~nan_idx]

print(decoded.x, decoded.time)

In [None]:
decode_jumps = vdm.remove_teleports(decoded, speed_thresh=0, min_length=1)

In [None]:
decode_jumps.time, decode_jumps.x

## Counts filtering parameter check

In [None]:
std = [0.1, 0.025, 0.01, 0.002, None, 0.5, 1.0]
error = [29.9113296726, 29.3032199091, 45.8427697608, 45.8427697608, 45.8427697608, 28.623598705, 29.0339763118]

In [None]:
plt.plot(std, error, '.', ms=15)
plt.xlim(-0.1, 1.1)
plt.ylim(25, 48)
plt.show()

## Other stuff

In [None]:
from load_data import get_pos, get_spikes, get_lfp

import info.R063d2_info as r063d2
import info.R063d3_info as r063d3

In [None]:
%matplotlib inline
import os
import numpy as np
import matplotlib.pyplot as plt
import pickle
from scipy.interpolate import InterpolatedUnivariateSpline
import random
import seaborn as sns

import vdmlab as vdm

from load_data import get_pos, get_spikes
from maze_functions import find_zones
from tuning_curves_functions import get_tc_1d, find_ideal
from decode_functions import get_edges, point_in_zones, compare_rates
from plotting_functions import plot_compare_decoded_track

In [None]:
# pickle_filepath = 'C:\\Users\\Emily\\Code\\emi_shortcut\\cache\\pickled\\'
# output_filepath = 'C:\\Users\\Emily\\Code\\emi_shortcut\\plots\\'
pickle_filepath = 'E:\\code\\emi_shortcut\\cache\\pickled\\'
output_filepath = 'E:\\code\\emi_shortcut\\plots\\'

In [None]:
infos = [r063d2, r063d3]
# infos = [r063d2, r063d3, r063d4, r063d5, r063d6,
#          r066d1, r066d2, r066d3, r066d4, r066d5,
#          r067d1, r067d2, r067d3, r067d4, r067d5,
#          r068d1, r068d2, r068d3, r068d4, r068d5]

shuffle_id = False
pauseB = False

combined_errors = []
combined_actual = dict(u=[], shortcut=[], novel=[], other=[], together=[])
combined_decoded = dict(u=[], shortcut=[], novel=[], other=[], together=[])

for info in infos:
    print(info.session_id)
    position = get_pos(info.pos_mat, info.pxl_to_cm)
    spikes = get_spikes(info.spike_mat)

    speed = position.speed(t_smooth=0.5)
    run_idx = np.squeeze(speed.data) >= 0.5
    run_pos = position[run_idx]

    # track_starts = [info.task_times['phase1'].start, info.task_times['phase2'].start, info.task_times['phase3'].start]
    # track_stops = [info.task_times['phase1'].stop, info.task_times['phase2'].stop, info.task_times['phase3'].stop]

    track_start = info.task_times['phase3'].start
    track_stop = info.task_times['phase3'].stop

    track_pos = run_pos.time_slice(track_start, track_stop)

    track_spikes = [spiketrain.time_slice(track_start, track_stop) for spiketrain in spikes]

    binsize = 3
    xedges = np.arange(track_pos.x.min(), track_pos.x.max() + binsize, binsize)
    yedges = np.arange(track_pos.y.min(), track_pos.y.max() + binsize, binsize)

    tuning_curves = vdm.tuning_curve_2d(track_pos, track_spikes, xedges, yedges, gaussian_sigma=0.1)
    if shuffle_id:
        random.shuffle(tuning_curves)

    if pauseB:
        decode_spikes = [spiketrain.time_slice(info.task_times['pauseB'].start, info.task_times['pauseB'].stop)
                         for spiketrain in spikes]
    else:
        decode_spikes = track_spikes

    counts_binsize = 0.025
    time_edges = get_edges(run_pos, counts_binsize, lastbin=True)
    counts = vdm.get_counts(decode_spikes, time_edges, gaussian_std=0.025)

    decoding_tc = []
    for tuning_curve in tuning_curves:
        decoding_tc.append(np.ravel(tuning_curve))
    decoding_tc = np.array(decoding_tc)

    likelihood = vdm.bayesian_prob(counts, decoding_tc, counts_binsize)

    xcenters = (xedges[1:] + xedges[:-1]) / 2.
    ycenters = (yedges[1:] + yedges[:-1]) / 2.
    xy_centers = vdm.cartesian(xcenters, ycenters)

    time_centers = (time_edges[1:] + time_edges[:-1]) / 2.

    decoded = vdm.decode_location(likelihood, xy_centers, time_centers)
    nan_idx = np.logical_and(np.isnan(decoded.x), np.isnan(decoded.y))
    decoded = decoded[~nan_idx]

    if not decoded.isempty:
        decoded = vdm.remove_teleports(decoded, speed_thresh=10, min_length=3)

    actual_x = np.interp(decoded.time, track_pos.time, track_pos.x)
    actual_y = np.interp(decoded.time, track_pos.time, track_pos.y)

    actual_position = vdm.Position(np.hstack((actual_x[..., np.newaxis], actual_y[..., np.newaxis])), decoded.time)

    errors = actual_position.distance(decoded)

    zones = find_zones(info, expand_by=7)
    actual_zones = point_in_zones(actual_position, zones)
    decoded_zones = point_in_zones(decoded, zones)


    combined_errors.append(np.mean(errors))

    combined_actual['u'].append(actual_zones['u'])
    combined_actual['shortcut'].append(actual_zones['shortcut'])
    combined_actual['novel'].append(actual_zones['novel'])
    combined_actual['other'].append(actual_zones['other'])
    combined_actual['together'].append(len(actual_zones['u'].time) + len(actual_zones['shortcut'].time) +
                                       len(actual_zones['novel'].time) + len(actual_zones['other'].time))

    combined_decoded['u'].append(decoded_zones['u'])
    combined_decoded['shortcut'].append(decoded_zones['shortcut'])
    combined_decoded['novel'].append(decoded_zones['novel'])
    combined_decoded['other'].append(decoded_zones['other'])
    combined_decoded['together'].append(len(decoded_zones['u'].time) + len(decoded_zones['shortcut'].time) +
                                        len(decoded_zones['novel'].time) + len(decoded_zones['other'].time))


def compare_decoded_actual(combined_actual, combined_decoded, shuffle_id, pauseB, output_filepath):
    keys = ['u', 'shortcut', 'novel', 'other']

    actual = dict(u=[], shortcut=[], novel=[], other=[], together=[])
    decode = dict(u=[], shortcut=[], novel=[], other=[], together=[])

    n_sessions = len(combined_actual['together'])

    for key in keys:
        if len(combined_actual['together']) != len(combined_decoded['together']):
            raise ValueError("must have same number of decoded and actual samples")

        for val in range(n_sessions):
            actual[key].append(len(combined_actual[key][val].time)/combined_actual['together'][val])
            decode[key].append(len(combined_decoded[key][val].time)/combined_decoded['together'][val])

    if shuffle_id and not pauseB:
        filename = 'combined-phase3-id_shuffle_decoded.png'
    elif shuffle_id and pauseB:
        filename = 'combined-pauseB-id_shuffle_decoded.png'
    elif pauseB and not shuffle_id:
        filename = 'combined-pauseB-id_decoded.png'
    else:
        filename = 'combined-phase3_decoded.png'
    savepath = os.path.join(output_filepath, filename)

    if pauseB:
        plot_compare_decoded_track(decode, max_y=1.0, savepath=savepath)
    else:
        plot_compare_decoded_track(decode, actual, distance=str(round(np.mean(combined_errors), 2)),
                                   max_y=1.0, savepath=savepath)

    return combined_errors


def normalized_time_spent(combined_actual, combined_decoded, shuffle_id, output_filepath):

    keys = ['u', 'shortcut', 'novel', 'other']

    normalized_actual = dict(u=[], shortcut=[], novel=[], other=[])
    normalized_decoded = dict(u=[], shortcut=[], novel=[], other=[])

    n_sessions = len(combined_actual['together'])

    for val in range(n_sessions):
        actual = dict()
        decode = dict()
        for key in keys:
            actual[key] = combined_actual[key][val]
            decode[key] = combined_decoded[key][val]
        norm_actual = compare_rates(actual)
        norm_decoded = compare_rates(decode)
        for key in keys:
            normalized_actual[key].append(norm_actual[key])
            normalized_decoded[key].append(norm_decoded[key])

    if shuffle_id:
        filename = 'combined-norm_phase3-id_shuffle_decoded.png'
    else:
        filename = 'combined-norm_phase3_decoded.png'
    savepath = os.path.join(output_filepath, filename)

    y_label = 'Points normalized by time spent'
    plot_compare_decoded_track(normalized_actual, normalized_decoded, y_label=y_label, max_y=60., savepath=savepath)


compare_decoded_actual(combined_actual, combined_decoded, shuffle_id, pauseB, output_filepath)
# plot_normalized = normalized_time_spent(combined_actual, combined_decoded, n_sessions, shuffle_id, output_filepath)


In [None]:
def compare_decoded_actual(combined_actual, combined_decoded, shuffle_id, pauseB, output_filepath):
    keys = ['u', 'shortcut', 'novel', 'other']

    actual = dict(u=[], shortcut=[], novel=[], other=[], together=[])
    decode = dict(u=[], shortcut=[], novel=[], other=[], together=[])

    n_sessions = len(combined_actual['together'])

    for key in keys:
        if len(combined_actual['together']) != len(combined_decoded['together']):
            raise ValueError("must have same number of decoded and actual samples")

        for val in range(n_sessions):
            actual[key].append(len(combined_actual[key][val].time)/combined_actual['together'][val])
            decode[key].append(len(combined_decoded[key][val].time)/combined_decoded['together'][val])

    if shuffle_id and not pauseB:
        filename = 'combined-phase3-id_shuffle_decoded.png'
    elif shuffle_id and pauseB:
        filename = 'combined-pauseB-id_shuffle_decoded.png'
    elif pauseB and not shuffle_id:
        filename = 'combined-pauseB-id_decoded.png'
    else:
        filename = 'combined-phase3_decoded.png'
    savepath = os.path.join(output_filepath, filename)

    if pauseB:
        plot_compare_decoded_track(decode, max_y=1.0, savepath=savepath)
    else:
        plot_compare_decoded_track(decode, actual, distance=str(round(np.mean(combined_errors), 2)),
                                   max_y=1.0, savepath=savepath)

    return combined_errors

In [None]:
compare_decoded_actual(combined_actual, combined_decoded, shuffle_id, pauseB, output_filepath)

In [None]:
val = 0
key = 'novel'
((len(combined_actual[key][val].time)/combined_actual['together'][val]) + info.track_length[key]/total_tracks)

In [None]:
((len(combined_actual[key][val].time)/combined_actual['together'][val]))

In [None]:
len(combined_actual[key][val].time)

In [None]:
info.track_length[key]/total_tracks

In [None]:
total_tracks = info.track_length['u'] + info.track_length['shortcut'] + info.track_length['novel'] 

In [None]:
(1*2)/(4+5)

In [None]:
1/3 + 1/2

In [None]:
5/6

In [None]:
u_proportion = info.track_length['u']/total_tracks
shortcut_proportion = info.track_length['shortcut']/total_tracks
novel_proportion = info.track_length['novel']/total_tracks
u_proportion + shortcut_proportion + novel_proportion

In [None]:
(0.4 + 0.6)/2

In [None]:
(374-272)/info.pxl_to_cm[1] + (663-448)/info.pxl_to_cm[0] + (274-76)/info.pxl_to_cm[1]

In [None]:
(220-118)/info.pxl_to_cm[0] + (286-45)/info.pxl_to_cm[1]

In [None]:
info = r063d3

In [None]:
(480-364)/info.pxl_to_cm[1] + (551-367)/info.pxl_to_cm[0] + (382-51)/info.pxl_to_cm[1] + (691-228)/info.pxl_to_cm[0]

## Other

In [None]:
pos = np.random.rand(20, 2)

In [None]:
pos

In [None]:
split_idx = np.array([4, 6, 10])

In [None]:
all_idx = [idx for idx in np.split(np.arange(pos.shape[0]), split_idx) if idx.size > 3]

In [None]:
pos[np.hstack(all_idx)]

In [None]:
time = np.linspace(0, np.pi*2, 201)
data = np.hstack((np.sin(time)))

In [None]:
plt.plot(time, data, '.')
plt.show()

In [None]:
position = vdm.Position(data, time)

In [None]:
speed = position.speed()

In [None]:
plt.plot(speed.time, speed.data)
plt.show()

In [None]:
run_idx = np.squeeze(speed.data) >= 0.5

In [None]:
run_idx

In [None]:
position = vdm.Position(data, time)
speed = position.speed()
run_idx = np.squeeze(speed.data) >= 0.7
run_position = position[run_idx]

len(run_position.x)

In [None]:
assert np.allclose(len(run_position.x), 100)

In [None]:
velocity = self[1:].distance(self[:-1])
velocity /= np.diff(self.time)
velocity = np.hstack(([0], velocity))

if t_smooth is not None:
    dt = np.median(np.diff(self.time))
    filter_length = np.ceil(t_smooth / dt)
    velocity = np.convolve(velocity, np.ones(int(filter_length))/filter_length, 'same')

return AnalogSignal(velocity, self.time)

In [None]:
t_smooth=0.5
velocity = np.diff(np.squeeze(position.data))
velocity /= np.diff(position.time)
velocity = np.hstack(([0], velocity))

dt = np.median(np.diff(position.time))
filter_length = np.ceil(t_smooth / dt)
velocity = np.convolve(velocity, np.ones(int(filter_length))/filter_length, 'same')

In [None]:
velocity

In [None]:
a = np.array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])
a

In [None]:
np.hstack((a[0:2], a[6:8]))

In [None]:
indices = []
for t_start, t_stop in zip(np.array([0, 6]), np.array([2, 8])):
    indices.append((a >= t_start) & (a <= t_stop))
indices = np.any(np.column_stack((indices)),axis=1)
a[indices]

In [None]:
res = np.any(np.column_stack((indices)),axis=1)
res

In [None]:
spikes = [vdm.SpikeTrain(np.array([0., 0., 1.])), vdm.SpikeTrain(np.array([3.6, 3.9]))]

In [None]:
one_line = LineString([[2, 0], [2, 2], [2, 4], [2, 6], [2, 8], [2, 10]])

one_start = Point([2, 2])
one_stop = Point([2, 8])

expand_by = 1

one_zone = vdm.expand_line(one_start, one_stop, one_line, expand_by)

In [None]:
this_idx = []
for pos_idx in range(len(position.time)):
    point = Point([position.x[pos_idx], position.y[pos_idx]])
    if one_zone.contains(point):
        this_idx.append(pos_idx)
        
this_pos = position[this_idx]
linear = this_pos.linearize(one_line, one_zone)

In [None]:
linear.x

In [None]:
plt.plot(position.x, position.y, 'g.', ms=10)
# plt.plot([2, 4, 2], [7, 5, 4], 'm.', ms=20)
plt.plot([2, 6], [4, 3], 'm.', ms=20)
plt.plot(one_zone.exterior.xy[0], one_zone.exterior.xy[1], 'b', lw=1)
plt.xlim(-1, 10)
plt.ylim(-1, 10)
plt.show()

In [None]:
def expand_line(start_pt, stop_pt, line, expand_by):
    line_expanded = line.buffer(expand_by)
    zone = start_pt.union(line_expanded).union(stop_pt)
    
    return zone

def find_zones(info, expand_by=6):
    u_line = LineString(info.u_trajectory)
    shortcut_line = LineString(info.shortcut_trajectory)
    novel_line = LineString(info.novel_trajectory)

    u_start = Point(info.u_trajectory[0])
    u_stop = Point(info.u_trajectory[-1])
    shortcut_start = Point(info.shortcut_trajectory[0])
    shortcut_stop = Point(info.shortcut_trajectory[-1])
    novel_start = Point(info.novel_trajectory[0])
    novel_stop = Point(info.novel_trajectory[-1])
    pedestal_center = Point(info.path_pts['pedestal'][0], info.path_pts['pedestal'][1])
    pedestal = pedestal_center.buffer(expand_by*2.2)

    zone = dict()
    zone['u'] = expand_line(u_start, u_stop, u_line, expand_by)
    zone['shortcut'] = expand_line(shortcut_start, shortcut_stop, shortcut_line, expand_by)
    zone['novel'] = expand_line(novel_start, novel_stop, novel_line, expand_by)
    zone['ushort'] = zone['u'].intersection(zone['shortcut'])
    zone['unovel'] = zone['u'].intersection(zone['novel'])
    zone['uped'] = zone['u'].intersection(pedestal)
    zone['shortped'] = zone['shortcut'].intersection(pedestal)
    zone['novelped'] = zone['novel'].intersection(pedestal)
    zone['pedestal'] = pedestal
    
    return zone

In [None]:
def trajectory_fields(tuning_curves, zone, xedges, yedges, field_thresh):
    
    xcenters = np.array((xedges[1:] + xedges[:-1]) / 2.)
    ycenters = np.array((yedges[1:] + yedges[:-1]) / 2.)
    
    tuning_points = []
    for i in itertools.product(ycenters, xcenters):
        tuning_points.append(i)
    tuning_points = np.array(tuning_points)

    this_neuron = 0
    fields_tc = dict(u=[], shortcut=[], novel=[], pedestal=[])
    fields_neuron = dict(u=[], shortcut=[], novel=[], pedestal=[])
    for neuron_tc in tuning_curves:
        this_neuron += 1
        field_idx = neuron_tc.flatten() > field_thresh
        field = tuning_points[field_idx]
        for pt in field:
            point = Point([pt[0], pt[1]])
            if zone['u'].contains(point) or zone['ushort'].contains(point) or zone['unovel'].contains(point):
                if this_neuron not in fields_neuron['u']:
                    fields_tc['u'].append(neuron_tc)
                    fields_neuron['u'].append(this_neuron)
            if zone['shortcut'].contains(point) or zone['shortped'].contains(point):
                if this_neuron not in fields_neuron['shortcut']:
                    fields_tc['shortcut'].append(neuron_tc)
                    fields_neuron['shortcut'].append(this_neuron)
            if zone['novel'].contains(point) or zone['novelped'].contains(point):
                if this_neuron not in fields_neuron['novel']:
                    fields_tc['novel'].append(neuron_tc)
                    fields_neuron['novel'].append(this_neuron)
            if zone['pedestal'].contains(point):
                if this_neuron not in fields_neuron['pedestal']:
                    fields_tc['pedestal'].append(neuron_tc)
                    fields_neuron['pedestal'].append(this_neuron)
                
    return fields_tc

In [None]:
import sys
# sys.path.append('C:\\Users\\Emily\\Code\\emi_shortcut\\info')
sys.path.append('E:\\code\\emi_shortcut\\info')
import info.R063d3_info as r063d3
info = r063d3

In [None]:
pickle_filepath = 'E:\\code\\emi_shortcut\\cache\\pickled'
# pickle_filepath = 'C:\\Users\\Emily\\Code\\emi_shortcut\\cache\\pickled'

In [None]:
position = get_pos(info.pos_mat, info.pxl_to_cm)
spikes = get_spikes(info.spike_mat)

In [None]:
binsize = 3
xedges = np.arange(position.x.min(), position.x.max()+binsize, binsize)
yedges = np.arange(position.y.min(), position.y.max()+binsize, binsize)

speed = position.speed(t_smooth=0.5)
run_idx = np.squeeze(speed.data) >= info.run_threshold
run_pos = position[run_idx]

t_start = info.task_times['phase3'].start
t_stop = info.task_times['phase3'].stop

sliced_pos = run_pos.time_slice(t_start, t_stop)

sliced_spikes = [spiketrain.time_slice(t_start, t_stop) for spiketrain in spikes]

tuning_curves = vdm.tuning_curve_2d(sliced_pos, sliced_spikes, xedges, yedges, gaussian_sigma=0.2)

In [None]:
zones = find_zones(info)

In [None]:
type(zones['u'])

In [None]:
fields_tc = trajectory_fields(tuning_curves, zones, xedges, yedges, field_thresh=5)

In [None]:
print(len(fields_tc['novel']))

In [None]:
tuning_curves[5]

In [None]:
plt.figure()
xx, yy = np.meshgrid(xedges, yedges)
for tuning_curve in fields_tc['novel']:
    pp = plt.pcolormesh(xx, yy, tuning_curve, cmap='YlGn')
    plt.colorbar(pp)
    plt.axis('off')
    plt.show()