In [None]:
%matplotlib inline
import matplotlib.pyplot as plt
import numpy as np
from shapely.geometry import Point, LineString
# import itertools
from scipy.interpolate import InterpolatedUnivariateSpline

from load_data import get_pos, get_spikes
from maze_functions import find_zones, trajectory_fields
from plotting_functions import plot_intersects, plot_zone
from decode_functions import get_edges

import vdmlab as vdm

In [None]:
xy = np.array([[2, 7],
               [4, 5],
               [6, 3],
               [8, 1],
               [2, 4]])
time = np.array([0., 1., 2., 3., 4.])
position = vdm.Position(xy, time)

In [None]:
plt.plot(position.x, position.y, 'b.', ms=10)
plt.xlim(0.5, 8.5)
plt.ylim(0.5, 7.5)
plt.show()

In [None]:
spikes = [vdm.SpikeTrain(np.array([3.6, 3.9])), 
          vdm.SpikeTrain(np.array([0., 0., 2.])),
          vdm.SpikeTrain(np.array([2., 2.4]))]

In [None]:
binsize = 2
xedges = np.arange(position.x.min(), position.x.max()+binsize, binsize)
yedges = np.arange(position.y.min(), position.y.max()+binsize, binsize)

tuning_curves = vdm.tuning_curve_2d(position, spikes, xedges, yedges, sampling_rate=1.)

In [None]:
plt.figure()
xx, yy = np.meshgrid(xedges, yedges)
for tuning_curve in tuning_curves:
    pp = plt.pcolormesh(xx, yy, tuning_curve, cmap='YlGn')
    plt.colorbar(pp)
    plt.axis('off')
    plt.show()

In [None]:
counts_binsize = 0.5

time_edges = get_edges(position, counts_binsize, lastbin=True)
counts = vdm.get_counts(spikes, time_edges, apply_filter=False)

In [None]:
print(time_edges)

In [None]:
print(counts)

In [None]:
decoding_tc = []
for tuning_curve in tuning_curves:
    decoding_tc.append(np.ravel(tuning_curve))
decoding_tc = np.array(decoding_tc)

In [None]:
shape = tuning_curves[0].shape

In [None]:
tuning_curves

In [None]:
decoding_tc

In [None]:
likelihood = vdm.bayesian_prob(counts, decoding_tc, counts_binsize)

In [None]:
likelihood

In [None]:
xcenters = (xedges[1:] + xedges[:-1]) / 2.
ycenters = (yedges[1:] + yedges[:-1]) / 2.
xy_centers = vdm.cartesian(xcenters, ycenters)

time_centers = (time_edges[1:] + time_edges[:-1]) / 2.

In [None]:
time_centers

In [None]:
decoded = vdm.decode_location(likelihood, xy_centers, time_centers)

In [None]:
decoded.x, decoded.y, decoded.time

In [None]:
nan_idx = np.logical_and(np.isnan(decoded.x), np.isnan(decoded.y))
decoded = decoded[~nan_idx]

In [None]:
decoded.x, decoded.y, decoded.time

In [None]:
x_spline = InterpolatedUnivariateSpline(position.time, position.x)
y_spline = InterpolatedUnivariateSpline(position.time, position.y)
actual_position = vdm.Position(np.hstack((x_spline(time_centers)[..., np.newaxis],
                                         (y_spline(time_centers)[..., np.newaxis]))), time_centers)

In [None]:
actual_position.x, actual_position.y, actual_position.time

In [None]:
error = np.abs(decoded.data - actual_position.data)

In [None]:
avg_error = np.nanmean(error)
avg_error

## How does the 1D decoding work?

In [None]:
x = np.array([2, 4, 6, 8, 3])
time = np.array([0., 1., 2., 3., 4.])
position = vdm.Position(x, time)

In [None]:
spikes = [vdm.SpikeTrain(np.array([3.6, 3.9])), vdm.SpikeTrain(np.array([2.2, 2.43]))]

In [None]:
pos_binsize = 1
tuning_curves = vdm.tuning_curve(position, spikes, binsize=pos_binsize, sampling_rate=1., gaussian_std=None)

In [None]:
tuning_curves

In [None]:
plt.plot(tuning_curves[0], 'b')
plt.plot(tuning_curves[1], 'm')
plt.show()

In [None]:
counts_binsize = 0.5
time_edges = get_edges(position, counts_binsize, lastbin=True)
counts = vdm.get_counts(spikes, time_edges)

In [None]:
time_edges.shape

In [None]:
counts

In [None]:
likelihood = vdm.bayesian_prob(counts, tuning_curves, counts_binsize)

In [None]:
likelihood

In [None]:
pos_edges = vdm.binned_position(position, pos_binsize)
x_centers = (pos_edges[1:] + pos_edges[:-1]) / 2.
x_centers = x_centers[..., np.newaxis]

time_centers = (time_edges[1:] + time_edges[:-1]) / 2.

decoded = vdm.decode_location(likelihood, x_centers, time_centers)
decoded

In [None]:
nan_idx = np.isnan(decoded.x)
decoded = decoded[~nan_idx]

In [None]:
decoded.x, decoded.time

In [None]:
spline = InterpolatedUnivariateSpline(position.time, position.x)
actual_position = vdm.Position(spline(decoded.time), decoded.time)

In [None]:
actual_position.x, actual_position.time

In [None]:
error = np.abs(decoded.x - actual_position.x)

In [None]:
decoded.x, actual_position.x

In [None]:
avg_error = np.mean(error)
avg_error

# check velocity 1D

In [None]:
x = np.array([2, 4, 6, 8, 3])
time = np.array([0., 1., 2., 3., 4.])
position = vdm.Position(x, time)

spikes = [vdm.SpikeTrain(np.array([3.6, 3.9])), 
          vdm.SpikeTrain(np.array([2.2, 2.4])),
          vdm.SpikeTrain(np.array([0.6, 0.9])),
          vdm.SpikeTrain(np.array([1., 1.1])), 
          vdm.SpikeTrain(np.array([1.7, 1.9]))]

pos_binsize = 1
tuning_curves = vdm.tuning_curve(position, spikes, binsize=pos_binsize, sampling_rate=1., gaussian_std=None)

counts_binsize = 0.5
time_edges = get_edges(position, counts_binsize, lastbin=True)
counts = vdm.get_counts(spikes, time_edges)

likelihood = vdm.bayesian_prob(counts, tuning_curves, counts_binsize)

pos_edges = vdm.binned_position(position, pos_binsize)
x_centers = (pos_edges[1:] + pos_edges[:-1]) / 2.
x_centers = x_centers[..., np.newaxis]

time_centers = (time_edges[1:] + time_edges[:-1]) / 2.

decoded = vdm.decode_location(likelihood, x_centers, time_centers)

nan_idx = np.isnan(decoded.x)
decoded = decoded[~nan_idx]

print(decoded.x, decoded.time)

In [None]:
decode_jumps = vdm.remove_teleports(decoded, speed_thresh=0, min_length=1)

In [None]:
decode_jumps.time, decode_jumps.x

## Counts filtering parameter check

In [None]:
std = [0.1, 0.025, 0.01, 0.002, None, 0.5, 1.0]
error = [29.9113296726, 29.3032199091, 45.8427697608, 45.8427697608, 45.8427697608, 28.623598705, 29.0339763118]

In [None]:
plt.plot(std, error, '.', ms=15)
plt.xlim(-0.1, 1.1)
plt.ylim(25, 48)
plt.show()

## Other stuff

In [None]:

from load_data import get_pos, get_spikes, get_lfp

import info.R063d2_info as r063d2
import info.R063d3_info as r063d3
info = r063d2

In [None]:
%matplotlib inline
import os
import numpy as np
import matplotlib.pyplot as plt
import pickle
from scipy.interpolate import InterpolatedUnivariateSpline
import random
import seaborn as sns

import vdmlab as vdm

from load_data import get_pos, get_spikes
from maze_functions import find_zones
from tuning_curves_functions import get_tc_1d, find_ideal
from decode_functions import get_edges
from plotting_functions import plot_compare_decoded_track

In [None]:
pickle_filepath = 'C:\\Users\\Emily\\Code\\emi_shortcut\\cache\\pickled\\'
output_filepath = 'C:\\Users\\Emily\\Code\\emi_shortcut\\plots\\'

In [None]:
infos = [r063d2]
# infos = [r063d2, r063d3, r063d4, r063d5, r063d6, r066d1, r066d2, r066d3, r066d4, r067d1]


for info in infos:
    print(info.session_id)
    position = get_pos(info.pos_mat, info.pxl_to_cm)
    spikes = get_spikes(info.spike_mat)

    speed = position.speed(t_smooth=0.5)
    run_idx = np.squeeze(speed.data) >= info.run_threshold
    run_pos = position[run_idx]

    track_start = info.task_times['phase3'].start
    track_stop = info.task_times['phase3'].stop

    track_pos = run_pos.time_slice(track_start, track_stop)

    track_spikes = [spiketrain.time_slice(track_start, track_stop) for spiketrain in spikes]

    binsize = 3
    xedges = np.arange(track_pos.x.min(), track_pos.x.max() + binsize, binsize)
    yedges = np.arange(track_pos.y.min(), track_pos.y.max() + binsize, binsize)

    tuning_curves = vdm.tuning_curve_2d(track_pos, track_spikes, xedges, yedges, gaussian_sigma=0.25)
#     random.shuffle(tuning_curves)

    counts_binsize = 0.025
    time_edges = get_edges(run_pos, counts_binsize, lastbin=True)
    counts = vdm.get_counts(track_spikes, time_edges, gaussian_std=counts_binsize)

    decoding_tc = []
    for tuning_curve in tuning_curves:
        decoding_tc.append(np.ravel(tuning_curve))
    decoding_tc = np.array(decoding_tc)

    likelihood = vdm.bayesian_prob(counts, decoding_tc, counts_binsize)

    xcenters = (xedges[1:] + xedges[:-1]) / 2.
    ycenters = (yedges[1:] + yedges[:-1]) / 2.
    xy_centers = vdm.cartesian(xcenters, ycenters)

    time_centers = (time_edges[1:] + time_edges[:-1]) / 2.

    decoded_pos = vdm.decode_location(likelihood, xy_centers, time_centers)
    nan_idx = np.logical_and(np.isnan(decoded_pos.x), np.isnan(decoded_pos.y))
    decoded_pos = decoded_pos[~nan_idx]

    decoded = vdm.remove_teleports(decoded_pos, speed_thresh=10, min_length=3)
    
    actual_x = np.interp(decoded.time, track_pos.time, track_pos.x)
    actual_y = np.interp(decoded.time, track_pos.time, track_pos.y)
    
    actual_position = vdm.Position(np.hstack((actual_x[..., np.newaxis], actual_y[..., np.newaxis])),
                                   decoded.time)

    errors = actual_position.distance(decoded)
    print('Actual distance:', np.mean(errors))

    plt.plot(actual_position.x, actual_position.y, 'r.', ms=0.7)
    plt.plot(decoded.x, decoded.y, 'b.')
    plt.show()

In [None]:
from decode_functions import point_in_zones
from shapely.geometry import Point

In [None]:
def point_in_zones(position, zones):
    """Assigns points if contained in shortcut zones

    Parameters
    ----------
    position : vdmlab.Position
    zones : dict
        With u, ushort, unovel, shortcut, shortped, novel, novelped, pedestal as keys

    Returns
    -------
    sorted_zones : dict
        With u, shortcut, novel, other as keys, each a vdmlab.Position object

    """
    u_data = []
    u_times = []
    shortcut_data = []
    shortcut_times = []
    novel_data = []
    novel_times = []
    other_data = []
    other_times = []

    for x, y, time in zip(position.x, position.y, position.time):
        point = Point([x, y])
        if zones['u'].contains(point) or zones['ushort'].contains(point) or zones['unovel'].contains(point):
            u_data.append([x, y])
            u_times.append(time)
            continue
        elif zones['shortcut'].contains(point) or zones['shortped'].contains(point):
            shortcut_data.append([x, y])
            shortcut_times.append(time)
            continue
        elif zones['novel'].contains(point) or zones['novelped'].contains(point):
            novel_data.append([x, y])
            novel_times.append(time)
            continue
        else:
            other_data.append([x, y])
            other_times.append(time)

    sorted_zones = dict()
    sorted_zones['u'] = vdm.Position(u_data, u_times)
    sorted_zones['shortcut'] = vdm.Position(shortcut_data, shortcut_times)
    sorted_zones['novel'] = vdm.Position(novel_data, novel_times)
    sorted_zones['other'] = vdm.Position(other_data, other_times)

    return sorted_zones

In [None]:
zones = find_zones(info, expand_by=7)
actual_zones = point_in_zones(actual_position, zones)
decoded_zones = point_in_zones(decoded, zones)

In [None]:
len(actual_zones['other'].time), len(decoded_zones['other'].time)

In [None]:
print(len(decoded.time), len(actual_position.time))

In [None]:
plot_compare_decoded_track(actual_zones, decoded_zones, str(round(np.mean(errors), 2)), savefig=False)

In [None]:
plt.plot(actual_zones['u'].x, actual_zones['u'].y, 'b.')
plt.plot(actual_zones['shortcut'].x, actual_zones['shortcut'].y, 'g.')
plt.plot(actual_zones['novel'].x, actual_zones['novel'].y, 'r.')
plt.plot(actual_zones['other'].x, actual_zones['other'].y, 'c.')
plt.show()

In [None]:
plt.plot(decoded_zones['u'].x, decoded_zones['u'].y, '.', color='b')
plt.plot(decoded_zones['shortcut'].x, decoded_zones['shortcut'].y, '.', color='g')
plt.plot(decoded_zones['novel'].x, decoded_zones['novel'].y, '.', color='r')
plt.plot(decoded_zones['other'].x, decoded_zones['other'].y, '.', color='c')
plt.show()

In [None]:
actual = [11.84, 16.73, 14.19, 5.20, 7.10, 8.79, 8.23, 10.16, 9.22, 13.20, 21.13, 11.32, 20.54, 17.58, 17.63, 7.79, 7.66, 22.53, ]

In [None]:
import numpy as np
np.mean(actual)

In [None]:
std = 24
test_xy = actual_position.data + np.random.normal(0, std, actual_position.data.shape)
test_pos = vdm.Position(test_xy, actual_position.time)
test_errors = actual_position.distance(test_pos)
print('Test distance:', np.mean(test_errors))

In [None]:
difference = 32
test_xy = actual_position.data + difference
test_pos = vdm.Position(test_xy, actual_position.time)
test_errors = actual_position.distance(test_pos)
print('Test distance:', np.mean(test_errors))

In [None]:
plt.plot(test_pos.x, test_pos.y, 'm.', ms=1)
plt.plot(actual_position.x, actual_position.y, 'b.')
plt.show()

In [None]:
plt.boxplot(errors)
plt.show()

In [None]:
from scipy.interpolate import InterpolatedUnivariateSpline

In [None]:
x = np.linspace(-np.pi, np.pi, 10)
y = np.sin(x)

In [None]:
plt.plot(x, y)
plt.show()

In [None]:
spline = InterpolatedUnivariateSpline(x, y)
xs = np.linspace(-np.pi, np.pi, 100)
plt.plot(xs, spline(xs), 'g')
plt.show()

In [None]:
spline.get_residual()

## Smoothing parameter check r063d3

In [None]:
results = [45.848643426859908, 29.302109714304461, 29.865521819738152, 29.913643399713909, 
           29.31095963377842, 28.624072044945866, 29.030549355726453, 30.077038030159233, 
           35.687684936555392, 45.848643426859908, 29.302109714304461, 29.865521819738152, 
           29.913643399713909, 29.31095963377842, 28.624072044945866, 29.030549355726453, 
           30.077038030159233, 35.687684936555392, 45.837111423306133, 29.298067119359189, 
           29.865976461880123, 29.906030473851345, 29.315789159414265, 28.624741085514444, 
           29.038110154705642, 30.079648033878346, 35.671749554313877, 31.61798486482013, 
           28.246634578841597, 29.647769711117515, 29.759165760476954, 29.411293543594194, 
           29.032704072802389, 29.042793425889936, 30.065880405200712, 35.645934498839758, 
           35.53800513968023, 28.658373941308831, 29.9207542421798, 29.708641610463751, 
           29.15494984997331, 28.578793132568546, 27.936605145242357, 29.258432270602221, 
           36.575977104419799, 36.941733764866761, 28.565957749459606, 31.102959353810512, 
           31.652782412930737, 31.239089639161165, 30.370535330601069, 30.036002190935022, 
           30.47971725975156, 34.304025317596022, 42.40508424723307, 36.370085788326456, 
           35.665673178845203, 35.775161545441925, 35.072970980897509, 33.78477541516596, 
           32.971823463940467, 32.494082917186063, 36.476158353888927, 45.262874125229665, 
           39.137571198808182, 38.514080294292121, 38.145109980432188, 37.304017745509341, 
           36.004485787827981, 34.679346348606181, 34.496813560336243, 39.817394108676588, 
           46.179512100491614, 43.733904525349814, 43.343444851581296, 42.579898778998576, 
           41.358252936093763, 39.791099877994974, 37.98070494210117, 36.890839324932948, 
           40.813685068781943, 49.4741802086211, 48.591570003593411, 47.833202625053076, 
           46.995084807712587, 46.078177190708551, 44.893523028409724, 43.235332282122172, 
           41.63670146397692, 42.87855295235785]

In [None]:
inputs = [[None, None], [None, 0.025], [None, 0.05], [None, 0.1], [None, 0.2], [None, 0.5], 
          [None, 1.0], [None, 2.0], [None, 5.0], [0.1, None], [0.1, 0.025], [0.1, 0.05], 
          [0.1, 0.1], [0.1, 0.2], [0.1, 0.5], [0.1, 1.0], [0.1, 2.0], [0.1, 5.0], [0.25, None], 
          [0.25, 0.025], [0.25, 0.05], [0.25, 0.1], [0.25, 0.2], [0.25, 0.5], [0.25, 1.0], 
          [0.25, 2.0], [0.25, 5.0], [0.5, None], [0.5, 0.025], [0.5, 0.05], [0.5, 0.1], [0.5, 0.2], 
          [0.5, 0.5], [0.5, 1.0], [0.5, 2.0], [0.5, 5.0], [1.0, None], [1.0, 0.025], [1.0, 0.05], 
          [1.0, 0.1], [1.0, 0.2], [1.0, 0.5], [1.0, 1.0], [1.0, 2.0], [1.0, 5.0], [3.0, None], 
          [3.0, 0.025], [3.0, 0.05], [3.0, 0.1], [3.0, 0.2], [3.0, 0.5], [3.0, 1.0], [3.0, 2.0], 
          [3.0, 5.0], [5.0, None], [5.0, 0.025], [5.0, 0.05], [5.0, 0.1], [5.0, 0.2], [5.0, 0.5], 
          [5.0, 1.0], [5.0, 2.0], [5.0, 5.0], [7.5, None], [7.5, 0.025], [7.5, 0.05], [7.5, 0.1], 
          [7.5, 0.2], [7.5, 0.5], [7.5, 1.0], [7.5, 2.0], [7.5, 5.0], [10.0, None], [10.0, 0.025], 
          [10.0, 0.05], [10.0, 0.1], [10.0, 0.2], [10.0, 0.5], [10.0, 1.0], [10.0, 2.0], [10.0, 5.0], 
          [15.0, None], [15.0, 0.025], [15.0, 0.05], [15.0, 0.1], [15.0, 0.2], [15.0, 0.5], 
          [15.0, 1.0], [15.0, 2.0], [15.0, 5.0]]

In [None]:
results = np.array(results)
inputs = np.array(inputs)

In [None]:
np.where(results == min(results))

In [None]:
val = 46
inputs[val], results[val]

In [None]:
val = 37
inputs[val], results[val]

In [None]:
val = 28
inputs[val], results[val]

In [None]:
inputs[42]

In [None]:
inputs[10]

In [None]:
results[10]

## Other

In [None]:
pos = np.random.rand(20, 2)

In [None]:
pos

In [None]:
split_idx = np.array([4, 6, 10])

In [None]:
all_idx = [idx for idx in np.split(np.arange(pos.shape[0]), split_idx) if idx.size > 3]

In [None]:
pos[np.hstack(all_idx)]

In [None]:
time = np.linspace(0, np.pi*2, 201)
data = np.hstack((np.sin(time)))

In [None]:
plt.plot(time, data, '.')
plt.show()

In [None]:
position = vdm.Position(data, time)

In [None]:
speed = position.speed()

In [None]:
plt.plot(speed.time, speed.data)
plt.show()

In [None]:
run_idx = np.squeeze(speed.data) >= 0.5

In [None]:
run_idx

In [None]:
position = vdm.Position(data, time)
speed = position.speed()
run_idx = np.squeeze(speed.data) >= 0.7
run_position = position[run_idx]

len(run_position.x)

In [None]:
assert np.allclose(len(run_position.x), 100)

In [None]:
velocity = self[1:].distance(self[:-1])
velocity /= np.diff(self.time)
velocity = np.hstack(([0], velocity))

if t_smooth is not None:
    dt = np.median(np.diff(self.time))
    filter_length = np.ceil(t_smooth / dt)
    velocity = np.convolve(velocity, np.ones(int(filter_length))/filter_length, 'same')

return AnalogSignal(velocity, self.time)

In [None]:
t_smooth=0.5
velocity = np.diff(np.squeeze(position.data))
velocity /= np.diff(position.time)
velocity = np.hstack(([0], velocity))

dt = np.median(np.diff(position.time))
filter_length = np.ceil(t_smooth / dt)
velocity = np.convolve(velocity, np.ones(int(filter_length))/filter_length, 'same')

In [None]:
velocity

In [None]:
a = np.array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])
a

In [None]:
np.hstack((a[0:2], a[6:8]))

In [None]:
indices = []
for t_start, t_stop in zip(np.array([0, 6]), np.array([2, 8])):
    indices.append((a >= t_start) & (a <= t_stop))
indices = np.any(np.column_stack((indices)),axis=1)
a[indices]

In [None]:
res = np.any(np.column_stack((indices)),axis=1)
res

In [None]:
spikes = [vdm.SpikeTrain(np.array([0., 0., 1.])), vdm.SpikeTrain(np.array([3.6, 3.9]))]

In [None]:
one_line = LineString([[2, 0], [2, 2], [2, 4], [2, 6], [2, 8], [2, 10]])

one_start = Point([2, 2])
one_stop = Point([2, 8])

expand_by = 1

one_zone = vdm.expand_line(one_start, one_stop, one_line, expand_by)

In [None]:
this_idx = []
for pos_idx in range(len(position.time)):
    point = Point([position.x[pos_idx], position.y[pos_idx]])
    if one_zone.contains(point):
        this_idx.append(pos_idx)
        
this_pos = position[this_idx]
linear = this_pos.linearize(one_line, one_zone)

In [None]:
linear.x

In [None]:
plt.plot(position.x, position.y, 'g.', ms=10)
# plt.plot([2, 4, 2], [7, 5, 4], 'm.', ms=20)
plt.plot([2, 6], [4, 3], 'm.', ms=20)
plt.plot(one_zone.exterior.xy[0], one_zone.exterior.xy[1], 'b', lw=1)
plt.xlim(-1, 10)
plt.ylim(-1, 10)
plt.show()

In [None]:
def expand_line(start_pt, stop_pt, line, expand_by):
    line_expanded = line.buffer(expand_by)
    zone = start_pt.union(line_expanded).union(stop_pt)
    
    return zone

def find_zones(info, expand_by=6):
    u_line = LineString(info.u_trajectory)
    shortcut_line = LineString(info.shortcut_trajectory)
    novel_line = LineString(info.novel_trajectory)

    u_start = Point(info.u_trajectory[0])
    u_stop = Point(info.u_trajectory[-1])
    shortcut_start = Point(info.shortcut_trajectory[0])
    shortcut_stop = Point(info.shortcut_trajectory[-1])
    novel_start = Point(info.novel_trajectory[0])
    novel_stop = Point(info.novel_trajectory[-1])
    pedestal_center = Point(info.path_pts['pedestal'][0], info.path_pts['pedestal'][1])
    pedestal = pedestal_center.buffer(expand_by*2.2)

    zone = dict()
    zone['u'] = expand_line(u_start, u_stop, u_line, expand_by)
    zone['shortcut'] = expand_line(shortcut_start, shortcut_stop, shortcut_line, expand_by)
    zone['novel'] = expand_line(novel_start, novel_stop, novel_line, expand_by)
    zone['ushort'] = zone['u'].intersection(zone['shortcut'])
    zone['unovel'] = zone['u'].intersection(zone['novel'])
    zone['uped'] = zone['u'].intersection(pedestal)
    zone['shortped'] = zone['shortcut'].intersection(pedestal)
    zone['novelped'] = zone['novel'].intersection(pedestal)
    zone['pedestal'] = pedestal
    
    return zone

In [None]:
def trajectory_fields(tuning_curves, zone, xedges, yedges, field_thresh):
    
    xcenters = np.array((xedges[1:] + xedges[:-1]) / 2.)
    ycenters = np.array((yedges[1:] + yedges[:-1]) / 2.)
    
    tuning_points = []
    for i in itertools.product(ycenters, xcenters):
        tuning_points.append(i)
    tuning_points = np.array(tuning_points)

    this_neuron = 0
    fields_tc = dict(u=[], shortcut=[], novel=[], pedestal=[])
    fields_neuron = dict(u=[], shortcut=[], novel=[], pedestal=[])
    for neuron_tc in tuning_curves:
        this_neuron += 1
        field_idx = neuron_tc.flatten() > field_thresh
        field = tuning_points[field_idx]
        for pt in field:
            point = Point([pt[0], pt[1]])
            if zone['u'].contains(point) or zone['ushort'].contains(point) or zone['unovel'].contains(point):
                if this_neuron not in fields_neuron['u']:
                    fields_tc['u'].append(neuron_tc)
                    fields_neuron['u'].append(this_neuron)
            if zone['shortcut'].contains(point) or zone['shortped'].contains(point):
                if this_neuron not in fields_neuron['shortcut']:
                    fields_tc['shortcut'].append(neuron_tc)
                    fields_neuron['shortcut'].append(this_neuron)
            if zone['novel'].contains(point) or zone['novelped'].contains(point):
                if this_neuron not in fields_neuron['novel']:
                    fields_tc['novel'].append(neuron_tc)
                    fields_neuron['novel'].append(this_neuron)
            if zone['pedestal'].contains(point):
                if this_neuron not in fields_neuron['pedestal']:
                    fields_tc['pedestal'].append(neuron_tc)
                    fields_neuron['pedestal'].append(this_neuron)
                
    return fields_tc

In [None]:
import sys
# sys.path.append('C:\\Users\\Emily\\Code\\emi_shortcut\\info')
sys.path.append('E:\\code\\emi_shortcut\\info')
import info.R063d3_info as r063d3
info = r063d3

In [None]:
pickle_filepath = 'E:\\code\\emi_shortcut\\cache\\pickled'
# pickle_filepath = 'C:\\Users\\Emily\\Code\\emi_shortcut\\cache\\pickled'

In [None]:
position = get_pos(info.pos_mat, info.pxl_to_cm)
spikes = get_spikes(info.spike_mat)

In [None]:
binsize = 3
xedges = np.arange(position.x.min(), position.x.max()+binsize, binsize)
yedges = np.arange(position.y.min(), position.y.max()+binsize, binsize)

speed = position.speed(t_smooth=0.5)
run_idx = np.squeeze(speed.data) >= info.run_threshold
run_pos = position[run_idx]

t_start = info.task_times['phase3'].start
t_stop = info.task_times['phase3'].stop

sliced_pos = run_pos.time_slice(t_start, t_stop)

sliced_spikes = [spiketrain.time_slice(t_start, t_stop) for spiketrain in spikes]

tuning_curves = vdm.tuning_curve_2d(sliced_pos, sliced_spikes, xedges, yedges, gaussian_sigma=0.2)

In [None]:
zones = find_zones(info)

In [None]:
type(zones['u'])

In [None]:
fields_tc = trajectory_fields(tuning_curves, zones, xedges, yedges, field_thresh=5)

In [None]:
print(len(fields_tc['novel']))

In [None]:
tuning_curves[5]

In [None]:
plt.figure()
xx, yy = np.meshgrid(xedges, yedges)
for tuning_curve in fields_tc['novel']:
    pp = plt.pcolormesh(xx, yy, tuning_curve, cmap='YlGn')
    plt.colorbar(pp)
    plt.axis('off')
    plt.show()