Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support factory kwargs in torch.nn modules #54508

Closed
wants to merge 3 commits into from

Conversation

jbschlosser
Copy link
Contributor

Continuation of #53144

@facebook-github-bot
Copy link
Contributor

facebook-github-bot commented Mar 23, 2021

💊 CI failures summary and remediations

As of commit 044ff1d (more details on the Dr. CI page):


  • 1/1 failures introduced in this PR

🕵️ 1 new failure recognized by patterns

The following CI failures do not appear to be due to upstream breakages:

See CircleCI build pytorch_windows_vs2019_py36_cuda10.1_test1 (1/1)

Step: "Test" (full log | diagnosis details | 🔁 rerun)

RuntimeError: Found no NVIDIA driver on your sy...ver from http://www.nvidia.com/Download/index.aspx
177c:2744 @ 00497531 - LdrpGetProcedureAddress - INFO: Locating procedure "FlsGetValue" by name
(177c.2744): C++ EH exception - code e06d7363 (first chance)
(177c.2744): C++ EH exception - code e06d7363 (first chance)
(177c.2744): C++ EH exception - code e06d7363 (first chance)
(177c.2744): C++ EH exception - code e06d7363 (first chance)
(177c.2744): C++ EH exception - code e06d7363 (first chance)
Traceback (most recent call last):
  File "<string>", line 1, in <module>
  File "C:\Users\circleci\project\build\win_tmp\build\torch\cuda\__init__.py", line 170, in _lazy_init
    torch._C._cuda_init()
RuntimeError: Found no NVIDIA driver on your system. Please check that you have an NVIDIA GPU and installed a driver from http://www.nvidia.com/Download/index.aspx
177c:2744 @ 00497578 - LdrLoadDll - ENTER: DLL name: api-ms-win-appmodel-runtime-l1-1-2
177c:2744 @ 00497578 - LdrpPreprocessDllName - INFO: DLL api-ms-win-appmodel-runtime-l1-1-2 was redirected to C:\Windows\SYSTEM32\kernel.appcore.dll by API set
177c:2744 @ 00497578 - LdrpLoadDllInternal - ENTER: DLL name: C:\Windows\SYSTEM32\kernel.appcore.dll
177c:2744 @ 00497578 - LdrpFindKnownDll - ENTER: DLL name: kernel.appcore.dll
177c:2744 @ 00497578 - LdrpFindKnownDll - RETURN: Status: 0x00000000
177c:2744 @ 00497578 - LdrpMinimalMapModule - ENTER: DLL name: C:\Windows\System32\kernel.appcore.dll
ModLoad: 00007ffa`89660000 00007ffa`89671000   C:\Windows\System32\kernel.appcore.dll
177c:2744 @ 00497578 - LdrpMinimalMapModule - RETURN: Status: 0x00000000
177c:2744 @ 00497578 - LdrpFindDllActivationContext - INFO: Probing for the manifest of DLL "C:\Windows\System32\kernel.appcore.dll" failed with status 0xc000008a
177c:2744 @ 00497578 - LdrpPreprocessDllName - INFO: DLL api-ms-win-core-profile-l1-1-0.dll was redirected to C:\Windows\SYSTEM32\kernelbase.dll by API set

This comment was automatically generated by Dr. CI (expand for details).Follow this link to opt-out of these comments for your Pull Requests.

Please report bugs/suggestions to the (internal) Dr. CI Users group.

@ezyang
Copy link
Contributor

ezyang commented Mar 23, 2021

looks like you got some merge conflicts

@@ -472,6 +472,24 @@ def is_warn_always_enabled():
"""
return _C._get_warnAlways()


def factory_kwargs(kwargs):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Need a docblock. Maybe something like:

"""
Given kwargs, returns a canonicalized dict of factory kwargs that can be directly passed
to factory functions like torch.empty, or errors if unrecognized kwargs are present.

This function makes it simple to write code like this::

    class MyModule(nn.Module):
        def __init__(self, **kwargs):
            factory_kwargs = torch.factory_kwargs(kwargs)
            self.weight = Parameter(torch.empty(10, **factory_kwargs))

Why should you use this function instead of just passing `kwargs` along directly?

1. This function does error validation, so if there are unexpected kwargs we will immediately report an error, instead of deferring it to the factory call
2. This function supports a special `factory_kwargs` argument, which can be used to explicitly specify a kwarg to be used for factory functions, in the event one of the factory kwargs conflicts with an already existing argument in the signature (e.g., in the signature ``def f(dtype, **kwargs)``, you can specify ``dtype`` for factory functions, as distinct from the dtype argument, by saying ``f(dtype1, factory_kwargs={"dtype": dtype2})``)
"""

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

note that the block needs to be in the function to be recognized as doc for the function IIRC

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

whoops fixed!


def _reset_parameters(self):
def reset_parameters(self):
Copy link
Contributor

@ezyang ezyang Mar 23, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: idly wondering if we should define _reset_parameters calling to reset_parameters for BC, ha!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Haha, I thought the same, but decided against it originally because _reset_parameters is technically private?

Happy to add it though!

)

self.tail.append(projection)

if reset_parameters:
self.reset_parameters()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems wrong; you don't have to explicitly reset parameters here because the propagated reset parameters arguments should have handled it already

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep, good catch! fixed

self.dropout1 = Dropout(dropout)
self.dropout2 = Dropout(dropout)

self.activation = _get_activation_fn(activation)

if reset_parameters:
self.reset_parameters()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Probably better not to have this one either

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed this one and TransformerDecoderLayer

@ezyang
Copy link
Contributor

ezyang commented Mar 23, 2021

I don't want to hold this too much on tests, but there may be some simple things we can do in the generic nn Module tests; e.g., instead of constructing the Module with just the known arguments, try also passing in device cpu.

else:
self.bias_k = self.bias_v = None

self.add_zero_attn = add_zero_attn

self._reset_parameters()
MultiheadAttention.reset_parameters(self)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is ugly, and apparently necessary for any base class. With self.reset_parameters() instead, instantiating a subclass errors out:

  1. subclass.__init__() is called
  2. super().__init__() is called
  3. In the base class constructor, self.reset_parameters() calls the overriding subclass.reset_parameters()
  4. subclass.reset_parameters() tries to reset parameters or buffers that haven't been created yet - error

Maybe this syntax should be made universal so we don't have to decide which classes can be base classes. Or some other workaround?

# Returns a database of args & kwargs that can be used to construct each module.
# Each entry is in class -> (args, kwargs) format.
# Example: torch.nn.Linear -> ([10, 5], {})
def build_constructor_arg_db():
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Alternatively, there's a similar DB in common_nn.py or something. It didn't have constructor args for every module, and in some cases, I need to pick different args to ensure paths are taken that create params / buffers.

Whenever ModuleInfo (analogous to the new OpInfo) comes about, this should be dealt with there.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1 for ModuleInfo (just imagine if there was already a basic version!)

Why is this a function that returns a dict instead of just being a dict?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The function is left over from when I started with the common_nn.py new_module_tests and only filled in entries for the missing modules. Could make it just a dict but I kinda liked the mildly functional style

@jbschlosser
Copy link
Contributor Author

I'll rebase tomorrow :)

Copy link
Collaborator

@albanD albanD left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for working on this Joel!

The changes for factory kwargs look good to me.
Also the generic tests look quite good!

I am not as convinced by the reset_parameters changes though.
I feel like it is somewhere between being a new API on nn.Module and just some internal convention we use.

If we want it to be the second, then my thinking is:

  • We should not modify the base nn.Module
  • We shouldn't need to add empty methods everywhere (check if the child Module has that method or not?)
  • Make that method internal so that it does not lead to issue if users implement another method with the same name
  • You can use .apply() as well to run this method on all the Modules that have that method without the need to implement custom functions on every single containers.

If we actually decide to go with the first one, I think we should go all in and get all the benefits from such a significant change.
In particular, we want to make it clear to the user which type of Module they are implementing and using. And there should be a clear benefit for implementing a structured version. Being able to call the default initialization independently is one, but why can't we leverage that to automatically call it during initialization as well?

@@ -472,6 +472,24 @@ def is_warn_always_enabled():
"""
return _C._get_warnAlways()


def factory_kwargs(kwargs):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

note that the block needs to be in the function to be recognized as doc for the function IIRC


class MyModule(nn.Module):
def __init__(self, **kwargs):
factory_kwargs = torch.factory_kwargs(kwargs)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is this in torch. and not torch.nn. ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point - I'll move it to torch.nn

@@ -145,8 +145,9 @@ class UninitializedParameter(UninitializedTensorMixin, Parameter):

cls_to_become = Parameter

def __new__(cls, requires_grad=True):
data = torch.Tensor()
def __new__(cls, requires_grad=True, **kwargs) -> None:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we update some doc to mention that this takes more arguments than the regular nn.Parameter()?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think so- the UninitializedParameter / UninitializedBuffer docs seem like the right place to do it

@@ -55,18 +55,20 @@ class MultiheadAttention(nn.MultiheadAttention):
def __init__(self, embed_dim: int, num_heads: int,
dropout: float = 0., bias: bool = True,
add_bias_kv: bool = False, add_zero_attn: bool = False,
kdim: int = None, vdim: int = None):
kdim: int = None, vdim: int = None, **kwargs) -> None:
factory_kwargs = torch.factory_kwargs(kwargs)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do you unpack them here? Why not just pass them throw as is?

You do this is a couple other places.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Auto-pilot updates :) I'll remove the unnecessary unpacking for subclasses

@ezyang ezyang requested a review from mruberry March 26, 2021 17:40
@ezyang
Copy link
Contributor

ezyang commented Mar 26, 2021

adding @mruberry for the new testing infra

@ezyang
Copy link
Contributor

ezyang commented Mar 26, 2021

We shouldn't need to add empty methods everywhere (check if the child Module has that method or not?)

This is not good because you cannot distinguish if a Module has parameters but forgot to implement reset_parameters, or if it has no parameters and it is supposed to not have it. This is one of the reasons why adding the empty reset_parameters is a good idea, it makes it explicit "yes, we checked, and no, the reset parameters here is a noop".

You can use .apply() as well to run this method on all the Modules that have that method without the need to implement custom functions on every single containers.

While this seems reasonable, I am torn because I don't want a default implementation of this. Maybe a helper function for recursively calling reset parameters on all submodules is the doctor's order.

why can't we leverage that to automatically call it during initialization as well?

Can't, we historically didn't specify if you call super constructor first or last in the constructor, I don't think we can assume that it is called after the parameters are setup.

@@ -0,0 +1,330 @@
import inspect
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To run the tests add the file name here:

TESTS = [

And verify the tests appear in the PR CI output.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks!

# Example: torch.nn.Linear -> ([10, 5], {})
def build_constructor_arg_db():
return {
torch.nn.AdaptiveAvgPool1d: ([5], {}),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If someone adds a new module, will they know all the lists they need to add it to?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Any module in __all__ for torch.nn modules (and the various quantization modules) will be tested here and an error will be thrown indicating that an entry should be added here.

torch.nn.AdaptiveAvgPool2d: ([5], {}),
torch.nn.AdaptiveAvgPool3d: ([5], {}),
torch.nn.AdaptiveLogSoftmaxWithLoss: ([100, 20, [5, 10, 15]], {}),
torch.nn.AdaptiveMaxPool1d: ([5], {}),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Style nit (would not block PR on this or make this change if it's a nontrivial amount of work): prefer tuples to lists

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Any particular reason? I don't mind doing the work to make the change, but it makes it harder to read imo



# Instantiates the given class with the given args, kwargs, optionally on a given device.
def instantiate_class(cls, args, kwargs, device=None):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is interesting. Why is device not part of kwargs?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

args / kwargs are used to instantiate each type of module, and it simplified the calls to just add in device separately when it's needed. But you're right that I could alternatively add device to kwargs

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Whatever's easiest.


# Returns a function that calls the real implementation of a method
# in addition to passing args to a mock object.
def mock_wrapper(method):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I couldn't get wraps to work for this specific use case where I don't have a particular object. This wrapper is used below to hook into all parameter creations or buffer registrations. If you know how to use wraps for this, let me know and I'll change it.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Interesting. I suppose this can wait until it's an issue (and it may never be an issue).


# Returns a function that calls the real implementation of a method
# in addition to passing args to a mock object.
def mock_wrapper(method):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why call the mock?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Calling the mock logs that the method was called

if module_creates_params_or_buffers and module_cls not in MODULES_WITHOUT_KWARGS_SUPPORT:
args, kwargs = get_example_args(module_cls, constructor_arg_db, device=device)

# if module_cls in LAZY_MODULES:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's up with this comment?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is testing logic for LazyTensors that should be put back in once meta tensor functionality is expanded. Not sure whether this PR should wait for that so the tests will be better?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I probably wouldn't wait (although you know better). I would just remove this comment (possibly replacing it with a TODO) or add a meta-comment explaining why this section is commented out.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Luckily enough, the support is in now so I uncommented it!

module_cls = getattr(mod_namespace, module_name)
if module_cls in MODULES_TO_SKIP: continue

# Create a function to run the test and setattr it onto the test class.
Copy link
Collaborator

@mruberry mruberry Mar 29, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not saying this should be the PR that creates ModuleInfo... but MAYBE this (or a follow-up) should be the PR that does it. In particular, generators like this that are lengthy and complicated was one of the reasons we wanted to switch to decorating test templates that look like more typical tests.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I for one am excited for ModuleInfo! As discussed offline, we'll meet sometime shortly after this PR to design what that could look like.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also adding a TODO in the code to refactor the constructor arg DB into ModuleInfo.

test_name = f'test_{namespace_basename}_{module_name}'
setattr(TestModuleInit, test_name, run_test)

instantiate_device_type_tests(test_cls, globals())
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would put this call on the current line 327 for readability and consistency with other test files

return args, kwargs


def generate_tests(test_cls, constructor_arg_db):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add a comment describing what this test tests

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Test logic is actually below in run_test(), and its checks are commented

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I split the logic out for generating tests for a single module, so the separation between that and iterating over all the modules should be clearer now

Copy link
Contributor

@ezyang ezyang left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Giving approval modulo testing; please look at mruberry's comments

@ngimel
Copy link
Collaborator

ngimel commented Apr 20, 2021

@facebook-github-bot
Copy link
Contributor

This pull request has been reverted by 92d24e3.

@facebook-github-bot
Copy link
Contributor

@jbschlosser has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

@facebook-github-bot
Copy link
Contributor

@jbschlosser has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

@ngimel
Copy link
Collaborator

ngimel commented Apr 21, 2021

ASAN is failing in the same way on PR

@jbschlosser
Copy link
Contributor Author

jbschlosser commented Apr 21, 2021

ASAN is failing in the same way on PR

Yep sorry for the mess. I thought the land would have stopped if the problem was unfixed but it went right through :/

Unlanding now and will make sure it's actually fixed before re-landing.

@jbschlosser jbschlosser reopened this Apr 21, 2021
@facebook-github-bot
Copy link
Contributor

@jbschlosser has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

krshrimali pushed a commit to krshrimali/pytorch that referenced this pull request May 19, 2021
Summary:
Continuation of pytorch#53144

Pull Request resolved: pytorch#54508

Reviewed By: mrshenli

Differential Revision: D27600457

Pulled By: jbschlosser

fbshipit-source-id: b58bfee61c3917524b4622f63ef216c27a588eb1
krshrimali pushed a commit to krshrimali/pytorch that referenced this pull request May 19, 2021
Summary:
Continuation of pytorch#53144

Pull Request resolved: pytorch#54508

Reviewed By: bdhirsh

Differential Revision: D27855386

Pulled By: jbschlosser

fbshipit-source-id: dabd505d2a04208e74b158570fb2859c736eea2c
krshrimali pushed a commit to krshrimali/pytorch that referenced this pull request May 19, 2021
Summary:
Continuation of pytorch#53144

Pull Request resolved: pytorch#54508

Reviewed By: malfet

Differential Revision: D27909732

Pulled By: jbschlosser

fbshipit-source-id: d8684b2403ab7eb336371d118799146a2520bd76
krshrimali pushed a commit to krshrimali/pytorch that referenced this pull request May 19, 2021
Summary:
Continuation of pytorch#53144

Pull Request resolved: pytorch#54508

Reviewed By: albanD

Differential Revision: D27939544

Pulled By: jbschlosser

fbshipit-source-id: 4bf517e5f74f093e27ca38a85e732da65e44d805
acairncross added a commit to MyrtleSoftware/myrtle-vision that referenced this pull request May 11, 2022
acairncross added a commit to MyrtleSoftware/myrtle-vision that referenced this pull request May 11, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

8 participants