-
Notifications
You must be signed in to change notification settings - Fork 382
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Support for recursive parameter assignment on module #107
Comments
With #117 this can be solved: For example: class Seq2Seq:
def __init__(self, encoder, decoder, **kwargs):
self.encoder = encoder(skorch.utils.params_for('encoder', kwargs))
self.decoder = decoder(skorch.utils.params_for('decoder', kwargs)) This would allow us to write ef = NeuralNet(
module=Seq2Seq(encoder=AttentionEncoderRNN, decoder=DecoderRNN),
module__encoder__num_hidden=23,
) which is exactly what we want. |
It would be also nice to support |
Still open for debate and there's no clear road. Postponing for r0.3.0. |
Also up for discussion: Not only allow setting parameters on sub-modules, but also on arbitrary attributes. This would, e.g., allow us to things like:
|
I propose the following utility or helper functions: from operator import methodcaller
def set_params_in_module(module, **kwargs):
for k, v in kwargs.items():
set_param_in_module(module, k, v)
def set_param_in_module(module, param, value):
name, key = param.rsplit('__', 1)
name = name.replace('__', '.')
for n, p in module.named_parameters():
if n.startswith(name):
methodcaller(f'{key}_', value)(p) This would allow support for the following syntax: class Encoder(nn.Module):
def __init__(self):
super().__init__()
self.embeddings = nn.Embedding(10, 10)
class MyModule(nn.Module):
def __init__(self):
super().__init__()
self.conv = nn.Conv2d(10, 1, 1)
self.encoder = Encoder()
module = MyModule()
set_params_in_module(module, conv__weight__requires_grad=False)
set_params_in_module(module, conv__weight__copy=torch.ones((1, 10, 1, 1)))
set_params_in_module(module, conv__weight__add=torch.ones((1, 10, 1, 1)))
set_params_in_module(module, encoder__requires_grad=False)
set_params_in_module(module,
conv__requires_grad=False,
encoder__embeddings__weight__requires_grad=False) Integrating this into |
I believe there could be a way. When a parameter is passed to |
If The |
This is way more complicated than anticipated and we should schedule this for 0.4.0 rather than delay 0.3.0. |
It would be helpful to have the ability to set parameters beyond module level (for sub-components of the module, for example):
I would expect
module.encoder.num_hidden
to be set to23
. This should be robust with respect to the initializtion of the sub-module, for example if the encoder has elements that depend on the initialized value, those elements should be updated as well. In the given example, I would expect not onlymodule.encoder.num_hidden
to be updated to23
but also thatmodule.encoder.lin.out_features
is updated (e.g. by re-initializing the whole module).The text was updated successfully, but these errors were encountered: