# Compare results across hyperparameters
Collect performance information like F1 scores for a given model architecture while varying the hyperparameters in the objective function that weight the weak and self-supervised terms; plot performance several different ways in a heatmap format

In [None]:
import matplotlib.pyplot as plt
import numpy as np
from numpy import genfromtxt
import os
import pandas as pd
import pickle
import seaborn as sns

from daart.data import DataGenerator
from daart.eval import get_precision_recall, run_lengths
from daart.io import get_expt_dir, find_experiment
from daart.models import Segmenter
from daart.transforms import ZScore
from daart.utils import compute_batch_pad

from daart_utils.paths import data_path, results_path
from daart_utils.plotting import plot_heatmaps

In [None]:
# save predicted states from models
save_states = True
# overwrite predicted states from models
overwrite_states = False
# compute state statistics like median bout duration and behavior ratios
compute_state_stats = False

dataset = 'ibl'

if dataset == 'fly':
    from daart_utils.session_ids.fly import SESS_IDS_TRAIN_5, SESS_IDS_TEST
    from daart_utils.session_ids.fly import label_names
elif dataset == 'ibl':
    from daart_utils.session_ids.ibl import SESS_IDS_TRAIN_5, SESS_IDS_TEST
    from daart_utils.session_ids.ibl import label_names
    
# train datasets
sess_ids = SESS_IDS_TRAIN_5[0]
# test datasets
sess_ids_test = SESS_IDS_TEST

expt_dir = get_expt_dir(os.path.join(results_path, dataset), sess_ids)
print(expt_dir)

# ATTN: need to set these params to match models that were fit
lambda_weak = [1, 5]
lambda_strong = 1
lambda_pred = [1, 5]

model_type = 'dtcn'
device = 'cuda'
batch_size = 2000
tt_expt_dir = 'test'

hparams = {
    'rng_seed_train': 0,
    'rng_seed_model': 0,
    'trial_splits': '9;1;0;0',
    'train_frac': 1,
    'batch_size': batch_size,
    'model_type': model_type,
    'learning_rate': 1e-4,
    'n_hid_layers': 2,
    'n_hid_units': 32,
    'n_lags': 4,
    'l2_reg': 0,
    'lambda_strong': lambda_strong,
    'bidirectional': True,
    'device': device,
    'dropout': 0.1,
    'activation': 'lrelu'
}

hparams['batch_pad'] = compute_batch_pad(hparams)
hparams['tt_expt_dir'] = os.path.join(expt_dir, model_type, tt_expt_dir)
model_types = [model_type]  # legacy, need to clean up

metrics_df = []

for sess_id_test in sess_ids_test:
    
    # markers
    if dataset == 'ibl':
        markers_file = os.path.join(
            data_path, dataset, 'markers', sess_id_test + '_labeled.npy')
    else:
        markers_file = os.path.join(
            data_path, dataset, 'markers', sess_id_test + '_labeled.h5')

    # define data generator signals
    signals = ['markers']
    transforms = [ZScore()]
    paths = [markers_file]

    # build data generator
    data_gen_test = DataGenerator(
        [sess_id_test], [signals], [transforms], [paths], device=device, 
        batch_size=batch_size, trial_splits=hparams['trial_splits'], 
        batch_pad=hparams['batch_pad'])
    print('----------------------------')
    print(data_gen_test)
    print('----------------------------')
    print('\n')

    # load hand labels
    hand_labels_file = os.path.join(
        data_path, dataset, 'labels-hand', sess_id_test + '_labels.csv')
    labels = genfromtxt(hand_labels_file, delimiter=',', dtype=np.int, encoding=None)
    labels = labels[1:, 1:]  # get rid of headers, etc.
    states = np.argmax(labels, axis=1)
    cutoff = int(np.floor(states.shape[0] / batch_size)) * batch_size
    states = states[:cutoff]

    # load heuristic labels
    heur_labels_file = os.path.join(
        data_path, dataset, 'labels-heuristic', sess_id_test + '_labels.csv')
    labels_h = genfromtxt(heur_labels_file, delimiter=',', dtype=np.int, encoding=None)
    labels_h = labels_h[1:, 1:]  # get rid of headers, etc.
    states_h = np.argmax(labels_h, axis=1)
    states_heuristic = states_h[:cutoff]

    # -----------------------------------------
    # collect results
    # -----------------------------------------        
    for lw in lambda_weak:
        for lp in lambda_pred:

            hparams['lambda_weak'] = lw
            hparams['lambda_pred'] = lp
            
            # f1, precision, accuracy
            scores = {m: None for m in model_types}
            # ratio of time spent in each state
            ratios = {m: None for m in model_types}
            # median bout length for each state
            bout_lens = {m: None for m in model_types}

            # load states/model
            version_int = find_experiment(hparams)
            if version_int is None:
                continue
            version_str = str('version_%i' % version_int)
            version_dir = os.path.join(hparams['tt_expt_dir'], version_str)

            # check to see if states exist
            states_file = os.path.join(version_dir, '%s_states.npy' % sess_id_test)
            if os.path.exists(states_file) and not overwrite_states:
                predictions = np.load(states_file)
            else:
                # load model
                model_file = os.path.join(version_dir, 'best_val_model.pt')
                arch_file = os.path.join(version_dir, 'hparams.pkl')
                # print('Loading model defined in %s' % arch_file)
                with open(arch_file, 'rb') as f:
                    hparams_new = pickle.load(f)
                hparams_new['device'] = hparams.get('device', 'cpu')
                model = Segmenter(hparams_new)
                model.load_parameters_from_file(model_file)
                model.to(hparams_new['device'])
                model.eval()

                # compute predictions
                predictions = model.predict_labels(data_gen_test)['labels']
                predictions = np.argmax(np.vstack(predictions[0]), axis=1)

                # save predictions
                if save_states:
                    print('saving states to %s' % states_file)
                    np.save(states_file, predictions)

            # compute precision and recall for each behavior type (model)
            scores[model_type] = get_precision_recall(
                states, predictions, background=0, n_classes=len(label_names) - 1)

            if compute_state_stats:
                ratios[model_type] = np.bincount(predictions) / len(predictions)
                bouts = run_lengths(predictions)
                bout_lens[model_type] = [np.median(a) for _, a in bouts.items()]
            else:
                ratios[model_type] = [None] * (len(label_names) + 1)
                bout_lens[model_type] = [None] * (len(label_names) + 1)

            # TODO: temporary
            if version_int is None:
                continue

            # compute precision and recall for each behavior type (heuristic)
            scores['heuristic'] = get_precision_recall(
                states, states_heuristic, background=0, n_classes=len(label_names) - 1)

            # store
            for l, label_name in enumerate(label_names[1:]):
                df_dict = {
                    'sess_id': sess_id_test,
                    'label': label_name,
                    'lambda_weak': lw,
                    'lambda_pred': lp,
                    'precision_heur': scores['heuristic']['precision'][l],
                    'recall_heur': scores['heuristic']['recall'][l],
                    'f1_heur': scores['heuristic']['f1'][l],
                }
                for model_type in model_types:
                    model_name = 'mlp' if model_type == 'temporal-mlp' else model_type
                    df_dict['precision_%s' % model_name] = scores[model_type]['precision'][l]
                    df_dict['recall_%s' % model_name] = scores[model_type]['recall'][l]
                    df_dict['f1_%s' % model_name] = scores[model_type]['f1'][l]
                    df_dict['ratio_%s' % model_name] = ratios[model_type][l + 1]
                    df_dict['bout_len_%s' % model_name] = bout_lens[model_type][l + 1]
                metrics_df.append(pd.DataFrame(df_dict, index=[0]))

metrics_df = pd.concat(metrics_df)

### plot results: heat plots for single model

In [None]:
sns.set_context('talk')
sns.set_style('white')

metric = 'f1_dtcn'

n_rows = 2
n_cols = 2

# 1: by dataset/label
# 2: by label, avg over datasets
# 3: avg over dataset/label
plot_types = [2, 3]

for plot_type in plot_types:
    annot = True #True if plot_type == 3 else False
    plot_heatmaps(
        df=metrics_df, metric=metric, sess_ids=sess_ids_test, 
        title=metric.split('_')[-1].upper(), kind=plot_type, vmin=0.7, vmax=1,
        annot=annot, cmaps=['Oranges_r', 'Greens_r', 'Reds_r', 'Purples_r'],
        save_file=None)
