In [1]:
import h5py
import pandas as pd

In [8]:
h5_path = "/path/to/dreem-metrics.h5"

In [3]:
def extract_motmetrics(hdf5_path):
    """Extracts the MOT metrics from the hdf5 file.
    Returns a dataframe with the metrics.
    """
    with h5py.File(hdf5_path, "r") as results_file:
        for vid_name in results_file.keys():
            vid_group = results_file[vid_name]
            # Load MOT summary
            if "mot_summary" in vid_group:
                mot_summary_keys = list(vid_group["mot_summary"].attrs)
                mot_summary_values = [
                    vid_group["mot_summary"].attrs[key] for key in mot_summary_keys
                ]
                df_motmetrics = pd.DataFrame(
                    list(zip(mot_summary_keys, mot_summary_values)),
                    columns=["metric", "value"],
                )

    return df_motmetrics


def extract_gta(hdf5_path):
    """Extracts the global tracking accuracy from the hdf5 file.
    Returns a dataframe with the metrics.
    """
    with h5py.File(hdf5_path, "r") as results_file:
        for vid_name in results_file.keys():
            vid_group = results_file[vid_name]
            # Load global tracking accuracy if available
            if "global_tracking_accuracy" in vid_group:
                gta_keys = list(vid_group["global_tracking_accuracy"].attrs)
                gta_values = [
                    vid_group["global_tracking_accuracy"].attrs[key] for key in gta_keys
                ]
                df_gta = pd.DataFrame(
                    list(zip(gta_keys, gta_values)), columns=["metric", "value"]
                )

    return df_gta


def extract_switch_frame_crops(hdf5_path):
    """Extracts the crops of the frames with switches.
    Returns a dictionary with the frame id as the key and a list of crops of each instance in the frame as the value.
    """
    with h5py.File(hdf5_path, "r") as results_file:
        # Iterate through all video groups
        for vid_name in results_file.keys():
            vid_group = results_file[vid_name]
            frame_crop_dict = {}
            for key in vid_group.keys():
                if key.startswith("frame_"):
                    frame = vid_group[key]
                    frame_id = frame.attrs["frame_id"]
                    frame_crop_dict[frame_id] = []
                    for key in frame.keys():
                        if key.startswith("instance_"):
                            instance = frame[key]
                            if "crop" in instance.keys():
                                frame_crop_dict[frame_id].append(
                                    instance["crop"][:].squeeze().transpose()
                                )

    return frame_crop_dict

In [None]:
motmetrics = extract_motmetrics(hdf5_path)
gta = extract_gta(hdf5_path)
switch_frame_crops = extract_switch_frame_crops(hdf5_path)