New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Support factory kwargs in torch.nn modules #54508
Conversation
💊 CI failures summary and remediationsAs of commit 044ff1d (more details on the Dr. CI page):
🕵️ 1 new failure recognized by patternsThe following CI failures do not appear to be due to upstream breakages: pytorch_windows_vs2019_py36_cuda10.1_test1 (1/1)Step: "Test" (full log | diagnosis details | 🔁 rerun)
|
looks like you got some merge conflicts |
torch/__init__.py
Outdated
@@ -472,6 +472,24 @@ def is_warn_always_enabled(): | |||
""" | |||
return _C._get_warnAlways() | |||
|
|||
|
|||
def factory_kwargs(kwargs): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Need a docblock. Maybe something like:
"""
Given kwargs, returns a canonicalized dict of factory kwargs that can be directly passed
to factory functions like torch.empty, or errors if unrecognized kwargs are present.
This function makes it simple to write code like this::
class MyModule(nn.Module):
def __init__(self, **kwargs):
factory_kwargs = torch.factory_kwargs(kwargs)
self.weight = Parameter(torch.empty(10, **factory_kwargs))
Why should you use this function instead of just passing `kwargs` along directly?
1. This function does error validation, so if there are unexpected kwargs we will immediately report an error, instead of deferring it to the factory call
2. This function supports a special `factory_kwargs` argument, which can be used to explicitly specify a kwarg to be used for factory functions, in the event one of the factory kwargs conflicts with an already existing argument in the signature (e.g., in the signature ``def f(dtype, **kwargs)``, you can specify ``dtype`` for factory functions, as distinct from the dtype argument, by saying ``f(dtype1, factory_kwargs={"dtype": dtype2})``)
"""
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
note that the block needs to be in the function to be recognized as doc for the function IIRC
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
whoops fixed!
torch/nn/modules/activation.py
Outdated
|
||
def _reset_parameters(self): | ||
def reset_parameters(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: idly wondering if we should define _reset_parameters
calling to reset_parameters
for BC, ha!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Haha, I thought the same, but decided against it originally because _reset_parameters
is technically private?
Happy to add it though!
torch/nn/modules/adaptive.py
Outdated
) | ||
|
||
self.tail.append(projection) | ||
|
||
if reset_parameters: | ||
self.reset_parameters() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This seems wrong; you don't have to explicitly reset parameters here because the propagated reset parameters arguments should have handled it already
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yep, good catch! fixed
torch/nn/modules/transformer.py
Outdated
self.dropout1 = Dropout(dropout) | ||
self.dropout2 = Dropout(dropout) | ||
|
||
self.activation = _get_activation_fn(activation) | ||
|
||
if reset_parameters: | ||
self.reset_parameters() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Probably better not to have this one either
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixed this one and TransformerDecoderLayer
I don't want to hold this too much on tests, but there may be some simple things we can do in the generic nn Module tests; e.g., instead of constructing the Module with just the known arguments, try also passing in device cpu. |
torch/nn/modules/activation.py
Outdated
else: | ||
self.bias_k = self.bias_v = None | ||
|
||
self.add_zero_attn = add_zero_attn | ||
|
||
self._reset_parameters() | ||
MultiheadAttention.reset_parameters(self) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is ugly, and apparently necessary for any base class. With self.reset_parameters()
instead, instantiating a subclass errors out:
subclass.__init__()
is calledsuper().__init__()
is called- In the base class constructor,
self.reset_parameters()
calls the overridingsubclass.reset_parameters()
subclass.reset_parameters()
tries to reset parameters or buffers that haven't been created yet - error
Maybe this syntax should be made universal so we don't have to decide which classes can be base classes. Or some other workaround?
# Returns a database of args & kwargs that can be used to construct each module. | ||
# Each entry is in class -> (args, kwargs) format. | ||
# Example: torch.nn.Linear -> ([10, 5], {}) | ||
def build_constructor_arg_db(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Alternatively, there's a similar DB in common_nn.py
or something. It didn't have constructor args for every module, and in some cases, I need to pick different args to ensure paths are taken that create params / buffers.
Whenever ModuleInfo
(analogous to the new OpInfo
) comes about, this should be dealt with there.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
+1 for ModuleInfo (just imagine if there was already a basic version!)
Why is this a function that returns a dict instead of just being a dict?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The function is left over from when I started with the common_nn.py
new_module_tests
and only filled in entries for the missing modules. Could make it just a dict but I kinda liked the mildly functional style
I'll rebase tomorrow :) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for working on this Joel!
The changes for factory kwargs look good to me.
Also the generic tests look quite good!
I am not as convinced by the reset_parameters changes though.
I feel like it is somewhere between being a new API on nn.Module and just some internal convention we use.
If we want it to be the second, then my thinking is:
- We should not modify the base nn.Module
- We shouldn't need to add empty methods everywhere (check if the child Module has that method or not?)
- Make that method internal so that it does not lead to issue if users implement another method with the same name
- You can use
.apply()
as well to run this method on all the Modules that have that method without the need to implement custom functions on every single containers.
If we actually decide to go with the first one, I think we should go all in and get all the benefits from such a significant change.
In particular, we want to make it clear to the user which type of Module they are implementing and using. And there should be a clear benefit for implementing a structured version. Being able to call the default initialization independently is one, but why can't we leverage that to automatically call it during initialization as well?
torch/__init__.py
Outdated
@@ -472,6 +472,24 @@ def is_warn_always_enabled(): | |||
""" | |||
return _C._get_warnAlways() | |||
|
|||
|
|||
def factory_kwargs(kwargs): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
note that the block needs to be in the function to be recognized as doc for the function IIRC
torch/__init__.py
Outdated
|
||
class MyModule(nn.Module): | ||
def __init__(self, **kwargs): | ||
factory_kwargs = torch.factory_kwargs(kwargs) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why is this in torch. and not torch.nn. ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good point - I'll move it to torch.nn
torch/nn/parameter.py
Outdated
@@ -145,8 +145,9 @@ class UninitializedParameter(UninitializedTensorMixin, Parameter): | |||
|
|||
cls_to_become = Parameter | |||
|
|||
def __new__(cls, requires_grad=True): | |||
data = torch.Tensor() | |||
def __new__(cls, requires_grad=True, **kwargs) -> None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should we update some doc to mention that this takes more arguments than the regular nn.Parameter()
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think so- the UninitializedParameter
/ UninitializedBuffer
docs seem like the right place to do it
@@ -55,18 +55,20 @@ class MultiheadAttention(nn.MultiheadAttention): | |||
def __init__(self, embed_dim: int, num_heads: int, | |||
dropout: float = 0., bias: bool = True, | |||
add_bias_kv: bool = False, add_zero_attn: bool = False, | |||
kdim: int = None, vdim: int = None): | |||
kdim: int = None, vdim: int = None, **kwargs) -> None: | |||
factory_kwargs = torch.factory_kwargs(kwargs) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why do you unpack them here? Why not just pass them throw as is?
You do this is a couple other places.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Auto-pilot updates :) I'll remove the unnecessary unpacking for subclasses
adding @mruberry for the new testing infra |
This is not good because you cannot distinguish if a Module has parameters but forgot to implement
While this seems reasonable, I am torn because I don't want a default implementation of this. Maybe a helper function for recursively calling reset parameters on all submodules is the doctor's order.
Can't, we historically didn't specify if you call super constructor first or last in the constructor, I don't think we can assume that it is called after the parameters are setup. |
7de5112
to
0912b88
Compare
@@ -0,0 +1,330 @@ | |||
import inspect |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
To run the tests add the file name here:
Line 31 in d4045e9
TESTS = [ |
And verify the tests appear in the PR CI output.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks!
test/test_module_init.py
Outdated
# Example: torch.nn.Linear -> ([10, 5], {}) | ||
def build_constructor_arg_db(): | ||
return { | ||
torch.nn.AdaptiveAvgPool1d: ([5], {}), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If someone adds a new module, will they know all the lists they need to add it to?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Any module in __all__
for torch.nn modules (and the various quantization modules) will be tested here and an error will be thrown indicating that an entry should be added here.
test/test_module_init.py
Outdated
torch.nn.AdaptiveAvgPool2d: ([5], {}), | ||
torch.nn.AdaptiveAvgPool3d: ([5], {}), | ||
torch.nn.AdaptiveLogSoftmaxWithLoss: ([100, 20, [5, 10, 15]], {}), | ||
torch.nn.AdaptiveMaxPool1d: ([5], {}), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Style nit (would not block PR on this or make this change if it's a nontrivial amount of work): prefer tuples to lists
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Any particular reason? I don't mind doing the work to make the change, but it makes it harder to read imo
test/test_module_init.py
Outdated
|
||
|
||
# Instantiates the given class with the given args, kwargs, optionally on a given device. | ||
def instantiate_class(cls, args, kwargs, device=None): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is interesting. Why is device not part of kwargs?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
args
/ kwargs
are used to instantiate each type of module, and it simplified the calls to just add in device separately when it's needed. But you're right that I could alternatively add device to kwargs
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Whatever's easiest.
|
||
# Returns a function that calls the real implementation of a method | ||
# in addition to passing args to a mock object. | ||
def mock_wrapper(method): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I couldn't get wraps
to work for this specific use case where I don't have a particular object. This wrapper is used below to hook into all parameter creations or buffer registrations. If you know how to use wraps
for this, let me know and I'll change it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Interesting. I suppose this can wait until it's an issue (and it may never be an issue).
|
||
# Returns a function that calls the real implementation of a method | ||
# in addition to passing args to a mock object. | ||
def mock_wrapper(method): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why call the mock?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Calling the mock logs that the method was called
test/test_module_init.py
Outdated
if module_creates_params_or_buffers and module_cls not in MODULES_WITHOUT_KWARGS_SUPPORT: | ||
args, kwargs = get_example_args(module_cls, constructor_arg_db, device=device) | ||
|
||
# if module_cls in LAZY_MODULES: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What's up with this comment?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is testing logic for LazyTensors that should be put back in once meta tensor functionality is expanded. Not sure whether this PR should wait for that so the tests will be better?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I probably wouldn't wait (although you know better). I would just remove this comment (possibly replacing it with a TODO) or add a meta-comment explaining why this section is commented out.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Luckily enough, the support is in now so I uncommented it!
test/test_module_init.py
Outdated
module_cls = getattr(mod_namespace, module_name) | ||
if module_cls in MODULES_TO_SKIP: continue | ||
|
||
# Create a function to run the test and setattr it onto the test class. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not saying this should be the PR that creates ModuleInfo... but MAYBE this (or a follow-up) should be the PR that does it. In particular, generators like this that are lengthy and complicated was one of the reasons we wanted to switch to decorating test templates that look like more typical tests.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I for one am excited for ModuleInfo
! As discussed offline, we'll meet sometime shortly after this PR to design what that could look like.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also adding a TODO in the code to refactor the constructor arg DB into ModuleInfo
.
test/test_module_init.py
Outdated
test_name = f'test_{namespace_basename}_{module_name}' | ||
setattr(TestModuleInit, test_name, run_test) | ||
|
||
instantiate_device_type_tests(test_cls, globals()) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would put this call on the current line 327 for readability and consistency with other test files
return args, kwargs | ||
|
||
|
||
def generate_tests(test_cls, constructor_arg_db): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add a comment describing what this test tests
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Test logic is actually below in run_test()
, and its checks are commented
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I split the logic out for generating tests for a single module, so the separation between that and iterating over all the modules should be clearer now
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Giving approval modulo testing; please look at mruberry's comments
Reverting (again) because it broke asan https://app.circleci.com/pipelines/github/pytorch/pytorch/304608/workflows/d07158c1-b75a-49a1-a3c6-a41371d0bce3/jobs/12561730 |
This pull request has been reverted by 92d24e3. |
60431cf
to
5ee4848
Compare
@jbschlosser has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator. |
5ee4848
to
6dec5bf
Compare
@jbschlosser has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator. |
ASAN is failing in the same way on PR |
Yep sorry for the mess. I thought the land would have stopped if the problem was unfixed but it went right through :/ Unlanding now and will make sure it's actually fixed before re-landing. |
@jbschlosser has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator. |
Summary: Continuation of pytorch#53144 Pull Request resolved: pytorch#54508 Reviewed By: mrshenli Differential Revision: D27600457 Pulled By: jbschlosser fbshipit-source-id: b58bfee61c3917524b4622f63ef216c27a588eb1
Summary: Continuation of pytorch#53144 Pull Request resolved: pytorch#54508 Reviewed By: bdhirsh Differential Revision: D27855386 Pulled By: jbschlosser fbshipit-source-id: dabd505d2a04208e74b158570fb2859c736eea2c
Summary: Continuation of pytorch#53144 Pull Request resolved: pytorch#54508 Reviewed By: malfet Differential Revision: D27909732 Pulled By: jbschlosser fbshipit-source-id: d8684b2403ab7eb336371d118799146a2520bd76
Summary: Continuation of pytorch#53144 Pull Request resolved: pytorch#54508 Reviewed By: albanD Differential Revision: D27939544 Pulled By: jbschlosser fbshipit-source-id: 4bf517e5f74f093e27ca38a85e732da65e44d805
Needed since PyTorch v1.9 (pytorch/pytorch#54508)
Needed since PyTorch v1.9 (pytorch/pytorch#54508)
Continuation of #53144