In [None]:
%matplotlib inline
import matplotlib.pyplot as plt
import numpy as np
import scipy
import os
import nept

from loading_data import get_data
from analyze_tuning_curves import get_tuning_curves
from analyze_decode_bytrial import get_trials, decode_trial
from analyze_decode import get_decoded_zones

In [None]:
import info.r067d2 as info

In [None]:
thisdir = os.getcwd()
pickle_filepath = os.path.join(thisdir, "cache", "pickled")
output_filepath = os.path.join(thisdir, "plots", "test")

In [None]:
shuffled = False

events, position, spikes, lfp, lfp_theta = get_data(info)

position = position.time_slice(info.task_times['phase3'].start, info.task_times['phase3'].stop)
spikes = [spiketrain.time_slice(info.task_times['phase3'].start, info.task_times['phase3'].stop) for spiketrain in spikes]

xedges, yedges = nept.get_xyedges(position)

trial_epochs = get_trials(events, info.task_times['phase3'])

all_trials = []

# for trial_idx in range(trial_epochs.n_epochs):
for trial_idx in range(2):
    trial_start = trial_epochs.starts[trial_idx]
    trial_stop = trial_epochs.stops[trial_idx]

    trial_times = nept.Epoch([trial_start, trial_stop])
    neurons = get_tuning_curves(info, position, spikes, xedges, yedges, speed_limit=0.4,
                                phase_id="phase3", trial_times=trial_times, trial_number=trial_idx,
                                cache=False)

    decode = decode_trial(info, neurons, trial_times, trial_idx, shuffled)

    all_trials.append(decode)

error_byactual_position = np.zeros((len(yedges), len(xedges)))
n_byactual_position = np.ones((len(yedges), len(xedges)))

for decode in all_trials:
    for error, x, y in zip(decode['errors'], decode['position'].x, decode['position'].y):
        x_idx = nept.find_nearest_idx(xedges, x)
        y_idx = nept.find_nearest_idx(yedges, y)
        error_byactual_position[y_idx][x_idx] += error
        n_byactual_position[y_idx][x_idx] += 1

error_byactual = error_byactual_position / n_byactual_position

xx, yy = np.meshgrid(xedges, yedges)

print("error")
pp = plt.pcolormesh(xx, yy, error_byactual, vmin=0., cmap='bone_r')
plt.colorbar(pp)
plt.axis('off')
if not shuffled:
    filename = "decoding_wout_current_trial-" + info.session_id + "-error.png"
else:
    filename = "decoding_wout_current_trial-" + info.session_id + "-error-shuffled.png"
plt.savefig(os.path.join(output_filepath, filename))
plt.close()

print("position occupancy")
pp = plt.pcolormesh(xx, yy, n_byactual_position, vmin=0., vmax=500., cmap="pink_r")
plt.colorbar(pp)
plt.axis('off')
if not shuffled:
    filename = "decoding_wout_current_trial-" + info.session_id + "-occupancy.png"
else:
    filename = "decoding_wout_current_trial-" + info.session_id + "-occupancy-shuffled.png"
plt.savefig(os.path.join(output_filepath, filename))
plt.close()

In [None]:
len(all_trials)

In [None]:
trajectory_errors = dict(u=[], shortcut=[], novel=[])

for decode in all_trials:
    decoded_zones, zone_errors, actual_position = get_decoded_zones(info, decode["decoded"], decode["position"], "phase3")
    trajectory_errors["u"].extend(zone_errors["u"])
    trajectory_errors["shortcut"].extend(zone_errors["shortcut"])
    trajectory_errors["novel"].extend(zone_errors["novel"])

In [None]:
plt.plot(decoded_zones['shortcut'].x, decoded_zones['shortcut'].y, "g.")
plt.show()

In [None]:
output = dict()
output['zones'] = decoded_zones
output['errors'] = decode["errors"]
output['zone_errors'] = zone_errors
output['times'] = decode["decoded"].n_samples
output['actual'] = actual_position
output['decoded'] = decode["decoded"]
output['epochs'] = decode["decoded_epochs"]

In [None]:
fig, ax = plt.subplots()
ind = np.arange(3)
width = 0.9

means = [np.mean(trajectory_errors['u']),
         np.mean(trajectory_errors['shortcut']),
         np.mean(trajectory_errors['novel'])]
# yerr = [scipy.stats.sem(trajectory_errors['u']),
#         scipy.stats.sem(trajectory_errors['shortcut']),
#         scipy.stats.sem(trajectory_errors['novel'])]
yerr = [np.std(trajectory_errors['u']),
        np.std(trajectory_errors['shortcut']),
        np.std(trajectory_errors['novel'])]
plt.bar(ind, means, width=width, color=["#0072b2ff", "#009e73ff", "#d55e00ff"], yerr=yerr)
ax.set_xticks(ind + width / 100)
ax.set_xticklabels(('u', 'shortcut', 'dead-end'))
ax.set_ylabel('Average error (cm)')
plt.show()