In [None]:
%matplotlib inline
import matplotlib.pyplot as plt
import numpy as np
import pickle
import os
import pandas as pd
import seaborn as sns
import vdmlab as vdm
import scipy.stats as stats
from collections import OrderedDict

from loading_data import get_data
from utils_plotting import plot_decoded_compare
from analyze_decode import get_decoded_proportions, get_zone_proportion

In [None]:
home = os.path.expanduser("~")
emi_shortcut = os.path.join(home, "code", "emi_shortcut")
pickle_filepath = os.path.join(emi_shortcut, "cache", "pickled")

In [None]:
# import info

# spike_sorted_infos = [
#     info.r063d2, info.r063d3, info.r063d4, info.r063d5, info.r063d6, info.r063d7,
#     info.r066d1, info.r066d2, info.r066d3, info.r066d4, info.r066d5, info.r066d6, info.r066d7,
#     info.r067d1, info.r067d2, info.r067d3, info.r067d4, info.r067d5, info.r067d6, info.r067d7,
#     info.r068d1, info.r068d2, info.r068d3, info.r068d4, info.r068d5, info.r068d6, info.r068d7]

# infos = spike_sorted_infos

In [None]:
import info.r063d2 as r063d2
import info.r063d3 as r063d3
infos = [r063d2, r063d3]

In [None]:
def get_decoded(info, experiment_times, pickle_filepath, f_combine):
    """Combines decoded outputs

    Parameters
    ----------
    info: module
    experiment_times: list of str
    pickle_filepath: str
    f_combine: function
        Either get_zone_proportion or get_errors

    Returns
    -------
    decode_together: OrderedDict
        With experiment_time as keys, each a dict
        with u, shortcut, novel, other as keys.

    """

    decode_together = OrderedDict()

    for experiment_time in experiment_times:
        filename = '_decode-' + experiment_time + '.pkl'
        decode_filename = info.session_id + filename
        pickled_decoded = os.path.join(pickle_filepath, decode_filename)

        if os.path.isfile(pickled_decoded):
            with open(pickled_decoded, 'rb') as fileobj:
                decoded = pickle.load(fileobj)
        else:
            raise ValueError("pickled decoded not found for " + info.session_id)

        decode_together[experiment_time] = f_combine(decoded, experiment_time)

    return decode_together


def get_errors(decoded, experiment_time):
    """Computes the error of decoded position compared to actual position

    Parameters
    ----------
    decoded: vdmlab.Position

    Returns: dict

    """
    decoded_error = dict()
    for key in decoded['actual'].keys():
        if experiment_time in ['phase1', 'phase2', 'phase3']:
            decoded_error[key] = decoded['zones'][key].distance(decoded['actual'][key])
        else:
            decoded_error[key] = 0

    return decoded_error

In [None]:
# experiment_times = ['prerecord', 'phase1', 'pauseA', 'phase2', 'pauseB', 'phase3', 'postrecord']
experiment_times = ['phase1', 'phase2', 'phase3']
# experiment_times = ['pauseA', 'pauseB']

decodes = []
errors = []
for info in infos:
    decodes.append(get_decoded(info, experiment_times, pickle_filepath, get_zone_proportion))
    errors.append(get_decoded(info, experiment_times, pickle_filepath, get_errors))

In [None]:
def combine_errors(errors):
    
    combine_errors = OrderedDict()

    for key in errors[0].keys():
        combine_errors[key] = dict(u=[], shortcut=[], novel=[], together=[])
        for error in errors:
            for trajectory in error[key].keys():
                combine_errors[key][trajectory].extend(error[key][trajectory])
                combine_errors[key]['together'].extend(error[key][trajectory])
                
    return combine_errors

In [None]:
combine_errors = combine_errors(errors)

In [None]:
np.mean(combine_errors['phase1']['together'])

In [None]:
experiment_time = 'phase1'
filename = '_decode-' + experiment_time + '.pkl'
decode_filename = info.session_id + filename
pickled_decoded = os.path.join(pickle_filepath, decode_filename)

if os.path.isfile(pickled_decoded):
    with open(pickled_decoded, 'rb') as fileobj:
        decoded = pickle.load(fileobj)
else:
    raise ValueError("pickled decoded not found for " + info.session_id)

In [None]:
decoded.keys()

In [None]:
def combine_decode(infos, filename, experiment_time, shuffle_id, tuning_curves=None):
    total_times = []
    combined_errors = []
    combined_lengths = dict(u=[], shortcut=[], novel=[], other=[], together=[])
    combined_decoded = dict(u=[], shortcut=[], novel=[], other=[], together=[])

    for i, info in enumerate(infos):
        decode_filename = info.session_id + filename
        pickled_decoded = os.path.join(pickle_filepath, decode_filename)

        if os.path.isfile(pickled_decoded):
            with open(pickled_decoded, 'rb') as fileobj:
                decoded = pickle.load(fileobj)
        else:
            if tuning_curves is None:
                raise ValueError("tuning curves required when generating decoded")
            decoded = analyze(info, tuning_curves[i], experiment_time=experiment_time, shuffle_id=shuffle_id)

        total_times.append(decoded['times'])

        combined_lengths['u'].append(LineString(info.u_trajectory).length)
        combined_lengths['shortcut'].append(LineString(info.shortcut_trajectory).length)
        combined_lengths['novel'].append(LineString(info.novel_trajectory).length)

        combined_decoded['u'].append(decoded['zones']['u'])
        combined_decoded['shortcut'].append(decoded['zones']['shortcut'])
        combined_decoded['novel'].append(decoded['zones']['novel'])
        combined_decoded['other'].append(decoded['zones']['other'])
        combined_decoded['together'].append(len(decoded['zones']['u'].time) +
                                            len(decoded['zones']['shortcut'].time) +
                                            len(decoded['zones']['novel'].time) +
                                            len(decoded['zones']['other'].time))

        keys = ['u', 'shortcut', 'novel']
        combined_errors = dict(u=[], shortcut=[], novel=[], together=[])
        for trajectory in keys:
            combined_errors[trajectory].extend(decoded['errors'][trajectory])
            combined_errors['together'].extend(decoded['errors'][trajectory])

    output = dict()
    output['combined_decoded'] = combined_decoded
    output['combined_errors'] = combined_errors
    output['total_times'] = total_times
    output['combined_lengths'] = combined_lengths

    return output

In [None]:
def plot_errors(infos, tuning_curves, by_trajectory, all_tracks_tc=False):
    experiment_time = 'phase3'
    print('getting decoded', experiment_time)
    decoded = combine_decode(infos, '_decode-tracks.pkl', experiment_time=experiment_time,
                             shuffle_id=False, tuning_curves=tuning_curves)

    print('getting decoded', experiment_time, 'shuffled')
    decoded_shuffle = combine_decode(infos, '_decode-tracks-shuffled.pkl', experiment_time=experiment_time,
                                     shuffle_id=True, tuning_curves=tuning_curves)

    if all_tracks_tc and by_trajectory:
        filename = 'combined-errors_decoded_all-tracks_by-trajectory.png'
    elif all_tracks_tc and not by_trajectory:
        filename = 'combined-errors_decoded_all-tracks.png'
    elif not all_tracks_tc and by_trajectory:
        filename = 'combined-errors_decoded_by-trajectory.png'
    else:
        filename = 'combined-errors_decoded.pdf'
    savepath = os.path.join(output_filepath, filename)
    plot_decoded_errors(decoded['combined_errors'], decoded_shuffle['combined_errors'], by_trajectory, fliersize=2,
                        savepath=savepath)

In [None]:
def plot_decoded_errors(decode_errors, shuffled_errors, experiment_time, by_trajectory=False, fliersize=1, savepath=None):
    """Plots boxplot distance between decoded and actual position for decoded and shuffled_id.

    Parameters
    ----------
    decode_errors: dict of lists
        With u, shortcut, novel, other, and together as keys.
    shuffled_errors: dict of lists
        With u, shortcut, novel, other, and together as keys.
    by_trajectory: boolean
    fliersize: int
    savepath : str or None
        Location and filename for the saved plot.

    """
    if by_trajectory:
        decoded_u = pd.DataFrame(dict(error=decode_errors[experiment_time]['u'], shuffled='Decoded_u'))
        decoded_shortcut = pd.DataFrame(dict(error=decode_errors[experiment_time]['shortcut'], shuffled='Decoded_shortcut'))
        decoded_novel = pd.DataFrame(dict(error=decode_errors[experiment_time]['novel'], shuffled='Decoded_novel'))
        
        shuffled_u = pd.DataFrame(dict(error=shuffled_errors[experiment_time]['u'], shuffled='ID-shuffle decoded_u'))
        shuffled_shortcut = pd.DataFrame(dict(error=shuffled_errors[experiment_time]['shortcut'], shuffled='ID-shuffle decoded_shortcut'))
        shuffled_novel = pd.DataFrame(dict(error=shuffled_errors[experiment_time]['novel'], shuffled='ID-shuffle decoded_novel'))

        data = pd.concat([shuffled_u, decoded_u, shuffled_shortcut, decoded_shortcut, shuffled_novel, decoded_novel])
        colours = 'colorblind'
    else:
        decoded_dict = dict(error=decode_errors[experiment_time]['together'], shuffled='Decoded')
        shuffled_dict = dict(error=shuffled_errors[experiment_time]['together'], shuffled='ID-shuffle decoded')
        decoded = pd.DataFrame(decoded_dict)
        shuffled = pd.DataFrame(shuffled_dict)
        data = pd.concat([shuffled, decoded])
        colours = ['#ffffff', '#bdbdbd']

        print('actual:', np.mean(decode_errors[experiment_time]['together']), 
              stats.sem(decode_errors[experiment_time]['together']))
        print('shuffle:', np.mean(shuffled_errors[experiment_time]['together']), 
              stats.sem(shuffled_errors[experiment_time]['together']))

    plt.figure(figsize=(3, 2))
    flierprops = dict(marker='o', markersize=fliersize, linestyle='none')
    # ax = sns.boxplot(x='shuffled', y='error', data=data, palette=colours, flierprops=flierprops)
    ax = sns.boxplot(x='shuffled', y='error', data=data, flierprops=flierprops)

    edge_colour = '#252525'
    for i, artist in enumerate(ax.artists):
        artist.set_edgecolor(edge_colour)
        artist.set_facecolor(colours[i])

        for j in range(i*6, i*6+6):
            line = ax.lines[j]
            line.set_color(edge_colour)
            line.set_mfc(edge_colour)
            line.set_mec(edge_colour)

    ax.set(xlabel=' ', ylabel="Error (cm)")

    plt.tight_layout()
    sns.despine()

    if savepath is not None:
        plt.savefig(savepath, transparent=True)
        plt.close()
    else:
        plt.show()

In [None]:
plot_decoded_errors(combine_errors, combine_errors, experiment_time='phase1')

In [None]:
plot_decoded_compare(decodes)

In [None]:
plot_decoded_compare(errors)

In [None]:
experiment_times = ['phase1', 'phase2', 'phase3']

decodes = []
for info in infos:
    decodes.append(get_decoded_proportions(info, experiment_times, pickle_filepath))
    
plot_compare_decoded(decodes)

In [None]:
experiment_times = ['prerecord', 'phase1', 'pauseA', 'phase2', 'pauseB', 'phase3', 'postrecord']

decodes = []
for info in infos:
    decodes.append(get_decoded_proportions(info, experiment_times, pickle_filepath))
    
plot_compare_decoded(decodes)

In [None]:
experiment_times = ['prerecord', 'postrecord']

decodes = []
for info in infos:
    decodes.append(get_decoded_proportions(info, experiment_times, pickle_filepath))
    
plot_compare_decoded(decodes)

In [None]:
import info.r066d1 as info

In [None]:
home = os.path.expanduser("~")
emi_shortcut = os.path.join(home, "code", "emi_shortcut")
pickle_filepath = os.path.join(emi_shortcut, "cache", "pickled")
# pickle_filepath = 'E:/code/emi_shortcut/cache/pickled'

In [None]:
use_all_tracks = False

In [None]:
from utils_maze import get_xyedges, speed_threshold

In [None]:
events, position, spikes, lfp, lfp_theta = get_data(info)
xedges, yedges = get_xyedges(position)

## Get tuning curves

In [None]:
run_pos = speed_threshold(position)

track_starts = [info.task_times['phase3'].start]
track_stops = [info.task_times['phase3'].stop]

position_tc = run_pos.time_slices(track_starts, track_stops)

track_spikes = [spiketrain.time_slices(track_starts, track_stops) for spiketrain in spikes]

tuning_curve = vdm.tuning_curve_2d(position_tc, track_spikes, xedges, yedges, gaussian_sigma=0.1)

In [None]:
tuning_curve.shape

In [None]:
# tuning curves with high firing rates
tc_sums = np.sum(np.sum(tuning_curve, axis=2), axis=1)
np.where(tc_sums > 3000) 

In [None]:
# tuning curves with low firing rates
low_thresh = 1
high_thresh = 3000
tc_sums = np.sum(np.sum(tuning_curve, axis=2), axis=1)
keep_neurons = (tc_sums > low_thresh) & (tc_sums < high_thresh)
tuning_curve = tuning_curve[keep_neurons]
tuning_curve.shape

In [None]:
xx, yy = np.meshgrid(xedges, yedges)

for ii in range(10):
    print(ii)
    pp = plt.pcolormesh(yy, xx, tuning_curve[ii], cmap='pink_r')
    plt.colorbar(pp)
    plt.axis('off')
    plt.show()

In [None]:
tuning_curve.shape

## Decoding for phase 1

In [None]:
shuffle_id = False
from analyze_decode import get_edges, point_in_zones
from utils_maze import find_zones

experiment_time = 'phase1'

In [None]:
track_times = ['phase1', 'phase2', 'phase3', 'tracks']
pedestal_times = ['pauseA', 'pauseB', 'prerecord', 'postrecord']

spikes = spikes[keep_neurons]

if experiment_time in track_times:
    run_pos = speed_threshold(position, speed_limit=0.4)
else:
    run_pos = position

track_starts = [info.task_times[experiment_time].start]
track_stops = [info.task_times[experiment_time].stop]

track_pos = run_pos.time_slices(track_starts, track_stops)

# if shuffle_id:
#     random.shuffle(tuning_curve)

In [None]:
len(spikes)

In [None]:
xx, yy = np.meshgrid(xedges, yedges)

histogram, xs, ys = np.histogram2d(position.x, position.y, bins=xx.shape)

pp = plt.pcolormesh(yy, xx, histogram, cmap='pink_r')
plt.colorbar(pp)
plt.axis('off')
plt.show()

In [None]:
xx, yy = np.meshgrid(xedges, yedges)

histogram, xs, ys = np.histogram2d(track_pos.x, track_pos.y, bins=xx.shape)

pp = plt.pcolormesh(yy, xx, histogram, cmap='pink_r')
plt.colorbar(pp)
plt.axis('off')
plt.show()

In [None]:
track_pos.n_samples

In [None]:
info.task_times[experiment_time].start, info.task_times[experiment_time].stop 

In [None]:
# if experiment_time == 'tracks':
#     decode_spikes = [spiketrain.time_slices(track_starts, track_stops) for spiketrain in spikes]
#     epochs_interest = vdm.Epoch(np.hstack([np.array(track_starts), np.array(track_stops)]))

# else:
decode_spikes = [spiketrain.time_slice(info.task_times[experiment_time].start,
                                       info.task_times[experiment_time].stop) for spiketrain in spikes]
#     sliced_lfp = lfp.time_slice(info.task_times[experiment_time].start, info.task_times[experiment_time].stop)
#     z_thresh = 3.0
#     power_thresh = 5.0
#     merge_thresh = 0.02
#     min_length = 0.01
#     swrs = vdm.detect_swr_hilbert(sliced_lfp, fs=info.fs, thresh=(140.0, 250.0), z_thresh=z_thresh,
#                                   power_thresh=power_thresh, merge_thresh=merge_thresh, min_length=min_length)

#     epochs_interest = vdm.find_multi_in_epochs(decode_spikes, swrs, min_involved=3)
#     if epochs_interest.n_epochs == 0:
#         epochs_interest = vdm.find_multi_in_epochs(decode_spikes, swrs, min_involved=1)

In [None]:
len(decode_spikes[0].time)

In [None]:
counts_binsize = 0.025
time_edges = get_edges(track_pos, counts_binsize, lastbin=True)
counts = vdm.get_counts(decode_spikes, time_edges, gaussian_std=0.005)

In [None]:
counts.shape

In [None]:
pp = plt.pcolormesh(counts, cmap='pink_r')
plt.colorbar(pp)
plt.axis('off')
plt.show()

In [None]:
tc_shape = tuning_curve.shape
print(counts.shape)
decoding_tc = tuning_curve.reshape(tc_shape[0], tc_shape[1] * tc_shape[2])

likelihood = vdm.bayesian_prob(counts, decoding_tc, counts_binsize)

xcenters = (xedges[1:] + xedges[:-1]) / 2.
ycenters = (yedges[1:] + yedges[:-1]) / 2.
xy_centers = vdm.cartesian(xcenters, ycenters)

time_centers = (time_edges[1:] + time_edges[:-1]) / 2.

decoded = vdm.decode_location(likelihood, xy_centers, time_centers)
print(decoded.x.shape)
nan_idx = np.logical_and(np.isnan(decoded.x), np.isnan(decoded.y))
decoded = decoded[~nan_idx]
print(decoded.x.shape)

In [None]:
xx, yy = np.meshgrid(xedges, yedges)

histogram, xs, ys = np.histogram2d(decoded.x, decoded.y, bins=xx.shape)

pp = plt.pcolormesh(yy, xx, histogram, cmap='pink_r')
plt.colorbar(pp)
plt.axis('off')
plt.show()

In [None]:
if not decoded.isempty:
    sequences = vdm.remove_teleports(decoded, speed_thresh=40, min_length=3)
    decoded_epochs = sequences.intersect(vdm.Epoch(info.task_times[experiment_time].start, 
                                                   info.task_times[experiment_time].stop))
    decoded = decoded[decoded_epochs]
else:
    raise ValueError("decoded cannot be empty.")

In [None]:
decoded.x.shape

In [None]:
zones = find_zones(info, expand_by=8)
decoded_zones = point_in_zones(decoded, zones)

keys = ['u', 'shortcut', 'novel']
errors = dict()
actual_position = dict()
if experiment_time in ['phase1', 'phase2', 'phase3', 'tracks']:
    for trajectory in keys:
        actual_x = np.interp(decoded_zones[trajectory].time, track_pos.time, track_pos.x)
        actual_y = np.interp(decoded_zones[trajectory].time, track_pos.time, track_pos.y)
        actual_position[trajectory] = vdm.Position(np.hstack((actual_x[..., np.newaxis],
                                                              actual_y[..., np.newaxis])),
                                                   decoded_zones[trajectory].time)
        errors[trajectory] = actual_position[trajectory].distance(decoded_zones[trajectory])
else:
    for trajectory in decoded_zones:
        errors[trajectory] = []
        actual_position[trajectory] = []

output = dict()
output['zones'] = decoded_zones
output['errors'] = errors
output['times'] = len(time_centers)
output['actual'] = actual_position
output['decoded'] = decoded

In [None]:
output['decoded'].n_samples / output['times']

In [None]:
plt.plot(output['zones']['u'].x, output['zones']['u'].y, 'b.', ms=2)
# plt.plot(output['zones']['other'].x, output['zones']['other'].y, 'r.', ms=4)
plt.show()

In [None]:
np.mean(output['errors']['u'])

In [None]:
pos = output['zones']['u']

xx, yy = np.meshgrid(xedges, yedges)

histogram, xs, ys = np.histogram2d(pos.x, pos.y, bins=xx.shape)

pp = plt.pcolormesh(yy, xx, histogram, cmap='pink_r')
plt.colorbar(pp)
plt.axis('off')
plt.show()

# from script

In [None]:
import info.r066d1 as info

In [None]:
tuning_curve_filename = info.session_id + '_tuning-curve.pkl'
pickled_tuning_curve = os.path.join(pickle_filepath, tuning_curve_filename)
with open(pickled_tuning_curve, 'rb') as fileobj:
    tuning_curve = pickle.load(fileobj)

In [None]:
tuning_curve.shape

In [None]:
from analyze_decode import analyze
experiment_time = 'postrecord'
dec = analyze(info, tuning_curve, experiment_time=experiment_time, shuffle_id=False)

In [None]:
dec['decoded'].n_samples, dec['times']

# Loading decode from pickle

In [None]:
import info.r066d1 as info

In [None]:
experiment_time = 'postrecord'
decode_filename = info.session_id + '_decode-' + experiment_time + '.pkl'
pickled_decode = os.path.join(pickle_filepath, decode_filename)
with open(pickled_decode, 'rb') as fileobj:
    decode = pickle.load(fileobj)

In [None]:
decode['decoded'].n_samples / decode['times']