Skip to content

Feature request: ModuleDict, like ModuleList #4048

@rkaplan

Description

@rkaplan

Self-explanatory. Currently there is no easy way to maintain a dictionary as an attribute of a Module whose values are themselves Modules that need to be registered. Introducing a ModuleDict class that functions the same as a ModuleList, except it exposes a dictionary interface, would resolve this. This has the following practical benefits for users:

  1. Checkpointing is much better with a ModuleDict that has strings as keys and Modules as values than a ModuleList. I have personally run into problems when I change the order of modules in my ModuleList and try to load old checkpoints, because the checkpoint key names are order-dependent with ModuleLists, so it breaks my code.

  2. It combines the advantages of a ModuleList (dynamic assignment of Modules to a class) with the advantages of using named attributes (documentation / readability / sane module names in checkpoints).

Another pro is that this is very simple to implement. Here is a first pass, if there is support for this feature I will clean it up and submit this as a PR:

class ModuleDict(Module):
    r"""Holds submodules in a dict.
    ModuleDict can be indexed like a regular Python dict, but modules it
    contains are properly registered, and will be visible by all Module methods.
    Arguments:
        modules (dict, optional): a list of modules to add
    Example::
        TODO.
    """

    def __init__(self, modules=None):
        super(ModuleDict, self).__init__()
        if modules is not None:
            self.update(modules)

    def __getitem__(self, key):
        return self._modules[key]

    def __setitem__(self, key, module):
        return setattr(self, key, module)

    def __len__(self):
        return len(self._modules)

    def __iter__(self):
        return iter(self._modules)

    def keys(self):
        return self._modules.keys()

    def items(self):
        return self._modules.items()

    def values(self):
        return self._modules.values()

    def get(self, key, default=None):
        if default is not None:
            return self._modules.get(key, default)
        return self._modules.get(key)

    def update(self, modules):
        r"""Updates modules from a Python dict.
        Arguments:
            modules (dict): dict of modules to append
        """
        if not isinstance(modules, dict):
            raise TypeError("ModuleDict.update should be called with a "
                            "dict, but got " + type(modules).__name__)
        for key, module in modules.items():
            self.add_module(key, module)
        return self

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions