## Analyze AEs with matrix subspace projection loss
This notebook is a template showcasing some ways to analyze autoencoders that have been fit with the matrix subspace projection (MSP) loss.

 <br>
 
### Contents
* [Plot loss metrics as a function of epochs](#Plot-loss-metrics-as-a-function-of-epoch)
* [Plot true vs predicted labels](#Plot-true-vs-predicted-labels)
* [Evaluate orthogonality of projection matrix](#Evaluate-orthogonality-of-projection-matrix)
* [Explore label/latent space](#Explore-label/latent-space)
    * [explore label space](#Explore-2D-label-space)
    * [explore latent space](#Explore-2D-latent-space)

In [None]:
import copy
import os
import pandas as pd
import seaborn as sns
import numpy as np
import matplotlib.pyplot as plt

from behavenet import get_user_dir
from behavenet import make_dir_if_not_exists
from behavenet.fitting.utils import get_expt_dir
from behavenet.fitting.utils import get_session_dir
from behavenet.fitting.utils import get_best_model_version
from behavenet.fitting.utils import get_lab_example

save_outputs = False  # true to save figures/movies to user's figure directory
format = 'png'  # figure format ('png' | 'jpeg' | 'pdf'); movies saved as mp4

# Plot loss metrics as a function of epoch

[Back to contents](#Contents)

In [None]:
from behavenet.plotting import load_metrics_csv_as_df

# set data info
lab = ?
expt = ?
n_labels = ?

# set model info
n_ae_latents = ?  # n_labels will be added to this
tt_expt_name = ?

hparams = {
    'data_dir': get_user_dir('data'),
    'save_dir': get_user_dir('save'),
    'experiment_name': tt_expt_name,
    'model_class': 'cond-ae-msp',
    'model_type': 'conv',
    'n_ae_latents': n_ae_latents + n_labels}

metrics_list = ['loss', 'loss_mse', 'loss_msp', 'r2']
metrics_df = load_metrics_csv_as_df(hparams, lab, expt, metrics_list)

In [None]:
# plot data
sns.set_style('white')
sns.set_context('talk')

for y in metrics_list:
    
    data_queried = metrics_df[(metrics_df.epoch > 10) & ~pd.isna(metrics_df.loss)]
    splt = sns.relplot(x='epoch', y=y, hue='dtype', kind='line', data=data_queried)
    splt.ax.set_xlabel('Epoch')
    if y == 'loss':
        splt.ax.set_ylabel('Total loss')
        splt.ax.set_yscale('log')
    elif y == 'loss_mse':
        splt.ax.set_ylabel('MSE per pixel')
        splt.ax.set_yscale('log')
    elif y == 'loss_msp':
        splt.ax.set_ylabel('MSE per label')
        splt.ax.set_yscale('log')
    elif y == 'r2':
        splt.ax.set_ylabel('Label $R^2$')

    if save_outputs:
        save_file = os.path.join(get_user_dir('fig'), 'ae', 'loss_vs_epoch')
        make_dir_if_not_exists(save_file)
        plt.savefig(save_file + '.' + format, dpi=300, format=format)

    plt.show()

# Plot true vs predicted labels

[Back to contents](#Contents)

In [None]:
from behavenet.fitting.utils import get_best_model_and_data
from behavenet.models import AEMSP

# set model info
version = 0 # 'best'  # test-tube version; 'best' finds the version with the lowest mse
sess_idx = 0  # when using a multisession, this determines which session is used
hparams = {
    'data_dir': get_user_dir('data'),
    'save_dir': get_user_dir('save'),
    'experiment_name': tt_expt_name,
    'model_class': 'cond-ae-msp',
    'model_type': 'conv',
    'n_ae_latents': n_ae_latents + n_labels}

trial_idxs = [1, 2, 3]  # test trials to plot

# programmatically fill out other hparams options
get_lab_example(hparams, lab, expt)   

model, data_generator = get_best_model_and_data(
    hparams, AEMSP, load_data=True, version=version, data_kwargs=None)
n_labels = model.n_labels
print(data_generator)

In [None]:
from behavenet.plotting.ae_utils import plot_neural_reconstruction_traces

for trial_idx in trial_idxs:
    trial = data_generator.datasets[sess_idx].batch_idxs['test'][trial_idx]
    batch = data_generator.datasets[sess_idx][trial]
    labels_og = batch['labels'].detach().cpu().numpy()
    labels_pred = model.get_transformed_latents(batch['images'])[:, :n_labels]
    plot = plot_neural_reconstruction_traces(labels_og, labels_pred, scale=2)

# Evaluate orthogonality of projection matrix

[Back to contents](#Contents)

In [None]:
U = model.U.weight.data.cpu().detach().numpy()

plt.figure(figsize=(6, 6))
overlap = np.matmul(U, U.T)
m = np.max(np.abs(overlap))
plt.imshow(overlap, cmap='RdBu', vmin=-m, vmax=m)
plt.colorbar()
plt.show()

# Explore label/latent space

[Back to contents](#Contents)

In [None]:
import torch

from behavenet.data.utils import get_data_generator_inputs

from behavenet.fitting.utils import get_best_model_and_data
from behavenet.fitting.eval import get_reconstruction

from behavenet.plotting.cond_ae_utils import get_crop
from behavenet.plotting.cond_ae_utils import get_input_range
from behavenet.plotting.cond_ae_utils import get_labels_2d_for_trial
from behavenet.plotting.cond_ae_utils import get_model_input
from behavenet.plotting.cond_ae_utils import interpolate_2d
from behavenet.plotting.cond_ae_utils import plot_2d_frame_array

### setup - define model

In [None]:
from behavenet.models import AEMSP as Model

# dataset info
n_ae_latents = 2  # not including label-related latents
label_min_p = 15  # minimum percentile for latent/label space interpolation
label_max_p = 85  # maximum percentile for latent/label space interpolation
n_frames = 3  # number of frames to plot along each manipulated dim
trial_idx = 0  # index into trials for base frame
batch_idx = 0  # index into batch for base frame
label_idxs = [5, 1]  # indices of labels to manipulate; y label first, then x
latent_idxs = np.array([0, 1])  # indices of latents to manipulate
    
show_markers = True
    
# set model info
version = 0  # test-tube version; 'best' finds the version with the lowest mse
sess_idx = 0  # when using a multisession, this determines which session is used
hparams = {
    'data_dir': get_user_dir('data'),
    'save_dir': get_user_dir('save'),
    'experiment_name': tt_expt_name,
    'model_class': 'cond-ae-msp',
    'model_type': 'conv',
    'n_ae_latents': n_ae_latents + n_labels,
    'rng_seed_data': 0,
    'trial_splits': '8;1;1;0',
    'train_frac': 1.0,
    'rng_seed_model': 0,
    'conditional_encoder': False,
}

# programmatically fill out other hparams options
get_lab_example(hparams, lab, expt)
hparams['session_dir'], sess_ids = get_session_dir(hparams)
hparams['expt_dir'] = get_expt_dir(hparams)

# build model
model_ae, data_generator = get_best_model_and_data(hparams, Model, version=version)

latent_range = get_input_range(
    'latents', hparams, model=model_ae, data_gen=data_generator)
label_range = get_input_range(
    'labels', hparams, sess_ids=sess_ids, sess_idx=sess_idx, 
    min_p=label_min_p, max_p=label_max_p)
label_sc_range = get_input_range(
    'labels_sc', hparams, sess_ids=sess_ids, sess_idx=sess_idx,
    min_p=label_min_p, max_p=label_max_p)

## Explore 2D label space

[Back to contents](#Contents)

In [None]:
ims_pt, ims_np, latents_np, labels_pt, labels_np, labels_2d_pt, labels_2d_np = \
    get_model_input(
        data_generator, hparams, model_ae, trial_idx=trial_idx, compute_latents=True, 
        compute_scaled_labels=False, compute_2d_labels=True)

ims_label, markers_loc_label, ims_crop_label = interpolate_2d(
    'labels', model_ae, ims_pt[None, batch_idx, :], latents_np[None, batch_idx, :], 
    labels_np[None, batch_idx, :], labels_2d_np[None, batch_idx, :], 
    mins=[label_range['min'][label_idxs[0]], label_range['min'][label_idxs[1]]], 
    maxes=[label_range['max'][label_idxs[0]], label_range['max'][label_idxs[1]]], 
    n_frames=n_frames, input_idxs=label_idxs, crop_type=None, 
    mins_sc=[label_sc_range['min'][label_idxs[0]], label_sc_range['min'][label_idxs[1]]], 
    maxes_sc=[label_sc_range['max'][label_idxs[0]], label_sc_range['max'][label_idxs[1]]], 
    crop_kwargs=None, ch=0)

In [None]:
marker_kwargs = {
    'markersize': 20, 'markeredgewidth': 3, 'markeredgecolor': [1, 1, 0],
    'fillstyle': 'none'}

if save_outputs:
    save_file = os.path.join(
        get_user_dir('fig'), 
        'ae', 'D=%02i_label-manipulation_%s_%s-crop.png' % 
        (hparams['n_ae_latents'], hparams['session'], crop_type))
else:
    save_file = None

plot_2d_frame_array(
    ims_label, markers=markers_loc_label, marker_kwargs=marker_kwargs, save_file=None,
    figsize=(15, 15))

## Explore 2D latent space

[Back to contents](#Contents)

In [None]:
ims_pt, ims_np, latents_np, labels_pt, labels_np, labels_2d_pt, labels_2d_np = \
    get_model_input(data_generator, hparams, model_ae, trial=None, trial_idx=trial_idx,
    compute_latents=True, compute_scaled_labels=False, compute_2d_labels=True)

latent_idxs += n_labels  # first `n_labels` dims are used to reconstruct labels

ims_latent, markers_loc_latent_, ims_crop_latent = interpolate_2d(
    'latents', model_ae, ims_pt[None, batch_idx, :], latents_np[None, batch_idx, :], 
    labels_np[None, batch_idx, :], labels_2d_np[None, batch_idx, :], 
    mins=[latent_range['min'][latent_idxs[0]], latent_range['min'][latent_idxs[1]]], 
    maxes=[latent_range['max'][latent_idxs[0]], latent_range['max'][latent_idxs[1]]], 
    n_frames=n_frames, input_idxs=latent_idxs, crop_type=None, 
    mins_sc=None, maxes_sc=None, crop_kwargs=None, marker_idxs=label_idxs, ch=0)

In [None]:
marker_kwargs = {
    'markersize': 20, 'markeredgewidth': 5, 'markeredgecolor': [1, 1, 0],
    'fillstyle': 'none'}

if save_outputs:
    save_file = os.path.join(
        get_user_dir('fig'), 
        'ae', 'D=%02i_latent-manipulation_%s_%s-crop.png' % 
        (hparams['n_ae_latents'], hparams['session'], crop_type))
else:
    save_file = None

plot_2d_frame_array(
    ims_latent, markers=markers_loc_latent_, marker_kwargs=marker_kwargs, 
    save_file=None, figsize=(15, 15))