## Fit and analyze ARHMMs
The next step of the BehaveNet pipeline is to model the low-dimensional representation of behavior using an autoregressive hidden Markov model (ARHMM).

<br>

### Contents
* [Fitting ARHMMs](#Fitting-ARHMMs)
* [Plot validation log probability as a function of discrete states](#Plot-validation-log-probability-as-a-function-of-discrete-states)
* [Visualize state segmentations for multiple trials](#Visualize-state-segmentations-for-multiple-trials)
* [Plot inferred and generated latents and states](#Plot-inferred-and-generated-latents-and-states)
* [Make real vs generated movies](#Make-real-vs-generated-movies)
* [Make syllable movies](#Make-syllable-movies)

In [None]:
import pickle
import scipy.io as sio
import pandas as pd
import seaborn as sns
import numpy as np
import matplotlib.pyplot as plt

from behavenet import get_user_dir, make_dir_if_not_exists
from behavenet.fitting.utils import get_expt_dir
from behavenet.fitting.utils import get_session_dir
from behavenet.fitting.utils import get_best_model_version
from behavenet.fitting.utils import get_lab_example
from behavenet.plotting.arhmm_utils import *

save_outputs = True  # true to save figures/movies to user's figure directory
format = 'png'  # figure format ('png' | 'jpeg' | 'pdf'); movies saved as mp4

## Fitting ARHMMs

[Back to contents](#Contents)

In [1]:
# Note that plots are saved

## Plot validation log probability as a function of discrete states
The number of discrete ARHMM states K is a hyperparameter of the model. Though there is no one best way to choose K, a helpful diagnostic is to look at the log probability of the trained ARHMM on held-out validation data as a function of K.

[Back to contents](#Contents)

In [None]:
from behavenet.fitting.utils import get_subdirs

# define which arhmm states to plot (must already be fit)
n_arhmm_states = [2, 4, 8, 16, 32]

# set model info
hparams = {
    'data_dir': get_user_dir('data'),
    'save_dir': get_user_dir('save'),
    'experiment_name': 'arhmm-example',
    'model_class': 'arhmm',
    'model_type': None,
    'noise_type': 'gaussian',
    'n_arhmm_lags': 1,
    'kappa': 0,
    'ae_experiment_name': 'ae-example',
    'n_ae_latents': 9,
}
get_lab_example(hparams, 'musall', 'vistrained')

metrics_df = []
for n_states in n_arhmm_states:
    hparams['n_arhmm_states'] = n_states
    hparams['session_dir'], _ = get_session_dir(hparams)
    expt_dir = get_expt_dir(hparams)
    # gather all versions
    try:
        versions = get_subdirs(expt_dir)
    except Exception:
        print('No models in %s; skipping' % expt_dir)
        continue
    # load csv files with model metrics (saved out from test tube)
    for i, version in enumerate(versions):
        # read metrics csv file
        model_dir = os.path.join(expt_dir, version)
        try:
            metrics = pd.read_csv(os.path.join(model_dir, 'metrics.csv'))
        except:
            continue
        with open(os.path.join(model_dir, 'meta_tags.pkl'), 'rb') as f:
            hp_new = pickle.load(f)
        if not hp_new['training_completed']:
            continue
        for i, row in metrics.iterrows():
            if 'test_loss' in row:
                metrics_df.append(pd.DataFrame({
                    'epoch': row['epoch'],
                    'loss': row['test_loss'],
                    'n_states': n_states}, index=[0]))
metrics_df = pd.concat(metrics_df)

In [None]:
# plot data
sns.set_style('white')
sns.set_context('talk')

splt = sns.relplot(x='n_states', y='loss', hue=None, kind='line', data=metrics_df)
splt.ax.set_xlabel('ARHMM states')
splt.ax.set_xscale('log')
splt.ax.set_xticks(n_arhmm_states)
splt.ax.set_xticklabels(n_arhmm_states)
splt.ax.set_ylabel('Log prob per frame')

if save_outputs:
    filename = os.path.join(get_user_dir('fig'), 'arhmm', 'll_vs_states')
    make_dir_if_not_exists(filename)
    plt.savefig(filename + '.' + format, dpi=300, format=format)

plt.show()

## Visualize state segmentations for multiple trials
Another useful visualization is to plot ARHMM state segmentations over multiple trials, where each time point is colored based on the discrete state assigned by the ARHMM.

[Back to contents](#Contents)

In [None]:
def plot_segmentations_by_trial(
        states, xtick_locs=None, xticklabel_offset=0, frame_rate=None, save_file=None, 
        title=None, cmap='tab20b'):

    from matplotlib.lines import Line2D

    n_trials = len(states)

    fig = plt.figure(figsize=(10, n_trials / 4))
    gs_bottom_left = plt.GridSpec(n_trials, 1, top=0.85, right=1)
    for i_trial in range(n_trials):

        axes = plt.subplot(gs_bottom_left[i_trial, 0])
        axes.imshow(
            states[i_trial][None, :], aspect='auto',
            extent=(0, len(states[i_trial]), 0, 1), cmap=cmap, alpha=0.8)
        axes.set_xticks([])
        axes.set_yticks([])
        axes.set_frame_on(False)

    if xtick_locs is not None and frame_rate is not None:
        axes.set_xticks(xtick_locs)
        axes.set_xticklabels(
            ((xticklabel_offset + np.asarray(xtick_locs)) / frame_rate).astype('int'))
        axes.set_xlabel('Time (s)')
    else:
        axes.set_xlabel('Time (bins)')
    axes = plt.subplot(gs_bottom_left[int(np.floor(n_trials / 2)), 0])
    axes.set_ylabel('Trials')

    plt.suptitle(title)
    plt.tight_layout()

    if save_file is not None:
        make_dir_if_not_exists(save_file)
        fig.savefig(save_file, transparent=True, bbox_inches='tight')
        plt.close(fig)
    else:
        plt.show()
    return fig

In [None]:
# user params
get_best_version = True  # False when looking at multiple models w/in a tt expt
dtype = 'test'  # data type to draw trials from: 'train' | 'val' | 'test'
sess_idx = 0  # when using a multisession, this determines which session is used

# define which arhmm states to plot (must already be fit)
n_arhmm_states = [2, 4]

# set model info
hparams = {
    'data_dir': get_user_dir('data'),
    'save_dir': get_user_dir('save'),
    'experiment_name': 'arhmm-example',
    'model_class': 'arhmm',
    'model_type': None,
    'noise_type': 'gaussian',
    'n_arhmm_lags': 1,
    'kappa': 0,
#     'rng_seed': 0,
#     'train_frac': 1.0,
#     'rng_seed': 0,
#     'rng_seed_model': 0,
    'ae_experiment_name': 'ae-example',
    'ae_model_type': 'conv',
    'n_ae_latents': 9,
}
get_lab_example(hparams, 'musall', 'vistrained')

xtick_locs = [0, 30, 60, 90, 120, 150, 180]
frame_rate = 30
n_trials = 20
           
for n_states in n_arhmm_states:
        
    hparams['n_arhmm_states'] = n_states
    hparams['session_dir'], sess_ids = get_session_dir(hparams)
    hparams['expt_dir'] = get_expt_dir(hparams)

    # get version/model
    if get_best_version:
        version = get_best_model_version(
            hparams['expt_dir'], measure='val_loss', best_def='max')[0]
    else:
        _, version = experiment_exists(hparams, which_version=True)

    # load model
    model_file = os.path.join(
        hparams['expt_dir'], 'version_%i' % version, 'best_val_model.pt')
    with open(model_file, 'rb') as f:
        hmm = pickle.load(f)

    # load latents
    _, latents_file = get_transforms_paths('ae_latents', hparams, sess_ids[sess_idx])
    with open(latents_file, 'rb') as f:
        all_latents = pickle.load(f)
    trial_idxs = {}
    for data_type in [dtype]:
        trial_idxs[data_type] = all_latents['trials'][data_type]

    # choose which trials to plot
    np.random.seed(0)
    trial_vec = np.random.choice(
        np.arange(0, len(all_latents['trials'][dtype])), size=(n_trials,), 
        replace=False)

    # collect states
    trial_idxs = {}
    latents = {}
    states = {}
    for data_type in [dtype]:
        trial_idxs[data_type] = all_latents['trials'][data_type]
        latents[data_type] = [
            all_latents['latents'][i_trial] for i_trial in trial_idxs[data_type]]
        states[data_type] = [
            np.full((max_frames,), fill_value=np.nan) for _ in latents[data_type]]
        for i, j in enumerate(trial_vec):
            x = latents[data_type][j]
            states_tmp = hmm.most_likely_states(x)
            n_frames = len(states_tmp)
            states[data_type][j][:n_frames] = states_tmp

    model_name = str(
        'multitrial_segmentation_D=%02i_K=%02i' % (
        hparams['n_ae_latents'], hparams['n_arhmm_states']))

    if save_outputs:
        save_file = os.path.join(
            get_user_dir('fig'), 'arhmm', model_name + '.' + format)
    else:
        save_file = None

    fig = plot_segmentations_by_trial(
        [states[dtype][t] for t in trial_vec], trial_info_dict, xtick_locs,  
        -pre_frames, frame_rate, save_file=save_file, title=model_name)

## Plot inferred and generated latents and states

[Back to contents](#Contents)

In [None]:
from behavenet.analyses.arhmm_utils import get_model_latents_states
from behavenet.analyses.arhmm_utils import plot_states_overlaid_with_latents

# user params
sess_idx = 0  # when using a multisession, this determines which session is used
version = ''  # test-tube version; 'best' finds the version with the lowest mse

# set model info
hparams = {
    'data_dir': get_user_dir('data'),
    'save_dir': get_user_dir('save'),
    'experiment_name': 'arhmm-example',
    'model_class': 'arhmm',
    'model_type': None,
    'n_arhmm_states': 4,
    'noise_type': 'gaussian',
    'n_arhmm_lags': 1,
    'kappa': 0,
    'ae_experiment_name': 'ae-example',
    'ae_model_type': 'conv',
    'n_ae_latents': 9,
}
get_lab_example(hparams, 'musall', 'vistrained')

# collect inferred latents/states along with generated latents/states
np.random.seed(101)
rdict = get_model_latents_states(
    hparams, version, sess_idx=sess_idx, return_samples=20, cond_sampling=False)

In [None]:
# plot data
sns.set_context('talk')
sns.set_style('white')

fig, axes = plt.subplots(2, 1, figsize=(10, 8))

j = 8 # 2
plot_latents = rdict['latents']['test'][j]
plot_states = rdict['states']['test'][j]
i = 9
plot_latents_gen = rdict['latents_gen'][i]
plot_states_gen = rdict['states_gen'][i]

xtick_locs = [0, 30, 60, 90, 120, 150, 180]
frame_rate = 30
max_frames = 200

# plot inferred latents and states
axes[0] = plot_states_overlaid_with_latents(
    plot_latents[:max_frames, :], plot_states[:max_frames], 
    ax=axes[0], xtick_locs=xtick_locs, frame_rate=frame_rate)
axes[0].set_xticks([])
axes[0].set_xlabel('')
axes[0].set_title('Inferred latents and states')

# plot generated latents and states
axes[1] = plot_states_overlaid_with_latents(
    plot_latents_gen[:max_frames, :], plot_states_gen[:max_frames], 
    ax=axes[1], xtick_locs=xtick_locs, frame_rate=frame_rate)
axes[1].set_title('Generated latents and states')

if save_outputs:
    filename = str(
        'inf_vs_gen__D=%02i_K=%02i' % 
        (hparams['n_ae_latents'], hparams['n_arhmm_states'])
    make_dir_if_not_exists(filename)
    plt.savefig(filename + '.' + format, dpi=300, format=format)

plt.show()

## Make real vs generated movies

[Back to contents](#Contents)

In [None]:
from behavenet.analyses.arhmm_utils import get_model_latents_states
from behavenet.analyses.arhmm_utils import make_real_vs_generated_movies
from behavenet.data.utils import get_data_generator_inputs
from behavenet.data.data_generator import ConcatSessionsGenerator

# user params
dtype = 'test'  # data type to draw trials from: 'train' | 'val' | 'test'
sess_idx = 0  # when using a multisession, this determines which session is used
version = ''  # test-tube version; 'best' finds the version with the lowest mse

# set model info
hparams = {
    'data_dir': get_user_dir('data'),
    'save_dir': get_user_dir('save'),
    'experiment_name': 'arhmm-example',
    'model_class': 'arhmm',
    'model_type': None,
    'n_arhmm_states': 4,
    'noise_type': 'gaussian',
    'n_arhmm_lags': 1,
    'kappa': 0,
    'ae_experiment_name': 'ae-example',
    'ae_model_type': 'conv',
    'n_ae_latents': 9,
}

# programmatically fill out other hparams options
get_lab_example(hparams, 'musall', 'vistrained')   
hparams['session_dir'], sess_ids = get_session_dir(hparams)
hparams['expt_dir'] = get_expt_dir(hparams)

# collect inferred latents/states along with generated latents/states
np.random.seed(101)
rdict = get_model_latents_states(
    hparams, version, sess_idx=sess_idx, return_samples=20, cond_sampling=False)

In [None]:
# plot data
sns.set_context('talk')
sns.set_style('white')

xtick_locs = [0, 30, 60, 90, 120, 150, 180]
hparams['frame_rate'] = 30

make_real_vs_generated_movies(
    fig_save_dir, hparams, rdict['model'], rdict['latents'][dtype], 
    rdict['states'][dtype], data_generator, sess_idx=sess_idx, ptype=1, 
    xtick_locs=xtick_locs)

## Make syllable movies

[Back to contents](#Contents)

In [None]:
from behavenet.plotting.arhmm_utils import make_syllable_movies_wrapper

# user params
dtype = 'train'  # data type to draw trials from: 'train' | 'val' | 'test'
sess_idx = 0  # when using a multisession, this determines which session is used

# set model info
hparams = {
    'data_dir': get_user_dir('data'),
    'save_dir': get_user_dir('save'),
    'experiment_name': 'arhmm-example',
    'model_class': 'arhmm',
    'model_type': None,
    'n_arhmm_states': 4,
    'noise_type': 'gaussian',
    'n_arhmm_lags': 1,
    'kappa': 0,
    'ae_experiment_name': 'ae-example',
    'ae_model_type': 'conv',
    'n_ae_latents': 9,
}

save_file = os.path.join(
    get_user_dir('fig'), 'syllable-movies_D=%02i_K=%02i' % 
    (hparams['n_ae_latents'], hparams['n_arhmm_states']))
make_syllable_movies_wrapper(hparams, save_file, sess_idx=sess_idx, n_rows=6, dtype=dtype)