# Summary
Load a single fly dataset, preprocess the markers, and fit an arhmm using several initialization methods with multiple restarts

In [None]:
import os
import numpy as np

import matplotlib.pyplot as plt
import ssm
import seaborn as sns
import pandas as pd

from behavior.data import preprocess_and_split_data, shuffle_data

In [None]:
from ssm.plots import gradient_cmap, white_to_color_cmap
sns.set_context('talk')
sns.set_style('white')

color_names = [
    "windows blue",
    "red",
    "amber",
    "faded green",
    "dusty purple",
    "orange"
    ]

colors = sns.xkcd_palette(color_names)
cmap = gradient_cmap(colors)

# 1. Load and preprocess data

In [None]:
# options
expt_ids = ['2019_08_08_fly1']
# expt_ids = ['2019_08_08_fly1_1']

# preprocessing directives
preprocess_list = {
    # 'filter': {'type': 'median', 'window_size': 3},
    'filter': {'type': 'savgol', 'window_size': 5, 'order': 2},
    # 'standardize': {}, # zscore labels
    'unitize': {}, # scale labels in [0, 1]
}

marker_obj = preprocess_and_split_data(
    expt_ids, preprocess_list, algo='dgp', load_from='h5')

datas_tr, tags_tr, _ = shuffle_data(dlc_obj, dtype='train')
datas_val, tags_val, _ = shuffle_data(dlc_obj, dtype='val')
datas_test, tags_test, _ = shuffle_data(dlc_obj, dtype='test')
D = marker_obj[0].markers_dict['train'][0].shape[1]

# 2. Fit ARHMM with EM

In [None]:
from behavior.ssmutils import get_expt_dir, get_model_name, fit_with_random_restarts
from behavior.paths import RESULTS_PATH

K = 8
lags = 1
obs = 'ar'
num_restarts = 3
num_iters = 100
method = 'em'  # 'em' | 'stochastic_em_adam' | 'stochastic_em_sgd' (non-conjugate)

em_models = {}
em_lps = {}
em_model_all = {}
em_lps_all = {}

init_types = ['random', 'kmeans', 'kmeans-diff']

for it in init_types:
    expt_dir = get_expt_dir(RESULTS_PATH, expt_ids)
    save_path = os.path.join(
        RESULTS_PATH, expt_dir, 'single-session_%s-init_%s' % (it, method))
    if not os.path.exists(save_path):
        os.makedirs(save_path)
    em_models[it], em_lps[it], em_model_all[it], em_lps_all[it] = fit_with_random_restarts(
        K, D, obs, lags, datas_tr, num_restarts=num_restarts, num_iters=num_iters, 
        method=method, save_path=save_path, init_type=it)

### plot training curves

In [None]:
plt.figure()

plot_all_restarts = True
    
for i, init_type in enumerate(init_types):
    if plot_all_restarts:
        for j, restart in enumerate(em_lps_all[init_type]):
            label = init_type if j == 0 else None
            plt.plot(restart, 'k', label=label, color=colors[i])
    else:
        plt.plot(em_lps[it], 'k', label=init_type, color=colors[i])

plt.xlabel('Epoch')
plt.ylabel('Log probability')
plt.legend(bbox_to_anchor=(1.1, 1.05))
if expt_ids[0] == '2019_08_08_fly1':
#     plt.ylim([3540000, 3560000])
    plt.ylim([3450000, 3470000])
elif expt_ids[0] == '2019_08_08_fly1_1':
    plt.ylim([4580000, 4630000])
# plt.yscale('log')
plt.show()    

# 3. Create a syllable movie

In [None]:
from behavior.data import load_video
import behavior.plotting as plotting

dtype = 'train'

# get states from one of the models
arhmm = em_models['kmeans-diff']
states = [arhmm.most_likely_states(data) for data in marker_obj[0].markers_dict[dtype]]

# load video frames
video = load_video(expt_ids[0])

In [None]:
save_file = os.path.join(RESULTS_PATH, 'figs', '%s_syllable-video.mp4' % expt_ids[0])
plotting.make_syllable_movie(
    save_file, states, video, marker_obj[0].idxs_dict[dtype], single_state=None, 
    min_threshold=10, n_pre_frames=0, n_buffer=10, plot_n_frames=500)

# 4. Create a labeled movie clip

In [None]:
# get markers/states for all time points; must match up with time dim of video data
markers_all = marker_obj[0].get_marker_array()
states_all = arhmm.most_likely_states(markers_all)

# name each state; default to non-descriptive labels for now
state_mapping = {i: 'state %i' % i for i in range(K)}

save_file = os.path.join(RESULTS_PATH, 'figs', '%s_labeled-video.mp4' % expt_ids[0])
idxs = np.arange(0, 500)
plotting.make_labeled_movie_wmarkers(
    save_file, states_all, video, markers_all, idxs, state_mapping)