## Plot train/test error across different models/n train frames

In [None]:
import hydra
import json
import matplotlib.pyplot as plt
import numpy as np
import os
import pandas as pd
import plotly.express as px
import plotly.graph_objects as go
import seaborn as sns
import torch

from lightning_pose.utils.io import return_absolute_data_paths
from lightning_pose.utils.scripts import (
    get_imgaug_transform, get_dataset, get_data_module, get_loss_factories,
)

import sys
sys.path.append('/home/mattw/Dropbox/github/paninski-lab/tracking-diagnostics')
from diagnostics.handler import ModelHandler
from diagnostics.io import get_keypoint_names, get_base_config, update_loss_config
from diagnostics.visualizations import get_y_label

### define configuration

In [None]:
# dataset_name = "ibl-fingers"
# dataset_name = "ibl-pupil-2"
# dataset_name = "ibl-pupil-ks004"
# dataset_name = "ibl-paw-2"
# dataset_name = 'mirror-mouse-1.5'
# dataset_name = 'mirror-fish'
dataset_name = "fly"

base_config_dir = "/home/mattw/Dropbox/research-code/pose-estimation/configs"
base_save_dir = "/media/mattw/behavior/results/pose-estimation/"

model_name = 'hparam-search_07-22b' #'grid-search-0'

cfg = get_base_config(base_config_dir, "config_%s" % dataset_name)

In [None]:
# load ground truth labels
csv_file = os.path.join(cfg.data.data_dir, cfg.data.csv_file)
csv_data = pd.read_csv(csv_file, header=list(cfg.data.header_rows))
keypoints_gt = csv_data.iloc[:, 1:].to_numpy().reshape(csv_data.shape[0], -1, 2)
keypoint_names = get_keypoint_names(csv_data, cfg.data.header_rows)

In [None]:
# build data module
data_dir, video_dir = return_absolute_data_paths(data_cfg=cfg.data)
imgaug_transform = get_imgaug_transform(cfg=cfg)
dataset = get_dataset(cfg=cfg, data_dir=data_dir, imgaug_transform=imgaug_transform)
data_module = get_data_module(cfg=cfg, dataset=dataset, video_dir=video_dir)
data_module.setup()

In [None]:
save_dir = os.path.join(base_save_dir, dataset_name)

# define models
to_compute = "temporal_norm"  # pca_multiview | pca_singleview | unimodal_mse | temporal_norm
train_frame = 75
rng_seeds = [0] #, 1, 2]
model_type = "heatmap"
do_context = False

loss_weight_dict = {
    'supervised': [None],
#     'temporal': [4.0, 4.5, 5.0, 5.5, 6.0, 6.5, 7.0, 8.0],
    'pca_singleview': [4.0, 4.5, 5.0, 5.5, 6.0, 6.5, 7.0, 8.0],
#     'pca_singleview': [3.0, 4.0, 4.5, 5.0, 5.5, 6.0],
#     'pca_multiview': [4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0],
}

pca_obj = None
pca_loss = None
datamodule = None

test_videos_directory = os.path.join(data_dir, 'videos_test')
video_names = os.listdir(test_videos_directory)
video_names.sort()

results_df = []
n_vids = 0
for video_name in video_names: #[::2][:10]:
    if not video_name.endswith('.mp4'):
        continue
    n_vids += 1
    
    for rng_seed in rng_seeds:
        # store results here
        metrics_collected = {bp: [] for bp in keypoint_names}
        cols_collected = []
        video_file = os.path.join(test_videos_directory, video_name)
        # loop over models and compute metric of interest
        for loss_type, loss_weights in loss_weight_dict.items():
            for loss_weight in loss_weights:

                # model_name = 'hparam-search_07-22c' if loss_type == "supervised" else 'hparam-search_07-22c_control'
                
                # find model checkpoint
                model_cfg = cfg.copy()
                model_cfg.model.do_context = do_context
                model_cfg.training.train_frames = train_frame
                model_cfg.training.rng_seed_data_pt = rng_seed
                model_cfg.training.rng_seed_model_pt = 0  #rng_seed
                model_cfg.model.model_name = model_name

                # put model-specific config info here
                if loss_type == 'supervised':
                    model_cfg.model.losses_to_use = []    
                else:
                    model_cfg = update_loss_config(model_cfg, loss_type, loss_weight)

                try:
                    handler = ModelHandler(
                        save_dir, model_cfg, verbose=False,
                        keys_to_sweep=['limit_train_batches'])
                except FileNotFoundError:
                    print('did not find %s model for train_frames=%i' % (loss_type, train_frame))
                    continue

                filename_pred = video_name.replace('.mp4', '.csv')
                saved_vid_preds_dir = os.path.join(handler.model_dir, 'video_preds')
                pred_csv_file = os.path.join(saved_vid_preds_dir, filename_pred)
                try:
                    result, _ = handler.compute_metric(
                        to_compute, pred_csv_file,
                        video_file=video_file,
                        confidence_thresh=0.05,
                        pca_loss_obj=pca_loss, # datamodule=datamodule,
                    )
                except FileNotFoundError:
                    print('could not find model predictions')
                    continue

                if loss_type == 'supervised':
                    for loss_type_, loss_weights_ in loss_weight_dict.items():
                        if loss_type_ == 'supervised':
                            # make a supervised entry, but not under this name
                            continue
                        else:
                            cols_collected.append('%s_s' % loss_type_)
                            for b, bodypart in enumerate(keypoint_names):
                                metrics_collected[bodypart].append(result[:, b])
    #                             metrics_collected[bodypart].append(np.log(result[:, b]))
                else:
                    cols_collected.append('%s_%.1f' % (loss_type, loss_weight))
                    for b, bodypart in enumerate(keypoint_names):
                        metrics_collected[bodypart].append(result[:, b])
    #                     metrics_collected[bodypart].append(np.log(result[:, b]))

        # collect results
        for bodypart in keypoint_names:
            dict_tmp = {
                'bodypart': bodypart,
                'rng_seed': rng_seed,
                # 'video': video_name.split('_')[0].split('-')[0],
                'video': video_name[:10], #[0].split('-')[0],
            }
    #         for col_name, metric in zip(cols_collected, metrics_collected[bodypart]):
    #             dict_tmp[col_name] = np.sum(metric > eps)
    #         results_df.append(pd.DataFrame(dict_tmp, index=[0]))
            for col_name, metric in zip(cols_collected, metrics_collected[bodypart]):
                dict_tmp[col_name] = metric
            # dict_tmp['time_idx'] = ['%s_%i' % (video_name, n) for n in np.arange(metric.shape[0])]
            results_df.append(pd.DataFrame(dict_tmp))

results_df = pd.concat(results_df)

### scatterplots for a pair of models

In [None]:
# sns.set_context('talk')
# sns.set_style('whitegrid')

# results_tmp = results_df.copy()

# x = 'pca_multiview_s'
# # y = 'multi_temp_7.0'
# y = 'pca_multiview_7.0'

# eps = 5
# # results_tmp.loc[(results_tmp[x] < eps) & (results_tmp[y] < eps), x] = np.nan
# # results_tmp.loc[(results_tmp[x] < eps) & (results_tmp[y] < eps), y] = np.nan

# import plotly.express as px
# import plotly.graph_objects as go
# fig = px.scatter(
#     results_tmp, 
#     x=x, y=y,
    
# #     facet_col='bodypart',
# #     facet_row='video',

# #     color='video',  
#     facet_col='bodypart',
#     facet_col_wrap=4,
    
#     # hover_data=['time_idx'],
#     log_x=True,
#     log_y=True,
#     opacity=0.25,
#     title='Temporal norm (%s)' % dataset_name,
#     range_x=[eps - 1, 100],
#     range_y=[eps - 1, 100],
# #     trendline="ols",
# #     marginal_x='histogram',
# #     marginal_y='histogram',
# )
# fig.update_traces(marker={'size': 3})

# trace = go.Scatter(x=[eps, 100], y=[eps, 100], line_color="black", mode="lines")
# trace.update(legendgroup="trendline", showlegend=False)
# fig.add_trace(trace, row="all", col="all", exclude_empty_subplots=True)
# fig.update_layout(width=800, height=1200)
# fig.show()

### barplots across all models

In [None]:
id_vars = ['bodypart', 'video', 'rng_seed']
if 'time_idx' in results_df.keys():
    id_vars += ['time_idx']
df_tmp = pd.melt(
    results_df, 
    id_vars=id_vars, 
    value_vars=cols_collected
)
def add_loss_name_col(row):
    return '_'.join(row['variable'].split('_')[:-1])
def add_loss_val_col(row):
    return row['variable'].split('_')[-1]
df_tmp['loss'] = df_tmp.apply(add_loss_name_col, axis=1)
df_tmp['loss_weight'] = df_tmp.apply(add_loss_val_col, axis=1)
df_tmp = df_tmp.drop('variable', axis=1)

eps = 5
df_tmp = df_tmp[df_tmp.value > eps]

# take mean over keypoints, so that error bars are over rng seeds
df_tmp = df_tmp.groupby(['rng_seed', 'loss', 'loss_weight', 'video']).mean().reset_index()

In [None]:
sns.set(context='talk', style='whitegrid', font_scale=1)

y_label = get_y_label(to_compute)

# df_tmp_ = df_tmp.groupby(['bodypart', 'video', 'loss', 'loss_weight']).mean().reset_index()
# df_tmp_ = df_tmp.copy()
df_tmp_ = df_tmp #[df_tmp.value > eps]

count_plot = False

if count_plot:
    kind = 'count'
    y = None
    y_label_ = 'Number of violations'
else:
    kind = 'bar'
    y = 'value'
    y_label_ = y_label
    
g = sns.catplot(
    x='loss_weight', y=y, 
#     log=True,
#     order=['s'] + [str(w) for w in loss_weight_dict['pca_multiview']],
    kind=kind,
    
#     col='loss',
#     col_wrap=np.min([len(df_tmp_.loss.unique()), 3]), 
#     data=df_tmp_,
#     col_order=['temporal', 'pca_singleview'],
    
#     col='bodypart',
#     row='loss',
#     col='bodypart',
#     col_wrap=4,
    sharey=False,
    ci=95,
    data=df_tmp_, #[df_tmp_.loss=='pca_singleview'],
)
# g = sns.displot(
#     hue='loss_weight', x='value', 
# #     kind='strip', dodge=True,
#     col='bodypart',
#     col_wrap=4,
# #     sharey=False,
#     data=df_tmp, #[df_tmp_.loss=='pca_singleview'],
# )

g.set_axis_labels('Loss weight', y_label_)
g.set_xticklabels(rotation=45, ha='center')

for ax in g.axes.flatten():
    ax.tick_params(axis='y', which='both', direction='out', length=4, left=True)
    ax.grid(b=True, which='both', color='gray', linewidth=0.1)

# g.set(ylim=[0.01, 40])

g.fig.subplots_adjust(top=0.9)
g.fig.suptitle('%s violations (%i videos)' % (y_label, n_vids))
plt.tight_layout()
base_dir = '/home/mattw/Dropbox/research-text/posters/2022_naisys_litpose'
# plt.savefig(os.path.join(base_dir, 'rick-mouse_held-out_rmse_eps=5.pdf'))
plt.show()

### by bodypart

In [None]:
sns.set(context='talk', style='whitegrid', font_scale=1)

# df_tmp_ = df_tmp.groupby(['bodypart', 'video', 'loss', 'loss_weight']).mean().reset_index()
# df_tmp_ = df_tmp.copy()

df_tmp_ = df_tmp[df_tmp.value > eps]

count_plot = False

if count_plot:
    kind = 'count'
    y = None
    y_label_ = 'Number of violations'
else:
    kind = 'boxen'
    y = 'value'
    y_label_ = y_label
    
g = sns.catplot(
    x='loss_weight', y=y, 
#     log=True,
    order=['s'] + [str(w) for w in loss_weight_dict['pca_multiview']],
    kind=kind,
    
#     col='loss',
#     col_wrap=np.min([len(df_tmp_.loss.unique()), 3]), 
#     data=df_tmp_,
#     col_order=['temporal', 'pca_singleview'],
    
#     col='bodypart',
#     row='loss',
    col='bodypart',
    col_wrap=4,
    sharey=False,
    ci=95,
    data=df_tmp_, #[df_tmp_.loss=='pca_singleview'],
)
# g = sns.displot(
#     hue='loss_weight', x='value', 
# #     kind='strip', dodge=True,
#     col='bodypart',
#     col_wrap=4,
# #     sharey=False,
#     data=df_tmp, #[df_tmp_.loss=='pca_singleview'],
# )

g.set_axis_labels('Loss weight', y_label_)
g.set_xticklabels(rotation=45, ha='center')

for ax in g.axes.flatten():
    ax.tick_params(axis='y', which='both', direction='out', length=4, left=True)
    ax.grid(b=True, which='both', color='gray', linewidth=0.1)

# g.set(ylim=[0.01, 40])

g.fig.subplots_adjust(top=0.9)
g.fig.suptitle('%s violations (%i videos)' % (y_label, n_vids))
plt.tight_layout()
plt.show()