In [110]:
from datasets.hand_pose_dataset import HandPoseDataset
from os.path import join, isdir
from os import listdir
import yaml
import pandas as pd

In [111]:
checkpoints_path = join("checkpoints", "ablation")
subject_ids = HandPoseDataset._get_subject_ids()
data = {}
metrics = []
for checkpoint_name in sorted(listdir(checkpoints_path)):
    # check if the checkpoint has complete experiments
    subjects_in_checkpoint = []
    for folder in listdir(join(checkpoints_path, checkpoint_name)):
        if folder in subject_ids and isdir(join(checkpoints_path, checkpoint_name, folder)) and "metrics.yaml" in listdir(join(checkpoints_path, checkpoint_name, folder)):
            subjects_in_checkpoint.append(folder)
    # skips empty checkpoints
    if len(subjects_in_checkpoint) <= 1:
        continue
    # loads the cfg used
    with open(join(checkpoints_path, checkpoint_name, "cfg.yaml")) as f:
        cfg = yaml.safe_load(f)
    # initialize the data dict for this checkpoint
    metrics_per_run = []
    # loops over each subject
    for subject_id in subjects_in_checkpoint:
        # load the metrics for this subject
        with open(
            join(checkpoints_path, checkpoint_name, subject_id, "metrics.yaml")
        ) as f:
            metrics_per_run.append(yaml.safe_load(f)[0])
    # compute the mean metrics for this checkpoint
    metrics_per_run = pd.DataFrame(metrics_per_run)
    metrics_per_run_mean, metrics_per_run_std = (
        metrics_per_run.mean(),
        metrics_per_run.std(),
    )
    # parses some numbers
    for col in metrics_per_run.columns:
        for df in [metrics_per_run_mean, metrics_per_run_std]:
            if col.startswith("cls_"):
                df[col] = (df[col] * 100).round(3)
            elif col.startswith("num_params"):
                df[col] = (df[col] / 1e3).astype(int)
            elif col.startswith("time_test"):
                df[col] = (df[col] * 1e3).round(3)
    metrics_per_run = (
        metrics_per_run_mean.astype(str) + " ± " + metrics_per_run_std.astype(str)
    )
    metrics_per_run["image_backbone"] = cfg["image_backbone_name"]
    metrics_per_run["landmarks_backbone"] = cfg["landmarks_backbone_name"]
    # append to the metrics list
    metrics.append(metrics_per_run)

# parses the metrics into a DataFrame
metrics = pd.DataFrame(metrics)
metrics = metrics[
    [
        "landmarks_backbone",
        "image_backbone",
        "cls_acc_test",
        "cls_f1_test",
        "cls_prec_test",
        "cls_rec_test",
        "cls_loss_test",
        "num_params_test",
        "time_test",
    ]
]
metrics = metrics.rename(
    columns={
        "image_backbone": "Image backbone",
        "landmarks_backbone": "Landmarks backbone",
        "cls_acc_test": "Accuracy",
        "cls_f1_test": "F1",
        "cls_prec_test": "Precision",
        "cls_rec_test": "Recall",
        "cls_loss_test": "Loss",
        "num_params_test": "# Params (k)",
        "time_test": "Time (ms)",
    }
)
metrics = metrics.sort_values(by=["Landmarks backbone", "Image backbone"])
metrics = metrics.reset_index(drop=True)
print(metrics.to_latex())

\begin{tabular}{llllllllll}
\toprule
 & Landmarks backbone & Image backbone & Accuracy & F1 & Precision & Recall & Loss & # Params (k) & Time (ms) \\
\midrule
0 & linear & clip-b & 77.049 ± 7.865 & 75.661 ± 8.868 & 77.294 ± 8.584 & 77.049 ± 7.865 & 86.194 ± 35.94 & 91012.0 ± 0.0 & 2.509 ± 0.025 \\
1 & linear & convnextv2-t & 84.153 ± 5.037 & 83.522 ± 5.724 & 84.44 ± 5.791 & 84.153 ± 5.037 & 50.039 ± 19.381 & 31422.0 ± 0.0 & 2.999 ± 0.038 \\
2 & linear & NaN & 76.359 ± 8.393 & 75.039 ± 8.935 & 77.125 ± 8.894 & 76.359 ± 8.393 & 82.447 ± 33.256 & 1982.0 ± 0.0 & 0.307 ± 0.004 \\
3 & mlp & convnextv2-t & 85.544 ± 6.34 & 84.666 ± 7.154 & 85.841 ± 6.855 & 85.544 ± 6.34 & 51.301 ± 30.109 & 33616.0 ± 0.0 & 3.036 ± 0.024 \\
4 & mlp & resnet18 & 79.085 ± 8.776 & 77.505 ± 9.746 & 79.18 ± 9.029 & 79.085 ± 8.776 & 84.801 ± 45.585 & 16402.0 ± 0.0 & 1.334 ± 0.004 \\
5 & mlp & NaN & 78.835 ± 9.721 & 77.677 ± 10.06 & 80.411 ± 8.823 & 78.835 ± 9.721 & 84.346 ± 41.366 & 4176.0 ± 0.0 & 0.357 ± 0.004 \\
6 &