In [None]:
%matplotlib inline
import matplotlib.pyplot as plt
import numpy as np
import pickle
import os
import pandas as pd
from collections import OrderedDict
from shapely.geometry import Point, LineString

import vdmlab as vdm

from loading_data import get_data
from utils_maze import get_xyedges, find_zones, speed_threshold
from analyze_decode import get_edges, point_in_zones

In [None]:
output_filepath = 'E:/code/emi_shortcut/plots/intermediate'
pickle_filepath = 'E:/code/emi_shortcut/cache/pickled'

In [None]:
import info.r066d1 as info

In [None]:
neurons_filename = info.session_id + '_neurons.pkl'
pickled_neurons = os.path.join(pickle_filepath, neurons_filename)
with open(pickled_neurons, 'rb') as fileobj:
    neurons = pickle.load(fileobj)
    
experiment_time = 'phase3'
speed_limit = 0.4
shuffle_id = False
min_length = 3

In [None]:
print('decoding:', info.session_id)

track_times = ['phase1', 'phase2', 'phase3', 'tracks']
pedestal_times = ['pauseA', 'pauseB', 'prerecord', 'postrecord']

events, position, spikes, lfp, lfp_theta = get_data(info)
xedges, yedges = vdm.get_xyedges(position)

In [None]:
exp_start = info.task_times[experiment_time].start
exp_stop = info.task_times[experiment_time].stop

decode_spikes = neurons.time_slice(exp_start, exp_stop)

In [None]:
xx, yy = np.meshgrid(xedges, yedges)

for i, tc in enumerate(neurons.tuning_curves[123:124]):
    print(i)
    pp = plt.pcolormesh(xx, yy, tc, cmap='pink_r')
    plt.colorbar(pp)
    plt.axis('off')
    plt.show()

In [None]:
all_tuning_curves = np.zeros(neurons.tuning_shape)
for i in range(neurons.n_neurons):
    all_tuning_curves += neurons.tuning_curves[i]

pp = plt.pcolormesh(xx, yy, all_tuning_curves, cmap='pink_r')
plt.colorbar(pp)
plt.axis('off')
plt.show()

In [None]:
exp_start = info.task_times[experiment_time].start
exp_stop = info.task_times[experiment_time].stop

spikes = neurons.time_slice(exp_start, exp_stop)

In [None]:
if experiment_time in track_times:
    run_position = speed_threshold(position, speed_limit=speed_limit)
else:
    run_position = position
    
exp_position = run_position.time_slice(exp_start, exp_stop)

In [None]:
xx, yy = np.meshgrid(xedges, yedges)
histogram, xs, ys = np.histogram2d(exp_position.x, exp_position.y, 
                                   bins=xx.shape)

pp = plt.pcolormesh(yy, xx, histogram, cmap='pink_r')
plt.colorbar(pp)
plt.axis('off')
plt.show()

In [None]:
if shuffle_id:
    random.shuffle(tuning_curve)

In [None]:
if experiment_time in track_times:
    epochs_interest = vdm.Epoch(np.hstack([exp_start, exp_stop]))
elif experiment_time in pedestal_times:
    sliced_lfp = lfp.time_slice(exp_start, exp_stop)

    z_thresh = 3.0
    power_thresh = 5.0
    merge_thresh = 0.02
    min_length = 0.01
    swrs = vdm.detect_swr_hilbert(sliced_lfp, fs=info.fs, thresh=(140.0, 250.0), z_thresh=z_thresh,
                                  power_thresh=power_thresh, merge_thresh=merge_thresh, min_length=min_length)

    epochs_interest = vdm.find_multi_in_epochs(decode_spikes, swrs, min_involved=4)
    if epochs_interest.n_epochs == 0:
        epochs_interest = vdm.find_multi_in_epochs(decode_spikes, swrs, min_involved=1)
else:
    raise ValueError("unrecognized experimental phase. Must be in ['prerecord', 'phase1', 'pauseA', 'phase2', "
                     "'pauseB', phase3', 'postrecord'].")

In [None]:
counts_binsize = 0.025
time_edges = get_edges(exp_position, counts_binsize, lastbin=True)
counts = vdm.get_counts(decode_spikes, time_edges, gaussian_std=0.005)

In [None]:
pp = plt.pcolormesh(counts, cmap='pink_r')
plt.colorbar(pp)
plt.axis('off')
plt.show()

In [None]:
tc_shape = neurons.tuning_curves.shape
decoding_tc = neurons.tuning_curves.reshape(tc_shape[0], 
                                            tc_shape[1] * tc_shape[2])

likelihood = vdm.bayesian_prob(counts, decoding_tc, counts_binsize)

In [None]:
xcenters = (xedges[1:] + xedges[:-1]) / 2.
ycenters = (yedges[1:] + yedges[:-1]) / 2.
xy_centers = vdm.cartesian(xcenters, ycenters)

time_centers = (time_edges[1:] + time_edges[:-1]) / 2.

decoded = vdm.decode_location(likelihood, xy_centers, time_centers)
nan_idx = np.logical_and(np.isnan(decoded.x), np.isnan(decoded.y))
decoded = decoded[~nan_idx]

In [None]:
xx, yy = np.meshgrid(xedges, yedges)
histogram, xs, ys = np.histogram2d(decoded.x, decoded.y, bins=xx.shape)

pp = plt.pcolormesh(yy, xx, histogram, cmap='pink_r')
plt.colorbar(pp)
plt.axis('off')
plt.show()

In [None]:
if not decoded.isempty:
    sequences = vdm.remove_teleports(decoded, speed_thresh=40, min_length=min_length)
    decoded_epochs = sequences.intersect(epochs_interest)
    decoded = decoded[decoded_epochs]
else:
    raise ValueError("decoded cannot be empty.")

In [None]:
zones = find_zones(info, remove_feeder=True, expand_by=8)
decoded_zones = point_in_zones(decoded, zones)

In [None]:
keys = ['u', 'shortcut', 'novel']
errors = dict()
actual_position = dict()
if experiment_time in ['phase1', 'phase2', 'phase3', 'tracks']:
    for trajectory in keys:
        actual_x = np.interp(decoded_zones[trajectory].time, exp_position.time, exp_position.x)
        actual_y = np.interp(decoded_zones[trajectory].time, exp_position.time, exp_position.y)
        actual_position[trajectory] = vdm.Position(np.hstack((actual_x[..., np.newaxis],
                                                              actual_y[..., np.newaxis])),
                                                   decoded_zones[trajectory].time)

        if actual_position[trajectory].n_samples > 0:
            errors[trajectory] = actual_position[trajectory].distance(decoded_zones[trajectory])
else:
    for trajectory in decoded_zones:
        errors[trajectory] = []
        actual_position[trajectory] = []

output = dict()
output['zones'] = decoded_zones
output['errors'] = errors
output['times'] = len(time_centers)
output['actual'] = actual_position
output['decoded'] = decoded

In [None]:
xx, yy = np.meshgrid(xedges, yedges)
histogram, xs, ys = np.histogram2d(output['decoded'].x, output['decoded'].y, bins=xx.shape)

pp = plt.pcolormesh(yy, xx, histogram, cmap='pink_r')
plt.colorbar(pp)
plt.axis('off')
plt.show()

In [None]:
print('u error mean', np.mean(output['errors']['u']))
print('shortcut error mean', np.mean(output['errors']['shortcut']))
print('novel error mean', np.mean(output['errors']['novel']))

In [None]:
filename = '_decode-' + experiment_time + '.pkl'
decode_filename = info.session_id + filename
pickled_decoded = os.path.join(pickle_filepath, decode_filename)

with open(pickled_decoded, 'rb') as fileobj:
    decode = pickle.load(fileobj)

In [None]:
xx, yy = np.meshgrid(xedges, yedges)
histogram, xs, ys = np.histogram2d(decode['decoded'].x, decode['decoded'].y, bins=xx.shape)

pp = plt.pcolormesh(yy, xx, histogram, cmap='pink_r')
plt.colorbar(pp)
plt.axis('off')
plt.show()

In [None]:
from utils_fields import categorize_fields
from analyze_tuning_curves import get_odd_firing_idx
from plot_sequence_raster import plot_sequence

In [None]:
field_thresh = 1.0
fields_tunings = categorize_fields(neurons.tuning_curves, zones, xedges, yedges, field_thresh=field_thresh)

In [None]:
u_line = LineString(info.u_trajectory)

In [None]:
u_dist = []
for neuron in fields_tunings['u']:
    yy = ycenters[np.where(fields_tunings['u'][neuron] == fields_tunings['u'][neuron].max())[0][0]]
    xx = xcenters[np.where(fields_tunings['u'][neuron] == fields_tunings['u'][neuron].max())[1][0]]

    pt = Point(xx, yy)
    if zones['u'].contains(pt):
        u_dist.append((u_line.project(pt), neuron))

In [None]:
ordered_dist_u = sorted(u_dist, key=lambda x:x[0])
sort_idx = []
for neuron in ordered_dist_u:
    sort_idx.append(neuron[1])

In [None]:
sort_spikes = []
sort_tuning_curves = []
for neuron in sort_idx:
    sort_tuning_curves.append(fields_tunings['u'][neuron])
    sort_spikes.append(neurons.spikes[neuron])

In [None]:
odd_firing_idx = get_odd_firing_idx(sort_tuning_curves, max_mean_firing=2)

In [None]:
odd_firing_idx

In [None]:
ordered_spikes = []
ordered_fields =[]
for i, neuron in enumerate(sort_spikes):
    if i not in odd_firing_idx:
        ordered_spikes.append(neuron)
        ordered_fields.append(sort_tuning_curves[i])

In [None]:
plot_sequence(ordered_spikes, lfp, info.sequence['u']['run'].starts[0], info.sequence['u']['run'].stops[0])