-
Notifications
You must be signed in to change notification settings - Fork 25.6k
[aota] Allow some mutations in backward #128409
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
[ghstack-poisoned]
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/128409
Note: Links to docs will display an error until the docs builds have been completed. ✅ You can merge normally! (2 Unrelated Failures)As of commit 61ed9b1 with merge base d71f922 ( FLAKY - The following jobs failed but were likely due to flakiness present on trunk:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
[ghstack-poisoned]
requires_grad: bool | ||
keep_input_mutations: bool | ||
# JointFn Mutation Info, is filled only after jointFn tracing. | ||
joint_mutates_data: Optional[bool] = None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We only actually need to support joint_mutates_data
- if any of the other 3 types of mutation below happen during the backward, we should raise an error during tracing and say we don't support it (we technically could support it, but metadata mutation / storage mutation are all kind of niche, and supporting them in the backward adds complexity that we probably don't care about right now)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If these are only going to be set when we need autograd (i.e. trace_joint = True), what do you think about keeping them in an optional list of indices or bools instead of directly on InputInfo? That way you can avoid needing to assert that the joint_mutates_* is not none each time.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Agreed, let's just make it a (non-optional?) list of indices, similar to indices_of_inps_to_detach
(code).
This is actually pretty important for runtime perf: models can have 100+ inputs to the forward and backwrd graph, and looping through all 100 of them can have noticeable runtime overhead.
we can add index i
to the list only if input i gets a backward mutation and it requires grad. That way, our assertion at runtime only has to check if the list is non-empty to raise an error ( but it can still print the actual indices for a good error message).
|
||
# Backward with forward inputs mutations is not supported in double backward. | ||
if torch.is_grad_enabled() and any( | ||
i.joint_mutates_data and not i.mutates_data |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
hmm two things about this assertion:
(1) if we have a graph that mutates a tensor that does not require grad during the backward, this assert will fail incorrectly (even if is_grad_enabled() == True
during the backward and we see an input mutation, as long as that input does not require grad it is ok to keep it in the graph)
(2) this assert will never fire if we have a tensor that is mutated in both the forward and the backward. but that seems wrong.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
About (2): Potentially in future we can retrace backward_module after partitioning and by version counter determine if bwd mutated tensor that was mutated in forward.
|
||
|
||
# This class tells us info about user inputs. | ||
@dataclass(frozen=True) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could we keep this frozen and use dataclasses.replace
to add these values instead? Unless the copy is going to be too much of a perf hit, the benefit of keeping the data structure immutable will be really helpful. Most pre compile wrappers create a new copy of ViewAndMutationData anyway.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks, I did not know about dataclasses.replace - applied.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Moved fields from InputAliasInfo to ViewMutationsMeta field, keeping InputAliasInfo immutable
[ghstack-poisoned]
#127572 Allow mutations in backward on forward inputs, if 1/ not mutationg metadata Enforced at compilation time. 2/ if create_graph=True: mutated input does not require_grad Enforced in runtime, when create_graph mode can be detected by checking torch.is_grad_enabled() Adding joint_mutates_data field to InputAliasInfo, which is not set at creation of InputAliasInfo (after tracing just forward), but is set after tracing joint function => Using dataclasses.replace to update this field. [ghstack-poisoned]
return tuple(out) | ||
|
||
# Backward with forward inputs mutations is not supported in double backward. | ||
if torch.is_grad_enabled() and any( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this still has the runtime performance issues from before: we need to loop through the entire list at runtime (where # entries scales with number of inputs to the backward).
I would just make a new field on ViewAndMutationMeta
that is something like indices_of_inputs_that_require_grad_with_mutations_in_bw: List[int]
, that only contains the indices of inputs that we might need to error on
#127572 Allow mutations in backward on forward inputs, if 1/ not mutationg metadata Enforced at compilation time. 2/ if create_graph=True: mutated input does not require_grad Enforced in runtime, when create_graph mode can be detected by checking torch.is_grad_enabled() Adding input_joint_info to track mutations of inputs during joint. Created a separate field in ViewAndMutationMeta as it is filled only after joint fn tracing. [ghstack-poisoned]
tokens: Dict[Any, torch.Tensor] = field(default_factory=dict) | ||
|
||
# Filled after jointFn tracing. | ||
# Kept for runtime checks when those mutations are allowed. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you add more to this comment - something like:
Only filled in if/when we trace the joint function
If an input requires grad an is mutated in the backward, it is only safe to keep the mutation
in the graph if gradients are disabled while the backward runs
(grad mode is disabled by default when users run the backward, but can be turned on with create_graph=True)
At runtime during the backward, we use this list of indices to error properly if we find out
that it was not safe to include a backward mutation in the graph.
#127572 Allow mutations in backward on forward inputs, if 1/ not mutationg metadata Enforced at compilation time. 2/ if create_graph=True: mutated input does not require_grad Enforced in runtime, when create_graph mode can be detected by checking torch.is_grad_enabled() Adding input_joint_info to track mutations of inputs during joint. Created a separate field in ViewAndMutationMeta as it is filled only after joint fn tracing. [ghstack-poisoned]
#127572 Allow mutations in backward on forward inputs, if 1/ not mutationg metadata Enforced at compilation time. 2/ if create_graph=True: mutated input does not require_grad Enforced in runtime, when create_graph mode can be detected by checking torch.is_grad_enabled() Adding input_joint_info to track mutations of inputs during joint. Created a separate field in ViewAndMutationMeta as it is filled only after joint fn tracing. [ghstack-poisoned]
AssertionError, "input that requires_grad and was mutated in the backward" | ||
): | ||
self.verify_aot_autograd(f, inp_grad, test_mutation=True) | ||
self.verify_aot_autograd(f, inp_grad, test_mutation=True) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: can you add a quick comment in the test explaining that we can properly handle keeping the backward mutation in the graph in this test because the backward is running under no_grad? (someone is gonna look at this test in a year and have to squint really hard to figure that out)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nice!
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
pytorch#127572 Allow mutations in backward on forward inputs, if 1/ not mutationg metadata Enforced at compilation time. 2/ if create_graph=True: mutated input does not require_grad Enforced in runtime, when create_graph mode can be detected by checking torch.is_grad_enabled() Adding input_joint_info to track mutations of inputs during joint. Created a separate field in ViewAndMutationMeta as it is filled only after joint fn tracing. Pull Request resolved: pytorch#128409 Approved by: https://github.com/bdhirsh
pytorch#127572 Allow mutations in backward on forward inputs, if 1/ not mutationg metadata Enforced at compilation time. 2/ if create_graph=True: mutated input does not require_grad Enforced in runtime, when create_graph mode can be detected by checking torch.is_grad_enabled() Adding input_joint_info to track mutations of inputs during joint. Created a separate field in ViewAndMutationMeta as it is filled only after joint fn tracing. Pull Request resolved: pytorch#128409 Approved by: https://github.com/bdhirsh
Stack from ghstack (oldest at bottom):
#127572
Allow mutations in backward on forward inputs, if
1/ not mutationg metadata
Enforced at compilation time.
2/ if create_graph=True: mutated input does not require_grad
Enforced in runtime, when create_graph mode can be detected by checking torch.is_grad_enabled()
Adding input_joint_info to track mutations of inputs during joint.
Created a separate field in ViewAndMutationMeta as it is filled only after joint fn tracing.