In [None]:
%matplotlib inline
import matplotlib.pyplot as plt
import numpy as np
import os
import nept

from loading_data import get_data
from analyze_tuning_curves import get_only_tuning_curves
from utils_maze import get_bin_centers, get_zones, get_xy_idx, get_matched_trials, get_all_trials

In [None]:
thisdir = os.getcwd()
pickle_filepath = os.path.join(thisdir, "cache", "pickled")
output_filepath = os.path.join(thisdir, "plots", "trials", "tuning_curves")
if not os.path.exists(output_filepath):
    os.makedirs(output_filepath)

In [None]:
import info.r063d7 as r063d7
import info.r063d8 as r063d8
import info.r068d3 as r068d3
import info.r068d5 as r068d5
# infos = [r063d7]
from run import analysis_infos
infos = analysis_infos

In [None]:
savefig = False

In [None]:
for info in infos:
    print(info.session_id)
    events, position, spikes, _, _ = get_data(info)
        
    phase = info.task_times["phase3"]
    sliced_position = position.time_slice(phase.start, phase.stop)

    # trials = get_trials(events, phase)
    trials = get_matched_trials(info, sliced_position, subset=False)
#     trials = get_all_trials(info, sliced_position, subset=False)

    error_byactual_position = np.zeros((len(info.yedges), len(info.xedges)))
    n_byactual_position = np.ones((len(info.yedges), len(info.xedges)))

    session_n_active = []
    session_likelihoods = []
    session_decoded = []
    session_actual = []
    session_errors = []
    n_timebins = []

    for trial in trials:
        starts = [start for start in trials.starts if start != trial.start]
        stops = [stop for stop in trials.stops if stop != trial.stop]
        epoch_of_interest = nept.Epoch([starts, stops])

        tuning_curves = get_only_tuning_curves(info,
                                               position,
                                               spikes,
                                               epoch_of_interest)

    xx, yy = np.meshgrid(info.xedges, info.yedges)

    all_tuning_curves = np.zeros((tuning_curves.shape[1], tuning_curves.shape[2]))
    for i in range(tuning_curves.shape[0]):
        all_tuning_curves += tuning_curves[i]

    plt.plot(sliced_position.x, sliced_position.y, "b.", ms=1, alpha=0.2)
    pp = plt.pcolormesh(xx, yy, all_tuning_curves, vmax=200., cmap='pink_r')
    plt.colorbar(pp)
    plt.axis('off')
    if savefig:
        plt.savefig(os.path.join(output_filepath, info.session_id+"_trials_tuning-curves.png"))
        plt.close()
    else:
        plt.show()

    trial_positions = position.time_slice(trials.starts, trials.stops)
    occupancy = nept.get_occupancy(trial_positions, info.yedges, info.xedges)
    
    plt.figure()
    plt.plot(sliced_position.x, sliced_position.y, "b.", ms=1, alpha=0.2)
    pp = plt.pcolormesh(xx, yy, occupancy, vmax=20., cmap="Greys")
    colourbar = plt.colorbar(pp)
    plt.axis('off')
    if savefig:
        plt.savefig(os.path.join(output_filepath, info.session_id+"_trials_occupancy.png"))
        plt.close()
    else:
        plt.show()

In [None]:
starts = [print(start) for start in trials.starts if start == trial.start]

In [None]:
len(np.unique(trials.starts))

In [None]:
len(np.unique(trials.stops))