## Plot train/test error across different models/n train frames

In [None]:
import hydra
import json
import matplotlib.pyplot as plt
import numpy as np
import os
import pandas as pd
import plotly.express as px
import plotly.graph_objects as go
import seaborn as sns
import torch

from lightning_pose.utils.io import return_absolute_data_paths
from lightning_pose.utils.scripts import get_imgaug_transform, get_dataset, get_data_module

import sys
sys.path.append('/home/mattw/Dropbox/github/paninski-lab/tracking-diagnostics')
from diagnostics.handler import ModelHandler

### define configuration

In [None]:
dataset_name = "ibl-fingers"
# dataset_name = "ibl-pupil-2"
# dataset_name = "ibl-paw-2"
base_config_dir = "/home/mattw/Dropbox/research-code/pose-estimation/configs"
base_save_dir = "/media/mattw/behavior/results/pose-estimation/"

hydra.initialize_config_dir(base_config_dir)
cfg = hydra.compose(config_name="config_%s" % dataset_name)
cfg.training.imgaug = "default"

# load ground truth labels
csv_file = os.path.join(cfg.data.data_dir, cfg.data.csv_file)
csv_data = pd.read_csv(csv_file, header=list(cfg.data.header_rows))
keypoints_gt = csv_data.iloc[:, 1:].to_numpy().reshape(csv_data.shape[0], -1, 2)
if len(cfg.data.header_rows) == 3:
    keypoint_names = [c[1] for c in csv_data.columns[1::2]]
else:
    keypoint_names = [c[0] for c in csv_data.columns[1::2]]


In [None]:
# build data module
data_dir, video_dir = return_absolute_data_paths(data_cfg=cfg.data)
imgaug_transform = get_imgaug_transform(cfg=cfg)
dataset = get_dataset(cfg=cfg, data_dir=data_dir, imgaug_transform=imgaug_transform)
data_module = get_data_module(cfg=cfg, dataset=dataset, video_dir=video_dir)
data_module.setup()

In [None]:
save_dir = os.path.join(base_save_dir, dataset_name)

# define models
to_compute = 'pca_reproj'  # pca_reproj | unimodal_mse | temporal_norm
model_name = 'grid-search-0'
train_frame = 1 # 000
rng_seed = 0

loss_weight_dict = {
    'supervised': [None],
#     'unimodal_mse': [10.0, 1.0],
    'temporal': [4.0, 3.0, 2.0, 1.0, 0.0, -1.0, -2.0, -3.0, -4.0],
    'pca_singleview': [4.0, 3.0, 2.0, 1.0, 0.0, -1.0, -2.0, -3.0, -4.0],
}

pca_obj = None
datamodule = None

if to_compute == 'rmse':
    raise NotImplementedError
elif to_compute == 'pca_reproj':
    y_label = 'PCA reprojection error'
    from lightning_pose.utils.pca import KeypointPCA
    pca_obj = KeypointPCA(loss_type='pca_singleview', data_module=data_module)
    pca_obj()
elif to_compute == 'unimodal_mse':
    y_label = 'Unimodal MSE'
elif to_compute == 'temporal_norm':
    y_label = 'Temporal norm'

if dataset_name == "ibl-paw-2":
    test_videos_directory = os.path.join(data_dir, 'videos_og_short')
elif dataset_name == "ibl-pupil-2":
    test_videos_directory = os.path.join(data_dir, 'videos_test')
elif dataset_name == "ibl-fingers":
    test_videos_directory = os.path.join(data_dir, 'videos')

video_names = os.listdir(test_videos_directory)
video_names.sort()

results_df = []
n_vids = 0
for video_name in video_names: #[::2][:10]:
    n_vids += 1
    
    # store results here
    metrics_collected = {bp: [] for bp in keypoint_names}
    cols_collected = []
    video_file = os.path.join(test_videos_directory, video_name)
    # loop over models and compute metric of interest
    for loss_type, loss_weights in loss_weight_dict.items():
        for loss_weight in loss_weights:

            # find model checkpoint
            model_cfg = cfg.copy()
            model_cfg.training.train_frames = train_frame
            model_cfg.training.rng_seed_data_pt = rng_seed
            model_cfg.training.rng_seed_data_dali = rng_seed
            model_cfg.training.rng_seed_model_pt = rng_seed
            model_cfg.model.model_name = model_name

            # put model-specific config info here
            if loss_type == 'supervised':
                model_cfg.model.losses_to_use = []    
            else:
                model_cfg.model.losses_to_use = [loss_type]
                model_cfg.losses[loss_type].log_weight = loss_weight
    #                 print(model_cfg.losses)

            try:
                handler = ModelHandler(save_dir, model_cfg, verbose=False)
            except FileNotFoundError:
                print('did not find %s model for train_frames=%i' % (loss_type, train_frame))
                continue

            filename_pred = video_name.replace('.mp4', '_predictions.csv')
            saved_vid_preds_dir = os.path.join(handler.model_dir, 'video_predictions')
            pred_csv_file = os.path.join(saved_vid_preds_dir, filename_pred)
            print(pred_csv_file)
            filename_heat = video_name.replace('.mp4', '_heatmaps.h5')
            saved_heat_dir = os.path.join(handler.model_dir, 'video_heatmaps')
            heat_h5_file = os.path.join(saved_heat_dir, filename_heat)
            try:
                result = handler.compute_metric(
                    to_compute, pred_csv_file,
                    pca_obj=pca_obj, datamodule=datamodule, heatmap_file=heat_h5_file)
            except FileNotFoundError:
                print('could not find model predictions')
                continue

            if loss_type == 'supervised':
                for loss_type_, loss_weights_ in loss_weight_dict.items():
                    if loss_type_ == 'supervised':
                        # make a supervised entry, but not under this name
                        continue
                    else:
                        cols_collected.append('%s_s' % loss_type_)
                        for b, bodypart in enumerate(keypoint_names):
                            metrics_collected[bodypart].append(result[:, b])
            else:
                cols_collected.append('%s_%.1f' % (loss_type, loss_weight))
                for b, bodypart in enumerate(keypoint_names):
                    metrics_collected[bodypart].append(result[:, b])

    # collect results
    for bodypart in keypoint_names:
        dict_tmp = {
            'bodypart': bodypart,
            'video': video_name.split('_')[0].split('-')[0],
        }
        for col_name, metric in zip(cols_collected, metrics_collected[bodypart]):
            dict_tmp[col_name] = metric
        dict_tmp['time_idx'] = np.arange(result.shape[0])
        results_df.append(pd.DataFrame(dict_tmp))

results_df = pd.concat(results_df)

### scatterplots for a pair of models

In [None]:
sns.set_context('talk')
sns.set_style('whitegrid')

results_tmp = results_df.copy()

import plotly.express as px
import plotly.graph_objects as go
fig = px.scatter(
    results_tmp, 
    # x='unimodal_mse_0', y='unimodal_mse_1.0',
    #x='temporal_s', y='temporal_0.0',
    x='temporal_s', y='pca_singleview_-4.0',
    facet_col='bodypart',
#     facet_col_wrap=2,
    facet_row='video',
    hover_data=['time_idx'],
    log_x=True,
    log_y=True,
    opacity=0.5,
    title='PCA reprojection error (%s)' % dataset_name,
#     trendline="ols",
#     marginal_x='histogram',
#     marginal_y='histogram',
)
fig.update_traces(marker={'size': 5})

# trace = go.Scatter(x=[0.1, 10], y=[0.1, 10], line_color="black", mode="lines")
# trace.update(legendgroup="trendline", showlegend=False)
# fig.add_trace(trace, row="all", col="all", exclude_empty_subplots=True)
fig.update_layout(width=800, height=300)
fig.show()

### barplots across all models

In [None]:
df_tmp = pd.melt(
    results_df, 
    id_vars=['bodypart', 'time_idx', 'video'], 
    value_vars=cols_collected
)
def add_loss_name_col(row):
    return '_'.join(row['variable'].split('_')[:-1])
def add_loss_val_col(row):
    return row['variable'].split('_')[-1]
df_tmp['loss'] = df_tmp.apply(add_loss_name_col, axis=1)
df_tmp['loss_weight'] = df_tmp.apply(add_loss_val_col, axis=1)
df_tmp = df_tmp.drop('variable', axis=1)

In [None]:
sns.set(context='talk', style='whitegrid', font_scale=1)

df_tmp_ = df_tmp.groupby(['bodypart', 'video', 'loss', 'loss_weight']).mean().reset_index()

g = sns.catplot(
    x='loss_weight', y='value', 
    log=True,
    order=['s'] + [str(w) for w in loss_weight_dict['pca_singleview']],
#     kind='strip', dodge=True,
    kind='bar',
    
#     col='loss',
#     col_wrap=np.min([len(df_tmp_.loss.unique()), 3]), 
#     data=df_tmp_,
#     col_order=['temporal', 'pca_singleview'],
    
    col='bodypart',
    sharey=False,
    data=df_tmp_[df_tmp_.loss=='pca_singleview'],
)
g.set_axis_labels('Loss weight', y_label)
g.set_xticklabels(rotation=45, ha='center')

for ax in g.axes.flatten():
    ax.tick_params(axis='y', which='both', direction='out', length=4, left=True)
    ax.grid(b=True, which='both', color='gray', linewidth=0.1)

if dataset_name == 'ibl-paw-2':
    if to_compute == 'temporal_norm':
        g.set(ylim=[0.1, 1])
#         if eval_mode == 'unused':
#             g.set(ylim=[3.8, 4.8])
#         elif eval_mode == 'validation':
#             g.set(ylim=[4, 5.7])
#         elif eval_mode == 'train':
#             g.set(ylim=[4, 5.5])
elif dataset_name == 'ibl-pupil-2':
    if to_compute == 'unimodal_mse':
        g.set(ylim=[1e-3, 1e-2])
    elif to_compute == 'pca_reproj':
        g.set(ylim=[0.08, 6])
    elif to_compute == 'temporal_norm':
        g.set(ylim=[0.05, 3])

# g.set(ylim=[12, 14.4])

g.fig.subplots_adjust(top=0.9)
g.fig.suptitle('%s (averaged across %i videos)' % (y_label, n_vids))
plt.tight_layout()
plt.show()