-
Notifications
You must be signed in to change notification settings - Fork 21.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add typing support to ModuleList and ModuleDict #80821
Comments
ModuleListValue would only be useful for homogenous module lists, but it's not clear to me how often you are actually going to have a homogenous list. Maybe we need ModuleTuple lol |
In the C++ API, this would be nice to have: if you know all the modules in the list will be of the same type, a typed Of course this is only relevant to the C++ API which can compile-time check the types of the modules in the list, so maybe I should open a new issue. |
I'd argue that generic ModuleList and ModuleDict are fairly useful. At least for me personally, I often have a parent module containing multiple submodules of the same kind. It is very annoying that typing them as class Parent(nn.Module):
children: Mapping[str, SubModule] # but really is a ModuleDict |
@ssnl Maybe as a better workaround, you can use |
An alternative could be to have some kind of Here is a very schematic implementation: from collections.abc import Mapping
from typing import Generic, Type, TypeVar, overload, reveal_type
Tco = TypeVar("Tco", bound=Module, covariant=True)
T = TypeVar("T")
class Module:
pass
class TypedKey(Generic[T, Tco]):
def __init__(self, type_: Type[Tco], name: T):
self.type_ = type_
self.name = name
class ModuleDict:
def __init__(self, data: Mapping[str | TypedKey[str, Module], Module]):
self.data = data
@overload
def __getitem__(self, key: str) -> Module: ...
@overload
def __getitem__(self, key: TypedKey[str, Tco]) -> Tco: ...
def __getitem__(self, key: str | TypedKey[str, Tco]) -> Module | Tco:
return self.data[key]
class Mod1(Module):
pass
class Mod2(Module):
pass
if __name__ == "__main__":
d = ModuleDict({"mod1": Mod1(), "mod2": Mod2()})
reveal_type(d["mod1"]) # Revealed type is Module
kmod1 = TypedKey(Mod1, "mod1")
kmod2 = TypedKey(Mod2, "mod2")
d2 = ModuleDict({kmod1: Mod1(), kmod2: Mod2()})
reveal_type(d2[kmod1]) # Revealed type is Mod1
reveal_type(d2[kmod2]) # Revealed type is Mod2 The advantage is that |
That is exactly what I need. Please consider to integrate it. |
馃殌 The feature, motivation and pitch
Currently, the containers
nn.ModuleList
andnn.ModuleDict
are typing-unaware, i.e. given this:self.my_modules[i]
is treated asnn.Module
, not asnn.Linear
. For example, VSCode complains about snippets likefunction_that_expects_tensor(self.my_modules[i].weight
), because it thinks that.weight
can be bothTensor
andnn.Module
.What I propose:
For
nn.DictModule
, it is a bit more complicated, since there are two different patterns:dict
-like: the set of keys is not fixed, all values are modules of the same type (e.g.nn.Linear
)TypedDict
-like: the set of keys is fixed, the values can be of different types (e.g.{'linear': nn.Linear(...), 'relu': nn.ReLU}
).(A) can be implemented similarly to the previous example:
In fact, this can cover (B) as well in a very limited way (by setting
ModuleDictValue=nn.Module
). And it is unclear to me how to implement the fully functioning (B), it looks like we need something likeTypedMutableMapping
, but there is no such thing in typing. So I would start withMutableMapping
and add TypedModuleDict when it becomes technically possible.Alternatives
No response
Additional context
No response
cc @ezyang @malfet @rgommers @xuzhao9 @gramster
The text was updated successfully, but these errors were encountered: