In [None]:
%matplotlib inline
import matplotlib.pyplot as plt
import numpy as np
import scipy
import os
import nept

from loading_data import get_data

In [None]:
thisdir = os.getcwd()
pickle_filepath = os.path.join(thisdir, "cache", "pickled")
output_filepath = os.path.join(thisdir, "plots", "shift_bins")
if not os.path.exists(output_filepath):
    os.makedirs(output_filepath)

In [None]:
import info.r063d2 as r063d2
import info.r063d3 as r063d3
# infos = [r063d2, r063d3]
from run import spike_sorted_infos
infos = spike_sorted_infos

In [None]:
def plot_occupancy(occupancy, xx, yy, pad, binsize, min_occupied, filepath=None):
    fig, ax = plt.subplots()
    pp = plt.pcolormesh(xx, yy, occupancy, vmax=10., cmap="Greys")
    proportion_occupied = min_occupied/occupancy.size
    ax.text(0.95, 0.05, "Occupied: %.2f" % proportion_occupied,
            horizontalalignment='right',
            verticalalignment='top',
            transform = plt.gcf().transFigure,
            fontsize=12)

    colourbar = plt.colorbar(pp)
    
    ax.spines['right'].set_visible(False)
    ax.spines['top'].set_visible(False)
    ax.yaxis.set_ticks_position('left')
    ax.xaxis.set_ticks_position('bottom')
    
    plt.tight_layout()
    if filepath is not None:
        plt.savefig(os.path.join(filepath, info.session_id+"-occupancies_"+str(binsize)+"cm_shifted-"+str(pad)+".png"))
        plt.close()
    else:
        plt.show()

In [None]:
def get_xyedges(position, binsize, pad):
    """Gets edges based on position min and max.

    Parameters
    ----------
    position: 2D nept.Position
    binsize: int

    Returns
    -------
    xedges: np.array
    yedges: np.array

    """
    if position.dimensions < 2:
        raise ValueError("position must be 2-dimensional")

    xedges = np.arange(position.x.min()-pad, position.x.max() + binsize, binsize)
    yedges = np.arange(position.y.min()-pad, position.y.max() + binsize, binsize)

    return xedges, yedges

In [None]:
for info in infos:
    events, position, spikes, _, _ = get_data(info)
    for binsize in [6, 8, 10, 12, 14, 16, 18, 20]:
        xedges, yedges = get_xyedges(position, binsize=binsize, pad=0)
        min_occupied = len(xedges) * len(yedges)
        for pad in range(binsize):
            xedges, yedges = get_xyedges(position, binsize=binsize, pad=pad)
            phase = "phase3"
            sliced_position = position.time_slice(info.task_times[phase].start, info.task_times[phase].stop)

            # Limit position to only running times
            run_epoch = nept.run_threshold(sliced_position, thresh=10., t_smooth=0.8)
            run_position = sliced_position[run_epoch]

            occupancy = nept.get_occupancy(run_position, yedges, xedges)

            n_occupied = np.zeros(occupancy.shape).astype(bool)
            n_occupied[occupancy>0] = True
            if np.sum(n_occupied) < min_occupied:
                min_occupied = np.sum(n_occupied)
                min_occupancy = occupancy
                min_pad = pad
                min_binsize = binsize
                min_xedges = xedges
                min_yedges = yedges
    #         xx, yy = np.meshgrid(xedges, yedges)
    #         plot_occupancy(occupancy, xx, yy, pad, binsize, filepath=None)

        xx, yy = np.meshgrid(min_xedges, min_yedges)
        plot_occupancy(min_occupancy, xx, yy, min_pad, min_binsize, min_occupied, filepath=output_filepath)