In [None]:
import cv2
import numpy as np
import os
import pandas as pd
import matplotlib.pyplot as plt
from matplotlib.gridspec import SubplotSpec
from scipy.stats import pearsonr
import seaborn as sns
from tqdm import tqdm

import plotly.express as px
import plotly.graph_objects as go
from plotly.subplots import make_subplots

from lightning_pose.utils.frame_selection import get_frames_from_idxs

In [None]:
sns.set_style('white')
df_save_path = "/media/mattw/behavior/results/pose-estimation/mirror-mouse"
vid_dir = "/media/mattw/behavior/pose-estimation-data-final/mirror-mouse/videos_test"
dataset_name = "mirror-mouse"
train_frames = "75"

models_to_compare = ['baseline', 'semi-super context']

In [None]:
# df_labeled_preds = pd.read_parquet(
#     os.path.join(df_save_path, "%s_labeled_preds.pqt" % dataset_name))
# df_labeled_metrics = pd.read_parquet(
#     os.path.join(df_save_path, "%s_labeled_metrics.pqt" % dataset_name))
df_video_preds = pd.read_parquet(
    os.path.join(df_save_path, "%s_video_preds.pqt" % dataset_name))
df_video_metrics = pd.read_parquet(
    os.path.join(df_save_path, "%s_video_metrics.pqt" % dataset_name))

In [None]:
def get_trace_mask(df, video_name, train_frames, model_type, rng_seed, metric_name=None):
    mask = ((df.train_frames == train_frames)
            & (df.rng_seed_data_pt == rng_seed)
            & (df.model_type == model_type)
            & (df.video_name == video_name))
    if metric_name is not None:
        mask = mask & (df.metric == metric_name)
    return mask

In [None]:
rng_seed = "0"
if dataset_name == 'mirror-mouse':
    vid_name = '180609_004'
    keypoint = 'paw1LH_top'
    if train_frames == "75":
        # slice for traces
        time_window = (4700, 4950)
        # slice for frames
        time_window_frames = (4771, 4774)
        # different segment for fig 4
#         # slice for traces
#         slc = (4500, 4750)
#         # slice for frames
#         slc_frames = (4530, 4533)
#         keypoint = 'paw2LF_top'
    elif train_frames == "1":
        # slice for traces
        time_window = (4700, 4950)
        # slice for frames
        time_window_frames = (4778, 4782)

elif dataset_name == 'mirror-fish':
    vid_name = '20210129_Quin'
    keypoint = 'stripeP_main'
    if train_frames == "75":
        # slice for traces
        time_window = (0, 2500)
        # slice for frames
        time_window_frames = (641, 644)
    elif train_frames == "1":
        # slice for traces
        time_window = (0, 2500)
        # slice for frames
        time_window_frames = (221, 224)
        
elif dataset_name == 'fly':
    vid_name = '2022_01_14_fly2'
    keypoint = 'hind-bot'
    if train_frames == "75":
        # slice for traces
        time_window = (0, 400)
        # slice for frames
        time_window_frames = (115, 118)
    elif train_frames == "1":
        # slice for traces
        time_window = (0, 400)
        # slice for frames
        time_window_frames = (250, 253)

vid_file = os.path.join(vid_dir, '%s.mp4' % vid_name)

# Trace plots for video data

In [None]:
def plot_traces_and_metrics(
        df_video_metrics, df_video_preds, models_to_compare, keypoint, vid_name, rng_seed, 
        time_window, save_file=None):

    colors = px.colors.qualitative.Plotly

    rows = 3
    row_heights = [2, 2, 0.75]
    metrics = df_video_metrics.metric.unique()
    if "temporal_norm" in metrics:
        rows += 1
        row_heights.insert(0, 0.75)
    if "pca_multiview_error" in metrics:
        rows += 1
        row_heights.insert(0, 0.75)
    if "pca_singleview_error" in metrics:
        rows += 1
        row_heights.insert(0, 0.75)

    fig_traces = make_subplots(
        rows=rows, cols=1,
        shared_xaxes=True,
        x_title="Frame number",
        row_heights=row_heights,
        vertical_spacing=0.03,
    )

    yaxis_labels = {}
    row = 1

    # plot temporal norms
    if "temporal_norm" in metrics:
        for c, model_type in enumerate(models_to_compare):
            mask = get_trace_mask(
                df_video_metrics, video_name=vid_name, metric_name="temporal_norm",
                train_frames=train_frames, model_type=model_type, rng_seed=rng_seed)
            fig_traces.add_trace(
                go.Scatter(
                    name=model_type,
                    x=np.arange(time_window[0], time_window[1]),
                    y=df_video_metrics[mask][keypoint][slice(*time_window)],
                    mode='lines',
                    line=dict(color=colors[c]),
                    showlegend=False,
                ),
                row=row, col=1
            )
        yaxis_labels['yaxis%i' % row] = "temporal<br>norm"
        row += 1

    # plot pca multiview reprojection errors
    if "pca_multiview_error" in metrics:
        for c, model_type in enumerate(models_to_compare):
            mask = get_trace_mask(
                df_video_metrics, video_name=vid_name, metric_name="pca_multiview_error",
                train_frames=train_frames, model_type=model_type, rng_seed=rng_seed)
            fig_traces.add_trace(
                go.Scatter(
                    name=model_type,
                    x=np.arange(time_window[0], time_window[1]),
                    y=df_video_metrics[mask][keypoint][slice(*time_window)],
                    mode='lines',
                    line=dict(color=colors[c]),
                    showlegend=False,
                ),
                row=row, col=1
            )
        yaxis_labels['yaxis%i' % row] = "pca multi<br>error"
        row += 1

    # plot pca singleview reprojection errors
    if "pca_singleview_error" in metrics:
        for c, model_type in enumerate(models_to_compare):
            mask = get_trace_mask(
                df_video_metrics, video_name=vid_name, metric_name="pca_multiview_error",
                train_frames=train_frames, model_type=model_type, rng_seed=rng_seed)
            fig_traces.add_trace(
                go.Scatter(
                    name=model_type,
                    x=np.arange(time_window[0], time_window[1]),
                    y=df_video_metrics[mask][keypoint][slice(*time_window)],
                    mode='lines',
                    line=dict(color=colors[c]),
                    showlegend=False,
                ),
                row=row, col=1
            )
        yaxis_labels['yaxis%i' % row] = "pca single<br>error"
        row += 1

    # plot traces
    for coord in ["x", "y"]:
        for c, model_type in enumerate(models_to_compare):
            mask = get_trace_mask(
                df_video_preds, video_name=vid_name,
                train_frames=train_frames, model_type=model_type, rng_seed=rng_seed)
            fig_traces.add_trace(
                go.Scatter(
                    name=model_type,
                    x=np.arange(time_window[0], time_window[1]),
                    y=df_video_preds[mask].loc[:, (keypoint, coord)][slice(*time_window)],
                    mode='lines',
                    line=dict(color=colors[c]),
                    showlegend=False if coord == "x" else True,
                ),
                row=row, col=1
            )
        yaxis_labels['yaxis%i' % row] = "%s coordinate" % coord
        row += 1

    # plot likelihoods
    for c, model_type in enumerate(models_to_compare):
        fig_traces.add_trace(
            go.Scatter(
                name=model_type,
                x=np.arange(time_window[0], time_window[1]),
                y=df_video_preds[mask].loc[:, (keypoint, "likelihood")][slice(*time_window)],
                mode='lines',
                line=dict(color=colors[c]),
                showlegend=False,
            ),
            row=row, col=1
        )
    yaxis_labels['yaxis%i' % row] = "confidence"
    row += 1

    for k, v in yaxis_labels.items():
        fig_traces["layout"][k]["title"] = v
    fig_traces.update_layout(
        width=800, height=np.sum(row_heights) * 125,
        title_text="Timeseries of %s" % keypoint
    )

    if save_file is not None:
        os.makedirs(os.path.dirname(save_file), exists_ok=True)
        fig_traces.write_image(save_file)

    fig_traces.show()

In [None]:
    # if save_figs:
    #     fig_dir = os.path.join(base_fig_dir, 'fig3_semi-supervised')
    #     if not os.path.isdir(fig_dir):
    #         os.makedirs(fig_dir)
    #     fig_traces.write_image(os.path.join(
    #         fig_dir, 
    #         'traces_%s_%s_%i-%i_tf=%i.pdf' % (
    #             dataset_name, keypoint, slc[0], slc[1], train_frames)))
save_file = None
plot_traces_and_metrics(
    df_video_metrics=df_video_metrics, df_video_preds=df_video_preds, 
    models_to_compare=models_to_compare, keypoint=keypoint, vid_name=vid_name, 
    rng_seed=rng_seed, time_window=time_window, save_file=save_file)

### plot a series of frames with markers

In [None]:
cap = cv2.VideoCapture(vid_file)
colors = px.colors.qualitative.Plotly

for idx_time in np.arange(time_window_frames[0], time_window_frames[1] + 1):
    print(idx_time)
    frame = get_frames_from_idxs(cap, [idx_time])
    plt.figure(figsize=(4, 4))
    
    # plot frame
    plt.imshow(frame[0, 0], cmap='gray', vmin=0, vmax=255)
    
    # plot predictions
    mask = get_trace_mask(
        df_video_preds, video_name=vid_name,
        train_frames=train_frames, model_type=models_to_compare[0], rng_seed=rng_seed)
    tmp = df_video_preds[mask].iloc[idx_time][keypoint].to_numpy()
    plt.plot(tmp[0], tmp[1], '.', markersize=15, color=colors[0])

    mask = get_trace_mask(
        df_video_preds, video_name=vid_name,
        train_frames=train_frames, model_type=models_to_compare[1], rng_seed=rng_seed)
    tmp = df_video_preds[mask].iloc[idx_time][keypoint].to_numpy()
    plt.plot(tmp[0], tmp[1], '.', markersize=15, color=colors[1])
    
    plt.xticks([])
    plt.yticks([])
    plt.axis('off')
    
#     if save_figs:
#         fig_dir = os.path.join(base_fig_dir, 'fig3_semi-supervised')
#         if not os.path.isdir(fig_dir):
#             os.makedirs(fig_dir)
#         plt.savefig(os.path.join(
#             fig_dir,
#             'frames_%s_%s_%i_tf=%i.png' % (dataset_name, keypoint, idx_time, train_frames)),
#             bbox_inches='tight', pad_inches=0)
    plt.show()