New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
module
will ignore params with requires_grad=False
#1779
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the fix!
tests/params/test_module.py
Outdated
def forward(self, s): | ||
pass | ||
|
||
with warnings.catch_warnings(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: I would suggest changing this to with pytest.warns(RuntimeWarning):
which will explicitly check that this warning is raised.
if param_value._cdata != returned_param._cdata: | ||
target_state_dict[param_name] = returned_param | ||
else: | ||
warnings.warn("{} was not registered in the param store because".format(param_name) + |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: space between because
and requires_grad
in the warning message.
The ss-vae example seems to be failing for some reason, but it doesn't seem related. |
Addresses #1778