Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add typing support to ModuleList and ModuleDict #80821

Open
Yura52 opened this issue Jul 3, 2022 · 6 comments
Open

Add typing support to ModuleList and ModuleDict #80821

Yura52 opened this issue Jul 3, 2022 · 6 comments
Labels
module: typing Related to mypy type annotations triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@Yura52
Copy link
Contributor

Yura52 commented Jul 3, 2022

馃殌 The feature, motivation and pitch

Currently, the containers nn.ModuleList and nn.ModuleDict are typing-unaware, i.e. given this:

class A(nn.Module):
    def __init__(self):
        self.my_modules = nn.ModuleList([nn.Linear(1, 1) for _ in range(10)])

self.my_modules[i] is treated as nn.Module, not as nn.Linear. For example, VSCode complains about snippets like function_that_expects_tensor(self.my_modules[i].weight), because it thinks that .weight can be both Tensor and nn.Module.

What I propose:

from collections.abc import MutableSequence
from typing import TypeVar

ModuleListValue = TypeVar('ModuleListValue', bound=nn.Module)

class ModuleList(Module, MutableSequence[ModuleListValue]):
    # now, some methods can be typed, e.g.:    
    def __getitem__(...) -> ModuleListValue:
        ...
    ...

For nn.DictModule, it is a bit more complicated, since there are two different patterns:

  • (A) dict-like: the set of keys is not fixed, all values are modules of the same type (e.g. nn.Linear)
  • (B) TypedDict-like: the set of keys is fixed, the values can be of different types (e.g. {'linear': nn.Linear(...), 'relu': nn.ReLU}).

(A) can be implemented similarly to the previous example:

...
class ModuleDict(Module, MutableMapping[str, ModuleDictValue]):
    ...

In fact, this can cover (B) as well in a very limited way (by setting ModuleDictValue=nn.Module). And it is unclear to me how to implement the fully functioning (B), it looks like we need something like TypedMutableMapping, but there is no such thing in typing. So I would start with MutableMapping and add TypedModuleDict when it becomes technically possible.

Alternatives

No response

Additional context

No response

cc @ezyang @malfet @rgommers @xuzhao9 @gramster

@Yura52 Yura52 changed the title Make ModuleList and ModuleDict generic Add typing support to ModuleList and ModuleDict Jul 3, 2022
@zou3519 zou3519 added module: typing Related to mypy type annotations triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module labels Jul 5, 2022
@ezyang
Copy link
Contributor

ezyang commented Jul 6, 2022

ModuleListValue would only be useful for homogenous module lists, but it's not clear to me how often you are actually going to have a homogenous list. Maybe we need ModuleTuple lol

@fx-carton
Copy link

In the C++ API, this would be nice to have: if you know all the modules in the list will be of the same type, a typed ModuleList would allow for the convenience of the ModuleList registering modules, while avoiding the trouble of implementing .forward() with the generic ModuleList (either trying all known possibilities with .as<Type>() or using AnyModule).

Of course this is only relevant to the C++ API which can compile-time check the types of the modules in the list, so maybe I should open a new issue.

@ssnl
Copy link
Collaborator

ssnl commented Oct 13, 2022

I'd argue that generic ModuleList and ModuleDict are fairly useful. At least for me personally, I often have a parent module containing multiple submodules of the same kind. It is very annoying that typing them as ModuleDict loses the type information. The ugly workaround I do is

class Parent(nn.Module):
    children: Mapping[str, SubModule]  # but really is a ModuleDict

@function2-llx
Copy link

@ssnl Maybe as a better workaround, you can use Mapping[str, SubModule] | ModuleDict.

@bdvllrs
Copy link

bdvllrs commented Feb 23, 2024

An alternative could be to have some kind of TypedKey class that externally keeps track of the types and use this as ModuleDict keys.

Here is a very schematic implementation:

from collections.abc import Mapping
from typing import Generic, Type, TypeVar, overload, reveal_type                                                                                                                                                


Tco = TypeVar("Tco", bound=Module, covariant=True)
T = TypeVar("T")


class Module:
    pass


class TypedKey(Generic[T, Tco]):
    def __init__(self, type_: Type[Tco], name: T):
        self.type_ = type_
        self.name = name


class ModuleDict:
    def __init__(self, data: Mapping[str | TypedKey[str, Module], Module]):
        self.data = data

    @overload
    def __getitem__(self, key: str) -> Module: ...

    @overload
    def __getitem__(self, key: TypedKey[str, Tco]) -> Tco: ...

    def __getitem__(self, key: str | TypedKey[str, Tco]) -> Module | Tco:
        return self.data[key]


class Mod1(Module):
    pass


class Mod2(Module):
    pass


if __name__ == "__main__":
    d = ModuleDict({"mod1": Mod1(), "mod2": Mod2()})
    reveal_type(d["mod1"])  # Revealed type is Module

    kmod1 = TypedKey(Mod1, "mod1")
    kmod2 = TypedKey(Mod2, "mod2")
    d2 = ModuleDict({kmod1: Mod1(), kmod2: Mod2()})
    reveal_type(d2[kmod1])  # Revealed type is Mod1
    reveal_type(d2[kmod2])  # Revealed type is Mod2

The advantage is that ModuleDict can still contain different Module types and have correct type inference when we use a TypedKey.

@RickoNoNo3
Copy link

That is exactly what I need. Please consider to integrate it.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
module: typing Related to mypy type annotations triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

No branches or pull requests

8 participants