In [None]:
import cv2
import numpy as np
import os
import pandas as pd
import matplotlib.pyplot as plt
from matplotlib.gridspec import SubplotSpec
from scipy.stats import pearsonr
import seaborn as sns
from tqdm import tqdm

import plotly.express as px
import plotly.graph_objects as go
from plotly.subplots import make_subplots

from diagnostics.paper_utils import get_trace_mask, plot_traces_and_metrics
from lightning_pose.utils.frame_selection import get_frames_from_idxs

In [None]:
sns.set_style('white')

dataset_name = "mirror-fish"
df_save_path = "/home/mattw/Dropbox/shared/litpose_results/%s" % dataset_name
vid_dir = "/media/mattw/behavior/pose-estimation-data-final/%s/videos_new" % dataset_name
train_frames = "75"

models_to_compare = ['baseline', 'semi-super context']

In [None]:
# df_labeled_preds = pd.read_parquet(
#     os.path.join(df_save_path, "%s_labeled_preds.pqt" % dataset_name))
# df_labeled_metrics = pd.read_parquet(
#     os.path.join(df_save_path, "%s_labeled_metrics.pqt" % dataset_name))
df_video_preds = pd.read_parquet(
    os.path.join(df_save_path, "%s_video_preds.pqt" % dataset_name))
df_video_metrics = pd.read_parquet(
    os.path.join(df_save_path, "%s_video_metrics.pqt" % dataset_name))

In [None]:
rng_seed = "0"
if dataset_name == 'mirror-mouse':
#     vid_name = '180609_004'
#     vid_name_load = '180609_004'
#     keypoint = 'paw1LH_top'
#     time_window = (300, 900)
#     time_window_frames = (4771, 4774)

    vid_name = '180609_004'
    vid_name_load = '180609_004'
    keypoint = 'paw2LF_top'
    time_window = (2800, 3600)
    time_window_frames = (4771, 4774)

elif dataset_name == 'mirror-fish':
    vid_name = '20210129_Quin'
    vid_name_load = vid_name
    keypoint = 'fork_main'
    if train_frames == "75":
        time_window = (0, 1000)
        time_window_frames = (641, 644)
    elif train_frames == "1":
        time_window = (0, 1000)
        time_window_frames = (221, 224)
        
#     vid_name = '20210202_Sean'
#     vid_name_load = vid_name
#     keypoint = 'dorsal_main'
# #     keypoint = 'caudal_v_main'
# #     keypoint = 'caudal_d_main'
#     if train_frames == "75":
#         time_window = (500, 3000)
#         time_window_frames = (641, 644)
        
elif dataset_name == 'fly':
#     vid_name = '2022_01_05_fly2_sample-1'
#     vid_name_load = '2022_01_05_fly2'
#     keypoint = 'mid-top'
#     if train_frames == "75":
#         time_window = (50, 250)
#         time_window_frames = (115, 118)
#     elif train_frames == "1":
#         time_window = (0, 400)
#         time_window_frames = (250, 253)

#     vid_name = '2022_01_14_fly1_sample-2'
#     vid_name_load = '2022_01_14_fly1'
#     keypoint = 'mid-bot'
#     time_window = (0, 700)

    vid_name = '2022_01_05_fly2_sample-1'
    vid_name_load = '2022_01_05_fly2'
    keypoint = 'mid-top'
    time_window = (100, 400)


assert vid_name in df_video_metrics.video_name.unique()
vid_file = os.path.join(vid_dir, '%s.mp4' % vid_name_load)

# Trace plots for video data

In [None]:
df_video_preds.columns.levels[0]

In [None]:
df_video_metrics.video_name.unique()

In [None]:
# save_file = None
# for vid_name in df_video_metrics.video_name.unique():
#     print(vid_name)
#     for keypoint in df_video_preds.columns.levels[0]:
# #         if keypoint in ["model_path", "model_type", "rng_seed_data_pt", "train_frames", "video_name"]:
# #             continue
#         if keypoint != 'paw2LF_top':
#             continue
#         else:
#             time_window = (0, 5000)
#             plot_traces_and_metrics(
#                 df_video_metrics=df_video_metrics, df_video_preds=df_video_preds, 
#                 models_to_compare=models_to_compare, keypoint=keypoint, vid_name=vid_name, 
#                 train_frames='75', rng_seed=rng_seed, time_window=time_window, 
#                 save_file=save_file)
#             plot_traces_and_metrics(
#                 df_video_metrics=df_video_metrics, df_video_preds=df_video_preds, 
#                 models_to_compare=models_to_compare, keypoint=keypoint, vid_name=vid_name, 
#                 train_frames='1', rng_seed=rng_seed, time_window=time_window, 
#                 save_file=save_file)
#             print("\n\n\n\n\n")

In [None]:
    # if save_figs:
    #     fig_dir = os.path.join(base_fig_dir, 'fig3_semi-supervised')
    #     if not os.path.isdir(fig_dir):
    #         os.makedirs(fig_dir)
    #     fig_traces.write_image(os.path.join(
    #         fig_dir, 
    #         'traces_%s_%s_%i-%i_tf=%i.pdf' % (
    #             dataset_name, keypoint, slc[0], slc[1], train_frames)))
save_file = None
# vid_name = '20210202_Sean'
# train_frames = '1'
# time_window = (500, 3000)
# keypoint = 'dorsal_main'
train_frames = '75'
plot_traces_and_metrics(
    df_video_metrics=df_video_metrics, df_video_preds=df_video_preds, 
    models_to_compare=models_to_compare, keypoint=keypoint, vid_name=vid_name, 
    train_frames=train_frames, rng_seed=rng_seed, time_window=time_window, 
    save_file=save_file)

### plot a series of frames with markers

In [None]:
cap = cv2.VideoCapture(vid_file)
colors = px.colors.qualitative.Plotly

for idx_time in np.arange(time_window_frames[0], time_window_frames[1] + 1):
    print(idx_time)
    frame = get_frames_from_idxs(cap, [idx_time])
    plt.figure(figsize=(4, 4))
    
    # plot frame
    plt.imshow(frame[0, 0], cmap='gray', vmin=0, vmax=255)
    
    # plot predictions
    mask = get_trace_mask(
        df_video_preds, video_name=vid_name,
        train_frames=train_frames, model_type=models_to_compare[0], rng_seed=rng_seed)
    tmp = df_video_preds[mask].iloc[idx_time][keypoint].to_numpy()
    plt.plot(tmp[0], tmp[1], '.', markersize=15, color=colors[0])

    mask = get_trace_mask(
        df_video_preds, video_name=vid_name,
        train_frames=train_frames, model_type=models_to_compare[1], rng_seed=rng_seed)
    tmp = df_video_preds[mask].iloc[idx_time][keypoint].to_numpy()
    plt.plot(tmp[0], tmp[1], '.', markersize=15, color=colors[1])
    
    plt.xticks([])
    plt.yticks([])
    plt.axis('off')
    
#     if save_figs:
#         fig_dir = os.path.join(base_fig_dir, 'fig3_semi-supervised')
#         if not os.path.isdir(fig_dir):
#             os.makedirs(fig_dir)
#         plt.savefig(os.path.join(
#             fig_dir,
#             'frames_%s_%s_%i_tf=%i.png' % (dataset_name, keypoint, idx_time, train_frames)),
#             bbox_inches='tight', pad_inches=0)
    plt.show()