Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Introducing PyroModuleList, because torch.nn.ModueList reinitializies modules when slice-indexing #3339

Merged
merged 8 commits into from
Mar 17, 2024

Conversation

MartinBubel
Copy link
Contributor

@MartinBubel MartinBubel commented Mar 11, 2024

Fixes #3341

Hi,

While using "nested" PyroModule's: A PyroModule with a PyroModule[torch.nn.ModuleList] argument, containing another such module (see the test implementation in the PR for a more detailed example), I encountered some errors related to the existence of multiple sample sites with the same name.

I spend some time on tracing the issue and found that the RuntimeError was caused by an unlucky combination of PyroModule and torch.nn.ModuleList:
First: why using torch.nn.ModuleList? Well I wanted to model a PyroModule that owns a list of sub-pyromodules. My idea was to replace linear = PyroModule[Linear](5, 2) from the modules example by PyroModule[torch.nn.ModuleList](...).
That works fine if there is a

class MyPyroModule(PyroModule):
    def __init__(self, some_module: PyroModule):
        self.my_submodule = PyroModule[torch.nn.ModuleList]([some_module for _ in range(3)])

but can fail if there is a nested structure, like

my_nested_pyro_module = MyPyroModule(MyPyroModule(...))

The "can fail" could be resolved to the following different types if accessing the self.my_submodule argument:

  • using a index-positioned access, like self.my_submodule[0](...) worked fine, even for nested modules
  • using slice-indexing access, like self.my_submodule[:-1](...) only works if there is a single MyPyroModule() and not a MyPyroModule(MyPyroModule(...))

The cause is this line in the torch.nn.ModuleList class:

@_copy_to_script_wrapper
    def __getitem__(self, idx: Union[int, slice]) -> Union[Module, 'ModuleList']:
        if isinstance(idx, slice):
            return self.__class__(list(self._modules.values())[idx])
        else:
            return self._modules[self._get_abs_string_index(idx)]

which calls self.__class__, which, means that for an object of type PyroModule[torch.nn.ModuleList], it calls the PyroModule.__init__ function, without the context of the parent module. This results in overwriting the names of the module's submodules and their ._pyro_name attributes, and because of that, during sampling, sample sites may not be unique anymore.

I see two possible fixes for this:

  1. I have introduced a pyro.nn.PyroModuleList class in this PR that inherits from torch.nn.ModuleList can overwrites the problematic __getitem__ function
  2. We may add some example/documentation on this issue, explicitly warning about using PyroModule[torch.nn.ModuleList] in combination with slice-indexing (feels a little unsafe to me)
  3. Maybe some has another (potentially better) idea of how to deal with this?

I know that this is kind of a "special purpose" usage and may not affect basic pyro usage, but in particular the way the modules example motivates the usage of PyroModule[torch.nn.Something] could, imho, quickly lure other users into this (and it took me quite some time to find the root issue).

Please let me know what you think about this PR, and whether it needs updates or clarification.

Best regards
Martin

Copy link
Member

@fritzo fritzo left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for this subtle fix, thanks for writing thorough tests, and thanks for your patience with our review! I have just a few comments.

pyro/nn/module.py Outdated Show resolved Hide resolved
pyro/nn/module.py Outdated Show resolved Hide resolved
pyro/nn/module.py Outdated Show resolved Hide resolved
@fritzo
Copy link
Member

fritzo commented Mar 17, 2024

BTW lint errors can be detected by running make lint and can often be fixed by running make format.

@MartinBubel
Copy link
Contributor Author

Linting is applied in b23177f.

Thanks @fritzo for the helpful review! I hope I included your suggestions as intended. If not, please let me know.

@fritzo
Copy link
Member

fritzo commented Mar 17, 2024

LGTM. I think you just need to make format to fix the lint error

@MartinBubel
Copy link
Contributor Author

Seems like something got lost in the linting commit, should all be good now.

@fritzo fritzo merged commit 8869834 into pyro-ppl:dev Mar 17, 2024
9 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Using PyroModule with torch.nn.ModueList fails for nested modules
2 participants