In [None]:
%matplotlib inline
import matplotlib.pyplot as plt
import matplotlib
import numpy as np
import itertools
import scipy
import os
import nept

from loading_data import get_data
from analyze_tuning_curves import get_tuning_curves

In [None]:
thisdir = os.getcwd()
pickle_filepath = os.path.join(thisdir, "cache", "pickled")
output_filepath = os.path.join(thisdir, "plots", "check_tc")

In [None]:
from run import spike_sorted_infos
import info.r063d6 as info
# infos = [info]
infos = spike_sorted_infos

In [None]:
def plot_individual_tuning_curves(tuning_curves, n_colours, cmap, xx, yy, filepath=None):
    # Plot individual tuning curves
    for i, tuning_curve in enumerate(tuning_curves):    
        tuning_curve = np.array(tuning_curve)
        tuning_curve[np.isnan(tuning_curve)] = -np.nanmax(tuning_curve) / n_colours

        plt.figure()
        pp = plt.pcolormesh(xx, yy, tuning_curve, cmap=cmap)

        colourbar = plt.colorbar(pp)
        plt.tight_layout()
        if filepath is not None:
            plt.savefig(os.path.join(output_filepath, info.session_id+"-tuning_curve"+str(i)+".png"))
            plt.close()
        else:
            plt.show()

In [None]:
def plot_combined_tuning_curves(tuning_curves, n_colours, cmap, xx, yy, pad, n_below_thresh, 
                                filepath=None):
    # Plot all tuning curves in the session
    multiple_tuning_curves = np.zeros((tuning_curves.shape[1], tuning_curves.shape[2]))
    for tuning_curve in tuning_curves:
        multiple_tuning_curves += tuning_curve

    multiple_tuning_curves = multiple_tuning_curves / np.nansum(multiple_tuning_curves)
    multiple_tuning_curves = np.array(multiple_tuning_curves)
    multiple_tuning_curves[np.isnan(multiple_tuning_curves)] = -np.nanmax(multiple_tuning_curves) / n_colours

    plt.figure()
    pp = plt.pcolormesh(xx, yy, multiple_tuning_curves, cmap=cmap)
    plt.text(-1., -30., "N occupied bins below thresh: "+str(n_below_thresh), fontsize=14)
    plt.colorbar(pp)
    plt.title(info.session_id + " tuning curves (normalized)")
    plt.tight_layout()
    if filepath is not None:
        filename = info.session_id+"-all-tuning_curves-min_shifted-"+str(pad)+".png"
        plt.savefig(os.path.join(filepath, filename))
        plt.close()
    else:
        plt.show()

In [None]:
def plot_occupancy(occupancy, xx, yy, pad, filepath=None):
    plt.figure()
    pp = plt.pcolormesh(xx, yy, occupancy, vmax=10., cmap="Greys")

    colourbar = plt.colorbar(pp)
    plt.tight_layout()
    if filepath is not None:
        plt.savefig(os.path.join(filepath, info.session_id+"-occupancies_shifted-"+str(pad)+".png"))
        plt.close()
    else:
        plt.show()

In [None]:
def get_xyedges(position, binsize, pad):
    """Gets edges based on position min and max.

    Parameters
    ----------
    position: 2D nept.Position
    binsize: int

    Returns
    -------
    xedges: np.array
    yedges: np.array

    """
    if position.dimensions < 2:
        raise ValueError("position must be 2-dimensional")

    xedges = np.arange(position.x.min()-pad, position.x.max() + binsize, binsize)
    yedges = np.arange(position.y.min()-pad, position.y.max() + binsize, binsize)

    return xedges, yedges

In [None]:
overall_n_below_threshs = []
for binsize in [8, 10, 12, 14, 16]:
    print("Binsize:", binsize)
    binsize_n_below_threshs = []
    for info in infos:
        filepath = os.path.join(output_filepath, "binsize"+str(binsize)+"cm")
        if not os.path.exists(filepath):
            os.makedirs(filepath)

        _, position, spikes, _, _ = get_data(info)
        for pad in range(binsize):
            xedges, yedges = get_xyedges(position, binsize=binsize, pad=pad)

            phase = "phase3"
            sliced_position = position.time_slice(info.task_times[phase].start, info.task_times[phase].stop)
            sliced_spikes = [spiketrain.time_slice(info.task_times[phase].start, info.task_times[phase].stop) for spiketrain in spikes]

            # Limit position and spikes to only running times
            run_epoch = nept.run_threshold(sliced_position, thresh=10., t_smooth=0.8)
            run_position = sliced_position[run_epoch]
            track_spikes = np.asarray([spiketrain.time_slice(run_epoch.starts, run_epoch.stops) for spiketrain in sliced_spikes])

            # Remove neurons with too few or too many spikes
            len_epochs = np.sum(run_epoch.durations)
            min_n_spikes = 0.4 * len_epochs
            max_n_spikes = 5 * len_epochs

            keep_idx = np.zeros(len(track_spikes), dtype=bool)
            for i, spiketrain in enumerate(track_spikes):
                if len(spiketrain.time) >= min_n_spikes and len(spiketrain.time) <= max_n_spikes:
                    keep_idx[i] = True
            tuning_spikes = track_spikes[keep_idx]

            occupancy = nept.get_occupancy(run_position, yedges, xedges)
            occupied_thresh = 1.
            n_below_thresh = np.sum(occupancy > 0) - np.sum(occupancy > occupied_thresh)
            print("N occupied bins below thresh:", n_below_thresh)
            binsize_n_below_threshs.append(n_below_thresh)
            tuning_curves = nept.tuning_curve_2d(run_position, tuning_spikes, xedges, yedges, 
                                                 occupied_thresh=occupied_thresh, gaussian_std=0.3)

            xx, yy = np.meshgrid(xedges, yedges)

            n_colours = 15.
            colours = [(1., 1., 1.)]
            colours.extend(matplotlib.cm.copper_r(np.linspace(0, 1, n_colours-1)))
            cmap = matplotlib.colors.ListedColormap(colours)

            plot_occupancy(occupancy, xx, yy, pad, filepath)

            plot_combined_tuning_curves(tuning_curves, n_colours, cmap, xx, yy, 
                                        pad, n_below_thresh=n_below_thresh, 
                                        filepath=filepath)

        #     plot_individual_tuning_curves(tuning_curves, n_colours, cmap, xx, yy, filepath)
        overall_n_below_threshs.append(binsize_n_below_threshs)

In [None]:
plt.plot([8, 10, 12, 14, 16], 
         np.mean(overall_n_below_threshs, axis=1))
plt.title("N below occupancy thresh")
plt.show()

In [None]:
position = run_position
spikes = tuning_spikes 
gaussian_std=0.3
occupied_thresh = 0.

position_2d = nept.get_occupancy(position, yedges, xedges)
shape = position_2d.shape
occupied_idx = position_2d > occupied_thresh

tuning_curves = np.full(((len(spikes),) + shape), np.nan)
for i, spiketrain in enumerate(spikes):
    f_xy = scipy.interpolate.interp1d(position.time, position.data.T, kind="nearest")
    spikes_xy = f_xy(spiketrain.time)

    spikes_2d, spikes_xedges, spikes_yedges = np.histogram2d(spikes_xy[1], spikes_xy[0], bins=[yedges, xedges])
    tuning_curves[i, occupied_idx] = spikes_2d[occupied_idx] / position_2d[occupied_idx]

if gaussian_std is not None:
    xbinsize = xedges[1] - xedges[0]
    ybinsize = yedges[1] - yedges[0]
    tuning_curves = nept.gaussian_filter(tuning_curves, gaussian_std, dt=xbinsize, axis=1)
    tuning_curves = nept.gaussian_filter(tuning_curves, gaussian_std, dt=ybinsize, axis=2)