In [None]:
%matplotlib inline
import matplotlib.pyplot as plt
import numpy as np
import itertools
import scipy
import os
import nept

from loading_data import get_data
from analyze_tuning_curves import get_tuning_curves

In [None]:
thisdir = os.getcwd()
pickle_filepath = os.path.join(thisdir, "cache", "pickled")
output_filepath = os.path.join(thisdir, "plots", "check_tc")

In [None]:
from run import spike_sorted_infos
infos = spike_sorted_infos

In [None]:
for info in infos:
    events, position, spikes, lfp, lfp_theta = get_data(info)
    xedges, yedges = nept.get_xyedges(position, binsize=3)

    phase = "phase3"
    sliced_position = position.time_slice(info.task_times[phase].start, info.task_times[phase].stop)
    sliced_spikes = [spiketrain.time_slice(info.task_times[phase].start, info.task_times[phase].stop) for spiketrain in spikes]

    # Limit position and spikes to only running times
    run_epoch = nept.run_threshold(sliced_position, thresh=0.167, t_smooth=0.5)
    run_position = sliced_position[run_epoch]
    track_spikes = np.asarray([spiketrain.time_slice(run_epoch.starts, run_epoch.stops) for spiketrain in sliced_spikes])

    # Remove neurons with too few or too many spikes
    len_epochs = np.sum(run_epoch.durations)
    min_n_spikes = 0.4 * len_epochs
    max_n_spikes = 5 * len_epochs

    keep_idx = np.zeros(len(track_spikes), dtype=bool)
    for i, spiketrain in enumerate(track_spikes):
        if len(spiketrain.time) >= min_n_spikes and len(spiketrain.time) <= max_n_spikes:
            keep_idx[i] = True
    tuning_spikes = track_spikes[keep_idx]

    tuning_curves = nept.tuning_curve_2d(run_position, tuning_spikes, xedges, yedges, occupied_thresh=0.5, gaussian_std=0.3)
    tuning_curves[np.isnan(tuning_curves)] = 0.

    # Plot individual tuning curves
    xx, yy = np.meshgrid(xedges, yedges)

    for i, tuning_curve in enumerate(tuning_curves):
        plt.figure()
        pp = plt.pcolormesh(xx, yy, tuning_curve, cmap="Greys")
        plt.colorbar(pp)
        plt.tight_layout()
        plt.savefig(os.path.join(output_filepath, info.session_id+"-tuning_curve"+str(i)+".png"))
        plt.close()
    #     plt.show()

    # Plot all tuning curves in the session
    multiple_tuning_curves = np.zeros((tuning_curves.shape[1], tuning_curves.shape[2]))
    for tuning_curve in tuning_curves:
        multiple_tuning_curves += tuning_curve

    multiple_tuning_curves = multiple_tuning_curves / np.sum(multiple_tuning_curves)

    plt.figure(figsize=(6, 5))
    pp = plt.pcolormesh(xx, yy, multiple_tuning_curves, cmap="Greys")
    plt.colorbar(pp)
    plt.title(info.session_id + " tuning curves (normalized)")
    plt.tight_layout()
    plt.savefig(os.path.join(output_filepath, info.session_id+"-all-tuning_curves.png"))
    plt.close()
    # plt.show()