-
Notifications
You must be signed in to change notification settings - Fork 25.4k
[ca][aot] use GraphModule CodeGen instead of GmWrapper for inputs flattening #141641
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
…ttening [ghstack-poisoned]
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/141641
Note: Links to docs will display an error until the docs builds have been completed. ✅ You can merge normally! (2 Unrelated Failures)As of commit dd3e09f with merge base 78491d6 ( BROKEN TRUNK - The following job failed but were present on the merge base:👉 Rebase onto the `viable/strict` branch to avoid these failures
UNSTABLE - The following job failed but was likely due to flakiness present on trunk and has been marked as unstable:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
… inputs flattening" cc ezyang SherlockNoMad EikanWang jgong5 wenzhe-nrv voznesenskym penguinwu Guobing-Chen XiaobingSuper zhuhaozhe blzheng jiayisunx chenyang78 kadeng chauhang amjames yf225 [ghstack-poisoned]
… inputs flattening" cc ezyang SherlockNoMad EikanWang jgong5 wenzhe-nrv voznesenskym penguinwu Guobing-Chen XiaobingSuper zhuhaozhe blzheng jiayisunx chenyang78 kadeng chauhang amjames yf225 [ghstack-poisoned]
… inputs flattening" cc ezyang SherlockNoMad EikanWang jgong5 wenzhe-nrv voznesenskym penguinwu Guobing-Chen XiaobingSuper zhuhaozhe blzheng jiayisunx chenyang78 kadeng chauhang amjames yf225 [ghstack-poisoned]
… inputs flattening" cc ezyang SherlockNoMad EikanWang jgong5 wenzhe-nrv voznesenskym penguinwu Guobing-Chen XiaobingSuper zhuhaozhe blzheng jiayisunx chenyang78 kadeng chauhang amjames yf225 [ghstack-poisoned]
… inputs flattening" cc ezyang SherlockNoMad EikanWang jgong5 wenzhe-nrv voznesenskym penguinwu Guobing-Chen XiaobingSuper zhuhaozhe blzheng jiayisunx chenyang78 kadeng chauhang amjames yf225 [ghstack-poisoned]
… inputs flattening" cc ezyang SherlockNoMad EikanWang jgong5 wenzhe-nrv voznesenskym penguinwu Guobing-Chen XiaobingSuper zhuhaozhe blzheng jiayisunx chenyang78 kadeng chauhang amjames yf225 [ghstack-poisoned]
|
||
def mark_compiled_autograd_gm(gm: torch.fx.GraphModule): | ||
assert "_compiled_autograd" not in gm.meta | ||
gm.meta["_compiled_autograd"] = True |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A bit uncertain why this is needed
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
CA previously skipped try_get_metadata_from_dynamo, because the lookup logic couldn't find the graph module fields on the GmWrapper. With this PR, we pass the try_get_metadata_from_dynamo checks, but this is an issue because it doesn't understand the dynamo graph's first input is boxed, and the aot graph's first inputs are unpacked
Other than to split out try_get_metadata_from_dynamo CA support, we have a couple of places in the stack using in_compiled_autograd_region
like skipping aot bw's cache, actualizing aot bw's lazy module, etc. which should be replaced to be on a graph level to work well with reentrant autograd
… inputs flattening" cc ezyang SherlockNoMad EikanWang jgong5 wenzhe-nrv voznesenskym penguinwu Guobing-Chen XiaobingSuper zhuhaozhe blzheng jiayisunx chenyang78 kadeng chauhang amjames yf225 [ghstack-poisoned]
assert type(gm.graph._codegen) is torch.fx.graph.CodeGen | ||
assert gm.graph._codegen._body_transformer is None | ||
boxed_inputs_count = len(self.example_inputs()[0]) | ||
gm.graph._codegen = torch.fx.graph._CompiledAutogradCodeGen( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Design wise, is it true that there should be no mutations to the CodeGen object once it is set?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I guess I would have expected something like, instead of mark_compiled_autograd_gm
, just test if the codegen is _CompiledAutogradCodeGen
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That works for passing the graph from dynamo to aotdispatch.
I'm not sure if it works from CA to dynamo, my current approach is to let dynamo tracing know about the boxed input so that it generates the post-graph bytecode properly for the grad mutations. _CompiledAutogradCodeGen would flatten the inputs when we're tracing.
# inputs is a list of activations and params
gradients = graph(inputs, ...)
# inputs is an empty list
inputs[i].grad = gradients[i] # need this to work
and currently we have some aliases generated in the bytecode
# inputs is a list of activations and params
inputs_ref_0 = inputs[0]
gradients = graph(inputs, ...)
# inputs is an empty list
inputs_ref_0.grad = gradients[i]
# https://github.com/pytorch/pytorch/pull/122535/files#r1560096481 | ||
if isinstance( | ||
mod, torch._dynamo.utils.GmWrapper | ||
) or torch._dynamo.utils.is_compiled_autograd_gm(mod): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what's the new invariant now? Are these two conditions always guaranteed to be both true / both false?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
They're not related. GmWrapper is used any time there's bumpy inputs e.g. from a non-dynamo frontend. The only time dynamo allows bumpy inputs is with a compiled autograd graph.
# However, we still want compile-time analysis to be done | ||
# on unpacked inputs as we don't have first class support | ||
# for lists. Hence, we unflatten the inputs here. | ||
return (args[: self._boxed_inputs_count], *args[self._boxed_inputs_count :]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why aren't all inputs just boxed, wouldn't that be simpler?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it's easier to handle a list of only tensors in dynamo variables, the rest of the inputs are symints, python callables
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
More Q for my information: do you end up with an actual ListVariable in Dynamo when you trace the boxed input?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yep, but we don't use the ListVariable after we unpack it into TensorVariables
] | ||
|
||
if torch._dynamo.compiled_autograd.in_compiled_autograd_region: | ||
if is_compiled_autograd_gm(gm): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is more of an "educate me about what the code used to do situation". Let's suppose I graph break in the middle of a compiled autograd region. When I resume compilation on the resumption function, I imagine I wouldn't have a compiled autograd gm anymore, right? So is it OK to not go into this condition? Maybe the argument is that the boxed arguments only occur on entry to the very top of the compiled autograd graph? But then when I graph break and resume, I will have a lot of intermediate stack entries that will get fed in as non-boxed inputs, will these get promptly deallocated?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Like maybe something like #122512 ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
When I resume compilation on the resumption function, I imagine I wouldn't have a compiled autograd gm anymore, right? So is it OK to not go into this condition?
The resume function takes non-boxed inputs, so we don't call this flatten_graph_inputs
altogether. Even if we did change the resume function to take boxed inputs, it shouldn't be possible to have a graph break happen before we unpack the boxed inputs (first nodes after placeholder in the graph)
when I graph break and resume, I will have a lot of intermediate stack entries that will get fed in as non-boxed inputs, will these get promptly deallocated?
No these won't deallocate until the end of that graph
# For overhead reasons, this is not the default wrapper, see comment: | ||
# https://github.com/pytorch/pytorch/pull/122535/files#r1560096481 | ||
if isinstance( | ||
mod, torch._dynamo.utils.GmWrapper |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I guess you gonna delete GmWrapper later?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We can't unless we deprecate non-dynamo frontends who produce graphs with inputs that aren't flat
gm.graph._codegen = torch.fx.graph._CompiledAutogradCodeGen( | ||
boxed_inputs_count | ||
) | ||
mark_compiled_autograd_gm(gm) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not entirely sure what's going on in this block of code.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This seems fine but this PR has mostly told me I don't really understand how the CA specific Dynamo carveouts work lol
Looks like this PR hasn't been updated in a while so we're going to go ahead and mark this as |
Stack from ghstack (oldest at bottom):
cc @ezyang @SherlockNoMad @EikanWang @jgong5 @wenzhe-nrv @voznesenskym @penguinwu @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @jiayisunx @chenyang78 @kadeng @chauhang @amjames @yf225