## Fit and analyze decoder models
The next step of the BehaveNet pipeline is using simultaneously recorded neural activity to predict behavior. Specifically, we can predict either of our compressed descriptions of behavior: the convolutional autoencoder latents or the ARHMM states.

We use linear models or feedforward deep networks to predict the state or latents for a given frame given a window of neural activity. We then compare our predictions to baseline (aka chance) performance. We can also use the convolutional autoencoder to convert the predicted latents into a full predicted behavioral video and compare to the original behavior.


<br>

### Contents
* [Decoding discrete states](#Decoding-discrete-states)
* [Decoding continuous latents](#Decoding-continuous-latents)
* [Assess decoding performance](#Assess-decoding-performance)
* [Plot true vs predicted latents](#Plot-true-vs-predicted-latents)
* [Make real vs predicted movies](#Make-real-vs-predicted-movies)


In [None]:
import pickle
import scipy.io as sio
import pandas as pd
import seaborn as sns
import numpy as np
import matplotlib.pyplot as plt
import matplotlib

matplotlib.rcParams.update({'font.size': 22})
from behavenet import get_user_dir, make_dir_if_not_exists
from behavenet.data.utils import get_transforms_paths
from behavenet.data.utils import load_labels_like_latents
from behavenet.fitting.utils import get_expt_dir
from behavenet.fitting.utils import get_session_dir
from behavenet.fitting.utils import get_best_model_version
from behavenet.fitting.utils import get_lab_example
from behavenet.plotting.arhmm_utils import *

save_outputs = True  # true to save figures/movies to user's figure directory
format = 'png'  # figure format ('png' | 'jpeg' | 'pdf'); movies saved as mp4

## Decoding discrete (ARHMM) states

First copy the example json files ``decoding_ae_model.json``, ``decoding_arhmm_model.json``, ``decoding_training.json`` and ``decoding_compute.json`` into your ``.behavenet`` directory, ``cd`` to the ``behavenet`` directory in the terminal, and run:

```console
$: python behavenet/fitting/decoder_grid_search.py --data_config ~/.behavenet/musall_vistrained_params.json --model_config ~/.behavenet/decoding_arhmm_model.json --training_config ~/.behavenet/decoding_training.json --compute_config ~/.behavenet/decoding_compute.json
```

[Back to contents](#Contents)

## Decoding continuous states (AE latents)

```console
$: python behavenet/fitting/decoder_grid_search.py --data_config ~/.behavenet/musall_vistrained_params.json --model_config ~/.behavenet/decoding_ae_model.json --training_config ~/.behavenet/decoding_training.json --compute_config ~/.behavenet/decoding_compute.json
```

[Back to contents](#Contents)

## Assess decoding performance
We want to examine how our predictions of both discrete states and continuous states compare to a baseline chance performance.

[Back to contents](#Contents)

In [None]:
from behavenet.fitting.utils import get_subdirs

# set model info
hparams = {
    'data_dir': get_user_dir('data'),
    'save_dir': get_user_dir('save'),
    'model_class': 'neural-arhmm',
    'experiment_name': 'grid_search',
    'model_type': 'mlp',
    'noise_type': 'gaussian',
    'n_arhmm_lags': 1,
    'transitions': 'stationary',
    'arhmm_experiment_name': 'state_number_search',
    'ae_experiment_name': 'ae-example',
    'n_ae_latents': 9,
    'neural_arhmm_experiment_name':'grid_search',
    'n_arhmm_states': 4,
    'n_max_lags': 8
}

hparams['neural_arhmm_experiment_name'] = hparams['experiment_name']
hparams['neural_arhmm_model_type'] = hparams['model_type']

get_lab_example(hparams, 'musall', 'vistrained')
sess_idx = 0 
hparams['session_dir'], sess_ids = get_session_dir(hparams)
hparams['expt_dir'] = get_expt_dir(hparams)


## Get discrete chance performance (accuracy of always predicting the most common training state)
_, states_file = get_transforms_paths('states', hparams, sess_ids[sess_idx])
with open(states_file, 'rb') as f:
    all_states = pickle.load(f)
most_common_train_state = scipy.stats.mode(np.concatenate([all_states['states'][i] for i in all_states['trials']['train']])).mode[0]

all_test_states = np.concatenate([all_states['states'][i][hparams['n_max_lags']:-hparams['n_max_lags']] for i in all_states['trials']['test']])
chance_arhmm_performance = (all_test_states==0).sum()/all_test_states.shape[0]

## Get discrete chance performance (accuracy of always predicting the most common training state)
_, states_file = get_transforms_paths('neural_arhmm_predictions', hparams, sess_ids[sess_idx])
with open(states_file, 'rb') as f:
    all_state_predictions = pickle.load(f)
all_test_state_predictions = np.concatenate([np.argmax(all_state_predictions['predictions'][i][hparams['n_max_lags']:-hparams['n_max_lags']],axis=1) for i in all_state_predictions['trials']['test']])
decoding_arhmm_performance = (all_test_states==all_test_state_predictions).sum()/all_test_states.shape[0]



In [None]:
from behavenet.fitting.utils import get_subdirs

# set model info
sess_idx = 0
hparams = {
    'data_dir': get_user_dir('data'),
    'save_dir': get_user_dir('save'),
    'model_class': 'neural-ae',
    'ae_model_type': 'conv',
    'ae_experiment_name': 'ae-example',
    'n_ae_latents': 9,
    'experiment_name':'grid_search',
    'model_type':'mlp',
    'n_max_lags': 8
}

hparams['neural_ae_experiment_name'] = hparams['experiment_name']
hparams['neural_ae_model_type'] = hparams['model_type']

get_lab_example(hparams, 'musall', 'vistrained')

hparams['session_dir'], sess_ids = get_session_dir(hparams)
expt_dir = get_expt_dir(hparams)

## Get discrete chance performance (accuracy of always predicting the most common training state)
_, latents_file = get_transforms_paths('ae_latents', hparams, sess_ids[sess_idx])
with open(latents_file, 'rb') as f:
    all_latents = pickle.load(f)
mean_ae_latents = np.mean(np.concatenate([all_latents['latents'][i] for i in all_latents['trials']['train']]),axis=0)

all_test_latents = np.concatenate([all_latents['latents'][i][hparams['n_max_lags']:-hparams['n_max_lags']] for i in all_latents['trials']['test']])
chance_ae_performance = np.mean((all_test_latents-mean_ae_latents)**2)

## Get discrete chance performance (accuracy of always predicting the most common training state)
_, latent_predictions_file = get_transforms_paths('neural_ae_predictions', hparams, sess_ids[sess_idx])
with open(latent_predictions_file, 'rb') as f:
    all_latent_predictions = pickle.load(f)
all_test_latent_predictions = np.concatenate([all_latent_predictions['predictions'][i][hparams['n_max_lags']:-hparams['n_max_lags']] for i in all_latents['trials']['test']])
decoding_ae_performance = np.mean((all_test_latents-all_test_latent_predictions)**2)



In [None]:
fig, axes = plt.subplots(1,2, figsize=(10,10))

bar = axes[0].bar([0,1], [chance_arhmm_performance, decoding_arhmm_performance])
bar[0].set_color('#355C7D')
bar[1].set_color('#F67280')
bar = axes[1].bar([0,1], [chance_ae_performance, decoding_ae_performance])
bar[0].set_color('#355C7D')
bar[1].set_color('#F67280')

axes[0].set_xticks([0,1])
axes[0].set_xticklabels(['Chance','Decoding'])

axes[1].set_xticks([0,1])
axes[1].set_xticklabels(['Chance','Decoding'])

axes[0].set_ylabel('Fraction correct')
axes[1].set_ylabel('Mean Squared Error')

axes[0].set_title('ARHMM State Decoding')
axes[1].set_title('AE Latent Decoding')

plt.tight_layout()

## Plot true vs predicted latents

[Back to contents](#Contents)

In [None]:
def plot_real_vs_predicted(
        latents, latents_predicted, save_file=None, xtick_locs=None,  frame_rate=None, format='png'):
    """Plot real and sampled latents overlaying real and (potentially sampled) states.
    Parameters
    ----------
    latents : :obj:`np.ndarray`
        shape (n_frames, n_latents)
    latents_samp : :obj:`np.ndarray`
        shape (n_frames, n_latents)
    states : :obj:`np.ndarray`
        shape (n_frames,)
    states_samp : :obj:`np.ndarray`
        shape (n_frames,) if :obj:`latents_samp` are not conditioned on :obj:`states`, otherwise
        shape (0,)
    save_file : :obj:`str`
        full save file (path and filename)
    xtick_locs : :obj:`array-like`, optional
        tick locations in bin values for plot
    frame_rate : :obj:`float`, optional
        behavioral video framerate; to properly relabel xticks
    format : :obj:`str`, optional
        any accepted matplotlib save format, e.g. 'png' | 'pdf' | 'jpeg'
    Returns
    -------
    :obj:`matplotlib.figure.Figure`
        matplotlib figure handle
    """

    fig, ax = plt.subplots(1, 1, figsize=(10, 8))

        
    spc = 1.1 * abs(latents.max())
    n_latents = latents.shape[1]
    plotting_latents = latents + spc * np.arange(n_latents)
    plotting_predicted_latents = latents_predicted + spc * np.arange(n_latents)
    ymin = min(-spc - 1, np.min(plotting_latents))
    ymax = max(spc * n_latents, np.max(plotting_latents))
    ax.plot(plotting_latents, '-k', lw=3, label='AE Latents')
    ax.plot(plotting_predicted_latents, '-g', lw=3, label='Predicted AE latents')
    ax.set_ylim([ymin, ymax])

    ax.set_yticks([])

    ax.set_xlabel('Time (bins)')

    if xtick_locs is not None:
        ax.set_xticks(xtick_locs)
        if frame_rate is not None:
            ax.set_xticklabels((np.asarray(xtick_locs) / frame_rate).astype('int'))
            ax.set_xlabel('Time (sec)')
    handles, labels = plt.gca().get_legend_handles_labels()
    by_label = dict(zip(labels, handles))
    plt.legend(by_label.values(), by_label.keys(),loc='center left', bbox_to_anchor=(1, 0.5))

    if save_file is not None:
        make_dir_if_not_exists(save_file)
        plt.savefig(save_file, dpi=300, format=format)

   # return fig

In [None]:
# user params
get_best_version = True  # False when looking at multiple models w/in a tt expt
dtype = 'test'  # data type to draw trials from: 'train' | 'val' | 'test'
sess_idx = 0  # when using a multisession, this determines which session is used
max_frames = 200

# define which arhmm states to plot (must already be fit)
n_arhmm_states = [2, 4]

# set model info
hparams = {
    'data_dir': get_user_dir('data'),
    'save_dir': get_user_dir('save'),
    'experiment_name': 'grid_search',
    'model_class': 'neural-ae',
    'model_type': 'mlp',
    'ae_experiment_name': 'ae-example',
    'n_ae_latents': 9,
    'ae_model_type': 'conv',
}

hparams['neural_ae_experiment_name'] = hparams['experiment_name']
hparams['neural_ae_model_type'] = hparams['model_type']

get_lab_example(hparams, 'musall', 'vistrained')

xtick_locs = [0, 30, 60, 90, 120, 150, 180]
frame_rate = 30
n_trials = 20
           
for n_states in n_arhmm_states:
        
    hparams['n_arhmm_states'] = n_states
    hparams['session_dir'], sess_ids = get_session_dir(hparams)
    hparams['expt_dir'] = get_expt_dir(hparams)

    # get version/model
    if get_best_version:
        version = get_best_model_version(
            hparams['expt_dir'], measure='val_loss', best_def='min')[0]
    else:
        _, version = experiment_exists(hparams, which_version=True)

    # load model
    model_file = os.path.join(
        hparams['expt_dir'], 'version_%i' % version, 'best_val_model.pt')
    with open(model_file, 'rb') as f:
        hmm = pickle.load(f)
        
            
    # load latents
    _, latents_file = get_transforms_paths('ae_latents', hparams, sess_ids[sess_idx])
    with open(latents_file, 'rb') as f:
        all_latents = pickle.load(f)

        
    # load latent predictions
    _, latents_file = get_transforms_paths('neural_ae_predictions', hparams, sess_ids[sess_idx])
    with open(latents_file, 'rb') as f:
        predicted_latents = pickle.load(f)

    # choose which trials to plot
    np.random.seed(0)
    trial_vec = np.random.choice(
        np.arange(0, len(all_latents['trials'][dtype])), size=(n_trials,), 
        replace=False)
    
model_name = str(
    'latent_prediction_D=%02i' % (
    hparams['n_ae_latents']))

if save_outputs:
    save_file = os.path.join(
        get_user_dir('fig'), hparams['model_class'], model_name + '.' + format)
else:
    save_file = None

plot_real_vs_predicted(all_latents['latents'][trial_vec[0]], predicted_latents['predictions'][trial_vec[0]], save_file=save_file, xtick_locs=None,  frame_rate=30, format='png')

