In [2]:
import numpy as np
import os
from oe_acute import trial_utils as tu
from oe_acute import MNE
import matplotlib.pyplot as plt
import matplotlib.cm as cm
import pickle
import tqdm
import glob

import sys
#from oe_acute import pyMNE
import pickle as pkl

In [None]:
pfinals = pkl.load(open("/mnt/cube/"))

In [None]:
def plot_MNE(pfinals, unit, bird, block, figure_output_path, n_eigvec_to_display=6, sdim=256, rf_shape=(16,16), color_map='jet'):

    ''' 
    This function produces the MNE summary plot given a list of MNE output parameters (pfinals)
    '''

    os.makedirs(figure_output_path, exist_ok=True)
    assert(sdim == (rf_shape[0]*rf_shape[1]))
    a_avg = 0
    j_avg = 0
    n_jackknives = len(pfinals)
    for jack, pfinal in enumerate(pfinals):
        # extract a matrix
        a = pfinal[1:sdim+1]
        a = np.reshape(a,rf_shape)
        a_avg += a
        # extract J matrix
        j = pfinal[-1*sdim**2:]
        j = np.reshape(j,(sdim,sdim))
        j_avg += j

    a_avg /= n_jackknives
    j_avg /= n_jackknives

    eigval, eigvec = np.linalg.eig(j_avg)
    # display recovered rf to generative rf
    eigval_ixd = [(eigval[i],i) for i in range(len(eigval))]
    eigval_ixd.sort()
    sorted_eigval, permt = zip(*eigval_ixd)


    topn_negative = [np.reshape(eigvec[:, permt[x]], rf_shape) for x in range(n_eigvec_to_display)]
    topn_positive = [np.reshape(eigvec[:, permt[x]], rf_shape) for x in range(-1, -(n_eigvec_to_display+1), -1)]


    fig = plt.figure(constrained_layout=True, figsize=(13.3, 10))
    gs = fig.add_gridspec(4, n_eigvec_to_display)

    neg_axs = []
    pos_axs = []
    for idx, v in enumerate(topn_negative):
        ax = fig.add_subplot(gs[0, idx])
        neg_axs.append(ax)
        ax.imshow(v, cmap=color_map, interpolation="gaussian", origin='lower', aspect='equal')
        ax.set_title('{:.3f}'.format(sorted_eigval[idx]), fontsize=14)
        ax.tick_params(labelbottom=False, labelleft=False, direction='in', bottom=False, left=False)


    for idx, v in enumerate(topn_positive):
        ax = fig.add_subplot(gs[1, idx])
        neg_axs.append(ax)
        ax.imshow(v, cmap=color_map, interpolation="gaussian", origin='lower', aspect='equal')
        ax.set_title('{:.3f}'.format(sorted_eigval[-(idx+1)]), fontsize=14)
        ax.tick_params(labelbottom=False, labelleft=False, direction='in', bottom=False, left=False)


    ax_eigs = fig.add_subplot(gs[2:, :3])
    ax_eigs.plot(sorted_eigval, 'k.')
    #ax_eigs.plot(sorted_eigval[:n_eigvec_to_display], 'rx')
    #ax_eigs.plot(range(sdim-n_eigvec_to_display, sdim+1), sorted_eigval[-(n_eigvec_to_display+1):], 'rx')
    ax_eigs.set_title('Sorted Eigenvalues of J Matrix', fontsize=18)
    ax_eigs.set_ylabel('Value', fontsize=16)
    ax_eigs.set_xlabel('Index', fontsize=16)
    ax_eigs.tick_params(labelsize=14)

    ax_a = fig.add_subplot(gs[2:, 3:])
    ax_a.imshow(a_avg, cmap=color_map, interpolation="gaussian", origin='lower', aspect='equal')
    ax_a.set_title('Linear Feature', fontsize=18)
    ax_a.tick_params(labelbottom=False, labelleft=False, direction='in', bottom=False, left=False)

    fig.suptitle("{} Block: {} Unit: {}".format(bird, block, unit), fontsize=20)

    plt.show()
    #fig_f = os.path.join(figure_output_path, 'MNEs_unit_{}.png'.format(unit))
    #plt.savefig(fig_f)
    #plt.close(fig)

In [None]:
MNE.plot_MNE(pfinal, unit, figure_output_path)