In [9]:
from datasets.hand_pose_dataset import HandPoseDataset
from os.path import join, isdir
from os import listdir
import yaml
import pandas as pd

In [10]:
checkpoints_path = join("checkpoints", "ablation")
subject_ids = HandPoseDataset._get_subject_ids()
data = {}
metrics = []
for checkpoint_name in sorted(listdir(checkpoints_path)):
    # check if the checkpoint has complete experiments
    subjects_in_checkpoint = []
    for folder in listdir(join(checkpoints_path, checkpoint_name)):
        if folder in subject_ids and isdir(join(checkpoints_path, checkpoint_name, folder)) and "metrics.yaml" in listdir(join(checkpoints_path, checkpoint_name, folder)):
            subjects_in_checkpoint.append(folder)
    # skips empty checkpoints
    if len(subjects_in_checkpoint) <= 1:
        continue
    # loads the cfg used
    with open(join(checkpoints_path, checkpoint_name, "cfg.yaml")) as f:
        cfg = yaml.safe_load(f)
    # initialize the data dict for this checkpoint
    metrics_per_run = []
    # loops over each subject
    for subject_id in subjects_in_checkpoint:
        # load the metrics for this subject
        with open(
            join(checkpoints_path, checkpoint_name, subject_id, "metrics.yaml")
        ) as f:
            metrics_per_run.append(yaml.safe_load(f)[0])
    # compute the mean metrics for this checkpoint
    metrics_per_run = pd.DataFrame(metrics_per_run)
    metrics_per_run_mean, metrics_per_run_std = (
        metrics_per_run.mean(),
        metrics_per_run.std(),
    )
    # parses some numbers
    for col in metrics_per_run.columns:
        for df in [metrics_per_run_mean, metrics_per_run_std]:
            if col.startswith("cls_"):
                df[col] = (df[col] * 100).round(3)
            elif col.startswith("num_params"):
                df[col] = (df[col] / 1e3).astype(int)
            elif col.startswith("time_test"):
                df[col] = (df[col] * 1e3).round(3)
    metrics_per_run = (
        metrics_per_run_mean.astype(str) + " ± " + metrics_per_run_std.astype(str)
    )
    metrics_per_run["image_backbone"] = cfg["image_backbone_name"]
    metrics_per_run["landmarks_backbone"] = cfg["landmarks_backbone_name"]
    # append to the metrics list
    metrics.append(metrics_per_run)

# parses the metrics into a DataFrame
metrics = pd.DataFrame(metrics)
metrics = metrics[
    [
        "landmarks_backbone",
        "image_backbone",
        "cls_acc_test",
        "cls_f1_test",
        "cls_prec_test",
        "cls_rec_test",
        "cls_loss_test",
        "num_params_test",
        "time_test",
    ]
]
metrics = metrics.rename(
    columns={
        "image_backbone": "Image backbone",
        "landmarks_backbone": "Landmarks backbone",
        "cls_acc_test": "Accuracy",
        "cls_f1_test": "F1",
        "cls_prec_test": "Precision",
        "cls_rec_test": "Recall",
        "cls_loss_test": "Loss",
        "num_params_test": "# Params (k)",
        "time_test": "Time (ms)",
    }
)
metrics = metrics.sort_values(by=["Landmarks backbone", "Image backbone"])
metrics = metrics.reset_index(drop=True)
print(metrics.to_latex())

\begin{tabular}{llllllllll}
\toprule
 & Landmarks backbone & Image backbone & Accuracy & F1 & Precision & Recall & Loss & # Params (k) & Time (ms) \\
\midrule
0 & linear & clip-b & 77.049 ± 7.865 & 75.661 ± 8.868 & 77.294 ± 8.584 & 77.049 ± 7.865 & 86.194 ± 35.94 & 91012.0 ± 0.0 & 2.509 ± 0.025 \\
1 & linear & convnextv2-b & 84.915 ± 6.148 & 84.24 ± 6.832 & 85.03 ± 6.775 & 84.915 ± 6.148 & 48.733 ± 21.113 & 91773.0 ± 0.0 & 4.659 ± 0.045 \\
2 & linear & convnextv2-t & 84.153 ± 5.037 & 83.522 ± 5.724 & 84.44 ± 5.791 & 84.153 ± 5.037 & 50.039 ± 19.381 & 31422.0 ± 0.0 & 2.999 ± 0.038 \\
3 & linear & dinov2-b & 82.302 ± 5.348 & 81.629 ± 5.766 & 83.216 ± 5.905 & 82.302 ± 5.348 & 60.533 ± 26.573 & 90136.0 ± 0.0 & 2.705 ± 0.017 \\
4 & linear & dinov2-s & 80.444 ± 6.022 & 78.871 ± 6.828 & 80.001 ± 6.891 & 80.444 ± 6.022 & 76.027 ± 34.229 & 24826.0 ± 0.0 & 2.742 ± 0.026 \\
5 & linear & resnet18 & 76.335 ± 8.5 & 74.99 ± 9.236 & 76.496 ± 8.772 & 76.335 ± 8.5 & 86.833 ± 39.585 & 14208.0 ± 0.0 & 1.3