-
Notifications
You must be signed in to change notification settings - Fork 25k
expanded weights without fast rules #70140
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
[ghstack-poisoned]
CI Flow Status⚛️ CI FlowRuleset - Version:
You can add a comment to the PR and tag @pytorchbot with the following commands: # ciflow rerun, "ciflow/default" will always be added automatically
@pytorchbot ciflow rerun
# ciflow rerun with additional labels "-l <ciflow/label_name>", which is equivalent to adding these labels manually and trigger the rerun
@pytorchbot ciflow rerun -l ciflow/scheduled -l ciflow/slow For more information, please take a look at the CI Flow Wiki. |
🔗 Helpful links
💊 CI failures summary and remediationsAs of commit f6fa517 (more details on the Dr. CI page):
🕵️ 2 new failures recognized by patternsThe following CI failures do not appear to be due to upstream breakages:
|
[ghstack-poisoned]
[ghstack-poisoned]
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Need to run but here are some initial comments
@property | ||
def shape(self): | ||
return self.orig_weight.shape |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What happens if you call expanded_weight.size()
? does that return the correct thing?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added size()
function
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Follow-up question: (maybe more for @albanD) is it user error to pass a __torch_function__
tensor subclass into some C++ code that requires a Tensor? For the PyTorch frontend API the answer to this is probably no because it should enter __torch_function__
but if users have things like custom C++ operators that they've pybind'ed into Python, do we consider this to be a user error?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It is not an error and the code will use the underlying Tensor as-is. This is how nn.Parameter
works today.
But it is a subtlety that the user needs to be aware, as soon as you pass the frontend API binding or enter any other c++ function. Only the c++ Tensor associated with your subclass will exist.
[Design Doc for Expanded Weights](https://gist.github.com/samdow/fa0a164fec7963f93ff45284989cfc55) <-- gives an overview of the design for Expanded Weights Introduces the ExpandedWeights mechanism and user-facing API without any custom implemented, faster rules. - User facing API is in `_stateless.py` (with documentation) - Testing is in test_expanded_weights (tests in this version only test the fallback and one module, which will also call the fallback) - The rest is the implementation of the slow fallback + the mechanism for being able to register faster per sample grad rules. None of the faster rules are implemented here, but they are all implemented in #70141 [ghstack-poisoned]
torch/nn/utils/_per_sample_grad.py
Outdated
|
||
# dependency on `functional_call` means that this can't be exposed in utils | ||
# without creating circular dependency | ||
def per_sample_call(module, batch_size, args, kwargs=None): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: we might want to bikeshed this name some more. I would be really confused if I saw per_sample_call(model)
in someone's code if I didn't already know about per-sample-gradients
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah definitely see that--what about call_with_per_sample_gradients
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The code in this PR looks good to me but I have a suggestion around organization and testing. This PR introduced:
- The ExpandedWeights Object
- The per_sample_grad API
- A lot of helper functions
Only (1) is being tested in test_expanded_weights.py. (2) and (3) are probably tested in the next PR in the stack (I haven't looked yet) but in general each PR in a stack should be able to stand by themselves. Maybe we should include the per-sample-grad rule for simple layer (like linear) in this PR so we can test the per_sample_call API as well as the helper functions and demonstrate how everything works end-to-end.
Regarding the ExpandedWeights Object -- is there a list of common Tensor attributes to override somewhere? I think we are missing stride()
, is_contiguous()
, memory_format()
, but there might be more. It might be good to find some prominent user of __torch_function__
(I don't know of any) and see what attributes they override
[Design Doc for Expanded Weights](https://gist.github.com/samdow/fa0a164fec7963f93ff45284989cfc55) <-- gives an overview of the design for Expanded Weights Introduces the ExpandedWeights mechanism and user-facing API without any custom implemented, faster rules. - User facing API is in `_stateless.py` (with documentation) - Testing is in test_expanded_weights (tests in this version only test the fallback and one module, which will also call the fallback) - The rest is the implementation of the slow fallback + the mechanism for being able to register faster per sample grad rules. None of the faster rules are implemented here, but they are all implemented in #70141 [ghstack-poisoned]
[Design Doc for Expanded Weights](https://gist.github.com/samdow/fa0a164fec7963f93ff45284989cfc55) <-- gives an overview of the design for Expanded Weights Introduces the ExpandedWeights mechanism and user-facing API without any custom implemented, faster rules. - User facing API is in `_stateless.py` (with documentation) - Testing is in test_expanded_weights (tests in this version only test the fallback and one module, which will also call the fallback) - The rest is the implementation of the slow fallback + the mechanism for being able to register faster per sample grad rules. None of the faster rules are implemented here, but they are all implemented in #70141 [ghstack-poisoned]
[Design Doc for Expanded Weights](https://gist.github.com/samdow/fa0a164fec7963f93ff45284989cfc55) <-- gives an overview of the design for Expanded Weights Introduces the ExpandedWeights mechanism and user-facing API without any custom implemented, faster rules. - User facing API is in `_stateless.py` (with documentation) - Testing is in test_expanded_weights (tests in this version only test the fallback and one module, which will also call the fallback) - The rest is the implementation of the slow fallback + the mechanism for being able to register faster per sample grad rules. None of the faster rules are implemented here, but they are all implemented in #70141 [ghstack-poisoned]
[Design Doc for Expanded Weights](https://gist.github.com/samdow/fa0a164fec7963f93ff45284989cfc55) <-- gives an overview of the design for Expanded Weights Introduces the ExpandedWeights mechanism and user-facing API without any custom implemented, faster rules. - User facing API is in `_stateless.py` (with documentation) - Testing is in test_expanded_weights (tests in this version only test the fallback and one module, which will also call the fallback) - The rest is the implementation of the slow fallback + the mechanism for being able to register faster per sample grad rules. None of the faster rules are implemented here, but they are all implemented in #70141 [ghstack-poisoned]
[Design Doc for Expanded Weights](https://gist.github.com/samdow/fa0a164fec7963f93ff45284989cfc55) <-- gives an overview of the design for Expanded Weights Introduces the ExpandedWeights mechanism and user-facing API without any custom implemented, faster rules. - User facing API is in `_stateless.py` (with documentation) - Testing is in test_expanded_weights (tests in this version only test the fallback and one module, which will also call the fallback) - The rest is the implementation of the slow fallback + the mechanism for being able to register faster per sample grad rules. None of the faster rules are implemented here, but they are all implemented in #70141 [ghstack-poisoned]
[Design Doc for Expanded Weights](https://gist.github.com/samdow/fa0a164fec7963f93ff45284989cfc55) <-- gives an overview of the design for Expanded Weights Introduces the ExpandedWeights mechanism and user-facing API without any custom implemented, faster rules. - User facing API is in `_stateless.py` (with documentation) - Testing is in test_expanded_weights (tests in this version only test the fallback and one module, which will also call the fallback) - The rest is the implementation of the slow fallback + the mechanism for being able to register faster per sample grad rules. None of the faster rules are implemented here, but they are all implemented in #70141 [ghstack-poisoned]
[Design Doc for Expanded Weights](https://gist.github.com/samdow/fa0a164fec7963f93ff45284989cfc55) <-- gives an overview of the design for Expanded Weights Introduces the ExpandedWeights mechanism and user-facing API without any custom implemented, faster rules. - User facing API is in `_stateless.py` (with documentation) - Testing is in test_expanded_weights (tests in this version only test the fallback and one module, which will also call the fallback) - The rest is the implementation of the slow fallback + the mechanism for being able to register faster per sample grad rules. None of the faster rules are implemented here, but they are all implemented in #70141 [ghstack-poisoned]
@zou3519 The remaining test failures look unrelated to the PR. I can try to look at CI and rebase if it's green. Added in this update:
|
[Design Doc for Expanded Weights](https://gist.github.com/samdow/fa0a164fec7963f93ff45284989cfc55) <-- gives an overview of the design for Expanded Weights Introduces the ExpandedWeights mechanism and user-facing API without any custom implemented, faster rules. - User facing API is in `_stateless.py` (with documentation) - Testing is in test_expanded_weights (tests in this version only test the fallback and one module, which will also call the fallback) - The rest is the implementation of the slow fallback + the mechanism for being able to register faster per sample grad rules. None of the faster rules are implemented here, but they are all implemented in #70141 [ghstack-poisoned]
[Design Doc for Expanded Weights](https://gist.github.com/samdow/fa0a164fec7963f93ff45284989cfc55) <-- gives an overview of the design for Expanded Weights Introduces the ExpandedWeights mechanism and user-facing API without any custom implemented, faster rules. - User facing API is in `_stateless.py` (with documentation) - Testing is in test_expanded_weights (tests in this version only test the fallback and one module, which will also call the fallback) - The rest is the implementation of the slow fallback + the mechanism for being able to register faster per sample grad rules. None of the faster rules are implemented here, but they are all implemented in #70141 [ghstack-poisoned]
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
TODO:
- check to see if _make_wrapper_subclass works
- (rzou): aux_output, num_true_outputs is a bit weird, check next stack up
results = [] | ||
|
||
results.append(grad_if_exists_for_input(input, lambda: grad_output.matmul(unpack_expanded_weight_or_tensor(weight)))) | ||
results.extend([None] * 3) # weight and bias don't compute batched gradients |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm confused -- why 3? This means we're returning a total of 4 values from backward, right?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
input, weight, bias, kwarg_names are the inputs
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The update hopefully made this easier to understand because kwarg_names is at the front so the None for it gets added at the start instead of here
class ExpandedWeight(torch.Tensor): | ||
def __init__(self, orig_weight, batch_size): | ||
self.batch_size = batch_size | ||
self.orig_weight = orig_weight |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The ExpandedWeight object has orig_weight's data, and then we're assigning orig_weight to it. So this effectively doubles the parameter
expanded_args_without_kwargs = expanded_args[:2] | ||
output, aux_outputs = forward_helper(F.linear, expanded_args_without_kwargs, expanded_kwargs, 1) | ||
ctx.args = expanded_args_without_kwargs | ||
ctx.kwargs = expanded_kwargs | ||
ctx.aux_outputs = aux_outputs |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Okay the reason why I think these lines tripped me up is that they are generic and that makes them difficult to read.
Roughly speaking, here's what we're doing in the forward pass of this (and all autograd.Function for per-sample-grads):
- run f(*unexpanded_args, *unexpanded_kwargs) (or some hacked version of f, if we need intermediate values)
- save values for backward: save the required unexpanded arguments and intermediates
- return the output (as opposed to all intermediate values)
Now, the reason why this code is confusing is that it's not clear what is being saved. We are saving all the args and kwargs, but we're also saving "aux outputs", which turn out to be nothing in this case.
It would make sense for this to be generic if we planned to refactor all of the autograd.Function forward passes to look the same. Is that a good idea?
For F.linear -- there is actually no need to save bias for the backward pass, and if input does not require gradient, then there is no need to save the input! (doesn't need to happen in this PR, but those are potential optimizations). So we might not want all the autograd.Function forward passes to look similar (especially because there are already checks specific to F.linear here)
If we decide that we want this to look generic, I'd probably recommend aux_outputs be renamed to intermediates.
If we decide we don't want this to look generic, It could read better as:
output, = forward_helper(F.linear, expanded_args, expanded_kwargs)
ctx.unexpanded_weight = expanded_args[1]
return output
For e.g. group_norm this could look like:
output, mean, rstd = forward_helper(torch.aten.ops.native_group_norm, expanded_args, expanded_kwargs)
ctx.mean = mean
ctx.rstd = rstd
return output
The benefit of the non-generic form is that one doesn't have to go deep diving into the backward() part of the autograd.Function to see exactly what args, kwargs, aux_outputs are.
[Design Doc for Expanded Weights](https://gist.github.com/samdow/fa0a164fec7963f93ff45284989cfc55) <-- gives an overview of the design for Expanded Weights Introduces the ExpandedWeights mechanism and user-facing API without any custom implemented, faster rules. - User facing API is in `_stateless.py` (with documentation) - Testing is in test_expanded_weights (tests in this version only test the fallback and one module, which will also call the fallback) - The rest is the implementation of the slow fallback + the mechanism for being able to register faster per sample grad rules. None of the faster rules are implemented here, but they are all implemented in #70141 [ghstack-poisoned]
[Design Doc for Expanded Weights](https://gist.github.com/samdow/fa0a164fec7963f93ff45284989cfc55) <-- gives an overview of the design for Expanded Weights Introduces the ExpandedWeights mechanism and user-facing API without any custom implemented, faster rules. - User facing API is in `_stateless.py` (with documentation) - Testing is in test_expanded_weights (tests in this version only test the fallback and one module, which will also call the fallback) - The rest is the implementation of the slow fallback + the mechanism for being able to register faster per sample grad rules. None of the faster rules are implemented here, but they are all implemented in #70141 [ghstack-poisoned]
@zou3519 All the comments make sense! I moved the output unpacking to be function specific and added the comments. Per offline discussion, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Cool!
We should probably file a follow-up issue to see if we actually duplicate the memory when using _make_subclass
@samdow has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator. |
[Design Doc for Expanded Weights](https://gist.github.com/samdow/fa0a164fec7963f93ff45284989cfc55) <-- gives an overview of the design for Expanded Weights Introduces the ExpandedWeights mechanism and user-facing API without any custom implemented, faster rules. - User facing API is in `_stateless.py` (with documentation) - Testing is in test_expanded_weights - The rest is the implementation of the erroring fallback + the mechanism for being able to register faster per sample grad rules. Only linear is implemented here, but they are all implemented in #70141 Differential Revision: [D34350950](https://our.internmc.facebook.com/intern/diff/D34350950) [ghstack-poisoned]
[Design Doc for Expanded Weights](https://gist.github.com/samdow/fa0a164fec7963f93ff45284989cfc55) <-- gives an overview of the design for Expanded Weights Introduces the ExpandedWeights mechanism and user-facing API without any custom implemented, faster rules. - User facing API is in `_stateless.py` (with documentation) - Testing is in test_expanded_weights - The rest is the implementation of the erroring fallback + the mechanism for being able to register faster per sample grad rules. Only linear is implemented here, but they are all implemented in #70141 Differential Revision: [D34350950](https://our.internmc.facebook.com/intern/diff/D34350950) [ghstack-poisoned]
@samdow has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator. |
Summary: Pull Request resolved: #70140 [Design Doc for Expanded Weights](https://gist.github.com/samdow/fa0a164fec7963f93ff45284989cfc55) <-- gives an overview of the design for Expanded Weights Introduces the ExpandedWeights mechanism and user-facing API without any custom implemented, faster rules. - User facing API is in `_stateless.py` (with documentation) - Testing is in test_expanded_weights - The rest is the implementation of the erroring fallback + the mechanism for being able to register faster per sample grad rules. Only linear is implemented here, but they are all implemented in #70141 Test Plan: Imported from OSS Reviewed By: mikaylagawarecki Differential Revision: D34350950 Pulled By: samdow fbshipit-source-id: 69c664b0bc3dff6951358d79d7e5d94882f7aef2
Hey @samdow. |
Added not user facing because we don't add release notes for prototype features. Proper tags to be added when this becomes beta |
Summary: Pull Request resolved: pytorch/pytorch#70140 [Design Doc for Expanded Weights](https://gist.github.com/samdow/fa0a164fec7963f93ff45284989cfc55) <-- gives an overview of the design for Expanded Weights Introduces the ExpandedWeights mechanism and user-facing API without any custom implemented, faster rules. - User facing API is in `_stateless.py` (with documentation) - Testing is in test_expanded_weights - The rest is the implementation of the erroring fallback + the mechanism for being able to register faster per sample grad rules. Only linear is implemented here, but they are all implemented in #70141 Test Plan: Imported from OSS Reviewed By: mikaylagawarecki Differential Revision: D34350950 Pulled By: samdow fbshipit-source-id: 69c664b0bc3dff6951358d79d7e5d94882f7aef2 (cherry picked from commit ae1620d3b6507b27c3bc08ecfb2b1418aa8ce7d7)
Summary: Pull Request resolved: pytorch/pytorch#70140 [Design Doc for Expanded Weights](https://gist.github.com/samdow/fa0a164fec7963f93ff45284989cfc55) <-- gives an overview of the design for Expanded Weights Introduces the ExpandedWeights mechanism and user-facing API without any custom implemented, faster rules. - User facing API is in `_stateless.py` (with documentation) - Testing is in test_expanded_weights - The rest is the implementation of the erroring fallback + the mechanism for being able to register faster per sample grad rules. Only linear is implemented here, but they are all implemented in #70141 Test Plan: Imported from OSS Reviewed By: mikaylagawarecki Differential Revision: D34350950 Pulled By: samdow fbshipit-source-id: 69c664b0bc3dff6951358d79d7e5d94882f7aef2 (cherry picked from commit ae1620d3b6507b27c3bc08ecfb2b1418aa8ce7d7)
Stack from ghstack:
Design Doc for Expanded Weights <-- gives an overview of the design for Expanded Weights
Introduces the ExpandedWeights mechanism and user-facing API without any custom implemented, faster rules.
_stateless.py
(with documentation)Differential Revision: D34350950