In [None]:
%matplotlib inline
import matplotlib.pyplot as plt
import numpy as np
import pickle
import os
import pandas as pd
import seaborn as sns
import vdmlab as vdm
import scipy.stats as stats
from collections import OrderedDict
from shapely.geometry import Point, LineString

from loading_data import get_data
from utils_plotting import plot_decoded_compare
from plot_decode import get_decoded, get_zone_proportion, get_combined, compare_rates

In [None]:
home = os.path.expanduser("~")
emi_shortcut = os.path.join(home, "code", "emi_shortcut")
pickle_filepath = os.path.join(emi_shortcut, "cache", "pickled")

In [None]:
import info.r066d1 as info

In [None]:
home = os.path.expanduser("~")
emi_shortcut = os.path.join(home, "code", "emi_shortcut")
pickle_filepath = os.path.join(emi_shortcut, "cache", "pickled")
# pickle_filepath = 'E:/code/emi_shortcut/cache/pickled'

In [None]:
use_all_tracks = False

In [None]:
from utils_maze import get_xyedges, speed_threshold

In [None]:
events, position, spikes, lfp, lfp_theta = get_data(info)
xedges, yedges = get_xyedges(position)

## Get tuning curves

In [None]:
run_pos = speed_threshold(position)

track_starts = [info.task_times['phase3'].start]
track_stops = [info.task_times['phase3'].stop]

position_tc = run_pos.time_slices(track_starts, track_stops)

track_spikes = [spiketrain.time_slices(track_starts, track_stops) for spiketrain in spikes]

tuning_curve = vdm.tuning_curve_2d(position_tc, track_spikes, xedges, yedges, gaussian_sigma=0.1)

In [None]:
tuning_curve.shape

In [None]:
# tuning curves with high firing rates
tc_sums = np.sum(np.sum(tuning_curve, axis=2), axis=1)
np.where(tc_sums > 3000) 

In [None]:
# tuning curves with low firing rates
low_thresh = 1
high_thresh = 3000
tc_sums = np.sum(np.sum(tuning_curve, axis=2), axis=1)
keep_neurons = (tc_sums > low_thresh) & (tc_sums < high_thresh)
tuning_curve = tuning_curve[keep_neurons]
tuning_curve.shape

In [None]:
xx, yy = np.meshgrid(xedges, yedges)

for ii in range(10):
    print(ii)
    pp = plt.pcolormesh(yy, xx, tuning_curve[ii], cmap='pink_r')
    plt.colorbar(pp)
    plt.axis('off')
    plt.show()

In [None]:
tuning_curve.shape

## Decoding for phase 1

In [None]:
shuffle_id = False
from analyze_decode import get_edges, point_in_zones
from utils_maze import find_zones

experiment_time = 'phase1'

In [None]:
track_times = ['phase1', 'phase2', 'phase3', 'tracks']
pedestal_times = ['pauseA', 'pauseB', 'prerecord', 'postrecord']

spikes = spikes[keep_neurons]

if experiment_time in track_times:
    run_pos = speed_threshold(position, speed_limit=0.4)
else:
    run_pos = position

track_starts = [info.task_times[experiment_time].start]
track_stops = [info.task_times[experiment_time].stop]

track_pos = run_pos.time_slices(track_starts, track_stops)

# if shuffle_id:
#     random.shuffle(tuning_curve)

In [None]:
len(spikes)

In [None]:
xx, yy = np.meshgrid(xedges, yedges)

histogram, xs, ys = np.histogram2d(position.x, position.y, bins=xx.shape)

pp = plt.pcolormesh(yy, xx, histogram, cmap='pink_r')
plt.colorbar(pp)
plt.axis('off')
plt.show()

In [None]:
xx, yy = np.meshgrid(xedges, yedges)

histogram, xs, ys = np.histogram2d(track_pos.x, track_pos.y, bins=xx.shape)

pp = plt.pcolormesh(yy, xx, histogram, cmap='pink_r')
plt.colorbar(pp)
plt.axis('off')
plt.show()

In [None]:
track_pos.n_samples

In [None]:
info.task_times[experiment_time].start, info.task_times[experiment_time].stop 

In [None]:
# if experiment_time == 'tracks':
#     decode_spikes = [spiketrain.time_slices(track_starts, track_stops) for spiketrain in spikes]
#     epochs_interest = vdm.Epoch(np.hstack([np.array(track_starts), np.array(track_stops)]))

# else:
decode_spikes = [spiketrain.time_slice(info.task_times[experiment_time].start,
                                       info.task_times[experiment_time].stop) for spiketrain in spikes]
#     sliced_lfp = lfp.time_slice(info.task_times[experiment_time].start, info.task_times[experiment_time].stop)
#     z_thresh = 3.0
#     power_thresh = 5.0
#     merge_thresh = 0.02
#     min_length = 0.01
#     swrs = vdm.detect_swr_hilbert(sliced_lfp, fs=info.fs, thresh=(140.0, 250.0), z_thresh=z_thresh,
#                                   power_thresh=power_thresh, merge_thresh=merge_thresh, min_length=min_length)

#     epochs_interest = vdm.find_multi_in_epochs(decode_spikes, swrs, min_involved=3)
#     if epochs_interest.n_epochs == 0:
#         epochs_interest = vdm.find_multi_in_epochs(decode_spikes, swrs, min_involved=1)

In [None]:
len(decode_spikes[0].time)

In [None]:
counts_binsize = 0.025
time_edges = get_edges(track_pos, counts_binsize, lastbin=True)
counts = vdm.get_counts(decode_spikes, time_edges, gaussian_std=0.005)

In [None]:
counts.shape

In [None]:
pp = plt.pcolormesh(counts, cmap='pink_r')
plt.colorbar(pp)
plt.axis('off')
plt.show()

In [None]:
tc_shape = tuning_curve.shape
print(counts.shape)
decoding_tc = tuning_curve.reshape(tc_shape[0], tc_shape[1] * tc_shape[2])

likelihood = vdm.bayesian_prob(counts, decoding_tc, counts_binsize)

xcenters = (xedges[1:] + xedges[:-1]) / 2.
ycenters = (yedges[1:] + yedges[:-1]) / 2.
xy_centers = vdm.cartesian(xcenters, ycenters)

time_centers = (time_edges[1:] + time_edges[:-1]) / 2.

decoded = vdm.decode_location(likelihood, xy_centers, time_centers)
print(decoded.x.shape)
nan_idx = np.logical_and(np.isnan(decoded.x), np.isnan(decoded.y))
decoded = decoded[~nan_idx]
print(decoded.x.shape)

In [None]:
xx, yy = np.meshgrid(xedges, yedges)

histogram, xs, ys = np.histogram2d(decoded.x, decoded.y, bins=xx.shape)

pp = plt.pcolormesh(yy, xx, histogram, cmap='pink_r')
plt.colorbar(pp)
plt.axis('off')
plt.show()

In [None]:
if not decoded.isempty:
    sequences = vdm.remove_teleports(decoded, speed_thresh=40, min_length=3)
    decoded_epochs = sequences.intersect(vdm.Epoch(info.task_times[experiment_time].start, 
                                                   info.task_times[experiment_time].stop))
    decoded = decoded[decoded_epochs]
else:
    raise ValueError("decoded cannot be empty.")

In [None]:
decoded.x.shape

In [None]:
zones = find_zones(info, expand_by=8)
decoded_zones = point_in_zones(decoded, zones)

keys = ['u', 'shortcut', 'novel']
errors = dict()
actual_position = dict()
if experiment_time in ['phase1', 'phase2', 'phase3', 'tracks']:
    for trajectory in keys:
        actual_x = np.interp(decoded_zones[trajectory].time, track_pos.time, track_pos.x)
        actual_y = np.interp(decoded_zones[trajectory].time, track_pos.time, track_pos.y)
        actual_position[trajectory] = vdm.Position(np.hstack((actual_x[..., np.newaxis],
                                                              actual_y[..., np.newaxis])),
                                                   decoded_zones[trajectory].time)
        errors[trajectory] = actual_position[trajectory].distance(decoded_zones[trajectory])
else:
    for trajectory in decoded_zones:
        errors[trajectory] = []
        actual_position[trajectory] = []

output = dict()
output['zones'] = decoded_zones
output['errors'] = errors
output['times'] = len(time_centers)
output['actual'] = actual_position
output['decoded'] = decoded

In [None]:
output['decoded'].n_samples / output['times']

In [None]:
plt.plot(output['zones']['u'].x, output['zones']['u'].y, 'b.', ms=2)
# plt.plot(output['zones']['other'].x, output['zones']['other'].y, 'r.', ms=4)
plt.show()

In [None]:
np.mean(output['errors']['u'])

In [None]:
pos = output['zones']['u']

xx, yy = np.meshgrid(xedges, yedges)

histogram, xs, ys = np.histogram2d(pos.x, pos.y, bins=xx.shape)

pp = plt.pcolormesh(yy, xx, histogram, cmap='pink_r')
plt.colorbar(pp)
plt.axis('off')
plt.show()

# from script

In [None]:
import info.r066d1 as info

In [None]:
tuning_curve_filename = info.session_id + '_tuning-curve.pkl'
pickled_tuning_curve = os.path.join(pickle_filepath, tuning_curve_filename)
with open(pickled_tuning_curve, 'rb') as fileobj:
    tuning_curve = pickle.load(fileobj)

In [None]:
tuning_curve.shape

In [None]:
from analyze_decode import analyze
experiment_time = 'postrecord'
dec = analyze(info, tuning_curve, experiment_time=experiment_time, shuffle_id=False)

In [None]:
dec['decoded'].n_samples, dec['times']

# Loading decode from pickle

In [None]:
import info.r066d1 as info

In [None]:
experiment_time = 'postrecord'
decode_filename = info.session_id + '_decode-' + experiment_time + '.pkl'
pickled_decode = os.path.join(pickle_filepath, decode_filename)
with open(pickled_decode, 'rb') as fileobj:
    decode = pickle.load(fileobj)

In [None]:
decode['decoded'].n_samples / decode['times']