Skip to content
This repository was archived by the owner on Aug 1, 2025. It is now read-only.

Conversation

@wconstab
Copy link
Contributor

@wconstab wconstab commented Jul 21, 2022

Attempts to interpose between the unmodified DDP program and
user-provided dynamo backend, adding a special 'graph break'
stage that breaks up graphs to be compiled by dynamo according
to the DDP bucketing strategy

Detects DDP by using a new API (get_active_ddp_module) exposed by torch DDP, and only enables itself inside the 'run_ddp_forward' method. Also grabs bucket size and other metadata from the active DDP module.

TODO

  • validate/benchmark

Builds on prototype from @alanwaketan #489

cc @ezyang @mrshenli @anijain2305 @jansel

with compile_lock:
# TODO(whc) find a way to get these parameters from the DDP module
# and configure the backend compiler and convert_frame correctly
ddp_optimizer = DDPOptimizer(bucket_cap_mb = 25, parameters_to_ignore = [], backend_compile_fn = BACKENDS["aot_autograd"])
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jansel any ideas how to achieve these few lines worth of stuff in a more sane way?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  • how to detect 'inside ddp'
  • how to access attributes of the DDP module object (such as its 'bucket_cap_mb' attr)
  • how to get the user-provided backend compile_fn rather than hardcoding aot_autograd here
  • whether the hijacked_callback approach is even correct?

print("DDPOptimizer called with FX graph:")
gm.graph.print_tabular()
print()
# 1. Splitting the gm into small graphs, following DDP parameter bucketing
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That means this logic needs to be synchronized with the DDP bucketing logic, is that right

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

right.

  1. ddp first buckets based on reverse order of parameter declaration in nn module structure
  2. then after first run, ddp rebuckets (incurs tensor copies) using the actual execution order of first backward, which can be different
  3. ddp's actual bucketing logic is pretty simple, given an order of parameters and a set of parameters to exclude, it fills buckets to capacity

@ezyang
Copy link
Contributor

ezyang commented Jul 21, 2022

@wconstab now that you've gotten this working e2e, how long do you think we're going to keep using this code? It looks like it ended up being a bit complicated, and iirc, @jansel proposed a different strategy that specialized on whether or not we were the first use of a parameter or not (though, tbf, I don't remember how this proposal was going to work)

@wconstab
Copy link
Contributor Author

@wconstab now that you've gotten this working e2e, how long do you think we're going to keep using this code? It looks like it ended up being a bit complicated, and iirc, @jansel proposed a different strategy that specialized on whether or not we were the first use of a parameter or not (though, tbf, I don't remember how this proposal was going to work)

@ezyang I'd say this is barely working e2e, it's still a POC and needs some help to be viable (hence the questions on the PR). I think this approach could have 2 outcomes

  • best case (for it): we don't find much value in larger-graph optimizations and/or it is really hard to fix AotAutograd tracing collectives, so we keep using this
  • worst case (for it): we only use this for a really short time before we find it necessary to do the tracing approach

All of that depends on perf benchmark numbers too. It may be that the cost/benefit on benchmarks isn't there for investing in the harder approach, so we should check the perf and prioritize accordingly

def catch_errors(frame, cache_size):
try:
# TODO(whc) move the ddp code below the bailouts. but make sure it doesn't cause DDP to be skipped
ddp_module = DistributedDataParallel.get_active_ddp_module()

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what is get_active_ddp_module?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added that api in a diff on the pytorch side: pytorch/pytorch#83333

The idea is, I needed to do 2 things and i solved them both with one API.

  1. let torchdynamo know if its compiling a frame that is inside a DDP forward context or not
  2. give dynamo's DDPOptimizer access to the bucket size, and other state vars for the current DDP module

Returning 'None' for active DDP module from get_active_ddp_module() is the equivalent of #1 being false, and I can then extract the needed values off the module if #1 is true.

bucket_actual_sizes = []
node_splits = [[]]
for node in reversed(gm.graph.nodes):
if bucket_bytes >= self.bucket_bytes_cap:

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok, that might be worth implementing here. If you have a preference on what the logic would look like let me know.

node_splits = [[]]
for node in reversed(gm.graph.nodes):
if bucket_bytes >= self.bucket_bytes_cap:
bucket_actual_sizes.insert(0, bucket_bytes)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you want us to add an API in DDP to align its bucketization with provided parameter-to-bucket mappings? Or do we need more exploration before moving on with that?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think that is a good idea, just to make the behavior of the 'compiled' program seem more stable/predictable. But i'd rather treat that as an optimization to add after getting the critical parts working.

self.parameters_to_ignore = parameters_to_ignore
self.backend_compile_fn = backend_compile_fn

def compile_fn(self, gm: fx.GraphModule, example_inputs: List[torch.Tensor]):

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what happens with dynamic execution orders? Will re-compilation leads to new graph breaks? DDP might not be able to adapt to new bucketization accordingly, as 1) frequent rebucketziation is expensive 2) re-compilation is a one-time thing, but execution might switch between different orders multiple times.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's assume that DDP initialized its buckets with the 'wrong' order, i.e. an order that will not be observed at runtime.

Also note, I don't think it's possible under dynamo for there to be dynamic behavior within one compiled function. If something causes dynamism, it would force dynamo to recompile anyways. This could be OK if it happened infrequently but otherwise it would be a perf issue (regardless of DDP) and the user might need to change their model to avoid it if using dynamo.

So, what I think would happen is

  1. DDP's initial buckets are chosen. They may be wrong, and are currently ignored by DDPOptimizer.
  2. DDPOptimizer partitions the graph into buckets according to the true order of execution of the forward graph.
  3. Then it runs AotAutograd and generates a backward graph, but bucket ordering is not updated respecting the backward graph.
  4. At runtime, the compiled fwd/bwd graph is run and hooks fire in-between the split subgraphs. DDP may observe a different execution order and update its bucketing scheme after the first iteration.

If the program is not dynamic, things would be stable after that. If there is dynamism, then I think each new version would incur dynamo compilation, and if execution order is different enough to make DDP reconfigure buckets, it seems that DDP would do so multiple times.

This seems to suggest that dynamo and DDP costs under dynamism are orthogonal: in eager, dynamism would potentially cause multiple re-configurations of DDP buckets, and that remains the same under dynamo. Without DDP, dynamo still incurs expensive recompilation under dynamism. Does that sound right to you @mrshenli?

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep, this makes sense to me.

in eager, dynamism would potentially cause multiple re-configurations of DDP buckets, and that remains the same under dynamo.

Today's DDP will only reconfigure buckets once after the first iteration. But I agree with your point that, in this case, Dynamo + DDP should be able to offer on-par perf with eager-DDP.

If there is dynamism, then I think each new version would incur dynamo compilation, and if execution order is different enough to make DDP reconfigure buckets, it seems that DDP would do so multiple times.

My original concern was that, suppose there are one dynamic if-else clause, and the app will switch between these two paths every 100 iterations. Would I be correct that dynamo will only trigger re-compile on the 1st and 101st iteration, but not 201st iteration as it is the same as the cached on for the 1st iteration? If that's the case, how can we tell DDP to update its buckets on the 201st iteration. Just OOC, I don't think this is a high-pri to address.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would I be correct that dynamo will only trigger re-compile on the 1st and 101st iteration, but not 201st iteration as it is the same as the cached on for the 1st iteration?

yes, this is correct. Dynamo would reuse cached programs for the if/else branches after they got compiled once each.

how can we tell DDP to update its buckets on the 201st iteration.

Well, (1) dynamo's DDPOptimizer could insert a node into its compiled graph which calls into the DDP module and notifies it which bucket strategy it is about to use. (2) this could happen on every iteration if we could make it fast enough and DDP mostly ignores it or checks its hash against the most recent one; or we could try to make dynamo actually pay attention to when its cached program changed from last cached program and only on those boundary conditions issue the notification to DDP.

I'm not sure how well this would work, but it seems like a possibility. It might make more sense to invest in a new DDP approach that traces comm ops into the graph or implements DDP in FX though. I see this graph-break solution as potentially a stop-gap.

for node in nodes:
partition_map[node] = p
split_gm = fx.passes.split_module.split_module(
gm, None, lambda node: partition_map[node]
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@mrshenli I've just rewritten most of this DDPOptimizer to use fx.passes.split_module pass. That pass accepts a 'partition function' which takes 'fx.Node' as input and returns int that represents which partition it should be placed in by the pass.

Since I was iterating on Jiewen's POC, I still used his bucketing logic (above) to build a map, then I made my partition function a simple lambda that reads from the map.

I would propose instead that we put a 'DDP partition function' inside pytorch core in the distributed.py, and I can just call that from here and delete the above code. This way the logic lives close to DDP and can be more easily be maintained by your team. Wdyt?

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

SGTM, it might be easier to keep it consistent with the rest of DDP.

Question: would I be correct that we are not planning to change existing DDP to use this partition function any time soon though? Because there is no guarantee that we can get an fx.Graph from the input module? Besides, today's DDP already has a execution order based solution to re-organize buckets?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we are not planning to change existing DDP to use this partition function any time soon though

if you mean rewriting core DDP to rely on this FX pass, I wasn't planning on it.

there is no guarantee that we can get an fx.Graph from the input module

right, I think there could be cases where legacy DDP works but FX-DDP couldn't get a single whole graph. I'm not sure how often this would be in practice, or if we could make due with applying bucketization to partial graphs that were captured in those cases.

ddp_optimizer = DDPOptimizer(
bucket_bytes_cap=ddp_module.bucket_bytes_cap,
parameters_to_ignore=ddp_module.parameters_to_ignore,
backend_compile_fn=my_compiler,
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@anijain2305 @voznesenskym Any idea how I can get the 'my_compiler' function that the user provided in their 'torchdynamo.optimize' call? I could probably figure out a way to hack it together, but I am hoping for a clean way.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not in front of it right now, I can check definitively later, but is that not callback in this func?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

originally i assumed callback was the thing, and i used it. But i have been facing some (possibly unrelated) bugs, and when i read the code I thought callback is actually wrapped in some context which i am not sure i wanted to reuse.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh that might be convert_frame and friends. Try adding this line to convert_frame before the return

_convert_frame._compiler_fn = compiler_fn

And then see if your callback has a _compiler_fn? I can look in a moment as well.

Attempts to interpose between the unmodified DDP program and
user-provided dynamo backend, adding a special 'graph break'
stage that breaks up graphs to be compiled by dynamo according
to the DDP bucketing strategy

Adds test_distributed.py for DDPOptimizer
- mark failing ones that regressed
- add new baseline tests for ddp/fsdp not blowing up (without using
optimizer)
- use gloo instead of nccl as nccl isn't in our pytorch build
- avoid running the optimizer tests that depend on pytorch changes
  unless those changes are present
pytorchmergebot pushed a commit to pytorch/pytorch that referenced this pull request Sep 17, 2022
Pairs up with torchdynamo PR pytorch/torchdynamo#628

Exposes a new API that lets torchdynamo know when it is compiling the 'forward' of a module that is inside a DDPmodule.
Pull Request resolved: #83333
Approved by: https://github.com/mrshenli
@wconstab wconstab merged commit 0b0accb into main Sep 17, 2022
mehtanirav pushed a commit to pytorch/pytorch that referenced this pull request Oct 4, 2022
Pairs up with torchdynamo PR pytorch/torchdynamo#628

Exposes a new API that lets torchdynamo know when it is compiling the 'forward' of a module that is inside a DDPmodule.
Pull Request resolved: #83333
Approved by: https://github.com/mrshenli
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants