In [1]:
import os
import pickle
from tqdm import tqdm

In [2]:
def load_data(path: str, keep_output_type: str | None = None, verbose: bool = False, errors: str = 'raise') -> dict[str, list]:
    """
    Load data from a file or directory of files

    Parameters
    ----------
    path : str
        The path to the file or directory of files
    keep_output_type : str, optional
        The type of the output, by default None (all types are kept)
    verbose : bool, optional
        Whether to show a progress bar, by default False

    Returns
    -------
    dict
        A dictionary of the data
    """
    equivalent_keys = {
        'trajectories': ['traj'],
        'videos': ['render'],
    }

    if os.path.isfile(path):
        if verbose:
            print('Loading data from file...')
        with open(path, 'rb') as file:
            data = pickle.load(file)
    else:
        data = {}
        pbar = tqdm(desc='Loading data for directory', disable=not verbose, total=sum([len(files) for _, _, files in os.walk(path)]))
        for root, _, files in os.walk(path):
            for file in sorted(files):
                pbar.set_postfix(file=file)
                # Read the file
                with open(os.path.join(root, file), 'rb') as f:
                    file_data = pickle.load(f)
    
                    # Rename the keys to their canonical names
                    for key, equivalent in equivalent_keys.items():
                        for e in equivalent:
                            if e in file_data:
                                file_data[key] = file_data.pop(e)

                    # Add the data to the dictionary
                    for key, value in file_data.items():
                        if key not in data:
                            data[key] = []
                        data[key].extend(value)
                pbar.update(1)

    # Keep only the specified output type (given that it is a key of equivalent_keys)
    if keep_output_type is not None and keep_output_type in equivalent_keys:
        for key in equivalent_keys.keys():

            # Remove every key that is not specified
            if key != keep_output_type and key in data:
                data.pop(key)

    # Check if all values have equal length
    value_lengths = [len(v) for v in data.values()]
    if len(set(value_lengths)) != 1:
        if errors == 'raise':
            raise ValueError(f'All values of the key "{key}" must have the same length')
        elif errors in ['print', 'warn']:
            print(f'Warning: All values of the key "{key}" must have the same length')
        elif errors == 'ignore':
            pass

    return data

In [3]:
from bcnf.utils import get_dir

In [4]:
data = load_data(get_dir('data', 'bcnf-data', 'fixed_data_render_2s_15FPS'), keep_output_type='trajectories', verbose=True)

Loading data for directory: 100%|██████████| 5/5 [00:22<00:00,  4.40s/it, file=fixed_data_render_2s_15FPS_5.pkl]
