-
Notifications
You must be signed in to change notification settings - Fork 129
DDP optimization via graph-breaks in Dynamo #628
Conversation
98b402f to
dacc416
Compare
torchdynamo/eval_frame.py
Outdated
| with compile_lock: | ||
| # TODO(whc) find a way to get these parameters from the DDP module | ||
| # and configure the backend compiler and convert_frame correctly | ||
| ddp_optimizer = DDPOptimizer(bucket_cap_mb = 25, parameters_to_ignore = [], backend_compile_fn = BACKENDS["aot_autograd"]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@jansel any ideas how to achieve these few lines worth of stuff in a more sane way?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
- how to detect 'inside ddp'
- how to access attributes of the DDP module object (such as its 'bucket_cap_mb' attr)
- how to get the user-provided backend compile_fn rather than hardcoding aot_autograd here
- whether the hijacked_callback approach is even correct?
dacc416 to
eacb1c3
Compare
| print("DDPOptimizer called with FX graph:") | ||
| gm.graph.print_tabular() | ||
| print() | ||
| # 1. Splitting the gm into small graphs, following DDP parameter bucketing |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That means this logic needs to be synchronized with the DDP bucketing logic, is that right
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
right.
- ddp first buckets based on reverse order of parameter declaration in nn module structure
- then after first run, ddp rebuckets (incurs tensor copies) using the actual execution order of first backward, which can be different
- ddp's actual bucketing logic is pretty simple, given an order of parameters and a set of parameters to exclude, it fills buckets to capacity
|
@wconstab now that you've gotten this working e2e, how long do you think we're going to keep using this code? It looks like it ended up being a bit complicated, and iirc, @jansel proposed a different strategy that specialized on whether or not we were the first use of a parameter or not (though, tbf, I don't remember how this proposal was going to work) |
@ezyang I'd say this is barely working e2e, it's still a POC and needs some help to be viable (hence the questions on the PR). I think this approach could have 2 outcomes
All of that depends on perf benchmark numbers too. It may be that the cost/benefit on benchmarks isn't there for investing in the harder approach, so we should check the perf and prioritize accordingly |
2fb63c4 to
e1c5e5e
Compare
torchdynamo/eval_frame.py
Outdated
| def catch_errors(frame, cache_size): | ||
| try: | ||
| # TODO(whc) move the ddp code below the bailouts. but make sure it doesn't cause DDP to be skipped | ||
| ddp_module = DistributedDataParallel.get_active_ddp_module() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what is get_active_ddp_module?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I added that api in a diff on the pytorch side: pytorch/pytorch#83333
The idea is, I needed to do 2 things and i solved them both with one API.
- let torchdynamo know if its compiling a frame that is inside a DDP forward context or not
- give dynamo's DDPOptimizer access to the bucket size, and other state vars for the current DDP module
Returning 'None' for active DDP module from get_active_ddp_module() is the equivalent of #1 being false, and I can then extract the needed values off the module if #1 is true.
| bucket_actual_sizes = [] | ||
| node_splits = [[]] | ||
| for node in reversed(gm.graph.nodes): | ||
| if bucket_bytes >= self.bucket_bytes_cap: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
DDP might add a 1MB bucket to kick off allreduce sooner: https://github.com/pytorch/pytorch/blob/f81b4ae55cf4d9b44641178f31fd713f65d5af2e/torch/nn/parallel/distributed.py#L720
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ok, that might be worth implementing here. If you have a preference on what the logic would look like let me know.
| node_splits = [[]] | ||
| for node in reversed(gm.graph.nodes): | ||
| if bucket_bytes >= self.bucket_bytes_cap: | ||
| bucket_actual_sizes.insert(0, bucket_bytes) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do you want us to add an API in DDP to align its bucketization with provided parameter-to-bucket mappings? Or do we need more exploration before moving on with that?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think that is a good idea, just to make the behavior of the 'compiled' program seem more stable/predictable. But i'd rather treat that as an optimization to add after getting the critical parts working.
| self.parameters_to_ignore = parameters_to_ignore | ||
| self.backend_compile_fn = backend_compile_fn | ||
|
|
||
| def compile_fn(self, gm: fx.GraphModule, example_inputs: List[torch.Tensor]): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what happens with dynamic execution orders? Will re-compilation leads to new graph breaks? DDP might not be able to adapt to new bucketization accordingly, as 1) frequent rebucketziation is expensive 2) re-compilation is a one-time thing, but execution might switch between different orders multiple times.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's assume that DDP initialized its buckets with the 'wrong' order, i.e. an order that will not be observed at runtime.
Also note, I don't think it's possible under dynamo for there to be dynamic behavior within one compiled function. If something causes dynamism, it would force dynamo to recompile anyways. This could be OK if it happened infrequently but otherwise it would be a perf issue (regardless of DDP) and the user might need to change their model to avoid it if using dynamo.
So, what I think would happen is
- DDP's initial buckets are chosen. They may be wrong, and are currently ignored by DDPOptimizer.
- DDPOptimizer partitions the graph into buckets according to the true order of execution of the forward graph.
- Then it runs AotAutograd and generates a backward graph, but bucket ordering is not updated respecting the backward graph.
- At runtime, the compiled fwd/bwd graph is run and hooks fire in-between the split subgraphs. DDP may observe a different execution order and update its bucketing scheme after the first iteration.
If the program is not dynamic, things would be stable after that. If there is dynamism, then I think each new version would incur dynamo compilation, and if execution order is different enough to make DDP reconfigure buckets, it seems that DDP would do so multiple times.
This seems to suggest that dynamo and DDP costs under dynamism are orthogonal: in eager, dynamism would potentially cause multiple re-configurations of DDP buckets, and that remains the same under dynamo. Without DDP, dynamo still incurs expensive recompilation under dynamism. Does that sound right to you @mrshenli?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yep, this makes sense to me.
in eager, dynamism would potentially cause multiple re-configurations of DDP buckets, and that remains the same under dynamo.
Today's DDP will only reconfigure buckets once after the first iteration. But I agree with your point that, in this case, Dynamo + DDP should be able to offer on-par perf with eager-DDP.
If there is dynamism, then I think each new version would incur dynamo compilation, and if execution order is different enough to make DDP reconfigure buckets, it seems that DDP would do so multiple times.
My original concern was that, suppose there are one dynamic if-else clause, and the app will switch between these two paths every 100 iterations. Would I be correct that dynamo will only trigger re-compile on the 1st and 101st iteration, but not 201st iteration as it is the same as the cached on for the 1st iteration? If that's the case, how can we tell DDP to update its buckets on the 201st iteration. Just OOC, I don't think this is a high-pri to address.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would I be correct that dynamo will only trigger re-compile on the 1st and 101st iteration, but not 201st iteration as it is the same as the cached on for the 1st iteration?
yes, this is correct. Dynamo would reuse cached programs for the if/else branches after they got compiled once each.
how can we tell DDP to update its buckets on the 201st iteration.
Well, (1) dynamo's DDPOptimizer could insert a node into its compiled graph which calls into the DDP module and notifies it which bucket strategy it is about to use. (2) this could happen on every iteration if we could make it fast enough and DDP mostly ignores it or checks its hash against the most recent one; or we could try to make dynamo actually pay attention to when its cached program changed from last cached program and only on those boundary conditions issue the notification to DDP.
I'm not sure how well this would work, but it seems like a possibility. It might make more sense to invest in a new DDP approach that traces comm ops into the graph or implements DDP in FX though. I see this graph-break solution as potentially a stop-gap.
| for node in nodes: | ||
| partition_map[node] = p | ||
| split_gm = fx.passes.split_module.split_module( | ||
| gm, None, lambda node: partition_map[node] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@mrshenli I've just rewritten most of this DDPOptimizer to use fx.passes.split_module pass. That pass accepts a 'partition function' which takes 'fx.Node' as input and returns int that represents which partition it should be placed in by the pass.
Since I was iterating on Jiewen's POC, I still used his bucketing logic (above) to build a map, then I made my partition function a simple lambda that reads from the map.
I would propose instead that we put a 'DDP partition function' inside pytorch core in the distributed.py, and I can just call that from here and delete the above code. This way the logic lives close to DDP and can be more easily be maintained by your team. Wdyt?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
SGTM, it might be easier to keep it consistent with the rest of DDP.
Question: would I be correct that we are not planning to change existing DDP to use this partition function any time soon though? Because there is no guarantee that we can get an fx.Graph from the input module? Besides, today's DDP already has a execution order based solution to re-organize buckets?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we are not planning to change existing DDP to use this partition function any time soon though
if you mean rewriting core DDP to rely on this FX pass, I wasn't planning on it.
there is no guarantee that we can get an fx.Graph from the input module
right, I think there could be cases where legacy DDP works but FX-DDP couldn't get a single whole graph. I'm not sure how often this would be in practice, or if we could make due with applying bucketization to partial graphs that were captured in those cases.
torchdynamo/eval_frame.py
Outdated
| ddp_optimizer = DDPOptimizer( | ||
| bucket_bytes_cap=ddp_module.bucket_bytes_cap, | ||
| parameters_to_ignore=ddp_module.parameters_to_ignore, | ||
| backend_compile_fn=my_compiler, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@anijain2305 @voznesenskym Any idea how I can get the 'my_compiler' function that the user provided in their 'torchdynamo.optimize' call? I could probably figure out a way to hack it together, but I am hoping for a clean way.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not in front of it right now, I can check definitively later, but is that not callback in this func?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
originally i assumed callback was the thing, and i used it. But i have been facing some (possibly unrelated) bugs, and when i read the code I thought callback is actually wrapped in some context which i am not sure i wanted to reuse.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
oh that might be convert_frame and friends. Try adding this line to convert_frame before the return
_convert_frame._compiler_fn = compiler_fn
And then see if your callback has a _compiler_fn? I can look in a moment as well.
729dcec to
f711f1e
Compare
a2cfabf to
4249bbd
Compare
288189b to
fef2f39
Compare
fef2f39 to
4291830
Compare
Attempts to interpose between the unmodified DDP program and user-provided dynamo backend, adding a special 'graph break' stage that breaks up graphs to be compiled by dynamo according to the DDP bucketing strategy Adds test_distributed.py for DDPOptimizer
4291830 to
f080298
Compare
- mark failing ones that regressed - add new baseline tests for ddp/fsdp not blowing up (without using optimizer)
- use gloo instead of nccl as nccl isn't in our pytorch build - avoid running the optimizer tests that depend on pytorch changes unless those changes are present
Pairs up with torchdynamo PR pytorch/torchdynamo#628 Exposes a new API that lets torchdynamo know when it is compiling the 'forward' of a module that is inside a DDPmodule. Pull Request resolved: #83333 Approved by: https://github.com/mrshenli
Pairs up with torchdynamo PR pytorch/torchdynamo#628 Exposes a new API that lets torchdynamo know when it is compiling the 'forward' of a module that is inside a DDPmodule. Pull Request resolved: #83333 Approved by: https://github.com/mrshenli
Attempts to interpose between the unmodified DDP program and
user-provided dynamo backend, adding a special 'graph break'
stage that breaks up graphs to be compiled by dynamo according
to the DDP bucketing strategy
Detects DDP by using a new API (get_active_ddp_module) exposed by torch DDP, and only enables itself inside the 'run_ddp_forward' method. Also grabs bucket size and other metadata from the active DDP module.
TODO
Builds on prototype from @alanwaketan #489
cc @ezyang @mrshenli @anijain2305 @jansel