Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support for recursive parameter assignment on module #107

Open
ottonemo opened this issue Nov 1, 2017 · 8 comments
Open

Support for recursive parameter assignment on module #107

ottonemo opened this issue Nov 1, 2017 · 8 comments

Comments

@ottonemo
Copy link
Member

ottonemo commented Nov 1, 2017

It would be helpful to have the ability to set parameters beyond module level (for sub-components of the module, for example):

class Seq2Seq:
    def __init__(self, encoder, decoder, **kwargs):
        self.encoder = encoder
        self.decoder = decoder

class Encoder:
    def __init__(self, num_hidden=100):
        self.num_hidden = num_hidden
        self.lin = nn.Linear(1, num_hidden)

ef = NeuralNet(
        module=Seq2Seq(encoder=AttentionEncoderRNN, decoder=DecoderRNN),
        module__encoder__num_hidden=23,
    )

I would expect module.encoder.num_hidden to be set to 23. This should be robust with respect to the initializtion of the sub-module, for example if the encoder has elements that depend on the initialized value, those elements should be updated as well. In the given example, I would expect not only module.encoder.num_hidden to be updated to 23 but also that module.encoder.lin.out_features is updated (e.g. by re-initializing the whole module).

@ottonemo
Copy link
Member Author

With #117 this can be solved:

For example:

class Seq2Seq:
    def __init__(self, encoder, decoder, **kwargs):
        self.encoder = encoder(skorch.utils.params_for('encoder', kwargs))
        self.decoder = decoder(skorch.utils.params_for('decoder', kwargs))

This would allow us to write

ef = NeuralNet(
        module=Seq2Seq(encoder=AttentionEncoderRNN, decoder=DecoderRNN),
        module__encoder__num_hidden=23,
    )

which is exactly what we want.

@ottonemo
Copy link
Member Author

ottonemo commented Dec 1, 2017

It would be also nice to support nn.Sequential.

@ottonemo ottonemo moved this from Open to Open for Release in Add basic functionality Dec 14, 2017
@benjamin-work benjamin-work added this to To Do in Release 0.2 Dec 20, 2017
@ottonemo ottonemo added r0.3.0 and removed r0.2.0 labels Apr 26, 2018
@ottonemo
Copy link
Member Author

Still open for debate and there's no clear road. Postponing for r0.3.0.

@ottonemo ottonemo removed this from To Do in Release 0.2 Apr 26, 2018
@ottonemo ottonemo added this to To do in Release 0.3.0 Apr 26, 2018
@benjamin-work
Copy link
Contributor

Also up for discussion: Not only allow setting parameters on sub-modules, but also on arbitrary attributes. This would, e.g., allow us to things like:

net.set_params(module__encoder__embeddings__weight__requires_grad=False)

@thomasjpfan
Copy link
Member

thomasjpfan commented Jul 1, 2018

I propose the following utility or helper functions:

from operator import methodcaller

def set_params_in_module(module, **kwargs):
    for k, v in kwargs.items():
        set_param_in_module(module, k, v)

def set_param_in_module(module, param, value):
    name, key = param.rsplit('__', 1)
    name = name.replace('__', '.')
    for n, p in module.named_parameters():
        if n.startswith(name):
            methodcaller(f'{key}_', value)(p)

This would allow support for the following syntax:

class Encoder(nn.Module):
    def __init__(self):
        super().__init__()
        self.embeddings = nn.Embedding(10, 10)

class MyModule(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv = nn.Conv2d(10, 1, 1)
        self.encoder = Encoder()

module = MyModule()
set_params_in_module(module, conv__weight__requires_grad=False)
set_params_in_module(module, conv__weight__copy=torch.ones((1, 10, 1, 1)))
set_params_in_module(module, conv__weight__add=torch.ones((1, 10, 1, 1)))
set_params_in_module(module, encoder__requires_grad=False)
set_params_in_module(module, 
                     conv__requires_grad=False,
                     encoder__embeddings__weight__requires_grad=False)

Integrating this into NeutralNet is tricky, because keywords prefixed with module__ are passed into the modules __init__ function.

@BenjaminBossan
Copy link
Collaborator

I believe there could be a way. When a parameter is passed to module, if it doesn't contain a __, proceed normally. Otherwise, proceed as you suggested (but the call must be recursive if there are several __).

@thomasjpfan
Copy link
Member

If module__inner__linear is set in NeutralNet, then inner__linear will be passed, as a keyword, to the module's __init__ function. This enables @ottonemo's use case of using skorch.utils.params_for during the module's __init__.

The set_params_in_module function is used after the module has successfully called __init__.

@ottonemo
Copy link
Member Author

This is way more complicated than anticipated and we should schedule this for 0.4.0 rather than delay 0.3.0.

@ottonemo ottonemo removed this from To do in Release 0.3.0 Jul 26, 2018
@ottonemo ottonemo removed the r0.5.0 label Dec 17, 2018
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
No open projects
Add basic functionality
Open for Release
Development

No branches or pull requests

4 participants