Skip to content

Conversation

emcastillo
Copy link
Collaborator

@emcastillo emcastillo commented Sep 11, 2020

Retake on #40493 after all the feedback from @albanD

This PR implements the generic Lazy mechanism and a sample LazyLinear layer with the UninitializedParameter.

The main differences with the previous PR are two;
Now torch.nn.Module remains untouched.
We don't require an explicit initialization or a dummy forward pass before starting the training or inference of the actual module. Making this much simpler to use from the user side.

As we discussed offline, there was the suggestion of not using a mixin, but changing the __class__ attribute of LazyLinear to become Linear once it's completely initialized. While this can be useful, by the time being we need LazyLinear to be a torch.nn.Module subclass since there are many checks that rely on the modules being instances of torch.nn.Module.
This can cause problems when we create complex modules such as

class MyNetwork(torch.nn.Module):
    def __init__(self):
        super(MyNetwork, self).__init__()
        self.conv = torch.nn.Conv2d(20, 4, 2)
        self.linear = torch.nn.LazyLinear(10)
    def forward(self, x):
        y = self.conv(x).clamp(min=0)
        return self.linear(y)

Here, when the setattr function is called at the time LazyLinear is registered, it won't be added to the child modules of MyNetwork, so we have to manually do it later, but currently there is no way to do such thing as we can't access the parent module from LazyLinear once it becomes the Linear module. (We can add a workaround to this if needed).

TODO:

Add convolutions once the design is OK
Fix docstrings

@emcastillo emcastillo requested a review from apaszke as a code owner September 11, 2020 06:07
@facebook-github-bot facebook-github-bot added the oncall: jit Add this issue/PR to JIT oncall triage queue label Sep 11, 2020
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This needs more debugging ad it might not be needed

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's the plan to check this? Shall we ask someone from the jit side to take a look?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cc @suo

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

would be nice 😅

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems fine—from what I understand, after the first forward we'll have fully initialized the parameters and the LazyModule will be indistinguishable from a regular nn.Module, right? If so, then it's fine to just error and ask the user to run forward.

@dr-ci
Copy link

dr-ci bot commented Sep 11, 2020

💊 CI failures summary and remediations

As of commit 13cd65b (more details on the Dr. CI page):


  • 3/3 failures possibly* introduced in this PR
    • 1/3 non-CircleCI failure(s)---

2 failures not recognized by patterns:

Job Step Action
CircleCI pytorch_linux_xenial_py3_clang7_onnx_ort_test1 Run tests 🔁 rerun
CircleCI pytorch_linux_xenial_py3_clang7_onnx_ort_test2 Run tests 🔁 rerun

ci.pytorch.org: 1 failed


This comment was automatically generated by Dr. CI (expand for details).Follow this link to opt-out of these comments for your Pull Requests.

Please report bugs/suggestions on the GitHub issue tracker or post in the (internal) Dr. CI Users group.

See how this bot performed.

This comment has been revised 65 times.

@ngimel ngimel added triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module and removed oncall: jit Add this issue/PR to JIT oncall triage queue labels Sep 15, 2020
@dzhulgakov dzhulgakov self-requested a review September 15, 2020 21:31
Copy link
Collaborator

@albanD albanD left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for updating the PR!

We had a chance to sit down with @mruberry and @ngimel to discuss the PR.
We do think it is much "safer" than the previous one and is looking quite good. See the inline comments below.

One question I would have is if it would be possible to remove the LazyModuleMixin from the __class__ after we ran the hook and everything is initialized? After this point, we don't need the custom loading logic at all right?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The location for this docstring might not be great as users won't check this doc in general but only the LazyLinear one (or corresponding Module they use).

If we only do LazyLinear for now, we should just move this to that docstring.
When we add convolution later, maybe this shared part can be moved to a note in the doc that each lazy module we implement refers to.

Also an important part of this doc should be when we need to do a dummy forward pass or not.
In particular it is not required before creating optimizers or serializing but it is unclear what should be done when the user wants to do custom initialization of their weights, or use DataParallel, or DistributedDataParallel, or reparametrization like nn.utils.weight_norm().

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed, having the doc on when we need to do a dummy forward pass is super important for setting expectations :)

@emcastillo
Copy link
Collaborator Author

Thanks for the awesome review as always, will work on it and send changes soon 😇

@yf225
Copy link
Contributor

yf225 commented Sep 25, 2020

Thanks a lot for working on this @emcastillo! This would be really useful as we can take advantage of this lazy-init concept for one of our internal PyTorch library as well.

One question I have is: do you know how it would interact with the Module.to() API? would it error out if the parameters are not initialized yet?

@emcastillo
Copy link
Collaborator Author

We allow moving the parameters to a different device before materializing them, so it is safe to call .to before the initialization :)
They will be instantiated in the correct device

@emcastillo
Copy link
Collaborator Author

I think I addressed all the comments.
Now we forbid to call param.shape, fixed the tests, docs and also make the module to change its class to Linear once it's initialized.

Copy link
Collaborator

@albanD albanD left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks quite good for me.
One question I still have beyond the inline comments is how does it interact with DataParallel and DistributedDataParallel? If they don't we should raise an error or at least document it very clearly if raising an error is tricky.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's the plan to check this? Shall we ask someone from the jit side to take a look?

Copy link
Collaborator

@dzhulgakov dzhulgakov left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks really good! Some really minor comments, but I think it's close to landing

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cc @suo

@emcastillo
Copy link
Collaborator Author

emcastillo commented Oct 1, 2020

Hi, I've been debugging the use of DataParallel with the module and I saw that we require an explicit dummy forward pass.
This is because when the module is replicated, all the parameters are copied as tensors so we lose the ability to magically convert an uninitialized parameter into a parameter and it remains as an empty tensor which can't be materialized.

https://github.com/pytorch/pytorch/blob/master/torch/nn/parallel/replicate.py#L144-L146

Solutions to this are:

  1. Error when a Lazy module is passed to DataParallel and ask the user to do a dummy forward pass (this requires the user to know how the input tensor will be split across the modules).

  2. Do the dummy forward pass in the DataParallel itself before replicating the module, if the module is lazy and has uninitialized parameters we do this in the forward call. (Cons, tricky and too magical)

  3. Add an UnitializedBuffer (Tensor) class that wraps Tensors that were originally UninitializedParameters when calling the replicate function. This will solve the issue and make DataParallel work without any other changes.

@albanD
Copy link
Collaborator

albanD commented Oct 1, 2020

  1. sounds tricky because you will need to keep the autograd link between the parameter and the buffer which can be tricky.
    I think that 1. is a safe alternative for now.

@mruberry
Copy link
Collaborator

Hey @emcastillo! You're so fast I'm afraid you resolved some of my comments before I could even finish editing them. Please take a close look because I've made some updates and, in a few places, you'll need to be more careful than I was about formatting.

I believe I've reviewed all the documentation now. I also updated some other conversations with our most recent thinking for clarity. I think after another round of docs revisions we should be in good shape. I'm looking forward to seeing this PR land.

Please let me know if you have additional questions and ping me when you've had an opportunity to review and update the docs.

@emcastillo
Copy link
Collaborator Author

hahah, I will give another pass :),
Thanks a lot for taking all this time in correcting all these things, I really appreciate all the help.

Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@albanD has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

Copy link
Collaborator

@mruberry mruberry left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looking forward to having the start of lazy modules in PyTorch!

Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@albanD has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

@facebook-github-bot
Copy link
Contributor

@albanD merged this pull request in d38a71d.

@emcastillo
Copy link
Collaborator Author

Thanks a lot everyone for all the help and bringing this to completion!
I will start adding convolutions and batch normalization!

facebook-github-bot pushed a commit that referenced this pull request Dec 1, 2020
Summary:
This PR implements LazyConvXd and LazyConvTransposeXd based on #44538. (cc. emcastillo and albanD)

Pull Request resolved: #47350

Reviewed By: ejguan

Differential Revision: D25220645

Pulled By: albanD

fbshipit-source-id: b5e2e866d53761a3415fd762d05a81920f8b16c3
@asi1024 asi1024 mentioned this pull request Feb 2, 2021
facebook-github-bot pushed a commit that referenced this pull request Feb 5, 2021
Summary:
This PR implements UninitializedBuffer and LazyBatchnormXd based on #44538. (cc. emcastillo and albanD)

Pull Request resolved: #51548

Reviewed By: zhangguanheng66

Differential Revision: D26276903

Pulled By: albanD

fbshipit-source-id: 0ac706974178363f8af075e59b41d5989418922f
facebook-github-bot pushed a commit that referenced this pull request Feb 18, 2021
Summary:
Some minor improvement for lazy modules introduced in #44538, #47350 and #51548.

This PR mainly turn the bias to `UninitializedParameter` and instead of creating empty tensors like
```python
self.bias = Parameter(torch.Tensor(0))
self.bias = UninitializedParameter()
```
I think it would be better to
```python
self.register_parameter('bias', None)
self.bias = UninitializedParameter()
```

In addition, I change the constructor of the `LazyBatchNorm` from
```python
self.running_mean = UninitializedBuffer()
```
to
```python
self.register_buffer('running_mean', UninitializedBuffer())
```
as the original one would not change the underlying `self._buffers`.

Thank you for your time on reviewing this PR :).

Gently ping albanD, mruberry

Pull Request resolved: #52212

Reviewed By: jbschlosser

Differential Revision: D26504508

Pulled By: albanD

fbshipit-source-id: 7094d0bb4fa9e2a40a07b79d350ea12a6ebfd080
xsacha pushed a commit to xsacha/pytorch that referenced this pull request Mar 31, 2021
…h#52212)

Summary:
Some minor improvement for lazy modules introduced in pytorch#44538, pytorch#47350 and pytorch#51548.

This PR mainly turn the bias to `UninitializedParameter` and instead of creating empty tensors like
```python
self.bias = Parameter(torch.Tensor(0))
self.bias = UninitializedParameter()
```
I think it would be better to
```python
self.register_parameter('bias', None)
self.bias = UninitializedParameter()
```

In addition, I change the constructor of the `LazyBatchNorm` from
```python
self.running_mean = UninitializedBuffer()
```
to
```python
self.register_buffer('running_mean', UninitializedBuffer())
```
as the original one would not change the underlying `self._buffers`.

Thank you for your time on reviewing this PR :).

Gently ping albanD, mruberry

Pull Request resolved: pytorch#52212

Reviewed By: jbschlosser

Differential Revision: D26504508

Pulled By: albanD

fbshipit-source-id: 7094d0bb4fa9e2a40a07b79d350ea12a6ebfd080
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Merged open source triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Projects

None yet

Development

Successfully merging this pull request may close these issues.