In [1]:
%matplotlib inline
import matplotlib.pyplot as plt
import os
import numpy as np
import random
import pickle
from shapely.geometry import Point, LineString

import vdmlab as vdm

from load_data import get_pos, get_spikes, get_lfp
from analyze_maze import find_zones
from analyze_decode import get_edges, point_in_zones

In [2]:
pickle_filepath = 'E:/code/emi_shortcut/cache/pickled'
output_filepath = 'E:/code/emi_shortcut/plots/decode/'

In [3]:
import info.r063d3 as r063d3

In [4]:
info = r063d3

shuffle_id = False
experiment_time = 'tracks'

tuning_curve_filename = info.session_id + '_tuning-curve.pkl'
pickled_tuning_curve = os.path.join(pickle_filepath, tuning_curve_filename)
with open(pickled_tuning_curve, 'rb') as fileobj:
    tuning_curve = pickle.load(fileobj)
    


In [15]:


print('decoding:', info.session_id)
position = get_pos(info.pos_mat, info.pxl_to_cm)
spikes = get_spikes(info.spike_mat)
lfp = get_lfp(info.good_swr[0])

speed = position.speed(t_smooth=0.5)
run_idx = np.squeeze(speed.data) >= 0.1
run_pos = position[run_idx]

# track_starts = [info.task_times['phase1'].start,
#                 info.task_times['phase2'].start,
#                 info.task_times['phase3'].start]
# track_stops = [info.task_times['phase1'].stop,
#                info.task_times['phase2'].stop,
#                info.task_times['phase3'].stop]

track_starts = [info.task_times['phase3'].start]
track_stops = [info.task_times['phase3'].stop]

track_pos = run_pos.time_slices(track_starts, track_stops)

if shuffle_id:
    random.shuffle(tuning_curve)

if experiment_time == 'tracks':
    decode_spikes = [spiketrain.time_slices(track_starts, track_stops) for spiketrain in spikes]
    epochs_interest = vdm.Epoch(np.hstack([np.array(track_starts), np.array(track_stops)]))
    
else:
    decode_spikes = [spiketrain.time_slice(info.task_times[experiment_time].start,
                                           info.task_times[experiment_time].stop) for spiketrain in spikes]
    sliced_lfp = lfp.time_slice(info.task_times[experiment_time].start, info.task_times[experiment_time].stop)
    z_thresh = 3.0
    power_thresh = 5.0
    merge_thresh = 0.02
    min_length = 0.01
    swrs = vdm.detect_swr_hilbert(sliced_lfp, fs=info.fs, thresh=(140.0, 250.0), z_thresh=z_thresh,
                                  power_thresh=power_thresh, merge_thresh=merge_thresh, min_length=min_length)

    epochs_interest = vdm.find_multi_in_epochs(decode_spikes, swrs, min_involved=3)

counts_binsize = 0.025
time_edges = get_edges(run_pos, counts_binsize, lastbin=True)
counts = vdm.get_counts(decode_spikes, time_edges, gaussian_std=0.025)

decoding_tc = []
for tuning in tuning_curve:
    decoding_tc.append(np.ravel(tuning))
decoding_tc = np.array(decoding_tc)

likelihood = vdm.bayesian_prob(counts, decoding_tc, counts_binsize)

binsize = 3
xedges = np.arange(track_pos.x.min(), track_pos.x.max() + binsize, binsize)
yedges = np.arange(track_pos.y.min(), track_pos.y.max() + binsize, binsize)

xcenters = (xedges[1:] + xedges[:-1]) / 2.
ycenters = (yedges[1:] + yedges[:-1]) / 2.
xy_centers = vdm.cartesian(xcenters, ycenters)

time_centers = (time_edges[1:] + time_edges[:-1]) / 2.

decoded = vdm.decode_location(likelihood, xy_centers, time_centers)
nan_idx = np.logical_and(np.isnan(decoded.x), np.isnan(decoded.y))
decoded = decoded[~nan_idx]


decoding: R063d3


In [16]:
if not decoded.isempty:
    sequences = vdm.remove_teleports(decoded, speed_thresh=10, min_length=3)

decoded_epochs = sequences.contains(epochs_interest)

sequence_decoded = vdm.epoch_position(decoded, decoded_epochs)

In [23]:
sequences.n_epochs

6503

In [24]:
len(sequence_decoded.time)

54630

In [19]:
decoded_epochs.n_epochs

6503

In [None]:
zones = find_zones(info, expand_by=7)
decoded_zones = point_in_zones(decoded, zones)

keys = ['u', 'shortcut', 'novel']
errors = dict()
actual_position = dict()
if experiment_time == 'tracks':
    for trajectory in keys:
        actual_x = np.interp(decoded_zones[trajectory].time, track_pos.time, track_pos.x)
        actual_y = np.interp(decoded_zones[trajectory].time, track_pos.time, track_pos.y)
        actual_position[trajectory] = vdm.Position(np.hstack((actual_x[..., np.newaxis],
                                                              actual_y[..., np.newaxis])),
                                                   decoded_zones[trajectory].time)
        errors[trajectory] = actual_position[trajectory].distance(decoded_zones[trajectory])
else:
    for trajectory in decoded_zones:
        errors[trajectory] = []
        actual_position[trajectory] = []

output = dict()
output['zones'] = decoded_zones
output['errors'] = errors
output['times'] = len(time_centers)
output['actual'] = actual_position
output['decoded'] = decoded

In [None]:
np.mean(output['errors']['u'])

In [None]:
plt.plot(decoded.time, decoded.y)
plt.plot(actual_position['u'].time, actual_position['u'].y, lw=4)
plt.xlim(6530, 6730)
plt.show