In [None]:
import os
import glob
from typing import Dict

import numpy as np
import pandas as pd
import matplotlib as mpl
import matplotlib.pyplot as plt
import matplotlib.cm as cm
import cycler

from py_utils import data_ops
from lib import misc

plt.rcParams['axes.grid'] = True

In [None]:
# First load all labelled frames into a pandas dataframe.
gt_root_dir = "/Users/zico/msc/data/gt_labels"
labelled_files = sorted(glob.glob(os.path.join(gt_root_dir, "**/CollectedData_UCT.h5"), recursive=True))
vid = None
for curr_file in labelled_files:
    vid = curr_file.split(gt_root_dir + "/")[1]
    vid = vid.split("/")[0]
    # Assuming that cam1 indicates the start of a new video sequence (the videos are sorted above).
    cam_idx = int(vid[-1])
    if cam_idx == 1:
        out_dir = curr_file.split("cam1/CollectedData_UCT.h5")[0]
        os.makedirs(out_dir, exist_ok=True)
    try:
        curr_df = pd.read_hdf(curr_file)
    except ValueError:
        continue    
    start_frame = int(curr_df.index[0].split("img")[1][0:3])
    frame_index = range(start_frame, start_frame + len(curr_df.index))
    curr_df.index = frame_index
    curr_df.to_hdf(os.path.join(out_dir, f"cam{cam_idx}.h5"), "df_with_missing", format="table", mode="w")

In [3]:
def load_dlc_points_as_df(df_fpaths):
    dfs = []
    cam_indices = []
    for path in df_fpaths:
        vid = path.split(".h5")[0]
        cam_indices.append(int(vid[-1]) - 1)
        dlc_df = pd.read_hdf(path)
        dlc_df = dlc_df.droplevel([0], axis=1).swaplevel(0,1,axis=1).T.unstack().T.reset_index().rename({'level_0':'frame'}, axis=1)
        dlc_df.columns.name = ''
        dfs.append(dlc_df)
    #create new dataframe
    dlc_df = pd.DataFrame(columns=['frame', 'camera', 'marker', 'x', 'y'])
    for i, df in enumerate(dfs):
        df['camera'] = cam_indices[i]
        df.rename(columns={'bodyparts':'marker'}, inplace=True)
        dlc_df = pd.concat([dlc_df, df], sort=True, ignore_index=True)

    dlc_df = dlc_df[['frame', 'camera', 'marker', 'x', 'y']]
    return dlc_df

In [4]:
dlc_dirs = ("/Users/zico/msc/data/gt_labels/2019_03_09LilyFlick", "/Users/zico/msc/data/gt_labels/2019_03_09JulesFlick2", "/Users/zico/msc/data/gt_labels/2017_12_16PhantomFlick2_1", "/Users/zico/msc/data/gt_labels/2017_09_03ZorroFlick1_1")
for dlc_file in dlc_dirs:
    dlc_fpaths = sorted(glob.glob(os.path.join(dlc_file, "*.h5")))
    res_df = load_dlc_points_as_df(dlc_fpaths)
    df = pd.DataFrame(res_df)
    ret_name = dlc_file.split("/")[-1]
    print(f"Saving...{ret_name}")
    df.to_csv(os.path.join(dlc_file, f"{ret_name}.csv"))

Saving...2019_03_09LilyFlick
Saving...2019_03_09JulesFlick2
Saving...2017_12_16PhantomFlick2_1
Saving...2017_09_03ZorroFlick1_1


In [None]:
def plot_reprojection_error(data: Dict, results_dir: str, show_plot=False) -> None:
    start_frame = data["start_frame"]
    meas_err = data["meas_err"]
    meas_weight = data["meas_weight"]

    x_axis_range = range(start_frame, start_frame + len(meas_weight))

    num_cams = meas_err.shape[1]
    meas_weight = np.expand_dims(meas_weight, 3) if len(meas_weight.shape) > 2 else np.expand_dims(meas_weight, 2)
    weighted_meas_err = meas_weight * meas_err
    xy_filtered_meas_err = np.array([np.mean(weighted_meas_err[:, cam_idx], axis=2) for cam_idx in range(num_cams)])

    markers = misc.get_markers()
    marker_colors = cm.jet(np.linspace(0, 1, len(markers)))
    mpl.rcParams['axes.prop_cycle'] = cycler.cycler('color', marker_colors)


    fig = plt.figure(figsize=(16, 12), dpi=120)
    fig.suptitle("Reprojection Error (After Filtering and Scaling)", fontsize=14)
    base_subplot_value = 320
    for idx in range(num_cams):
        base_subplot_value += 1
        plt.subplot(base_subplot_value)
        plt.title(
            f"CAM {idx+1} (\u03BC: {np.mean(xy_filtered_meas_err[idx, :, :]):.2f}, \u03C3: {np.std(xy_filtered_meas_err[idx, :, :]):.2f})"
        )
        plotted_values = plt.plot(x_axis_range, xy_filtered_meas_err[idx, :, :], marker="o", markersize=2)

    # Set common labels
    fig.legend(plotted_values, markers, loc=(0.91, 0.4))
    fig.text(0.5, 0.04, "Frame Number", ha='center', va='center')
    fig.text(0.06, 0.5, "Error [pixels]", ha='center', va='center', rotation='vertical')

    if show_plot:
        plt.show()
    else:
        plt.savefig(os.path.join(results_dir, "fte_meas_error_filtered.png"))
        plt.close()