-
Notifications
You must be signed in to change notification settings - Fork 25.6k
torch.nn.modules.LazyModuleMixin
and torch.nn.LazyLinear
(Shape Inference II)
#44538
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
torch/jit/_recursive.py
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This needs more debugging ad it might not be needed
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What's the plan to check this? Shall we ask someone from the jit side to take a look?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
cc @suo
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
would be nice 😅
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This seems fine—from what I understand, after the first forward we'll have fully initialized the parameters and the LazyModule
will be indistinguishable from a regular nn.Module
, right? If so, then it's fine to just error and ask the user to run forward.
💊 CI failures summary and remediationsAs of commit 13cd65b (more details on the Dr. CI page):
2 failures not recognized by patterns:
ci.pytorch.org: 1 failedThis comment was automatically generated by Dr. CI (expand for details).Follow this link to opt-out of these comments for your Pull Requests.Please report bugs/suggestions on the GitHub issue tracker or post in the (internal) Dr. CI Users group. This comment has been revised 65 times. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for updating the PR!
We had a chance to sit down with @mruberry and @ngimel to discuss the PR.
We do think it is much "safer" than the previous one and is looking quite good. See the inline comments below.
One question I would have is if it would be possible to remove the LazyModuleMixin from the __class__
after we ran the hook and everything is initialized? After this point, we don't need the custom loading logic at all right?
torch/nn/modules/lazy.py
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The location for this docstring might not be great as users won't check this doc in general but only the LazyLinear
one (or corresponding Module they use).
If we only do LazyLinear
for now, we should just move this to that docstring.
When we add convolution later, maybe this shared part can be moved to a note in the doc that each lazy module we implement refers to.
Also an important part of this doc should be when we need to do a dummy forward pass or not.
In particular it is not required before creating optimizers or serializing but it is unclear what should be done when the user wants to do custom initialization of their weights, or use DataParallel, or DistributedDataParallel, or reparametrization like nn.utils.weight_norm()
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Agreed, having the doc on when we need to do a dummy forward pass is super important for setting expectations :)
Thanks for the awesome review as always, will work on it and send changes soon 😇 |
Thanks a lot for working on this @emcastillo! This would be really useful as we can take advantage of this lazy-init concept for one of our internal PyTorch library as well. One question I have is: do you know how it would interact with the |
We allow moving the parameters to a different device before materializing them, so it is safe to call .to before the initialization :) |
4b39db6
to
77fb058
Compare
I think I addressed all the comments. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This looks quite good for me.
One question I still have beyond the inline comments is how does it interact with DataParallel and DistributedDataParallel? If they don't we should raise an error or at least document it very clearly if raising an error is tricky.
torch/jit/_recursive.py
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What's the plan to check this? Shall we ask someone from the jit side to take a look?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This looks really good! Some really minor comments, but I think it's close to landing
torch/jit/_recursive.py
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
cc @suo
Hi, I've been debugging the use of DataParallel with the module and I saw that we require an explicit dummy forward pass. https://github.com/pytorch/pytorch/blob/master/torch/nn/parallel/replicate.py#L144-L146 Solutions to this are:
|
|
Hey @emcastillo! You're so fast I'm afraid you resolved some of my comments before I could even finish editing them. Please take a close look because I've made some updates and, in a few places, you'll need to be more careful than I was about formatting. I believe I've reviewed all the documentation now. I also updated some other conversations with our most recent thinking for clarity. I think after another round of docs revisions we should be in good shape. I'm looking forward to seeing this PR land. Please let me know if you have additional questions and ping me when you've had an opportunity to review and update the docs. |
hahah, I will give another pass :), |
3ad42a7
to
b0518f1
Compare
e368946
to
1d3d0ab
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@albanD has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looking forward to having the start of lazy modules in PyTorch!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@albanD has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
Thanks a lot everyone for all the help and bringing this to completion! |
Summary: Some minor improvement for lazy modules introduced in #44538, #47350 and #51548. This PR mainly turn the bias to `UninitializedParameter` and instead of creating empty tensors like ```python self.bias = Parameter(torch.Tensor(0)) self.bias = UninitializedParameter() ``` I think it would be better to ```python self.register_parameter('bias', None) self.bias = UninitializedParameter() ``` In addition, I change the constructor of the `LazyBatchNorm` from ```python self.running_mean = UninitializedBuffer() ``` to ```python self.register_buffer('running_mean', UninitializedBuffer()) ``` as the original one would not change the underlying `self._buffers`. Thank you for your time on reviewing this PR :). Gently ping albanD, mruberry Pull Request resolved: #52212 Reviewed By: jbschlosser Differential Revision: D26504508 Pulled By: albanD fbshipit-source-id: 7094d0bb4fa9e2a40a07b79d350ea12a6ebfd080
…h#52212) Summary: Some minor improvement for lazy modules introduced in pytorch#44538, pytorch#47350 and pytorch#51548. This PR mainly turn the bias to `UninitializedParameter` and instead of creating empty tensors like ```python self.bias = Parameter(torch.Tensor(0)) self.bias = UninitializedParameter() ``` I think it would be better to ```python self.register_parameter('bias', None) self.bias = UninitializedParameter() ``` In addition, I change the constructor of the `LazyBatchNorm` from ```python self.running_mean = UninitializedBuffer() ``` to ```python self.register_buffer('running_mean', UninitializedBuffer()) ``` as the original one would not change the underlying `self._buffers`. Thank you for your time on reviewing this PR :). Gently ping albanD, mruberry Pull Request resolved: pytorch#52212 Reviewed By: jbschlosser Differential Revision: D26504508 Pulled By: albanD fbshipit-source-id: 7094d0bb4fa9e2a40a07b79d350ea12a6ebfd080
Retake on #40493 after all the feedback from @albanD
This PR implements the generic Lazy mechanism and a sample
LazyLinear
layer with theUninitializedParameter
.The main differences with the previous PR are two;
Now
torch.nn.Module
remains untouched.We don't require an explicit initialization or a dummy forward pass before starting the training or inference of the actual module. Making this much simpler to use from the user side.
As we discussed offline, there was the suggestion of not using a mixin, but changing the
__class__
attribute ofLazyLinear
to becomeLinear
once it's completely initialized. While this can be useful, by the time being we needLazyLinear
to be atorch.nn.Module
subclass since there are many checks that rely on the modules being instances oftorch.nn.Module
.This can cause problems when we create complex modules such as
Here, when the setattr function is called at the time LazyLinear is registered, it won't be added to the child modules of
MyNetwork
, so we have to manually do it later, but currently there is no way to do such thing as we can't access the parent module from LazyLinear once it becomes the Linear module. (We can add a workaround to this if needed).TODO:
Add convolutions once the design is OK
Fix docstrings