Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support empty batches for arbitrary dataset structures #534

Open
ffuuugor opened this issue Oct 31, 2022 · 0 comments
Open

Support empty batches for arbitrary dataset structures #534

ffuuugor opened this issue Oct 31, 2022 · 0 comments

Comments

@ffuuugor
Copy link
Contributor

For context see discussion in #530 (and thanks @joserapa98 for pointing out the issue)

At the moment (to be precise, after #530 will have been merged) Opacus can support empty batches only for datasets with a simple structure - every record should be a tuple of a simple type: either tensor or a primitive type.

For instance, datasets with records like this (Tensor, int) or this (Tensor, Tensor) are supported. However datasets like this (Tensor, (int, int)) are not.

Pytorch adresses similar problem with the following piece of code:

if isinstance(elem, collections.abc.Mapping):
    try:
        return elem_type({key: collate([d[key] for d in batch], collate_fn_map=collate_fn_map) for key in elem})
    except TypeError:
        # The mapping type may not support `__init__(iterable)`.
        return {key: collate([d[key] for d in batch], collate_fn_map=collate_fn_map) for key in elem}
elif isinstance(elem, tuple) and hasattr(elem, '_fields'):  # namedtuple
    return elem_type(*(collate(samples, collate_fn_map=collate_fn_map) for samples in zip(*batch)))
elif isinstance(elem, collections.abc.Sequence):
    # check to make sure that the elements in batch have consistent size
    it = iter(batch)
    elem_size = len(next(it))
    if not all(len(elem) == elem_size for elem in it):
        raise RuntimeError('each element in list of batch should be of equal size')
    transposed = list(zip(*batch))  # It may be accessed twice, so we use a list.


    if isinstance(elem, tuple):
        return [collate(samples, collate_fn_map=collate_fn_map) for samples in transposed]  # Backwards compatibility.
    else:
        try:
            return elem_type([collate(samples, collate_fn_map=collate_fn_map) for samples in transposed])
        except TypeError:
            # The sequence type may not support `__init__(iterable)` (e.g., `range`).
            return [collate(samples, collate_fn_map=collate_fn_map) for samples in transposed]

We need to adapt it to our needs and make sure DPDataLoader can handle datasets of arbitrary structure.

Relevant code pointer:

def wrap_collate_with_empty(

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant