In [None]:
import numpy as np
import os
import pandas as pd
import matplotlib.pyplot as plt
from matplotlib.gridspec import SubplotSpec
from scipy.stats import pearsonr
import seaborn as sns
from tqdm import tqdm

In [None]:
sns.set_style('white')
df_save_path = "/media/mattw/behavior/results/pose-estimation/mirror-mouse"
dataset_name = "mirror-mouse"

In [None]:
# df_labeled_preds = pd.read_parquet(
#     os.path.join(df_save_path, "%s_labeled_preds.pqt" % dataset_name))
df_labeled_metrics = pd.read_parquet(
    os.path.join(df_save_path, "%s_labeled_metrics.pqt" % dataset_name))
# df_video_preds = pd.read_parquet(
#     os.path.join(df_save_path, "%s_video_preds.pqt" % dataset_name))
df_video_metrics = pd.read_parquet(
    os.path.join(df_save_path, "%s_video_metrics.pqt" % dataset_name))

In [None]:
def plot_scatters(
    df, metric_names, train_frames, split_set, distribution, model_types, keypoint, ax,
    add_diagonal=False, add_trendline=False,    
):
    """Plot scatters using matplotlib"""
    symbols = ['.', '+', '^', 's', 'o']
    mask_0 = get_scatter_mask(
        df=df, metric_name=metric_names[0], train_frames=train_frames, split_set=split_set, 
        distribution=distribution, model_type=model_types[0])
    mask_1 = get_scatter_mask(
        df=df, metric_name=metric_names[1], train_frames=train_frames, split_set=split_set, 
        distribution=distribution, model_type=model_types[1])
    df_xs = df[mask_0][keypoint]
    df_ys = df[mask_1][keypoint]
    assert np.all(df_xs.index == df_ys.index)
    xs = df_xs.to_numpy()
    ys = df_ys.to_numpy()
    rng_seed = df[mask_0].rng_seed_data_pt.to_numpy()
    mn = np.min([np.percentile(xs, 1), np.percentile(ys, 1)])
    mx = np.max([np.percentile(xs, 99), np.percentile(ys, 99)])
    for j, r in enumerate(np.unique(rng_seed)):
        ax.scatter(
            xs[rng_seed == r], ys[rng_seed == r], marker=symbols[j], color='k',
            alpha=0.5, label='RNG seed %s' % r)
    
    ret_vals = None
    if add_diagonal:
        ax.plot([mn, mx], [mn, mx], 'k')
        ax.set_xlim(mn - 0.1 * mx, 1.1 * mx)
        ax.set_ylim(mn - 0.1 * mx, 1.1 * mx)
    if add_trendline:
        zs = np.polyfit(xs, ys, 1)
        p = np.poly1d(zs)
        r_val, p_val = pearsonr(xs, ys)
        ax.plot(xs, p(xs), '--r')
        ret_vals = r_val, p_val
    return ret_vals
    

def get_scatter_mask(
        df, metric_name, train_frames, model_type, split_set=None, distribution=None):
    mask = ((df.metric == metric_name)
            & (df.train_frames == train_frames)
            & (df.model_type == model_type))
    if split_set is not None:
        mask = mask & (df.set == split_set)
    if distribution is not None:
        mask = mask & (df.distribution == distribution)
    return mask


def create_subtitle(fig: plt.Figure, grid: SubplotSpec, title: str):
    row = fig.add_subplot(grid)
    # the '\n' is important
    row.set_title(f'{title}\n', fontweight='semibold', fontsize=14)
    # hide subplot
    row.set_frame_on(False)
    row.axis('off')

# Plots on labeled data

### basic metrics for each model type

In [None]:
# take mean over points, variability over seeds

keypoint = 'mean'
split_set = 'test'  # 'test' is only value for which InD and OOD both have results
train_frames = '75'  # '75' | '1'
plots = {
    'pixel_error': 'Pixel error',
    'pca_singleview_error': 'Pose PCA',
    'pca_multiview_error': 'Multi-view PCA',
}
models_to_compare = ['baseline', 'semi-super context']

labels_fontsize = 12
train_frame_str = 'full train frames' if train_frames == 1 else '%s train frames' % train_frames

# row 1: barplots of metrics for all models
# row 2: scatterplots of metrics on OOD data for 2 models
n_rows = 2

fig, axes = plt.subplots(n_rows, len(plots), figsize=(4 * len(plots), 3.5 * n_rows + 0.5))
if n_rows == 1:
    axes = [axes]

for i, (metric_name, title) in enumerate(plots.items()):
    mask = ((df_labeled_metrics.set==split_set) 
            & (df_labeled_metrics.metric==metric_name)
            & (df_labeled_metrics.train_frames==train_frames))

    # row 1: barplots of metrics for all models
    sns.barplot(
        x='distribution', y=keypoint, hue='model_type',
        hue_order=['baseline', 'context', 'semi-super', 'semi-super context'],
        data=df_labeled_metrics[mask],
        ax=axes[0][i],
    )
    axes[0][i].set_title(title)
    axes[0][i].set_ylabel('Error (pix)', fontsize=labels_fontsize)
    axes[0][i].set_xlabel('Distribution', fontsize=labels_fontsize)
    if i != 0:
        axes[0][i].get_legend().remove()
    
    # row 2: scatterplots of metrics on OOD data for 2 models
    plot_scatters(
        df=df_labeled_metrics, metric_names=[metric_name, metric_name], 
        train_frames=train_frames, split_set=split_set, distribution='OOD', 
        model_types=models_to_compare, keypoint=keypoint, ax=axes[1][i], add_diagonal=True)
    axes[1][i].set_title('%s (OOD data)' % title)
    axes[1][i].set_xlabel(
        '%s model error (pix)' % (models_to_compare[0].capitalize()), fontsize=labels_fontsize)
    axes[1][i].set_ylabel(
        '%s model error (pix)' % (models_to_compare[1].capitalize()), fontsize=labels_fontsize)

    if i == 0:
        axes[1][i].legend()

plt.subplots_adjust(top=0.95)
plt.suptitle('Labeled data results on %s dataset (%s)' % (dataset_name, train_frame_str), fontsize=14)
plt.tight_layout()
plt.show()

### pixel error vs metric scatters

In [None]:
plots = {
    'confidence': 'Confidence',
    'pca_singleview_error': 'Pose PCA (pix)',
    'pca_multiview_error': 'Multi-view PCA (pix)',
}
models_to_compare = ['baseline', 'context', 'semi-super', 'semi-super context']
n_cols = len(plots)
n_rows = len(models_to_compare)

fig, axes = plt.subplots(n_rows, n_cols, figsize=(4 * n_cols, 3.5 * n_rows + 0.5))
grid = plt.GridSpec(n_rows, n_cols)

for j, model_type in enumerate(models_to_compare):   
    create_subtitle(fig, grid[j, ::], model_type.capitalize())
    for i, (metric_name, title) in enumerate(plots.items()):
        r_val, p_val = plot_scatters(
            df=df_labeled_metrics, metric_names=['pixel_error', metric_name], 
            train_frames=train_frames, split_set=split_set, distribution='OOD', 
            model_types=[model_type, model_type], keypoint=keypoint, ax=axes[j][i],
            add_trendline=True)
        axes[j][i].set_title('r=%1.2f [p=%1.3f]' % (r_val, p_val))
        axes[j][i].set_xlabel('Pixel error', fontsize=labels_fontsize)
        axes[j][i].set_ylabel('%s' % title, fontsize=labels_fontsize)
        if i == 0:
            axes[j][i].legend()

plt.subplots_adjust(top=0.9)
plt.suptitle(
    'Labeled data results on %s dataset (%s)' % (dataset_name, train_frame_str), fontsize=16)
plt.tight_layout()
plt.show()

# Plots on unlabeled data

In [None]:
df_video_metrics_gr = df_video_metrics.groupby([
    'metric', 'video_name', 'model_path', 'rng_seed_data_pt', 'train_frames', 'model_type']
).mean().reset_index().set_index('video_name')
n_videos = len(df_video_metrics_gr.index.unique())

In [None]:
# take mean over points, variability over seeds

plots = {
    'temporal_norm': 'Temporal Norm',
    'pca_singleview_error': 'Pose PCA',
    'pca_multiview_error': 'Multi-view PCA',
    'confidence': 'Confidence',
}
models_to_compare = ['baseline', 'semi-super context']

# row 1: barplots of metrics for all models
# row 2: scatterplots of metrics for 2 models
n_rows = 2

fig, axes = plt.subplots(n_rows, len(plots), figsize=(4 * len(plots), 3.5 * n_rows + 0.5))
if n_rows == 1:
    axes = [axes]

for i, (metric_name, title) in enumerate(plots.items()):
    
    metric_str = 'confidence' if metric_name == 'confidence' else 'error (pix)'
    
    # row 1: barplots of metrics for all models
    mask = ((df_video_metrics_gr.metric==metric_name)
        & (df_video_metrics_gr.train_frames==train_frames))
    sns.barplot(
        x='model_type', y=keypoint, #hue='model_type',
        order=['baseline', 'context', 'semi-super', 'semi-super context'],
        data=df_video_metrics_gr[mask],
        ax=axes[0][i],
    )
    axes[0][i].set_title(title)
    axes[0][i].set_ylabel(metric_str.capitalize(), fontsize=labels_fontsize)
    axes[0][i].set_xlabel('Model', fontsize=labels_fontsize)
    axes[0][i].set_xticklabels(
        ['Baseline', 'Context', 'Semi-super', 'Semi-super\nContext'], fontsize=labels_fontsize)
    
    # row 2: scatterplots of metrics on OOD data for 2 models
    plot_scatters(
        df=df_video_metrics_gr, metric_names=[metric_name, metric_name], 
        train_frames=train_frames, split_set=None, distribution=None, 
        model_types=models_to_compare, keypoint=keypoint, ax=axes[1][i], add_diagonal=True)
    axes[1][i].set_title('%s (%i videos)' % (title, n_videos))
    axes[1][i].set_xlabel(
        '%s model %s' % (models_to_compare[0].capitalize(), metric_str), 
        fontsize=labels_fontsize)
    axes[1][i].set_ylabel(
        '%s model %s' % (models_to_compare[1].capitalize(), metric_str), 
        fontsize=labels_fontsize)
    if i == 0:
        axes[1][i].legend()

plt.subplots_adjust(top=0.95)
plt.suptitle('Video results on %s dataset (%s)' % (dataset_name, train_frame_str), fontsize=14)
plt.tight_layout()
plt.show()