# Behavenet Power Law Decoding

How populations of neurons in the visual cortex encode information is an open question. Some research has demonstrated that they maintain a balance between a high-dimensional, uncorrelated representation (highly flexible, but susceptible to noise) and a low-dimensional, correlated representation (less flexible, but robust to noise) [Stringer, et al., 2019](https://www.ncbi.nlm.nih.gov/pmc/articles/PMC6642054/).

Additional research has found that natural images are encoded by a sparse amount of neurons. [Yoshida, Ohki, 2020](https://www.ncbi.nlm.nih.gov/pmc/articles/pmid/32054847/) demonstrated that a single natural image is linearly decodable from a surprisingly small number of highly responsive neurons (~20), and the remaining neurons even degrade decoding. 

Do these properties hold for behavioral signals encoded in the visual cortex?

In [1]:
import scipy
import pickle
import numpy as np
import pandas as pd
from sklearn.metrics import r2_score

import matplotlib
import matplotlib.pyplot as plt
import seaborn as sns

matplotlib.rcParams.update({'font.size': 18, 'figure.figsize': (16, 8)})

import behavenet
from behavenet import get_user_dir, make_dir_if_not_exists
from behavenet.data.utils import get_transforms_paths
from behavenet.data.utils import load_labels_like_latents
from behavenet.fitting.utils import get_expt_dir
from behavenet.fitting.utils import get_session_dir
from behavenet.fitting.utils import get_best_model_version
from behavenet.fitting.utils import get_lab_example

from behavenet.fitting.utils import get_best_model_and_data
from behavenet.fitting.eval import export_predictions
 

ImportError: numpy.core.multiarray failed to import

In [4]:
from behavenet.fitting.utils import get_subdirs
import os
          
def latent_results(expt_name=None, sess_id='', version='best'):
    # set model info
    sess_idx = 0
    hparams = {
        'data_dir': get_user_dir('data'),
        'save_dir': get_user_dir('save'),
        'model_class': 'neural-ae',
        'ae_model_type': 'conv',
        'ae_experiment_name': 'latent_search',
        'n_ae_latents': 9,
        'experiment_name':'grid_search',
        'model_type':'ff',
        'n_max_lags': 8,
        'rng_seed_data': 0,
        'trial_splits': '8;1;1;0'
    }

    hparams['neural_ae_experiment_name'] = hparams['experiment_name']
    hparams['neural_ae_model_type'] = hparams['model_type']
    hparams['neural_ae_version'] = version
    
    get_lab_example(hparams, 'dipoppa', sess_id)
    
    hparams['session_dir'], sess_ids = get_session_dir(hparams)
    expt_dir = get_expt_dir(hparams)

    ## Get discrete chance performance (accuracy of always predicting the most common training state)
    _, latents_file = get_transforms_paths('ae_latents', hparams, sess_ids[sess_idx])
    with open(latents_file, 'rb') as f:
        all_latents = pickle.load(f)
    mean_ae_latents = np.mean(np.concatenate([all_latents['latents'][i] for i in all_latents['trials']['train']]),axis=0)

    all_test_latents = np.concatenate([all_latents['latents'][i][hparams['n_max_lags']:-hparams['n_max_lags']] for i in all_latents['trials']['test']])
    chance_ae_performance = np.mean((all_test_latents-mean_ae_latents)**2)
    
    ## Get discrete prediction performance
    # If sub-sampling - make sure to get results labelled by sample method and index name 
    if expt_name is not None:
        hparams['subsample_method'] = 'single'
        hparams['subsample_idxs_name'] = expt_name
        
    
    _, latent_predictions_file = get_transforms_paths('neural_ae_predictions', hparams, sess_ids[sess_idx])

    if not os.path.exists(latent_predictions_file):
        model, data_generator = get_best_model_and_data(hparams, Decoder, load_data=True, version=version)
        predictions = export_predictions(data_generator, model)
        
    with open(latent_predictions_file, 'rb') as f:
        all_latent_predictions = pickle.load(f)
    all_test_latent_predictions = np.concatenate([all_latent_predictions['predictions'][i][hparams['n_max_lags']:-hparams['n_max_lags']] for i in all_latents['trials']['test']])
    decoding_ae_performance = np.mean((all_test_latents-all_test_latent_predictions)**2)
    r2 = r2_score(all_test_latents, all_test_latent_predictions)
    
    return chance_ae_performance, decoding_ae_performance, r2


In [None]:
samples = [20, 60, 80, 100, 200, 300, 400, 500, 600, 700, 800, 900, 1000, 1100, 1200, 1300]
sess_ids = ['MD0ST5_1', 'MD0ST5_2', 'MD0ST5_3', 'MD0ST5_4']
decoding_errors = {}

failures = []
for sess_id in sess_ids:
    decoding_errors[sess_id] = {}
    for sample_size in samples:
        decoding_errors[sess_id][sample_size] = []
        for i in range(10):
            expt_name = 'n{}_t{}'.format(sample_size, i)
            try:
                chance, decoding, r2 = latent_results(expt_name=expt_name, sess_id=sess_id)
                decoding_errors[sess_id][sample_size].append((chance, decoding, r2))
            except:
                print('Failure on session: ', sess_id, 'trial: ',expt_name)
                failures.append(expt_name)

In [5]:
def get_results(res, session, samples, signal):
    '''
    Params:
        res: dict of results (from above cell)
        session: Session ID in dict
        samples: array of # neurons used in decoding
        signal: 'r2' | 'mse' | 'chance' --> what value to plot
    '''
    
    idx = 0
    if signal == 'r2':
        idx = 2
    elif signal == 'mse':
        idx = 1
    elif signal == 'chance':
        idx = 0
    else:
        raise ValueError('The only signals available are r2, mse, and chance')
        
    ret = []
    vals = res[session]
    
    # Iterate through dict of results and gather relevant metric
    for sample in samples:
        sample_vals = []
        for trial in vals[sample]:
            sample_vals.append(trial[idx])
        ret.append(sample_vals)
        
    return ret

In [None]:
from scipy.stats import sem
fig, ax = plt.subplots()
ax.set_yscale('log', basey=10)
ax.set_xscale('log', basex=10)

plt.grid = True

st = 0 # where to start on the x axis (ie; to start plot from 100 neurons, set st = 3)
for i, sess in enumerate(sess_ids):
    res = get_results(decoding_errors, sess, samples, 'mse')
    chances = get_results(decoding_errors, sess, samples, 'chance')
    chance = chances[0][0]
        
    means = [np.mean(t) for t in res[st:]] / chance
    std_error = [sem(t) for t in res[st:]] / chance
    
    sns.lineplot(samples[st:], means, label='Session %d'%(i+1))
    plt.fill_between(samples[st:], means+std_error, means-std_error, alpha=0.3)
    

plt.ylabel('Mean Squared Error')
plt.xlabel('Number of Neurons')

plt.axhline(1, 0.05, 0.95, color='black', linewidth=4, label='Chance Performance')
plt.legend()
plt.title('[MD0ST5, Sessions: 1-4] Decoding 9 Non-Linear Latent Variables')

plt.show()