-
Notifications
You must be signed in to change notification settings - Fork 25.2k
Description
Self-explanatory. Currently there is no easy way to maintain a dictionary as an attribute of a Module whose values are themselves Modules that need to be registered. Introducing a ModuleDict class that functions the same as a ModuleList, except it exposes a dictionary interface, would resolve this. This has the following practical benefits for users:
-
Checkpointing is much better with a ModuleDict that has strings as keys and Modules as values than a ModuleList. I have personally run into problems when I change the order of modules in my ModuleList and try to load old checkpoints, because the checkpoint key names are order-dependent with ModuleLists, so it breaks my code.
-
It combines the advantages of a ModuleList (dynamic assignment of Modules to a class) with the advantages of using named attributes (documentation / readability / sane module names in checkpoints).
Another pro is that this is very simple to implement. Here is a first pass, if there is support for this feature I will clean it up and submit this as a PR:
class ModuleDict(Module):
r"""Holds submodules in a dict.
ModuleDict can be indexed like a regular Python dict, but modules it
contains are properly registered, and will be visible by all Module methods.
Arguments:
modules (dict, optional): a list of modules to add
Example::
TODO.
"""
def __init__(self, modules=None):
super(ModuleDict, self).__init__()
if modules is not None:
self.update(modules)
def __getitem__(self, key):
return self._modules[key]
def __setitem__(self, key, module):
return setattr(self, key, module)
def __len__(self):
return len(self._modules)
def __iter__(self):
return iter(self._modules)
def keys(self):
return self._modules.keys()
def items(self):
return self._modules.items()
def values(self):
return self._modules.values()
def get(self, key, default=None):
if default is not None:
return self._modules.get(key, default)
return self._modules.get(key)
def update(self, modules):
r"""Updates modules from a Python dict.
Arguments:
modules (dict): dict of modules to append
"""
if not isinstance(modules, dict):
raise TypeError("ModuleDict.update should be called with a "
"dict, but got " + type(modules).__name__)
for key, module in modules.items():
self.add_module(key, module)
return self