## Evaluate heuristic labels using hand labels

In [None]:
import matplotlib.pyplot as plt
import numpy as np
from numpy import genfromtxt
import os
import pandas as pd
import seaborn as sns

from daart.eval import get_precision_recall
from daart_utils.paths import data_path

### define data paths

In [None]:
dataset = 'ibl'

if dataset == 'fly':
    from daart_utils.session_ids.fly import SESS_IDS_ALL as sess_ids
    from daart_utils.session_ids.fly import label_names
elif dataset == 'ibl':
    from daart_utils.session_ids.ibl import SESS_IDS_ALL as sess_ids
    from daart_utils.session_ids.ibl import label_names

metrics_df = []
for sess_id in sess_ids:
    
    # hand labels
    hand_labels_file = os.path.join(
        data_path, dataset, 'labels-hand', sess_id + '_labels.csv')
    labels = genfromtxt(hand_labels_file, delimiter=',', dtype=np.int, encoding=None)
    labels = labels[1:, 1:]  # get rid of headers, etc.
    states_hand = np.argmax(labels, axis=1)
    
    # heuristic labels
    heur_labels_file = os.path.join(
        data_path, dataset, 'labels-heuristic', sess_id + '_labels.csv')
    labels_h = genfromtxt(heur_labels_file, delimiter=',', dtype=np.int, encoding=None)
    labels_h = labels_h[1:, 1:]  # get rid of headers, etc.
    states_heuristic = np.argmax(labels_h, axis=1)

    # compute precision and recall for each behavior type
    scores = get_precision_recall(
        states_hand, states_heuristic, background=0, n_classes=len(label_names) - 1)

    # store
    for l, label_name in enumerate(label_names[1:]):
        metrics_df.append(pd.DataFrame({
            'sess_id': sess_id,
            'label': label_name,
            'f1': scores['f1'][l],
            'precision': scores['precision'][l],
            'recall': scores['recall'][l],
        }, index=[0]))

metrics_df = pd.concat(metrics_df)

### Plot precision/recall

In [None]:
sns.set_context('talk')
sns.set_style('whitegrid')

g = sns.relplot(
    x='precision', y='recall', col='label', col_wrap=2, hue='sess_id', data=metrics_df)
g.fig.subplots_adjust(top=0.9)
g.fig.suptitle('Heurisitc label evaluation')
plt.show()