In [None]:
import matplotlib.pyplot as plt
import numpy as np
from numpy import genfromtxt
import os
import pandas as pd
import pickle
import seaborn as sns
import torch

from daart.data import DataGenerator
from daart.eval import get_precision_recall, run_lengths
from daart.io import get_expt_dir, find_experiment
from daart.models import Segmenter
from daart.transforms import ZScore
from daart_scratch.plotting import plot_heatmaps

from daart_scratch.session_ids_ibl import EXPT_IDS_TRAIN_5, EXPT_IDS_TEST

# Compare results from different model types

In [None]:
# save predicted states from models
save_states = True
# overwrite predicted states from models
overwrite_states = False
# compute state statistics like median bout duration and behavior ratios
compute_state_stats = False

base_dir = '/media/mattw/ibl/segmentation'
# tt_expt_dir = 'grid-search'
tt_expt_dir = 'batch=2000_lr=1e-4'

# define training datasets
# expt_ids = [
#     '2019_06_26_fly2',
# #     '2019_08_14_fly1',
# #     '2019_10_21_fly1'
# #     '2019_08_08_fly1',
# ]
# expt_ids = EXPT_IDS_TRAIN_10[0]
expt_ids = EXPT_IDS_TRAIN_5[0]
# expt_ids = EXPT_IDS_TRAIN_3[0]

# test datasets
# expt_ids_test = EXPT_IDS_TRAIN_5[0]
expt_ids_test = EXPT_IDS_TEST
# expt_ids_test = [
#     'cortexlab_KS020_2020-02-06-001',
#     'danlab_DY_009_2020-02-27-001',
#     'mrsicflogellab_SWC_058_2020-12-11-001',
# ]

expt_dir = get_expt_dir(os.path.join(base_dir, 'classifiers'), expt_ids)
print(expt_dir)

model_types = ['dtcn'] # , 'dtcn'] #, 'dtcn'] # 'temporal-mlp', 'gru'] #, 'lstm', 'gru']
lambda_weak = [0, 0.1, 0.5, 1, 5]
lambda_strong = 1
lambda_pred = [0, 0.1, 0.5, 1, 5]

label_names = ['still', 'move', 'wheel-turn', 'groom']
device = 'cuda'
batch_size = 2000
trial_splits = '9;1;0;0'


hparams = {
    'rng_seed_train': 0,
    'rng_seed_model': 0,
    'trial_splits': trial_splits,
    'train_frac': 1,
    'batch_size': batch_size,
    'learning_rate': 1e-4,
    'n_hid_layers': 2,
    'n_hid_units': 32,
    'n_lags': 4,
    'l2_reg': 0,
    'lambda_strong': lambda_strong,
    'bidirectional': True,
    'device': device,
    'dropout': 0.1,
    'batch_pad': 24,
}
    
metrics_df = []

for expt_id_test in expt_ids_test:
    
    # DLC markers
    markers_file = os.path.join(base_dir, 'markers', expt_id_test + '_labeled.npy')
    # hand labels
    hand_labels_file = os.path.join(
        base_dir, 'labels_deepethogram', 'DATA', expt_id_test, expt_id_test + '_labels.csv')
    # heuristic labels
    labels_file = os.path.join(
        base_dir, 'labels-heuristic', expt_id_test + '_labels.pkl')

    # define data generator signals
    signals = ['markers']
    transforms = [ZScore()]
    paths = [markers_file]

    # build data generator
    data_gen_test = DataGenerator(
        [expt_id_test], [signals], [transforms], [paths], device=device, batch_size=batch_size, 
        trial_splits=trial_splits, batch_pad=hparams['batch_pad'])
    print('----------------------------')
    print(data_gen_test)
    print('----------------------------')
    print('\n')

    # load hand labels
    labels = genfromtxt(hand_labels_file, delimiter=',', dtype=np.int, encoding=None)
    labels = labels[1:, 1:]  # get rid of headers, etc.
    states = np.argmax(labels, axis=1)
    cutoff = int(np.floor(states.shape[0] / batch_size)) * batch_size
    states = states[:cutoff]

    # load heuristic labels
    with open(labels_file, 'rb') as f:
        states_heuristic = pickle.load(f)['states']
        states_heuristic = states_heuristic[:cutoff]

    # -----------------------------------------
    # collect results
    # -----------------------------------------        
    for lw in lambda_weak:
        for lp in lambda_pred:

            hparams['lambda_weak'] = lw
            hparams['lambda_pred'] = lp
            
            # f1, precision, accuracy
            scores = {m: None for m in model_types}
            # ratio of time spent in each state
            ratios = {m: None for m in model_types}
            # median bout length for each state
            bout_lens = {m: None for m in model_types}
            
            for model_type in model_types:
                
                activation = 'lrelu' if model_type == 'dtcn' else 'tanh'
        
                # print('collecting %s results' % model_type)
                hparams['activation'] = activation
                hparams['model_type'] = model_type
                hparams['tt_expt_dir'] = os.path.join(expt_dir, model_type, tt_expt_dir)

                # load states/model
                version_int = find_experiment(hparams)
                if version_int is None:
                    continue
                version_str = str('version_%i' % version_int)
                version_dir = os.path.join(hparams['tt_expt_dir'], version_str)
                
                # check to see if states exist
                states_file = os.path.join(version_dir, '%s_states.npy' % expt_id_test)
                if os.path.exists(states_file) and not overwrite_states:
                    predictions = np.load(states_file)
                else:
                    # load model
                    model_file = os.path.join(version_dir, 'best_val_model.pt')
                    arch_file = os.path.join(version_dir, 'hparams.pkl')
                    # print('Loading model defined in %s' % arch_file)
                    with open(arch_file, 'rb') as f:
                        hparams_new = pickle.load(f)
                    hparams_new['device'] = hparams.get('device', 'cpu')
                    model = Segmenter(hparams_new)
                    model.load_state_dict(torch.load(
                        model_file, map_location=lambda storage, loc: storage))
                    model.to(hparams_new['device'])
                    model.eval()

                    # compute predictions
                    predictions = model.predict_labels(data_gen_test)['labels']
                    predictions = np.argmax(np.vstack(predictions[0]), axis=1)
                    
                    # save predictions
                    if save_states:
                        print('saving states to %s' % states_file)
                        np.save(states_file, predictions)

                # compute precision and recall for each behavior type (model)
                scores[model_type] = get_precision_recall(
                    states, predictions, background=0, n_classes=4)
                
                if compute_state_stats:
                    ratios[model_type] = np.bincount(predictions) / len(predictions)
                    bouts = run_lengths(predictions)
                    bout_lens[model_type] = [np.median(a) for _, a in bouts.items()]
                else:
                    ratios[model_type] = [None] * (len(label_names) + 1)
                    bout_lens[model_type] = [None] * (len(label_names) + 1)

            # TODO: temporary
            if version_int is None:
                continue
                
            # compute precision and recall for each behavior type (heuristic)
            scores['heuristic'] = get_precision_recall(
                states, states_heuristic, background=0, n_classes=4)

            # store
            for l, label_name in enumerate(label_names):
                df_dict = {
                    'expt_id': expt_id_test,
                    'label': label_name,
                    'lambda_weak': lw,
                    'lambda_pred': lp,
                    'precision_heur': scores['heuristic']['precision'][l],
                    'recall_heur': scores['heuristic']['recall'][l],
                    'f1_heur': scores['heuristic']['f1'][l],
                }
                for model_type in model_types:
                    model_name = 'mlp' if model_type == 'temporal-mlp' else model_type
                    df_dict['precision_%s' % model_name] = scores[model_type]['precision'][l]
                    df_dict['recall_%s' % model_name] = scores[model_type]['recall'][l]
                    df_dict['f1_%s' % model_name] = scores[model_type]['f1'][l]
                    df_dict['ratio_%s' % model_name] = ratios[model_type][l + 1]
                    df_dict['bout_len_%s' % model_name] = bout_lens[model_type][l + 1]
                metrics_df.append(pd.DataFrame(df_dict, index=[0]))

metrics_df = pd.concat(metrics_df)

### plot results: heat plots for single model

In [None]:
sns.set_context('talk')
sns.set_style('white')

save_figs = False
metric = 'f1_dtcn'

n_rows = 2
n_cols = 2

# 1: by dataset/label
# 2: by label, avg over datasets
# 3: avg over dataset/label
plot_types = [2, 3]

if save_figs:
    fig_save_path = '/home/mattw/Dropbox/research-text/papers/2021-daart/figs'
    fig_save_file = os.path.join(fig_save_path, '%s.pdf' % metric)
else:
    fig_save_file = None

for plot_type in plot_types:
    annot = True #True if plot_type == 3 else False
    plot_heatmaps(
        df=metrics_df, metric=metric, expt_ids=expt_ids_test, 
        title=metric.split('_')[-1].upper(), kind=plot_type, vmin=0.7, vmax=1,
        annot=annot, cmaps=['Oranges_r', 'Greens_r', 'Reds_r', 'Purples_r'],
        save_file=fig_save_file) #cmap='Oranges_r')
